qdqd commited on
Commit
5ee835a
Β·
verified Β·
1 Parent(s): 9ba394a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +325 -0
app.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import threading
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+ from datetime import datetime
7
+ from pathlib import Path
8
+ from typing import List, Optional
9
+
10
+ import datasets
11
+ import pandas as pd
12
+ from dotenv import load_dotenv
13
+ from huggingface_hub import login
14
+ import gradio as gr
15
+ from duckduckgo_search import DDGS
16
+
17
+ from scripts.reformulator import prepare_response
18
+ from scripts.run_agents import (
19
+ get_single_file_description,
20
+ get_zip_description,
21
+ )
22
+ from scripts.text_inspector_tool import TextInspectorTool
23
+ from scripts.text_web_browser import (
24
+ ArchiveSearchTool,
25
+ FinderTool,
26
+ FindNextTool,
27
+ PageDownTool,
28
+ PageUpTool,
29
+ VisitTool,
30
+ SimpleTextBrowser,
31
+ )
32
+ from scripts.visual_qa import visualizer
33
+ from tqdm import tqdm
34
+
35
+ from smolagents import (
36
+ CodeAgent,
37
+ HfApiModel,
38
+ LiteLLMModel,
39
+ Model,
40
+ ToolCallingAgent,
41
+ )
42
+ from smolagents.agent_types import AgentText, AgentImage, AgentAudio
43
+ from smolagents.gradio_ui import pull_messages_from_step, handle_agent_output_types
44
+
45
+ AUTHORIZED_IMPORTS = [
46
+ "requests",
47
+ "zipfile",
48
+ "os",
49
+ "pandas",
50
+ "numpy",
51
+ "sympy",
52
+ "json",
53
+ "bs4",
54
+ "pubchempy",
55
+ "xml",
56
+ "yahoo_finance",
57
+ "Bio",
58
+ "sklearn",
59
+ "scipy",
60
+ "pydub",
61
+ "io",
62
+ "PIL",
63
+ "chess",
64
+ "PyPDF2",
65
+ "pptx",
66
+ "torch",
67
+ "datetime",
68
+ "fractions",
69
+ "csv",
70
+ ]
71
+
72
+ load_dotenv(override=True)
73
+ login(os.getenv("HF_TOKEN"))
74
+
75
+ append_answer_lock = threading.Lock()
76
+
77
+ SET = "validation"
78
+
79
+ custom_role_conversions = {"tool-call": "assistant", "tool-response": "user"}
80
+
81
+ ### LOAD EVALUATION DATASET
82
+
83
+ eval_ds = datasets.load_dataset("gaia-benchmark/GAIA", "2023_all")[SET]
84
+ eval_ds = eval_ds.rename_columns({"Question": "question", "Final answer": "true_answer", "Level": "task"})
85
+
86
+ def preprocess_file_paths(row):
87
+ if len(row["file_name"]) > 0:
88
+ row["file_name"] = f"data/gaia/{SET}/" + row["file_name"]
89
+ return row
90
+
91
+ eval_ds = eval_ds.map(preprocess_file_paths)
92
+ eval_df = pd.DataFrame(eval_ds)
93
+ print("Loaded evaluation dataset:")
94
+ print(eval_df["task"].value_counts())
95
+
96
+ user_agent = "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/119.0.0.0 Safari/537.36 Edg/119.0.0.0"
97
+
98
+ BROWSER_CONFIG = {
99
+ "viewport_size": 1024 * 5,
100
+ "downloads_folder": "downloads_folder",
101
+ "request_kwargs": {
102
+ "headers": {"User-Agent": user_agent},
103
+ "timeout": 300,
104
+ },
105
+ }
106
+
107
+ os.makedirs(f"./{BROWSER_CONFIG['downloads_folder']}", exist_ok=True)
108
+
109
+ # Custom OpenAI configuration
110
+ model = LiteLLMModel(
111
+ "openai/custom-gpt",
112
+ custom_role_conversions=custom_role_conversions,
113
+ api_key=os.getenv("OPENAI_API_KEY"),
114
+ api_base=os.getenv("CUSTOM_OPENAI_API_BASE"),
115
+ temperature=0.1,
116
+ frequency_penalty=0.2
117
+ )
118
+
119
+ text_limit = 20000
120
+ ti_tool = TextInspectorTool(model, text_limit)
121
+ browser = SimpleTextBrowser(**BROWSER_CONFIG)
122
+
123
+ class DuckDuckGoSearchTool:
124
+ """Search tool using DuckDuckGo"""
125
+ name = "web_search"
126
+ description = "Search the web using DuckDuckGo (current information)"
127
+
128
+ def __init__(self, max_results: int = 5):
129
+ self.max_results = max_results
130
+
131
+ def run(self, query: str) -> str:
132
+ """Return formatted search results (snippets from webpages)"""
133
+ try:
134
+ web_results = []
135
+ with DDGS() as ddgs:
136
+ for result in ddgs.text(query, max_results=self.max_results):
137
+ web_results.append({
138
+ 'title': result['title'],
139
+ 'url': result['href'],
140
+ 'content': result['body']
141
+ })
142
+
143
+ formatted_results = []
144
+ for idx, res in enumerate(web_results[:self.max_results], 1):
145
+ formatted_results.append(
146
+ f"[{idx}] {res['title']}\n"
147
+ f"URL: {res['url']}\n"
148
+ f"Content: {res['content'][:500]}{'...' if len(res['content']) > 500 else ''}"
149
+ )
150
+ return "\n\n".join(formatted_results)
151
+ except Exception as e:
152
+ return f"Search error: {str(e)}"
153
+
154
+ WEB_TOOLS = [
155
+ DuckDuckGoSearchTool(max_results=5),
156
+ VisitTool(browser),
157
+ PageUpTool(browser),
158
+ PageDownTool(browser),
159
+ FinderTool(browser),
160
+ FindNextTool(browser),
161
+ ArchiveSearchTool(browser),
162
+ TextInspectorTool(model, text_limit),
163
+ ]
164
+
165
+ # Agent creation in a factory function
166
+ def create_agent():
167
+ """Creates a fresh agent instance for each session"""
168
+ return CodeAgent(
169
+ model=model,
170
+ tools=[visualizer] + WEB_TOOLS,
171
+ max_steps=5,
172
+ verbosity_level=2,
173
+ additional_authorized_imports=AUTHORIZED_IMPORTS,
174
+ planning_interval=4,
175
+ )
176
+
177
+ document_inspection_tool = TextInspectorTool(model, 20000)
178
+
179
+ def stream_to_gradio(
180
+ agent,
181
+ task: str,
182
+ reset_agent_memory: bool = False,
183
+ additional_args: Optional[dict] = None,
184
+ ):
185
+ """Runs an agent with the given task and streams the messages from the agent as gradio ChatMessages."""
186
+ for step_log in agent.run(task, stream=True, reset=reset_agent_memory, additional_args=additional_args):
187
+ for message in pull_messages_from_step(
188
+ step_log,
189
+ ):
190
+ yield message
191
+
192
+ final_answer = step_log # Last log is the run's final_answer
193
+ final_answer = handle_agent_output_types(final_answer)
194
+
195
+ if isinstance(final_answer, AgentText):
196
+ yield gr.ChatMessage(
197
+ role="assistant",
198
+ content=f"**Final answer:**\n{final_answer.to_string()}\n",
199
+ )
200
+ elif isinstance(final_answer, AgentImage):
201
+ yield gr.ChatMessage(
202
+ role="assistant",
203
+ content={"path": final_answer.to_string(), "mime_type": "image/png"},
204
+ )
205
+ elif isinstance(final_answer, AgentAudio):
206
+ yield gr.ChatMessage(
207
+ role="assistant",
208
+ content={"path": final_answer.to_string(), "mime_type": "audio/wav"},
209
+ )
210
+ else:
211
+ yield gr.ChatMessage(role="assistant", content=f"**Final answer:** {str(final_answer)}")
212
+
213
+ class GradioUI:
214
+ """A one-line interface to launch your agent in Gradio"""
215
+
216
+ def __init__(self, file_upload_folder: str | None = None):
217
+ self.file_upload_folder = file_upload_folder
218
+ if self.file_upload_folder is not None:
219
+ if not os.path.exists(file_upload_folder):
220
+ os.mkdir(file_upload_folder)
221
+
222
+ def interact_with_agent(self, prompt, messages, session_state):
223
+ if 'agent' not in session_state:
224
+ session_state['agent'] = create_agent()
225
+
226
+ messages.append(gr.ChatMessage(role="user", content=prompt))
227
+ yield messages
228
+
229
+ for msg in stream_to_gradio(session_state['agent'], task=prompt, reset_agent_memory=False):
230
+ messages.append(msg)
231
+ yield messages
232
+ yield messages
233
+
234
+ def upload_file(
235
+ self,
236
+ file,
237
+ file_uploads_log,
238
+ allowed_file_types=[
239
+ "application/pdf",
240
+ "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
241
+ "text/plain",
242
+ ],
243
+ ):
244
+ if file is None:
245
+ return gr.Textbox("No file uploaded", visible=True), file_uploads_log
246
+
247
+ try:
248
+ mime_type, _ = mimetypes.guess_type(file.name)
249
+ except Exception as e:
250
+ return gr.Textbox(f"Error: {e}", visible=True), file_uploads_log
251
+
252
+ if mime_type not in allowed_file_types:
253
+ return gr.Textbox("File type disallowed", visible=True), file_uploads_log
254
+
255
+ original_name = os.path.basename(file.name)
256
+ sanitized_name = re.sub(r"[^\w\-.]", "_", original_name)
257
+
258
+ type_to_ext = {}
259
+ for ext, t in mimetypes.types_map.items():
260
+ if t not in type_to_ext:
261
+ type_to_ext[t] = ext
262
+
263
+ sanitized_name = sanitized_name.split(".")[:-1]
264
+ sanitized_name.append("" + type_to_ext[mime_type])
265
+ sanitized_name = "".join(sanitized_name)
266
+
267
+ file_path = os.path.join(self.file_upload_folder, os.path.basename(sanitized_name))
268
+ shutil.copy(file.name, file_path)
269
+
270
+ return gr.Textbox(f"File uploaded: {file_path}", visible=True), file_uploads_log + [file_path]
271
+
272
+ def log_user_message(self, text_input, file_uploads_log):
273
+ return (
274
+ text_input
275
+ + (
276
+ f"\nYou have been provided with these files, which might be helpful or not: {file_uploads_log}"
277
+ if len(file_uploads_log) > 0
278
+ else ""
279
+ ),
280
+ "",
281
+ )
282
+
283
+ def launch(self, **kwargs):
284
+ with gr.Blocks(theme="ocean", fill_height=True) as demo:
285
+ gr.Markdown("""# Open Deep Research - AI Agent Interface
286
+ Advanced question answering using DuckDuckGo search and custom AI models""")
287
+
288
+ session_state = gr.State({})
289
+ stored_messages = gr.State([])
290
+ file_uploads_log = gr.State([])
291
+ chatbot = gr.Chatbot(
292
+ label="Research Agent",
293
+ type="messages",
294
+ avatar_images=(
295
+ None,
296
+ "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/smolagents/mascot_smol.png",
297
+ ),
298
+ resizeable=True,
299
+ scale=1,
300
+ )
301
+
302
+ if self.file_upload_folder is not None:
303
+ upload_file = gr.File(label="Upload a file")
304
+ upload_status = gr.Textbox(label="Upload Status", interactive=False, visible=False)
305
+ upload_file.change(
306
+ self.upload_file,
307
+ [upload_file, file_uploads_log],
308
+ [upload_status, file_uploads_log],
309
+ )
310
+
311
+ text_input = gr.Textbox(lines=1, label="Enter your question")
312
+ text_input.submit(
313
+ self.log_user_message,
314
+ [text_input, file_uploads_log],
315
+ [stored_messages, text_input],
316
+ ).then(
317
+ self.interact_with_agent,
318
+ [stored_messages, chatbot, session_state],
319
+ [chatbot]
320
+ )
321
+
322
+ demo.launch(debug=True, share=True, **kwargs)
323
+
324
+ if __name__ == "__main__":
325
+ GradioUI().launch()