Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import openai
|
3 |
+
import requests
|
4 |
+
|
5 |
+
st.set_page_config(page_title="CodeLlama Playground - via DeepInfra", page_icon='π¦')
|
6 |
+
|
7 |
+
MODEL_IMAGES = {
|
8 |
+
"meta-llama/Meta-Llama-3-8B-Instruct": "https://em-content.zobj.net/source/twitter/376/llama_1f999.png", # Add the emoji for the Meta-Llama model
|
9 |
+
"codellama/CodeLlama-34b-Instruct-hf": "https://em-content.zobj.net/source/twitter/376/llama_1f999.png",
|
10 |
+
"mistralai/Mistral-7B-Instruct-v0.1": "https://em-content.zobj.net/source/twitter/376/tornado_1f32a-fe0f.png",
|
11 |
+
"mistralai/Mixtral-8x7B-Instruct-v0.1": "https://em-content.zobj.net/source/twitter/376/tornado_1f32a-fe0f.png",
|
12 |
+
}
|
13 |
+
|
14 |
+
# Create a mapping from formatted model names to their original identifiers
|
15 |
+
def format_model_name(model_key):
|
16 |
+
parts = model_key.split('/')
|
17 |
+
model_name = parts[-1] # Get the last part after '/'
|
18 |
+
name_parts = model_name.split('-')
|
19 |
+
|
20 |
+
# Custom formatting for specific models
|
21 |
+
if "Meta-Llama-3-8B-Instruct" in model_key:
|
22 |
+
return "Llama 3 8B-Instruct"
|
23 |
+
else:
|
24 |
+
# General formatting for other models
|
25 |
+
formatted_name = ' '.join(name_parts[:-2]).title() # Join them into a single string with title case
|
26 |
+
return formatted_name
|
27 |
+
|
28 |
+
formatted_names_to_identifiers = {
|
29 |
+
format_model_name(key): key for key in MODEL_IMAGES.keys()
|
30 |
+
}
|
31 |
+
|
32 |
+
# Debug to ensure names are formatted correctly
|
33 |
+
#st.write("Formatted Model Names to Identifiers:", formatted_names_to_identifiers)
|
34 |
+
|
35 |
+
selected_formatted_name = st.sidebar.radio(
|
36 |
+
"Select LLM Model",
|
37 |
+
list(formatted_names_to_identifiers.keys())
|
38 |
+
)
|
39 |
+
|
40 |
+
selected_model = formatted_names_to_identifiers[selected_formatted_name]
|
41 |
+
|
42 |
+
if MODEL_IMAGES[selected_model].startswith("http"):
|
43 |
+
st.image(MODEL_IMAGES[selected_model], width=90)
|
44 |
+
else:
|
45 |
+
st.write(f"Model Icon: {MODEL_IMAGES[selected_model]}", unsafe_allow_html=True)
|
46 |
+
|
47 |
+
# Display the selected model using the formatted name
|
48 |
+
model_display_name = selected_formatted_name # Already formatted
|
49 |
+
# st.write(f"Model being used: `{model_display_name}`")
|
50 |
+
|
51 |
+
st.sidebar.markdown('---')
|
52 |
+
|
53 |
+
API_KEY = st.secrets["api_key"]
|
54 |
+
|
55 |
+
openai.api_base = "https://api.deepinfra.com/v1/openai"
|
56 |
+
MODEL_CODELLAMA = selected_model
|
57 |
+
|
58 |
+
def get_response(api_key, model, user_input, max_tokens, top_p):
|
59 |
+
openai.api_key = api_key
|
60 |
+
try:
|
61 |
+
if "meta-llama/Meta-Llama-3-8B-Instruct" in model:
|
62 |
+
# Assume different API setup for Meta-Llama
|
63 |
+
chat_completion = requests.post(
|
64 |
+
"https://api.deepinfra.com/v1/openai/chat/completions",
|
65 |
+
headers={"Authorization": f"Bearer {api_key}"},
|
66 |
+
json={
|
67 |
+
"model": model,
|
68 |
+
"messages": [{"role": "user", "content": user_input}],
|
69 |
+
"max_tokens": max_tokens,
|
70 |
+
"top_p": top_p
|
71 |
+
}
|
72 |
+
).json()
|
73 |
+
return chat_completion['choices'][0]['message']['content'], None
|
74 |
+
else:
|
75 |
+
# Existing setup for other models
|
76 |
+
chat_completion = openai.ChatCompletion.create(
|
77 |
+
model=model,
|
78 |
+
messages=[{"role": "user", "content": user_input}],
|
79 |
+
max_tokens=max_tokens,
|
80 |
+
top_p=top_p
|
81 |
+
)
|
82 |
+
return chat_completion.choices[0].message.content, None
|
83 |
+
except Exception as e:
|
84 |
+
return None, str(e)
|
85 |
+
|
86 |
+
|
87 |
+
|
88 |
+
# Adjust the title based on the selected model
|
89 |
+
st.header(f"`{model_display_name}` Model")
|
90 |
+
|
91 |
+
with st.expander("About this app"):
|
92 |
+
st.write(f"""
|
93 |
+
This Chatbot app allows users to interact with various models including the new LLM models hosted on DeepInfra's OpenAI compatible API.
|
94 |
+
For more info, you can refer to [DeepInfra's documentation](https://deepinfra.com/docs/advanced/openai_api).
|
95 |
+
|
96 |
+
π‘ For decent answers, you'd want to increase the `Max Tokens` value from `100` to `500`.
|
97 |
+
""")
|
98 |
+
|
99 |
+
if "api_key" not in st.session_state:
|
100 |
+
st.session_state.api_key = ""
|
101 |
+
|
102 |
+
with st.sidebar:
|
103 |
+
max_tokens = st.slider('Max Tokens', 10, 500, 100)
|
104 |
+
top_p = st.slider('Top P', 0.0, 1.0, 0.5, 0.05)
|
105 |
+
|
106 |
+
if max_tokens > 100:
|
107 |
+
user_provided_api_key = st.text_input("π Your DeepInfra API Key", value=st.session_state.api_key, type='password')
|
108 |
+
if user_provided_api_key:
|
109 |
+
st.session_state.api_key = user_provided_api_key
|
110 |
+
if not st.session_state.api_key:
|
111 |
+
st.warning("βοΈ If you want to try this app with more than `100` tokens, you must provide your own DeepInfra API key. Get yours here β https://deepinfra.com/dash/api_keys")
|
112 |
+
|
113 |
+
if max_tokens <= 100 or st.session_state.api_key:
|
114 |
+
if "messages" not in st.session_state:
|
115 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
116 |
+
|
117 |
+
for message in st.session_state.messages:
|
118 |
+
with st.chat_message(message["role"]):
|
119 |
+
st.write(message["content"])
|
120 |
+
|
121 |
+
if prompt := st.chat_input():
|
122 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
123 |
+
with st.chat_message("assistant"):
|
124 |
+
with st.spinner("Thinking..."):
|
125 |
+
response, error = get_response(st.session_state.api_key, MODEL_CODELLAMA, prompt, max_tokens, top_p)
|
126 |
+
if error:
|
127 |
+
st.error(f"Error: {error}")
|
128 |
+
else:
|
129 |
+
placeholder = st.empty()
|
130 |
+
placeholder.markdown(response)
|
131 |
+
message = {"role": "assistant", "content": response}
|
132 |
+
st.session_state.messages.append(message)
|
133 |
+
|
134 |
+
# Clear chat history function and button
|
135 |
+
def clear_chat_history():
|
136 |
+
st.session_state.messages = [{"role": "assistant", "content": "How may I assist you today?"}]
|
137 |
+
st.sidebar.button('Clear Chat History', on_click=clear_chat_history)
|