pvanand commited on
Commit
e6ba032
·
verified ·
1 Parent(s): 0002a7f

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +49 -46
main.py CHANGED
@@ -147,53 +147,56 @@ async def chat(input_data: ChatInput):
147
  }
148
 
149
  input_message = HumanMessage(content=input_data.message)
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- async def generate():
152
- async for event in agent.astream_events(
153
- {"messages": [input_message]},
154
- config,
155
- version="v2"
156
- ):
157
- kind = event["event"]
158
-
159
- if kind == "on_chat_model_stream":
160
- content = event["data"]["chunk"].content
161
- if content:
162
- yield f"{json.dumps({'type': 'token', 'content': content})}\n"
163
-
164
- if kind == "on_tool_start":
165
- print(f"Debug - Tool being called: {event['name']}") # Add this debug line
166
- tool_input = event['data'].get('input', '')
167
- yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}\n"
168
-
169
- elif kind == "on_tool_end":
170
- tool_output = event['data'].get('output', '').content
171
- #print(type(tool_output))
172
- #print(dir(tool_output))
173
- #print the keys
174
- pattern = r'data: (.*?)\ndata:'
175
- match = re.search(pattern, tool_output)
176
- print(tool_output)
177
-
178
- if match:
179
- tool_output_json = match.group(1).strip()
180
- try:
181
- tool_output = json.loads(tool_output_json)
182
- if "artifacts" in tool_output:
183
- for artifact in tool_output["artifacts"]:
184
- artifact_content = requests.get(f"{API_URL}/artifact/{artifact['artifact_id']}").content
185
- print(artifact_content)
186
- tool_output["artifacts"][artifact["artifact_id"]] = artifact_content
187
- except Exception as e:
188
- print(e)
189
- print("Error parsing tool output as json: ", tool_output)
190
- else:
191
- print("No match found in tool output")
192
- yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}\n"
193
- return EventSourceResponse(
194
- generate(),
195
- media_type="text/event-stream"
196
- )
197
 
198
  @app.get("/health")
199
  async def health_check():
 
147
  }
148
 
149
  input_message = HumanMessage(content=input_data.message)
150
+ try:
151
+ async def generate():
152
+ async for event in agent.astream_events(
153
+ {"messages": [input_message]},
154
+ config,
155
+ version="v2"
156
+ ):
157
+ kind = event["event"]
158
+
159
+ if kind == "on_chat_model_stream":
160
+ content = event["data"]["chunk"].content
161
+ if content:
162
+ yield f"{json.dumps({'type': 'token', 'content': content})}\n"
163
 
164
+ if kind == "on_tool_start":
165
+ print(f"Debug - Tool being called: {event['name']}") # Add this debug line
166
+ tool_input = event['data'].get('input', '')
167
+ yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}\n"
168
+
169
+ elif kind == "on_tool_end":
170
+ tool_output = event['data'].get('output', '').content
171
+ #print(type(tool_output))
172
+ #print(dir(tool_output))
173
+ #print the keys
174
+ pattern = r'data: (.*?)\ndata:'
175
+ match = re.search(pattern, tool_output)
176
+ print(tool_output)
177
+
178
+ if match:
179
+ tool_output_json = match.group(1).strip()
180
+ try:
181
+ tool_output = json.loads(tool_output_json)
182
+ if "artifacts" in tool_output:
183
+ for artifact in tool_output["artifacts"]:
184
+ artifact_content = requests.get(f"{API_URL}/artifact/{artifact['artifact_id']}").content
185
+ print(artifact_content)
186
+ tool_output["artifacts"][artifact["artifact_id"]] = artifact_content
187
+ except Exception as e:
188
+ print(e)
189
+ print("Error parsing tool output as json: ", tool_output)
190
+ else:
191
+ print("No match found in tool output")
192
+ yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}\n"
193
+ return EventSourceResponse(
194
+ generate(),
195
+ media_type="text/event-stream"
196
+ )
197
+ except Exception as e:
198
+ print(f"Error during event streaming: {str(e)}")
199
+ yield f"{json.dumps({'type': 'error', 'content': str(e)})}\n"
 
 
 
 
 
 
 
 
 
 
200
 
201
  @app.get("/health")
202
  async def health_check():