Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +93 -38
src/streamlit_app.py
CHANGED
@@ -1,40 +1,95 @@
|
|
1 |
-
import altair as alt
|
2 |
-
import numpy as np
|
3 |
-
import pandas as pd
|
4 |
import streamlit as st
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
"""
|
7 |
-
# Welcome to Streamlit!
|
8 |
-
|
9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
11 |
-
forums](https://discuss.streamlit.io).
|
12 |
-
|
13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
14 |
-
"""
|
15 |
-
|
16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
18 |
-
|
19 |
-
indices = np.linspace(0, 1, num_points)
|
20 |
-
theta = 2 * np.pi * num_turns * indices
|
21 |
-
radius = indices
|
22 |
-
|
23 |
-
x = radius * np.cos(theta)
|
24 |
-
y = radius * np.sin(theta)
|
25 |
-
|
26 |
-
df = pd.DataFrame({
|
27 |
-
"x": x,
|
28 |
-
"y": y,
|
29 |
-
"idx": indices,
|
30 |
-
"rand": np.random.randn(num_points),
|
31 |
-
})
|
32 |
-
|
33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
34 |
-
.mark_point(filled=True)
|
35 |
-
.encode(
|
36 |
-
x=alt.X("x", axis=None),
|
37 |
-
y=alt.Y("y", axis=None),
|
38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
40 |
-
))
|
|
|
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
+
from transformers import pipeline
|
3 |
+
import torch
|
4 |
+
|
5 |
+
# Set the title of the Streamlit app
|
6 |
+
st.set_page_config(page_title="Hugging Face Chat", page_icon="🤗")
|
7 |
+
st.title("🤗 Hugging Face Model Chat")
|
8 |
+
|
9 |
+
# Add a sidebar for model selection
|
10 |
+
with st.sidebar:
|
11 |
+
st.header("Model Selection")
|
12 |
+
# A dictionary of available models
|
13 |
+
model_options = {
|
14 |
+
"NVIDIA Nemotron 3 8B": "nvidia/nemotron-3-8b-chat-4k-sft",
|
15 |
+
"Meta Llama 3.1 8B": "meta-llama/Llama-3.1-8B-Instruct",
|
16 |
+
"Mistral 7B Instruct": "mistralai/Mistral-7B-Instruct-v0.1",
|
17 |
+
"Gemma 7B It": "google/gemma-7b-it",
|
18 |
+
}
|
19 |
+
selected_model_name = st.selectbox("Choose a model:", list(model_options.keys()))
|
20 |
+
model_id = model_options[selected_model_name]
|
21 |
+
|
22 |
+
st.markdown("---")
|
23 |
+
st.markdown("This app allows you to chat with different open-source Large Language Models from the Hugging Face Hub.")
|
24 |
+
st.markdown("Select a model from the dropdown and start chatting!")
|
25 |
+
|
26 |
+
|
27 |
+
# Caching the model loading to improve performance
|
28 |
+
@st.cache_resource
|
29 |
+
def load_model(model_id):
|
30 |
+
"""Loads the selected model and tokenizer from Hugging Face."""
|
31 |
+
try:
|
32 |
+
# Use "text-generation" pipeline for chat models
|
33 |
+
pipe = pipeline(
|
34 |
+
"text-generation",
|
35 |
+
model=model_id,
|
36 |
+
torch_dtype=torch.bfloat16,
|
37 |
+
device_map="auto"
|
38 |
+
)
|
39 |
+
return pipe
|
40 |
+
except Exception as e:
|
41 |
+
st.error(f"Error loading model: {e}")
|
42 |
+
return None
|
43 |
+
|
44 |
+
# Load the selected model
|
45 |
+
pipe = load_model(model_id)
|
46 |
+
|
47 |
+
# Initialize chat history in session state
|
48 |
+
if "messages" not in st.session_state:
|
49 |
+
st.session_state.messages = []
|
50 |
+
|
51 |
+
# Display prior chat messages
|
52 |
+
for message in st.session_state.messages:
|
53 |
+
with st.chat_message(message["role"]):
|
54 |
+
st.markdown(message["content"])
|
55 |
+
|
56 |
+
# Get user input
|
57 |
+
if prompt := st.chat_input("What would you like to ask?"):
|
58 |
+
# Add user message to chat history
|
59 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
60 |
+
# Display user message
|
61 |
+
with st.chat_message("user"):
|
62 |
+
st.markdown(prompt)
|
63 |
+
|
64 |
+
# Generate a response from the model
|
65 |
+
if pipe:
|
66 |
+
with st.chat_message("assistant"):
|
67 |
+
with st.spinner("Thinking..."):
|
68 |
+
# Prepare the prompt for the model
|
69 |
+
# Note: Different models may have different prompt formats.
|
70 |
+
# This is a generic approach.
|
71 |
+
formatted_prompt = f"User: {prompt}\nAssistant:"
|
72 |
+
|
73 |
+
# Generate the response
|
74 |
+
response = pipe(
|
75 |
+
formatted_prompt,
|
76 |
+
max_new_tokens=512,
|
77 |
+
do_sample=True,
|
78 |
+
temperature=0.7,
|
79 |
+
top_p=0.95,
|
80 |
+
top_k=50
|
81 |
+
)
|
82 |
+
|
83 |
+
# Extract the generated text
|
84 |
+
if response and len(response) > 0 and "generated_text" in response[0]:
|
85 |
+
# The output often includes the prompt, so we clean it up.
|
86 |
+
assistant_response = response[0]["generated_text"].split("Assistant:")[-1].strip()
|
87 |
+
else:
|
88 |
+
assistant_response = "Sorry, I couldn't generate a response."
|
89 |
+
|
90 |
+
st.markdown(assistant_response)
|
91 |
+
# Add assistant response to chat history
|
92 |
+
st.session_state.messages.append({"role": "assistant", "content": assistant_response})
|
93 |
+
else:
|
94 |
+
st.error("Model not loaded. Cannot generate a response.")
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|