DrishtiSharma commited on
Commit
54c8d53
Β·
verified Β·
1 Parent(s): 20d4ebd

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +268 -0
test.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ref: https://github.com/kram254/Mixture-of-Agents-running-on-Groq/tree/main
2
+ import streamlit as st
3
+ import json
4
+ import asyncio
5
+ from typing import Union, Iterable, AsyncIterable
6
+ from moa.agent import MOAgent
7
+ from moa.agent.moa import ResponseChunk
8
+ from streamlit_ace import st_ace
9
+ import copy
10
+
11
+ # Default configuration
12
+ default_config = {
13
+ "main_model": "llama3-70b-8192",
14
+ "cycles": 3,
15
+ "layer_agent_config": {}
16
+ }
17
+
18
+ layer_agent_config_def = {
19
+ "layer_agent_1": {
20
+ "system_prompt": "Think through your response step by step. {helper_response}",
21
+ "model_name": "llama3-8b-8192"
22
+ },
23
+ "layer_agent_2": {
24
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
25
+ "model_name": "gemma-7b-it",
26
+ "temperature": 0.7
27
+ },
28
+ "layer_agent_3": {
29
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
30
+ "model_name": "llama3-8b-8192"
31
+ },
32
+ }
33
+
34
+ # Recommended configuration
35
+ rec_config = {
36
+ "main_model": "llama3-70b-8192",
37
+ "cycles": 2,
38
+ "layer_agent_config": {}
39
+ }
40
+
41
+ layer_agent_config_rec = {
42
+ "layer_agent_1": {
43
+ "system_prompt": "Think through your response step by step. {helper_response}",
44
+ "model_name": "llama3-8b-8192",
45
+ "temperature": 0.1
46
+ },
47
+ "layer_agent_2": {
48
+ "system_prompt": "Respond with a thought and then your response to the question. {helper_response}",
49
+ "model_name": "llama3-8b-8192",
50
+ "temperature": 0.2
51
+ },
52
+ "layer_agent_3": {
53
+ "system_prompt": "You are an expert at logic and reasoning. Always take a logical approach to the answer. {helper_response}",
54
+ "model_name": "llama3-8b-8192",
55
+ "temperature": 0.4
56
+ },
57
+ "layer_agent_4": {
58
+ "system_prompt": "You are an expert planner agent. Create a plan for how to answer the human's query. {helper_response}",
59
+ "model_name": "mixtral-8x7b-32768",
60
+ "temperature": 0.5
61
+ },
62
+ }
63
+
64
+ # Unified streaming function to handle async and sync responses
65
+ async def stream_or_async_response(messages: Union[Iterable[ResponseChunk], AsyncIterable[ResponseChunk]]):
66
+ layer_outputs = {}
67
+
68
+ async def process_message(message):
69
+ if message['response_type'] == 'intermediate':
70
+ layer = message['metadata']['layer']
71
+ if layer not in layer_outputs:
72
+ layer_outputs[layer] = []
73
+ layer_outputs[layer].append(message['delta'])
74
+ else:
75
+ for layer, outputs in layer_outputs.items():
76
+ st.write(f"Layer {layer}")
77
+ cols = st.columns(len(outputs))
78
+ for i, output in enumerate(outputs):
79
+ with cols[i]:
80
+ st.expander(label=f"Agent {i+1}", expanded=False).write(output)
81
+
82
+ layer_outputs.clear()
83
+ yield message['delta']
84
+
85
+ if isinstance(messages, AsyncIterable):
86
+ # Process asynchronous messages
87
+ async for message in messages:
88
+ await process_message(message)
89
+ else:
90
+ # Process synchronous messages
91
+ for message in messages:
92
+ await process_message(message)
93
+
94
+ # Set up the MOAgent
95
+ def set_moa_agent(
96
+ main_model: str = default_config['main_model'],
97
+ cycles: int = default_config['cycles'],
98
+ layer_agent_config: dict[dict[str, any]] = copy.deepcopy(layer_agent_config_def),
99
+ main_model_temperature: float = 0.1,
100
+ override: bool = False
101
+ ):
102
+ if override or ("main_model" not in st.session_state):
103
+ st.session_state.main_model = main_model
104
+
105
+ if override or ("cycles" not in st.session_state):
106
+ st.session_state.cycles = cycles
107
+
108
+ if override or ("layer_agent_config" not in st.session_state):
109
+ st.session_state.layer_agent_config = layer_agent_config
110
+
111
+ if override or ("main_temp" not in st.session_state):
112
+ st.session_state.main_temp = main_model_temperature
113
+
114
+ cls_ly_conf = copy.deepcopy(st.session_state.layer_agent_config)
115
+
116
+ if override or ("moa_agent" not in st.session_state):
117
+ st.session_state.moa_agent = MOAgent.from_config(
118
+ main_model=st.session_state.main_model,
119
+ cycles=st.session_state.cycles,
120
+ layer_agent_config=cls_ly_conf,
121
+ temperature=st.session_state.main_temp
122
+ )
123
+
124
+ del cls_ly_conf
125
+ del layer_agent_config
126
+
127
+ # Streamlit app layout
128
+ st.set_page_config(
129
+ page_title="Karios Agents Powered by Groq",
130
+ page_icon='static/favicon.ico',
131
+ menu_items={
132
+ 'About': "## Groq Mixture-Of-Agents \n Powered by [Groq](https://groq.com)"
133
+ },
134
+ layout="wide"
135
+ )
136
+ valid_model_names = [
137
+ 'llama3-70b-8192',
138
+ 'llama3-8b-8192',
139
+ 'gemma-7b-it',
140
+ 'gemma2-9b-it',
141
+ 'mixtral-8x7b-32768'
142
+ ]
143
+
144
+ st.markdown("<a href='https://groq.com'><img src='app/static/banner.png' width='500'></a>", unsafe_allow_html=True)
145
+ st.write("---")
146
+
147
+ # Initialize session state
148
+ if "messages" not in st.session_state:
149
+ st.session_state.messages = []
150
+
151
+ set_moa_agent()
152
+
153
+ # Sidebar for configuration
154
+ with st.sidebar:
155
+ st.title("MOA Configuration")
156
+ with st.form("Agent Configuration", border=False):
157
+ if st.form_submit_button("Use Recommended Config"):
158
+ try:
159
+ set_moa_agent(
160
+ main_model=rec_config['main_model'],
161
+ cycles=rec_config['cycles'],
162
+ layer_agent_config=layer_agent_config_rec,
163
+ override=True
164
+ )
165
+ st.session_state.messages = []
166
+ st.success("Configuration updated successfully!")
167
+ except Exception as e:
168
+ st.error(f"Error updating configuration: {str(e)}")
169
+
170
+ # Main model selection
171
+ new_main_model = st.selectbox(
172
+ "Select Main Model",
173
+ options=valid_model_names,
174
+ index=valid_model_names.index(st.session_state.main_model)
175
+ )
176
+
177
+ # Cycles input
178
+ new_cycles = st.number_input(
179
+ "Number of Layers",
180
+ min_value=1,
181
+ max_value=10,
182
+ value=st.session_state.cycles
183
+ )
184
+
185
+ # Main Model Temperature
186
+ main_temperature = st.number_input(
187
+ label="Main Model Temperature",
188
+ value=0.1,
189
+ min_value=0.0,
190
+ max_value=1.0,
191
+ step=0.1
192
+ )
193
+
194
+ # Layer agent configuration
195
+ new_layer_agent_config = st_ace(
196
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
197
+ language='json',
198
+ placeholder="Layer Agent Configuration (JSON)",
199
+ show_gutter=False,
200
+ wrap=True,
201
+ auto_update=True
202
+ )
203
+
204
+ if st.form_submit_button("Update Configuration"):
205
+ try:
206
+ new_layer_config = json.loads(new_layer_agent_config)
207
+ set_moa_agent(
208
+ main_model=new_main_model,
209
+ cycles=new_cycles,
210
+ layer_agent_config=new_layer_config,
211
+ main_model_temperature=main_temperature,
212
+ override=True
213
+ )
214
+ st.session_state.messages = []
215
+ st.success("Configuration updated successfully!")
216
+ except Exception as e:
217
+ st.error(f"Error updating configuration: {str(e)}")
218
+
219
+ # Main app layout
220
+ st.header("Mixture of Agents")
221
+ st.write("This project oversees implementation of Mixture of Agents architecture powered by Groq LLMs.")
222
+
223
+ # Display current configuration
224
+ with st.expander("Current MOA Configuration", expanded=False):
225
+ st.markdown(f"**Main Model**: `{st.session_state.main_model}`")
226
+ st.markdown(f"**Main Model Temperature**: `{st.session_state.main_temp:.1f}`")
227
+ st.markdown(f"**Layers**: `{st.session_state.cycles}`")
228
+ st.markdown("**Layer Agents Config:**")
229
+ st_ace(
230
+ value=json.dumps(st.session_state.layer_agent_config, indent=2),
231
+ language='json',
232
+ placeholder="Layer Agent Configuration (JSON)",
233
+ show_gutter=False,
234
+ wrap=True,
235
+ readonly=True,
236
+ auto_update=True
237
+ )
238
+
239
+ # Chat interface
240
+ for message in st.session_state.messages:
241
+ with st.chat_message(message["role"]):
242
+ st.markdown(message["content"])
243
+
244
+ if query := st.chat_input("Ask a question"):
245
+ async def handle_query():
246
+ st.session_state.messages.append({"role": "user", "content": query})
247
+ with st.chat_message("user"):
248
+ st.write(query)
249
+
250
+ moa_agent: MOAgent = st.session_state.moa_agent
251
+
252
+ with st.chat_message("assistant"):
253
+ message_placeholder = st.empty()
254
+ messages = moa_agent.chat(query, output_format='json')
255
+ async for response in stream_or_async_response(messages):
256
+ message_placeholder.markdown(response)
257
+
258
+ st.session_state.messages.append({"role": "assistant", "content": response})
259
+
260
+ asyncio.run(handle_query())
261
+
262
+
263
+ # Add acknowledgment at the bottom
264
+ st.markdown("---")
265
+ st.markdown("""
266
+ ###
267
+ This app is based on [Emmanuel M. Ndaliro's work](https://github.com/kram254/Mixture-of-Agents-running-on-Groq/tree/main).
268
+ """)