wop commited on
Commit
820263c
·
verified ·
1 Parent(s): d280987

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -88
app.py CHANGED
@@ -7,14 +7,14 @@ import datetime
7
  import json
8
 
9
  _ = load_dotenv(find_dotenv())
10
- st.set_page_config(page_icon="💬", layout="wide", page_title="...")
11
 
12
  def icon(emoji: str):
13
- """Shows an emoji as a Notion-style page icon."""
14
- st.write(
15
- f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
16
- unsafe_allow_html=True,
17
- )
18
 
19
 
20
  icon("⚡")
@@ -22,118 +22,117 @@ icon("⚡")
22
  st.subheader("Chatbot", divider="rainbow", anchor=False)
23
 
24
  client = Groq(
25
- api_key=os.environ['GROQ_API_KEY'],
26
  )
27
 
28
  # Read saved prompts from file
29
  with open("saved_prompts.txt", "r") as f:
30
- saved_prompts = f.read().split("<|>")
31
 
32
  prompt_names = [p.split(" ", 1)[0] for p in saved_prompts]
33
  prompt_map = {name: prompt for name, prompt in zip(prompt_names, saved_prompts)}
34
 
35
  # Initialize chat history and selected model
36
  if "messages" not in st.session_state:
37
- st.session_state.messages = []
38
 
39
  if "selected_model" not in st.session_state:
40
- st.session_state.selected_model = None
41
 
42
  # Define model details
43
  models = {
44
- "mixtral-8x7b-32768": {
45
- "name": "Mixtral-8x7b-Instruct-v0.1",
46
- "tokens": 32768,
47
- "developer": "Mistral",
48
- },
49
- "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
50
- "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
51
- "llama3-70b-8192": {"name": "LLaMA3-70b-8192", "tokens": 8192, "developer": "Meta"},
52
- "llama3-8b-8192": {"name": "LLaMA3-8b-8192", "tokens": 8192, "developer": "Meta"},
53
  }
54
 
55
  # Layout for model selection and max_tokens slider
56
  col1, col2 = st.columns(2)
57
 
58
  with col1:
59
- model_option = st.selectbox(
60
- "Choose a model:",
61
- options=list(models.keys()),
62
- format_func=lambda x: models[x]["name"],
63
- index=0, # Default to the first model in the list
64
- )
65
- # Add prompt dropdown
66
- prompt_option = st.selectbox("Choose a prompt:", options=prompt_names)
67
-
68
- if not prompt_option:
69
- prompt = ""
70
- else:
71
- prompt = prompt_map[prompt_option]
 
 
 
 
 
72
 
73
  # Detect model change and clear chat history if model has changed
74
  if st.session_state.selected_model != model_option:
75
- st.session_state.messages = []
76
- st.session_state.selected_model = model_option
77
 
78
  max_tokens_range = models[model_option]["tokens"]
79
 
80
  with col2:
81
- # Adjust max_tokens slider dynamically based on the selected model
82
- max_tokens = st.slider(
83
- "Max Tokens:",
84
- min_value=512, # Minimum value to allow some flexibility
85
- max_value=max_tokens_range,
86
- # Default value or max allowed if less
87
- value=min(32768, max_tokens_range),
88
- step=512,
89
- help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
90
- )
91
 
92
  # Display chat messages from history on app rerun
93
  for message in st.session_state.messages:
94
- avatar = "🧠" if message["role"] == "assistant" else "❓"
95
- with st.chat_message(message["role"], avatar=avatar):
96
- st.markdown(message["content"])
97
 
98
  def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
99
- """Yield chat response content from the Groq API response."""
100
- for chunk in chat_completion:
101
- if chunk.choices[0].delta.content:
102
- yield chunk.choices[0].delta.content
103
-
104
- if prompt := st.chat_input("Enter your prompt here...", value=prompt):
105
- st.session_state.messages.append({"role": "user", "content": prompt})
106
-
107
- with st.chat_message("user", avatar="❓"):
108
- st.markdown(prompt)
109
-
110
- # Fetch response from Groq API
111
- try:
112
- chat_completion = client.chat.completions.create(
113
- model=model_option,
114
- messages=[
115
- {"role": m["role"], "content": m["content"]}
116
- for m in st.session_state.messages
117
- ],
118
- max_tokens=max_tokens,
119
- stream=True,
120
- )
121
-
122
- # Use the generator function with st.write_stream
123
- with st.chat_message("assistant", avatar="🧠"):
124
- chat_responses_generator = generate_chat_responses(chat_completion)
125
- full_response = st.write_stream(chat_responses_generator)
126
- except Exception as e:
127
- st.error(e, icon="🚨")
128
-
129
- # Append the full response to session_state.messages
130
- if isinstance(full_response, str):
131
- st.session_state.messages.append(
132
- {"role": "assistant", "content": full_response}
133
- )
134
- else:
135
- # Handle the case where full_response is not a string
136
- combined_response = "\n".join(str(item) for item in full_response)
137
- st.session_state.messages.append(
138
- {"role": "assistant", "content": combined_response}
139
- )
 
7
  import json
8
 
9
  _ = load_dotenv(find_dotenv())
10
+ st.set_page_config(page_icon="", layout="wide", page_title="...")
11
 
12
  def icon(emoji: str):
13
+ """Shows an emoji as a Notion-style page icon."""
14
+ st.write(
15
+ f'<span style="font-size: 78px; line-height: 1">{emoji}</span>',
16
+ unsafe_allow_html=True,
17
+ )
18
 
19
 
20
  icon("⚡")
 
22
  st.subheader("Chatbot", divider="rainbow", anchor=False)
23
 
24
  client = Groq(
25
+ api_key=os.environ['GROQ_API_KEY'],
26
  )
27
 
28
  # Read saved prompts from file
29
  with open("saved_prompts.txt", "r") as f:
30
+ saved_prompts = f.read().split("<|>")
31
 
32
  prompt_names = [p.split(" ", 1)[0] for p in saved_prompts]
33
  prompt_map = {name: prompt for name, prompt in zip(prompt_names, saved_prompts)}
34
 
35
  # Initialize chat history and selected model
36
  if "messages" not in st.session_state:
37
+ st.session_state.messages = []
38
 
39
  if "selected_model" not in st.session_state:
40
+ st.session_state.selected_model = None
41
 
42
  # Define model details
43
  models = {
44
+ "mixtral-8x7b-32768": {
45
+ "name": "Mixtral-8x7b-Instruct-v0.1",
46
+ "tokens": 32768,
47
+ "developer": "Mistral",
48
+ },
49
+ "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
50
+ "llama2-70b-4096": {"name": "LLaMA2-70b-chat", "tokens": 4096, "developer": "Meta"},
51
+ "llama3-70b-8192": {"name": "LLaMA3-70b-8192", "tokens": 8192, "developer": "Meta"},
52
+ "llama3-8b-8192": {"name": "LLaMA3-8b-8192", "tokens": 8192, "developer": "Meta"},
53
  }
54
 
55
  # Layout for model selection and max_tokens slider
56
  col1, col2 = st.columns(2)
57
 
58
  with col1:
59
+ def update_prompt(selected_prompt): # Callback function for dropdown
60
+ global prompt
61
+ prompt = prompt_map[selected_prompt]
62
+
63
+ prompt_option = st.selectbox(
64
+ "Choose a prompt:",
65
+ options=list(models.keys()),
66
+ format_func=lambda x: models[x]["name"],
67
+ index=0, # Default to the first model in the list
68
+ on_change=update_prompt, # Call update_prompt on selection change
69
+ )
70
+
71
+ # Chat input without value argument
72
+ if prompt := st.chat_input("Enter your prompt here..."):
73
+ st.session_state.messages.append({"role": "user", "content": prompt})
74
+
75
+ with st.chat_message("user", avatar="❓"):
76
+ st.markdown(prompt)
77
 
78
  # Detect model change and clear chat history if model has changed
79
  if st.session_state.selected_model != model_option:
80
+ st.session_state.messages = []
81
+ st.session_state.selected_model = model_option
82
 
83
  max_tokens_range = models[model_option]["tokens"]
84
 
85
  with col2:
86
+ # Adjust max_tokens slider dynamically based on the selected model
87
+ max_tokens = st.slider(
88
+ "Max Tokens:",
89
+ min_value=512, # Minimum value to allow some flexibility
90
+ max_value=max_tokens_range,
91
+ # Default value or max allowed if less
92
+ value=min(32768, max_tokens_range),
93
+ step=512,
94
+ help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
95
+ )
96
 
97
  # Display chat messages from history on app rerun
98
  for message in st.session_state.messages:
99
+ avatar = "" if message["role"] == "assistant" else "❓"
100
+ with st.chat_message(message["role"], avatar=avatar):
101
+ st.markdown(message["content"])
102
 
103
  def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
104
+ """Yield chat response content from the Groq API response."""
105
+ for chunk in chat_completion:
106
+ if chunk.choices[0].delta.content:
107
+ yield chunk.choices[0].delta.content
108
+
109
+ # Fetch response from Groq API
110
+ try:
111
+ chat_completion = client.chat.completions.create(
112
+ model=model_option,
113
+ messages=[
114
+ {"role": m["role"], "content": m["content"]}
115
+ for m in st.session_state.messages
116
+ ],
117
+ max_tokens=max_tokens,
118
+ stream=True,
119
+ )
120
+
121
+ # Use the generator function with st.write_stream
122
+ with st.chat_message("assistant", avatar=""):
123
+ chat_responses_generator = generate_chat_responses(chat_completion)
124
+ full_response = st.write_stream(chat_responses_generator)
125
+ except Exception as e:
126
+ st.error(e, icon="")
127
+
128
+ # Append the full response to session_state.messages
129
+ if isinstance(full_response, str):
130
+ st.session_state.messages.append(
131
+ {"role": "assistant", "content": full_response}
132
+ )
133
+ else:
134
+ # Handle the case where full_response is not a string
135
+ combined_response = "\n".join(str(item) for item in full_response)
136
+ st.session_state.messages.append(
137
+ {"role": "assistant", "content": combined_response}
138
+ )