akshil-jain commited on
Commit
765555e
·
verified ·
1 Parent(s): bf1fa19

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +93 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,95 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
+ import torch
4
+
5
+ # Set the title of the Streamlit app
6
+ st.set_page_config(page_title="Hugging Face Chat", page_icon="🤗")
7
+ st.title("🤗 Hugging Face Model Chat")
8
+
9
+ # Add a sidebar for model selection
10
+ with st.sidebar:
11
+ st.header("Model Selection")
12
+ # A dictionary of available models
13
+ model_options = {
14
+ "NVIDIA Nemotron 3 8B": "nvidia/nemotron-3-8b-chat-4k-sft",
15
+ "Meta Llama 3.1 8B": "meta-llama/Llama-3.1-8B-Instruct",
16
+ "Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
17
+ "Gemma 7B It": "google/gemma-7b-it",
18
+ }
19
+ selected_model_name = st.selectbox("Choose a model:", list(model_options.keys()))
20
+ model_id = model_options[selected_model_name]
21
+
22
+ st.markdown("---")
23
+ st.markdown("This app allows you to chat with different open-source Large Language Models from the Hugging Face Hub.")
24
+ st.markdown("Select a model from the dropdown and start chatting!")
25
+
26
+
27
+ # Caching the model loading to improve performance
28
+ @st.cache_resource
29
+ def load_model(model_id):
30
+ """Loads the selected model and tokenizer from Hugging Face."""
31
+ try:
32
+ # Use "text-generation" pipeline for chat models
33
+ pipe = pipeline(
34
+ "text-generation",
35
+ model=model_id,
36
+ torch_dtype=torch.bfloat16,
37
+ device_map="auto"
38
+ )
39
+ return pipe
40
+ except Exception as e:
41
+ st.error(f"Error loading model: {e}")
42
+ return None
43
+
44
+ # Load the selected model
45
+ pipe = load_model(model_id)
46
+
47
+ # Initialize chat history in session state
48
+ if "messages" not in st.session_state:
49
+ st.session_state.messages = []
50
+
51
+ # Display prior chat messages
52
+ for message in st.session_state.messages:
53
+ with st.chat_message(message["role"]):
54
+ st.markdown(message["content"])
55
+
56
+ # Get user input
57
+ if prompt := st.chat_input("What would you like to ask?"):
58
+ # Add user message to chat history
59
+ st.session_state.messages.append({"role": "user", "content": prompt})
60
+ # Display user message
61
+ with st.chat_message("user"):
62
+ st.markdown(prompt)
63
+
64
+ # Generate a response from the model
65
+ if pipe:
66
+ with st.chat_message("assistant"):
67
+ with st.spinner("Thinking..."):
68
+ # Prepare the prompt for the model
69
+ # Note: Different models may have different prompt formats.
70
+ # This is a generic approach.
71
+ formatted_prompt = f"User: {prompt}\nAssistant:"
72
+
73
+ # Generate the response
74
+ response = pipe(
75
+ formatted_prompt,
76
+ max_new_tokens=512,
77
+ do_sample=True,
78
+ temperature=0.7,
79
+ top_p=0.95,
80
+ top_k=50
81
+ )
82
+
83
+ # Extract the generated text
84
+ if response and len(response) > 0 and "generated_text" in response[0]:
85
+ # The output often includes the prompt, so we clean it up.
86
+ assistant_response = response[0]["generated_text"].split("Assistant:")[-1].strip()
87
+ else:
88
+ assistant_response = "Sorry, I couldn't generate a response."
89
+
90
+ st.markdown(assistant_response)
91
+ # Add assistant response to chat history
92
+ st.session_state.messages.append({"role": "assistant", "content": assistant_response})
93
+ else:
94
+ st.error("Model not loaded. Cannot generate a response.")
95