wop commited on
Commit
2bfa474
·
verified ·
1 Parent(s): 5a84099

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -88
app.py CHANGED
@@ -9,6 +9,7 @@ import json
9
  _ = load_dotenv(find_dotenv())
10
  st.set_page_config(page_icon="💬", layout="wide", page_title="Groq Chat Bot...")
11
 
 
12
  def icon(emoji: str):
13
  """Shows an emoji as a Notion-style page icon."""
14
  st.write(
@@ -16,6 +17,7 @@ def icon(emoji: str):
16
  unsafe_allow_html=True,
17
  )
18
 
 
19
  icon("📣")
20
 
21
  st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)
@@ -24,6 +26,73 @@ client = Groq(
24
  api_key=os.environ['GROQ_API_KEY'],
25
  )
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  models = {
28
  "mixtral-8x7b-32768": {
29
  "name": "Mixtral-8x7b-Instruct-v0.1",
@@ -34,6 +103,7 @@ models = {
34
  "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
35
  }
36
 
 
37
  col1, col2 = st.columns(2)
38
 
39
  with col1:
@@ -41,15 +111,10 @@ with col1:
41
  "Choose a model:",
42
  options=list(models.keys()),
43
  format_func=lambda x: models[x]["name"],
44
- index=0,
45
  )
46
 
47
- if "messages" not in st.session_state:
48
- st.session_state.messages = []
49
-
50
- if "selected_model" not in st.session_state:
51
- st.session_state.selected_model = None
52
-
53
  if st.session_state.selected_model != model_option:
54
  st.session_state.messages = []
55
  st.session_state.selected_model = model_option
@@ -57,110 +122,55 @@ if st.session_state.selected_model != model_option:
57
  max_tokens_range = models[model_option]["tokens"]
58
 
59
  with col2:
 
60
  max_tokens = st.slider(
61
  "Max Tokens:",
62
- min_value=512,
63
  max_value=max_tokens_range,
 
64
  value=min(32768, max_tokens_range),
65
  step=512,
66
  help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
67
  )
68
 
 
69
  for message in st.session_state.messages:
70
  avatar = "🤖" if message["role"] == "assistant" else "🕺"
71
  with st.chat_message(message["role"], avatar=avatar):
72
  st.markdown(message["content"])
73
 
74
- def generate_chat_responses(user_prompt):
75
- """Fetches response from the Groq API using the run_conversation function."""
76
- response = run_conversation(user_prompt)
77
- yield response # Yield the response content
78
-
79
- def run_conversation(user_prompt):
80
- messages=[
81
- {
82
- "role": "system",
83
- "content": "You are a helpful assistant named ChattyBot."
84
- },
85
- {
86
- "role": "user",
87
- "content": user_prompt,
88
- }
89
- ]
90
- tools = [
91
- {
92
- "type": "function",
93
- "function": {
94
- "name": "time_date",
95
- "description": "The tool will return information about the time and date to the AI.",
96
- "parameters": {},
97
- },
98
- }
99
- ]
100
- try:
101
- response = client.chat.completions.create(
102
- model=model_option,
103
- messages=messages,
104
- tools=tools,
105
- tool_choice="auto",
106
- max_tokens=4096
107
- )
108
-
109
- response_message = response.choices[0].delta
110
- tool_calls = response_message.tool_calls
111
-
112
- if tool_calls:
113
- available_functions = {
114
- "time_date": get_tool_owner_info
115
- }
116
-
117
- messages.append(response_message)
118
-
119
- for tool_call in tool_calls:
120
- function_name = tool_call.function.name
121
- function_to_call = available_functions[function_name]
122
- function_args = json.loads(tool_call.function.arguments)
123
- function_response = function_to_call(**function_args)
124
- messages.append(
125
- {
126
- "tool_call_id": tool_call.id,
127
- "role": "tool",
128
- "name": function_name,
129
- "content": function_response,
130
- }
131
- )
132
-
133
- second_response = client.chat.completions.create(
134
- model=model_option,
135
- messages=messages
136
- )
137
 
138
- return second_response.choices[0].delta.content
139
- else:
140
- return response_message.content
141
- except Exception as e:
142
- st.error(e, icon="🚨")
143
- return None
144
 
145
- def get_tool_owner_info():
146
- owner_info = {
147
- "date_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
148
- }
149
- return json.dumps(owner_info)
150
 
151
  if prompt := st.chat_input("Enter your prompt here..."):
152
  st.session_state.messages.append({"role": "user", "content": prompt})
153
 
154
- with st.chat_message("user", avatar=""):
155
  st.markdown(prompt)
156
 
 
157
  try:
158
- # Use generate_chat_responses with user prompt
159
- with st.chat_message("assistant", avatar=""):
160
- chat_responses_generator = generate_chat_responses(prompt)
 
 
 
 
 
 
 
 
 
 
161
  full_response = st.write_stream(chat_responses_generator)
162
  except Exception as e:
163
- st.error(e, icon="")
164
 
165
  # Append the full response to session_state.messages
166
  if isinstance(full_response, str):
@@ -172,4 +182,5 @@ if prompt := st.chat_input("Enter your prompt here..."):
172
  combined_response = "\n".join(str(item) for item in full_response)
173
  st.session_state.messages.append(
174
  {"role": "assistant", "content": combined_response}
175
- )
 
 
9
  _ = load_dotenv(find_dotenv())
10
  st.set_page_config(page_icon="💬", layout="wide", page_title="Groq Chat Bot...")
11
 
12
+
13
  def icon(emoji: str):
14
  """Shows an emoji as a Notion-style page icon."""
15
  st.write(
 
17
  unsafe_allow_html=True,
18
  )
19
 
20
+
21
  icon("📣")
22
 
23
  st.subheader("Groq Chat Streamlit App", divider="rainbow", anchor=False)
 
26
  api_key=os.environ['GROQ_API_KEY'],
27
  )
28
 
29
+ def get_tool_owner_info():
30
+ owner_info = {
31
+ "date_time": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
32
+ }
33
+ return json.dumps(owner_info)
34
+
35
+ def run_conversation(user_prompt, messages):
36
+ tools = [
37
+ {
38
+ "type": "function",
39
+ "function": {
40
+ "name": "time_date",
41
+ "description": "The tool will return information about the time and date to the AI.",
42
+ "parameters": {},
43
+ },
44
+ }
45
+ ]
46
+ response = client.chat.completions.create(
47
+ model=model_option,
48
+ messages=messages,
49
+ tools=tools,
50
+ tool_choice="auto",
51
+ max_tokens=max_tokens
52
+ )
53
+
54
+ response_message = response.choices[0].message
55
+ tool_calls = response_message.tool_calls
56
+
57
+ if tool_calls:
58
+ available_functions = {
59
+ "time_date": get_tool_owner_info
60
+ }
61
+
62
+ messages.append(response_message)
63
+
64
+ for tool_call in tool_calls:
65
+ function_name = tool_call.function.name
66
+ function_to_call = available_functions[function_name]
67
+ function_args = json.loads(tool_call.function.arguments)
68
+ function_response = function_to_call(**function_args)
69
+ messages.append(
70
+ {
71
+ "tool_call_id": tool_call.id,
72
+ "role": "tool",
73
+ "name": function_name,
74
+ "content": function_response,
75
+ }
76
+ )
77
+
78
+ second_response = client.chat.completions.create(
79
+ model=model_option,
80
+ messages=messages
81
+ )
82
+
83
+ return second_response.choices[0].message.content
84
+ else:
85
+ return response_message
86
+
87
+
88
+ # Initialize chat history and selected model
89
+ if "messages" not in st.session_state:
90
+ st.session_state.messages = []
91
+
92
+ if "selected_model" not in st.session_state:
93
+ st.session_state.selected_model = None
94
+
95
+ # Define model details
96
  models = {
97
  "mixtral-8x7b-32768": {
98
  "name": "Mixtral-8x7b-Instruct-v0.1",
 
103
  "gemma-7b-it": {"name": "Gemma-7b-it", "tokens": 8192, "developer": "Google"},
104
  }
105
 
106
+ # Layout for model selection and max_tokens slider
107
  col1, col2 = st.columns(2)
108
 
109
  with col1:
 
111
  "Choose a model:",
112
  options=list(models.keys()),
113
  format_func=lambda x: models[x]["name"],
114
+ index=0, # Default to the first model in the list
115
  )
116
 
117
+ # Detect model change and clear chat history if model has changed
 
 
 
 
 
118
  if st.session_state.selected_model != model_option:
119
  st.session_state.messages = []
120
  st.session_state.selected_model = model_option
 
122
  max_tokens_range = models[model_option]["tokens"]
123
 
124
  with col2:
125
+ # Adjust max_tokens slider dynamically based on the selected model
126
  max_tokens = st.slider(
127
  "Max Tokens:",
128
+ min_value=512, # Minimum value to allow some flexibility
129
  max_value=max_tokens_range,
130
+ # Default value or max allowed if less
131
  value=min(32768, max_tokens_range),
132
  step=512,
133
  help=f"Adjust the maximum number of tokens (words) for the model's response. Max for selected model: {max_tokens_range}",
134
  )
135
 
136
+ # Display chat messages from history on app rerun
137
  for message in st.session_state.messages:
138
  avatar = "🤖" if message["role"] == "assistant" else "🕺"
139
  with st.chat_message(message["role"], avatar=avatar):
140
  st.markdown(message["content"])
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ def generate_chat_responses(chat_completion) -> Generator[str, None, None]:
144
+ """Yield chat response content from the Groq API response."""
145
+ for chunk in chat_completion:
146
+ if chunk.choices[0].delta.content:
147
+ yield chunk.choices[0].delta.content
 
148
 
 
 
 
 
 
149
 
150
  if prompt := st.chat_input("Enter your prompt here..."):
151
  st.session_state.messages.append({"role": "user", "content": prompt})
152
 
153
+ with st.chat_message("user", avatar="🕺"):
154
  st.markdown(prompt)
155
 
156
+ # Fetch response from Groq API
157
  try:
158
+ chat_completion = client.chat.completions.create(
159
+ model=model_option,
160
+ messages=[
161
+ {"role": m["role"], "content": m["content"]}
162
+ for m in st.session_state.messages
163
+ ],
164
+ max_tokens=max_tokens,
165
+ stream=True,
166
+ )
167
+
168
+ # Use the generator function with st.write_stream
169
+ with st.chat_message("assistant", avatar="🤖"):
170
+ chat_responses_generator = generate_chat_responses(chat_completion)
171
  full_response = st.write_stream(chat_responses_generator)
172
  except Exception as e:
173
+ st.error(e, icon="🚨")
174
 
175
  # Append the full response to session_state.messages
176
  if isinstance(full_response, str):
 
182
  combined_response = "\n".join(str(item) for item in full_response)
183
  st.session_state.messages.append(
184
  {"role": "assistant", "content": combined_response}
185
+ )
186
+