joaomorossini commited on
Commit
490ec28
·
1 Parent(s): 59c19d6

Create Gradio demo with Notion iframe

Browse files
Files changed (1) hide show
  1. agency_ai_demo/demo.py +588 -0
agency_ai_demo/demo.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import queue
3
+ import threading
4
+ from typing_extensions import override
5
+
6
+ from dotenv import load_dotenv
7
+ from openai import AzureOpenAI
8
+ from agency_swarm import Agent, Agency, set_openai_client
9
+ from agency_swarm.util.streaming import AgencyEventHandler
10
+ from agency_swarm.messages import MessageOutput
11
+ from openai.types.beta.threads import Message
12
+ from openai.types.beta.threads.runs import (
13
+ RunStep,
14
+ ToolCall,
15
+ FunctionToolCall,
16
+ CodeInterpreterToolCall,
17
+ FileSearchToolCall,
18
+ )
19
+ from agency_swarm.tools import FileSearch, CodeInterpreter
20
+
21
+ # Import our agents - using the same imports as in demo.ipynb
22
+ from agents.NotionProjectAgent import NotionProjectAgent
23
+ from agents.TechnicalProjectManager import TechnicalProjectManager
24
+ from agents.ResearchAndReportAgent import ResearchAndReportAgent
25
+
26
+
27
+ # Helper functions for file handling (from agency.py)
28
+ def get_file_purpose(file_name):
29
+ """Determine the purpose of the file based on its extension."""
30
+ if file_name.lower().endswith((".jpg", ".jpeg", ".png", ".gif", ".webp")):
31
+ return "vision"
32
+ return "assistants"
33
+
34
+
35
+ def get_tools(file_name):
36
+ """Determine the appropriate tools for the file based on its extension."""
37
+ tools = []
38
+
39
+ if file_name.lower().endswith(
40
+ (
41
+ ".py",
42
+ ".js",
43
+ ".html",
44
+ ".css",
45
+ ".ipynb",
46
+ ".r",
47
+ ".c",
48
+ ".cpp",
49
+ ".java",
50
+ ".json",
51
+ ".yaml",
52
+ ".yml",
53
+ ".csv",
54
+ ".tsv",
55
+ ".txt",
56
+ )
57
+ ):
58
+ tools.append({"type": "code_interpreter"})
59
+
60
+ if file_name.lower().endswith(
61
+ (
62
+ ".pdf",
63
+ ".docx",
64
+ ".doc",
65
+ ".pptx",
66
+ ".ppt",
67
+ ".xlsx",
68
+ ".xls",
69
+ ".csv",
70
+ ".tsv",
71
+ ".txt",
72
+ )
73
+ ):
74
+ tools.append({"type": "file_search"})
75
+
76
+ return tools
77
+
78
+
79
+ # Load environment variables
80
+ load_dotenv()
81
+
82
+
83
+ class NotionAgency(Agency):
84
+ """
85
+ Extension of the Agency class that includes a Notion database iframe
86
+ in the Gradio interface.
87
+ """
88
+
89
+ def demo_gradio(self, height=450, dark_mode=True, **kwargs):
90
+ """
91
+ Custom implementation of demo_gradio that includes a Notion iframe.
92
+ Inherits most functionality from the parent class but adds an iframe
93
+ at the top of the interface.
94
+ """
95
+ try:
96
+ import gradio as gr
97
+ except ImportError:
98
+ raise Exception("Please install gradio: pip install gradio")
99
+
100
+ # Use the specific Notion embed URL provided by the user
101
+ notion_embed_url = "https://morossini.notion.site/ebd/1a88235ee2ff801e8f93d8ab2e14de1d?v=1a88235ee2ff8063a16b000c5757dbd3"
102
+
103
+ # Create iframe with the exact attributes provided by the user
104
+ iframe_html = f"""
105
+ <iframe src="{notion_embed_url}" width="100%" height="600" frameborder="0" allowfullscreen></iframe>
106
+ <div style="text-align: center; margin-top: 5px; font-size: 12px; color: #888;">
107
+ If the Notion board doesn't appear, please ensure your Notion page is shared publicly with "Share to web" enabled.
108
+ </div>
109
+ """
110
+
111
+ js = """function () {
112
+ gradioURL = window.location.href
113
+ if (!gradioURL.endsWith('?__theme={theme}')) {
114
+ window.location.replace(gradioURL + '?__theme={theme}');
115
+ }
116
+ }"""
117
+
118
+ # Set dark mode by default
119
+ if dark_mode:
120
+ js = js.replace("{theme}", "dark")
121
+ else:
122
+ js = js.replace("{theme}", "light")
123
+
124
+ attachments = []
125
+ images = []
126
+ message_file_names = None
127
+ uploading_files = False
128
+ recipient_agent_names = [agent.name for agent in self.main_recipients]
129
+ recipient_agent = self.main_recipients[0]
130
+
131
+ with gr.Blocks(js=js) as demo:
132
+ chatbot_queue = queue.Queue()
133
+
134
+ # Add the iframe at the top, taking up space above the chatbot
135
+ with gr.Row():
136
+ iframe = gr.HTML(iframe_html)
137
+
138
+ # Original components from Agency.demo_gradio
139
+ chatbot = gr.Chatbot(height=height)
140
+ with gr.Row():
141
+ with gr.Column(scale=9):
142
+ dropdown = gr.Dropdown(
143
+ label="Recipient Agent",
144
+ choices=recipient_agent_names,
145
+ value=recipient_agent.name,
146
+ )
147
+ msg = gr.Textbox(label="Your Message", lines=4)
148
+ with gr.Column(scale=1):
149
+ file_upload = gr.Files(label="OpenAI Files", type="filepath")
150
+ button = gr.Button(value="Send", variant="primary")
151
+
152
+ def handle_dropdown_change(selected_option):
153
+ nonlocal recipient_agent
154
+ recipient_agent = self._get_agent_by_name(selected_option)
155
+
156
+ def handle_file_upload(file_list):
157
+ nonlocal attachments
158
+ nonlocal message_file_names
159
+ nonlocal uploading_files
160
+ nonlocal images
161
+ uploading_files = True
162
+ attachments = []
163
+ message_file_names = []
164
+ if file_list:
165
+ try:
166
+ for file_obj in file_list:
167
+ purpose = get_file_purpose(file_obj.name)
168
+
169
+ with open(file_obj.name, "rb") as f:
170
+ # Upload the file to OpenAI
171
+ file = self.main_thread.client.files.create(
172
+ file=f, purpose=purpose
173
+ )
174
+
175
+ if purpose == "vision":
176
+ images.append(
177
+ {
178
+ "type": "image_file",
179
+ "image_file": {"file_id": file.id},
180
+ }
181
+ )
182
+ else:
183
+ attachments.append(
184
+ {
185
+ "file_id": file.id,
186
+ "tools": get_tools(file.filename),
187
+ }
188
+ )
189
+
190
+ message_file_names.append(file.filename)
191
+ print(f"Uploaded file ID: {file.id}")
192
+ return attachments
193
+ except Exception as e:
194
+ print(f"Error: {e}")
195
+ return str(e)
196
+ finally:
197
+ uploading_files = False
198
+
199
+ uploading_files = False
200
+ return "No files uploaded"
201
+
202
+ def user(user_message, history):
203
+ if not user_message.strip():
204
+ return user_message, history
205
+
206
+ nonlocal message_file_names
207
+ nonlocal uploading_files
208
+ nonlocal images
209
+ nonlocal attachments
210
+ nonlocal recipient_agent
211
+
212
+ # Check if attachments contain file search or code interpreter types
213
+ def check_and_add_tools_in_attachments(attachments, recipient_agent):
214
+ for attachment in attachments:
215
+ for tool in attachment.get("tools", []):
216
+ if tool["type"] == "file_search":
217
+ if not any(
218
+ isinstance(t, FileSearch)
219
+ for t in recipient_agent.tools
220
+ ):
221
+ # Add FileSearch tool if it does not exist
222
+ recipient_agent.tools.append(FileSearch)
223
+ recipient_agent.client.beta.assistants.update(
224
+ recipient_agent.id,
225
+ tools=recipient_agent.get_oai_tools(),
226
+ )
227
+ print(
228
+ "Added FileSearch tool to recipient agent to analyze the file."
229
+ )
230
+ elif tool["type"] == "code_interpreter":
231
+ if not any(
232
+ isinstance(t, CodeInterpreter)
233
+ for t in recipient_agent.tools
234
+ ):
235
+ # Add CodeInterpreter tool if it does not exist
236
+ recipient_agent.tools.append(CodeInterpreter)
237
+ recipient_agent.client.beta.assistants.update(
238
+ recipient_agent.id,
239
+ tools=recipient_agent.get_oai_tools(),
240
+ )
241
+ print(
242
+ "Added CodeInterpreter tool to recipient agent to analyze the file."
243
+ )
244
+ return None
245
+
246
+ check_and_add_tools_in_attachments(attachments, recipient_agent)
247
+
248
+ if history is None:
249
+ history = []
250
+
251
+ original_user_message = user_message
252
+
253
+ # Append the user message with a placeholder for bot response
254
+ if recipient_agent:
255
+ user_message = (
256
+ f"👤 User 🗣️ @{recipient_agent.name}:\n" + user_message.strip()
257
+ )
258
+ else:
259
+ user_message = f"👤 User:" + user_message.strip()
260
+
261
+ nonlocal message_file_names
262
+ if message_file_names:
263
+ user_message += "\n\n📎 Files:\n" + "\n".join(message_file_names)
264
+
265
+ return original_user_message, history + [[user_message, None]]
266
+
267
+ class GradioEventHandler(AgencyEventHandler):
268
+ message_output = None
269
+
270
+ @classmethod
271
+ def change_recipient_agent(cls, recipient_agent_name):
272
+ nonlocal chatbot_queue
273
+ chatbot_queue.put("[change_recipient_agent]")
274
+ chatbot_queue.put(recipient_agent_name)
275
+
276
+ @override
277
+ def on_message_created(self, message: Message) -> None:
278
+ if message.role == "user":
279
+ full_content = ""
280
+ for content in message.content:
281
+ if content.type == "image_file":
282
+ full_content += (
283
+ f"🖼️ Image File: {content.image_file.file_id}\n"
284
+ )
285
+ continue
286
+
287
+ if content.type == "image_url":
288
+ full_content += f"\n{content.image_url.url}\n"
289
+ continue
290
+
291
+ if content.type == "text":
292
+ full_content += content.text.value + "\n"
293
+
294
+ self.message_output = MessageOutput(
295
+ "text",
296
+ self.agent_name,
297
+ self.recipient_agent_name,
298
+ full_content,
299
+ )
300
+
301
+ else:
302
+ self.message_output = MessageOutput(
303
+ "text", self.recipient_agent_name, self.agent_name, ""
304
+ )
305
+
306
+ chatbot_queue.put("[new_message]")
307
+ chatbot_queue.put(self.message_output.get_formatted_content())
308
+
309
+ @override
310
+ def on_text_delta(self, delta, snapshot):
311
+ chatbot_queue.put(delta.value)
312
+
313
+ @override
314
+ def on_tool_call_created(self, tool_call: ToolCall):
315
+ if isinstance(tool_call, dict):
316
+ if "type" not in tool_call:
317
+ tool_call["type"] = "function"
318
+
319
+ if tool_call["type"] == "function":
320
+ tool_call = FunctionToolCall(**tool_call)
321
+ elif tool_call["type"] == "code_interpreter":
322
+ tool_call = CodeInterpreterToolCall(**tool_call)
323
+ elif (
324
+ tool_call["type"] == "file_search"
325
+ or tool_call["type"] == "retrieval"
326
+ ):
327
+ tool_call = FileSearchToolCall(**tool_call)
328
+ else:
329
+ raise ValueError(
330
+ "Invalid tool call type: " + tool_call["type"]
331
+ )
332
+
333
+ # TODO: add support for code interpreter and retrieval tools
334
+ if tool_call.type == "function":
335
+ chatbot_queue.put("[new_message]")
336
+ self.message_output = MessageOutput(
337
+ "function",
338
+ self.recipient_agent_name,
339
+ self.agent_name,
340
+ str(tool_call.function),
341
+ )
342
+ chatbot_queue.put(
343
+ self.message_output.get_formatted_header() + "\n"
344
+ )
345
+
346
+ @override
347
+ def on_tool_call_done(self, snapshot: ToolCall):
348
+ if isinstance(snapshot, dict):
349
+ if "type" not in snapshot:
350
+ snapshot["type"] = "function"
351
+
352
+ if snapshot["type"] == "function":
353
+ snapshot = FunctionToolCall(**snapshot)
354
+ elif snapshot["type"] == "code_interpreter":
355
+ snapshot = CodeInterpreterToolCall(**snapshot)
356
+ elif snapshot["type"] == "file_search":
357
+ snapshot = FileSearchToolCall(**snapshot)
358
+ else:
359
+ raise ValueError(
360
+ "Invalid tool call type: " + snapshot["type"]
361
+ )
362
+
363
+ self.message_output = None
364
+
365
+ # TODO: add support for code interpreter and retrieval tools
366
+ if snapshot.type != "function":
367
+ return
368
+
369
+ chatbot_queue.put(str(snapshot.function))
370
+
371
+ if snapshot.function.name == "SendMessage":
372
+ try:
373
+ args = eval(snapshot.function.arguments)
374
+ recipient = args["recipient"]
375
+ self.message_output = MessageOutput(
376
+ "text",
377
+ self.recipient_agent_name,
378
+ recipient,
379
+ args["message"],
380
+ )
381
+
382
+ chatbot_queue.put("[new_message]")
383
+ chatbot_queue.put(
384
+ self.message_output.get_formatted_content()
385
+ )
386
+ except Exception as e:
387
+ pass
388
+
389
+ self.message_output = None
390
+
391
+ @override
392
+ def on_run_step_done(self, run_step: RunStep) -> None:
393
+ if run_step.type == "tool_calls":
394
+ for tool_call in run_step.step_details.tool_calls:
395
+ if tool_call.type != "function":
396
+ continue
397
+
398
+ if tool_call.function.name == "SendMessage":
399
+ continue
400
+
401
+ self.message_output = None
402
+ chatbot_queue.put("[new_message]")
403
+
404
+ self.message_output = MessageOutput(
405
+ "function_output",
406
+ tool_call.function.name,
407
+ self.recipient_agent_name,
408
+ tool_call.function.output,
409
+ )
410
+
411
+ chatbot_queue.put(
412
+ self.message_output.get_formatted_header() + "\n"
413
+ )
414
+ chatbot_queue.put(tool_call.function.output)
415
+
416
+ @override
417
+ @classmethod
418
+ def on_all_streams_end(cls):
419
+ cls.message_output = None
420
+ chatbot_queue.put("[end]")
421
+
422
+ def bot(original_message, history, dropdown):
423
+ nonlocal attachments
424
+ nonlocal message_file_names
425
+ nonlocal recipient_agent
426
+ nonlocal recipient_agent_names
427
+ nonlocal images
428
+ nonlocal uploading_files
429
+
430
+ if not original_message:
431
+ return (
432
+ "",
433
+ history,
434
+ gr.update(
435
+ value=recipient_agent.name,
436
+ choices=set([*recipient_agent_names, recipient_agent.name]),
437
+ ),
438
+ )
439
+
440
+ if uploading_files:
441
+ history.append([None, "Uploading files... Please wait."])
442
+ yield (
443
+ "",
444
+ history,
445
+ gr.update(
446
+ value=recipient_agent.name,
447
+ choices=set([*recipient_agent_names, recipient_agent.name]),
448
+ ),
449
+ )
450
+ return (
451
+ "",
452
+ history,
453
+ gr.update(
454
+ value=recipient_agent.name,
455
+ choices=set([*recipient_agent_names, recipient_agent.name]),
456
+ ),
457
+ )
458
+
459
+ print("Message files: ", message_file_names)
460
+ print("Images: ", images)
461
+
462
+ if images and len(images) > 0:
463
+ original_message = [
464
+ {
465
+ "type": "text",
466
+ "text": original_message,
467
+ },
468
+ *images,
469
+ ]
470
+
471
+ completion_thread = threading.Thread(
472
+ target=self.get_completion_stream,
473
+ args=(
474
+ original_message,
475
+ GradioEventHandler,
476
+ [],
477
+ recipient_agent,
478
+ "",
479
+ attachments,
480
+ None,
481
+ ),
482
+ )
483
+ completion_thread.start()
484
+
485
+ attachments = []
486
+ message_file_names = []
487
+ images = []
488
+ uploading_files = False
489
+
490
+ new_message = True
491
+ while True:
492
+ try:
493
+ bot_message = chatbot_queue.get(block=True)
494
+
495
+ if bot_message == "[end]":
496
+ completion_thread.join()
497
+ break
498
+
499
+ if bot_message == "[new_message]":
500
+ new_message = True
501
+ continue
502
+
503
+ if bot_message == "[change_recipient_agent]":
504
+ new_agent_name = chatbot_queue.get(block=True)
505
+ recipient_agent = self._get_agent_by_name(new_agent_name)
506
+ yield (
507
+ "",
508
+ history,
509
+ gr.update(
510
+ value=new_agent_name,
511
+ choices=set(
512
+ [*recipient_agent_names, recipient_agent.name]
513
+ ),
514
+ ),
515
+ )
516
+ continue
517
+
518
+ if new_message:
519
+ history.append([None, bot_message])
520
+ new_message = False
521
+ else:
522
+ history[-1][1] += bot_message
523
+
524
+ yield (
525
+ "",
526
+ history,
527
+ gr.update(
528
+ value=recipient_agent.name,
529
+ choices=set(
530
+ [*recipient_agent_names, recipient_agent.name]
531
+ ),
532
+ ),
533
+ )
534
+ except queue.Empty:
535
+ break
536
+
537
+ button.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot]).then(
538
+ bot, [msg, chatbot, dropdown], [msg, chatbot, dropdown]
539
+ )
540
+ dropdown.change(handle_dropdown_change, dropdown)
541
+ file_upload.change(handle_file_upload, file_upload)
542
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
543
+ bot, [msg, chatbot, dropdown], [msg, chatbot, dropdown]
544
+ )
545
+
546
+ # Enable queuing for streaming intermediate outputs
547
+ demo.queue(default_concurrency_limit=10)
548
+
549
+ # Launch the demo
550
+ demo.launch(**kwargs)
551
+ return demo
552
+
553
+
554
+ def main():
555
+ print("Setting up the demo...")
556
+
557
+ # Configure OpenAI client for agency-swarm
558
+ client = AzureOpenAI(
559
+ api_key=os.getenv("AZURE_OPENAI_API_KEY"),
560
+ api_version=os.getenv("OPENAI_API_VERSION"),
561
+ azure_endpoint=os.getenv("AZURE_OPENAI_ENDPOINT"),
562
+ timeout=5,
563
+ max_retries=5,
564
+ )
565
+
566
+ set_openai_client(client)
567
+
568
+ # Create our agents
569
+ technical_project_manager = TechnicalProjectManager()
570
+ research_and_report_agent = ResearchAndReportAgent()
571
+ notion_project_agent = NotionProjectAgent()
572
+
573
+ # Create the agency with our agents - using NotionAgency instead of Agency
574
+ agency = NotionAgency(
575
+ agency_chart=[
576
+ technical_project_manager,
577
+ [technical_project_manager, notion_project_agent],
578
+ [technical_project_manager, research_and_report_agent],
579
+ ],
580
+ shared_instructions="agency_manifesto.md",
581
+ )
582
+
583
+ # Launch the demo with Gradio's built-in deployment
584
+ return agency.demo_gradio(height=450)
585
+
586
+
587
+ if __name__ == "__main__":
588
+ main()