DrishtiSharma commited on
Commit
b0f546e
Β·
verified Β·
1 Parent(s): 54c8d53

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +119 -154
test.py CHANGED
@@ -7,6 +7,8 @@ from moa.agent import MOAgent
7
  from moa.agent.moa import ResponseChunk
8
  from streamlit_ace import st_ace
9
  import copy
 
 
10
 
11
  # Default configuration
12
  default_config = {
@@ -24,106 +26,9 @@ layer_agent_config_def = {
24
  "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
25
  "model_name": "gemma-7b-it",
26
  "temperature": 0.7
27
- },
28
- "layer_agent_3": {
29
- "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
30
- "model_name": "llama3-8b-8192"
31
- },
32
- }
33
-
34
- # Recommended configuration
35
- rec_config = {
36
- "main_model": "llama3-70b-8192",
37
- "cycles": 2,
38
- "layer_agent_config": {}
39
- }
40
-
41
- layer_agent_config_rec = {
42
- "layer_agent_1": {
43
- "system_prompt": "Think through your response step by step. {helper_response}",
44
- "model_name": "llama3-8b-8192",
45
- "temperature": 0.1
46
- },
47
- "layer_agent_2": {
48
- "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
49
- "model_name": "llama3-8b-8192",
50
- "temperature": 0.2
51
- },
52
- "layer_agent_3": {
53
- "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
54
- "model_name": "llama3-8b-8192",
55
- "temperature": 0.4
56
- },
57
- "layer_agent_4": {
58
- "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}",
59
- "model_name": "mixtral-8x7b-32768",
60
- "temperature": 0.5
61
- },
62
  }
63
 
64
- # Unified streaming function to handle async and sync responses
65
- async def stream_or_async_response(messages: Union[Iterable[ResponseChunk], AsyncIterable[ResponseChunk]]):
66
- layer_outputs = {}
67
-
68
- async def process_message(message):
69
- if message['response_type'] == 'intermediate':
70
- layer = message['metadata']['layer']
71
- if layer not in layer_outputs:
72
- layer_outputs[layer] = []
73
- layer_outputs[layer].append(message['delta'])
74
- else:
75
- for layer, outputs in layer_outputs.items():
76
- st.write(f"Layer {layer}")
77
- cols = st.columns(len(outputs))
78
- for i, output in enumerate(outputs):
79
- with cols[i]:
80
- st.expander(label=f"Agent {i+1}", expanded=False).write(output)
81
-
82
- layer_outputs.clear()
83
- yield message['delta']
84
-
85
- if isinstance(messages, AsyncIterable):
86
- # Process asynchronous messages
87
- async for message in messages:
88
- await process_message(message)
89
- else:
90
- # Process synchronous messages
91
- for message in messages:
92
- await process_message(message)
93
-
94
- # Set up the MOAgent
95
- def set_moa_agent(
96
- main_model: str = default_config['main_model'],
97
- cycles: int = default_config['cycles'],
98
- layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def),
99
- main_model_temperature: float = 0.1,
100
- override: bool = False
101
- ):
102
- if override or ("main_model" not in st.session_state):
103
- st.session_state.main_model = main_model
104
-
105
- if override or ("cycles" not in st.session_state):
106
- st.session_state.cycles = cycles
107
-
108
- if override or ("layer_agent_config" not in st.session_state):
109
- st.session_state.layer_agent_config = layer_agent_config
110
-
111
- if override or ("main_temp" not in st.session_state):
112
- st.session_state.main_temp = main_model_temperature
113
-
114
- cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
115
-
116
- if override or ("moa_agent" not in st.session_state):
117
- st.session_state.moa_agent = MOAgent.from_config(
118
- main_model=st.session_state.main_model,
119
- cycles=st.session_state.cycles,
120
- layer_agent_config=cls_ly_conf,
121
- temperature=st.session_state.main_temp
122
- )
123
-
124
- del cls_ly_conf
125
- del layer_agent_config
126
-
127
  # Streamlit app layout
128
  st.set_page_config(
129
  page_title="Karios Agents Powered by Groq",
@@ -133,65 +38,114 @@ st.set_page_config(
133
  },
134
  layout="wide"
135
  )
 
136
  valid_model_names = [
137
  'llama3-70b-8192',
138
  'llama3-8b-8192',
139
- 'gemma-7b-it',
140
- 'gemma2-9b-it',
141
- 'mixtral-8x7b-32768'
142
  ]
143
 
144
- st.markdown("<a href='https://groq.com'><img src='app/static/banner.png' width='500'></a>", unsafe_allow_html=True)
145
- st.write("---")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
  # Initialize session state
148
  if "messages" not in st.session_state:
149
  st.session_state.messages = []
150
 
151
- set_moa_agent()
 
152
 
153
- # Sidebar for configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  with st.sidebar:
155
  st.title("MOA Configuration")
156
- with st.form("Agent Configuration", border=False):
157
- if st.form_submit_button("Use Recommended Config"):
158
- try:
159
- set_moa_agent(
160
- main_model=rec_config['main_model'],
161
- cycles=rec_config['cycles'],
162
- layer_agent_config=layer_agent_config_rec,
163
- override=True
164
- )
165
- st.session_state.messages = []
166
- st.success("Configuration updated successfully!")
167
- except Exception as e:
168
- st.error(f"Error updating configuration: {str(e)}")
169
-
170
- # Main model selection
171
  new_main_model = st.selectbox(
172
  "Select Main Model",
173
  options=valid_model_names,
174
  index=valid_model_names.index(st.session_state.main_model)
175
  )
176
-
177
- # Cycles input
178
  new_cycles = st.number_input(
179
  "Number of Layers",
180
  min_value=1,
181
  max_value=10,
182
  value=st.session_state.cycles
183
  )
184
-
185
- # Main Model Temperature
186
  main_temperature = st.number_input(
187
  label="Main Model Temperature",
188
- value=0.1,
189
  min_value=0.0,
190
  max_value=1.0,
191
  step=0.1
192
  )
193
-
194
- # Layer agent configuration
195
  new_layer_agent_config = st_ace(
196
  value=json.dumps(st.session_state.layer_agent_config, indent=2),
197
  language='json',
@@ -204,19 +158,22 @@ with st.sidebar:
204
  if st.form_submit_button("Update Configuration"):
205
  try:
206
  new_layer_config = json.loads(new_layer_agent_config)
207
- set_moa_agent(
 
 
 
 
208
  main_model=new_main_model,
209
  cycles=new_cycles,
210
  layer_agent_config=new_layer_config,
211
- main_model_temperature=main_temperature,
212
- override=True
213
  )
214
- st.session_state.messages = []
215
  st.success("Configuration updated successfully!")
 
216
  except Exception as e:
217
  st.error(f"Error updating configuration: {str(e)}")
218
 
219
- # Main app layout
220
  st.header("Mixture of Agents")
221
  st.write("This project oversees implementation of Mixture of Agents architecture powered by Groq LLMs.")
222
 
@@ -225,7 +182,6 @@ with st.expander("Current MOA Configuration", expanded=False):
225
  st.markdown(f"**Main Model**: `{st.session_state.main_model}`")
226
  st.markdown(f"**Main Model Temperature**: `{st.session_state.main_temp:.1f}`")
227
  st.markdown(f"**Layers**: `{st.session_state.cycles}`")
228
- st.markdown("**Layer Agents Config:**")
229
  st_ace(
230
  value=json.dumps(st.session_state.layer_agent_config, indent=2),
231
  language='json',
@@ -236,29 +192,38 @@ with st.expander("Current MOA Configuration", expanded=False):
236
  auto_update=True
237
  )
238
 
239
- # Chat interface
240
- for message in st.session_state.messages:
241
- with st.chat_message(message["role"]):
242
- st.markdown(message["content"])
 
 
243
 
244
  if query := st.chat_input("Ask a question"):
245
- async def handle_query():
246
- st.session_state.messages.append({"role": "user", "content": query})
247
- with st.chat_message("user"):
248
- st.write(query)
249
-
250
- moa_agent: MOAgent = st.session_state.moa_agent
251
-
252
- with st.chat_message("assistant"):
253
- message_placeholder = st.empty()
254
- messages = moa_agent.chat(query, output_format='json')
255
- async for response in stream_or_async_response(messages):
256
- message_placeholder.markdown(response)
257
-
258
- st.session_state.messages.append({"role": "assistant", "content": response})
259
-
260
- asyncio.run(handle_query())
261
-
 
 
 
 
 
 
 
262
 
263
  # Add acknowledgment at the bottom
264
  st.markdown("---")
 
7
  from moa.agent.moa import ResponseChunk
8
  from streamlit_ace import st_ace
9
  import copy
10
+ import pandas as pd
11
+ import time
12
 
13
  # Default configuration
14
  default_config = {
 
26
  "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
27
  "model_name": "gemma-7b-it",
28
  "temperature": 0.7
29
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  # Streamlit app layout
33
  st.set_page_config(
34
  page_title="Karios Agents Powered by Groq",
 
38
  },
39
  layout="wide"
40
  )
41
+
42
  valid_model_names = [
43
  'llama3-70b-8192',
44
  'llama3-8b-8192',
45
+ 'gemma-7b-it'
 
 
46
  ]
47
 
48
+ # Caching function
49
+ @st.cache_data
50
+ def cached_chat_response(query, model, config):
51
+ """Cache responses to minimize redundant processing."""
52
+ moa_agent = MOAgent.from_config(**config)
53
+ return moa_agent.chat(query, output_format='json')
54
+
55
+ # Functions to reset, export, and import configurations
56
+ def reset_session():
57
+ """Reset session state to clear messages and restart."""
58
+ st.session_state.messages = []
59
+ st.experimental_rerun()
60
+
61
+ def export_config():
62
+ """Allow the user to download the current configuration as a JSON file."""
63
+ config_data = {
64
+ "main_model": st.session_state.main_model,
65
+ "cycles": st.session_state.cycles,
66
+ "layer_agent_config": st.session_state.layer_agent_config,
67
+ "main_temp": st.session_state.main_temp,
68
+ }
69
+ st.download_button(
70
+ label="Download Config",
71
+ data=json.dumps(config_data, indent=4),
72
+ file_name="moa_config.json",
73
+ mime="application/json",
74
+ )
75
+
76
+ def import_config(uploaded_file):
77
+ """Upload and apply a configuration from a JSON file."""
78
+ try:
79
+ config_data = json.load(uploaded_file)
80
+ set_moa_agent(
81
+ main_model=config_data['main_model'],
82
+ cycles=config_data['cycles'],
83
+ layer_agent_config=config_data['layer_agent_config'],
84
+ main_model_temperature=config_data['main_temp'],
85
+ override=True
86
+ )
87
+ st.success("Configuration imported successfully!")
88
+ st.experimental_rerun()
89
+ except Exception as e:
90
+ st.error(f"Error importing configuration: {str(e)}")
91
 
92
  # Initialize session state
93
  if "messages" not in st.session_state:
94
  st.session_state.messages = []
95
 
96
+ if "main_model" not in st.session_state:
97
+ st.session_state.main_model = default_config['main_model']
98
 
99
+ if "cycles" not in st.session_state:
100
+ st.session_state.cycles = default_config['cycles']
101
+
102
+ if "layer_agent_config" not in st.session_state:
103
+ st.session_state.layer_agent_config = copy.deepcopy(layer_agent_config_def)
104
+
105
+ if "main_temp" not in st.session_state:
106
+ st.session_state.main_temp = 0.1
107
+
108
+ if "moa_agent" not in st.session_state:
109
+ st.session_state.moa_agent = MOAgent.from_config(
110
+ main_model=st.session_state.main_model,
111
+ cycles=st.session_state.cycles,
112
+ layer_agent_config=st.session_state.layer_agent_config,
113
+ temperature=st.session_state.main_temp
114
+ )
115
+
116
+ # Sidebar
117
  with st.sidebar:
118
  st.title("MOA Configuration")
119
+ if st.button("Reset Session"):
120
+ reset_session()
121
+
122
+ if st.button("Export Configuration"):
123
+ export_config()
124
+
125
+ uploaded_file = st.file_uploader("Upload Config", type=["json"])
126
+ if uploaded_file:
127
+ import_config(uploaded_file)
128
+
129
+ # Configuration forms
130
+ with st.form("Agent Configuration"):
 
 
 
131
  new_main_model = st.selectbox(
132
  "Select Main Model",
133
  options=valid_model_names,
134
  index=valid_model_names.index(st.session_state.main_model)
135
  )
 
 
136
  new_cycles = st.number_input(
137
  "Number of Layers",
138
  min_value=1,
139
  max_value=10,
140
  value=st.session_state.cycles
141
  )
 
 
142
  main_temperature = st.number_input(
143
  label="Main Model Temperature",
144
+ value=st.session_state.main_temp,
145
  min_value=0.0,
146
  max_value=1.0,
147
  step=0.1
148
  )
 
 
149
  new_layer_agent_config = st_ace(
150
  value=json.dumps(st.session_state.layer_agent_config, indent=2),
151
  language='json',
 
158
  if st.form_submit_button("Update Configuration"):
159
  try:
160
  new_layer_config = json.loads(new_layer_agent_config)
161
+ st.session_state.main_model = new_main_model
162
+ st.session_state.cycles = new_cycles
163
+ st.session_state.main_temp = main_temperature
164
+ st.session_state.layer_agent_config = new_layer_config
165
+ st.session_state.moa_agent = MOAgent.from_config(
166
  main_model=new_main_model,
167
  cycles=new_cycles,
168
  layer_agent_config=new_layer_config,
169
+ temperature=main_temperature
 
170
  )
 
171
  st.success("Configuration updated successfully!")
172
+ st.experimental_rerun()
173
  except Exception as e:
174
  st.error(f"Error updating configuration: {str(e)}")
175
 
176
+ # Main layout
177
  st.header("Mixture of Agents")
178
  st.write("This project oversees implementation of Mixture of Agents architecture powered by Groq LLMs.")
179
 
 
182
  st.markdown(f"**Main Model**: `{st.session_state.main_model}`")
183
  st.markdown(f"**Main Model Temperature**: `{st.session_state.main_temp:.1f}`")
184
  st.markdown(f"**Layers**: `{st.session_state.cycles}`")
 
185
  st_ace(
186
  value=json.dumps(st.session_state.layer_agent_config, indent=2),
187
  language='json',
 
192
  auto_update=True
193
  )
194
 
195
+ # Model comparison
196
+ selected_models = st.multiselect(
197
+ "Select Models for Comparison",
198
+ valid_model_names,
199
+ default=[st.session_state.main_model]
200
+ )
201
 
202
  if query := st.chat_input("Ask a question"):
203
+ st.write(f"Query: {query}")
204
+ results = {}
205
+ for model in selected_models:
206
+ try:
207
+ st.session_state.moa_agent.set_model(model)
208
+ results[model] = cached_chat_response(query, model, st.session_state.layer_agent_config)
209
+ except Exception as e:
210
+ st.error(f"Agent {model} failed: {e}")
211
+
212
+ for model, response in results.items():
213
+ st.subheader(f"Response from {model}")
214
+ st.write(response)
215
+
216
+ # Progress bar for layers
217
+ progress_bar = st.progress(0)
218
+ total_layers = st.session_state.cycles
219
+
220
+ async def process_message(message):
221
+ try:
222
+ # Simulate layer processing
223
+ current_layer = int(message['metadata']['layer'])
224
+ progress_bar.progress(int((current_layer / total_layers) * 100))
225
+ except Exception as e:
226
+ st.error(f"Agent failed: {e}")
227
 
228
  # Add acknowledgment at the bottom
229
  st.markdown("---")