echodrift commited on
Commit
60ba1ca
·
1 Parent(s): 9173ae7

feat(agent): add agent to answer question

Browse files
Files changed (5) hide show
  1. agent.py +30 -0
  2. app.py +23 -3
  3. requirements.txt +2 -1
  4. system_prompt.txt +5 -0
  5. tools.py +215 -0
agent.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph.prebuilt import create_react_agent
2
+ from langchain_openai import ChatOpenAI
3
+ from langchain_core.messages import HumanMessage
4
+ import os
5
+ import dotenv
6
+ from tools import get_tools
7
+ import json
8
+ import pandas as pd
9
+ from tqdm import tqdm
10
+
11
+ dotenv.load_dotenv()
12
+
13
+ def post_process_answer(answer: str):
14
+ if "FINAL ANSWER:" not in answer:
15
+ raise ValueError("The answer does not contain 'FINAL ANSWER:' keyword")
16
+ key_answer_start_idx = answer.find("FINAL ANSWER:") + len("FINAL ANSWER:")
17
+ key_answer = answer[key_answer_start_idx:].strip()
18
+ return key_answer
19
+
20
+ llm = ChatOpenAI(
21
+ api_key=os.getenv("API_KEY"),
22
+ model=os.getenv("MODEL")
23
+ )
24
+
25
+ tools = get_tools()
26
+
27
+ with open("system_prompt.txt") as f:
28
+ system_prompt = f.read()
29
+
30
+ agent = create_react_agent(model=llm, tools=tools, prompt=system_prompt)
app.py CHANGED
@@ -3,6 +3,11 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
 
 
 
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
@@ -13,11 +18,26 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
13
  class BasicAgent:
14
  def __init__(self):
15
  print("BasicAgent initialized.")
 
 
 
 
 
 
 
 
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
 
 
 
 
 
 
 
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from agent import agent
7
+ from langchain_core.messages import HumanMessage
8
+ from dotenv import load_dotenv
9
+
10
+ load_dotenv()
11
 
12
  # (Keep Constants as is)
13
  # --- Constants ---
 
18
  class BasicAgent:
19
  def __init__(self):
20
  print("BasicAgent initialized.")
21
+
22
+ def post_process_answer(answer: str):
23
+ if "FINAL ANSWER:" not in answer:
24
+ raise ValueError("The answer does not contain 'FINAL ANSWER:' keyword")
25
+ key_answer_start_idx = answer.find("FINAL ANSWER:") + len("FINAL ANSWER:")
26
+ key_answer = answer[key_answer_start_idx:].strip()
27
+ return key_answer
28
+
29
  def __call__(self, question: str) -> str:
30
  print(f"Agent received question (first 50 chars): {question[:50]}...")
31
+ try:
32
+ result = agent.invoke({"messages": [HumanMessage(content=question)]})
33
+ answer = result["messages"][-1].content
34
+ key_answer = self.post_process_answer(answer)
35
+ print("Question:", question)
36
+ print("Answer:", key_answer)
37
+ except Exception as e:
38
+ print(e)
39
+ key_answer = str(e)
40
+ return key_answer
41
 
42
  def run_and_submit_all( profile: gr.OAuthProfile | None):
43
  """
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio
2
- requests
 
 
1
  gradio
2
+ requests
3
+ gradio[oauth]
system_prompt.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
tools.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from langchain.tools import tool
3
+ from duckduckgo_search import DDGS
4
+ from bs4 import BeautifulSoup
5
+ import tempfile
6
+ from typing import Optional
7
+ import os
8
+ from urllib.parse import urlparse
9
+
10
+
11
+ @tool("search", return_direct=False)
12
+ def search(query: str) -> str:
13
+ """Searches the internet using DuckDuckGo
14
+
15
+ Args:
16
+ query (str): Search query
17
+
18
+ Returns:
19
+ str: Search results
20
+ """
21
+ with DDGS() as ddgs:
22
+ results = [r for r in ddgs.text(query, max_results=5)]
23
+ return results if results else "No results found."
24
+
25
+
26
+ @tool("process_content", return_direct=False)
27
+ def process_content(url: str) -> str:
28
+ """Process content from a webpage
29
+
30
+ Args:
31
+ url (str): URL to get content
32
+
33
+ Returns:
34
+ str: Content in the webpage
35
+ """
36
+ response = requests.get(url)
37
+ soup = BeautifulSoup(response.content, "html.parser")
38
+ return soup.get_text()
39
+
40
+
41
+ @tool("save_file")
42
+ def save_file(content: str, filename: Optional[str] = None) -> str:
43
+ """
44
+ Save content to a temporary file and return the path.
45
+ Useful for processing files from the GAIA API.
46
+
47
+ Args:
48
+ content: The content to save to the file
49
+ filename: Optional filename, will generate a random name if not provided
50
+
51
+ Returns:
52
+ Path to the saved file
53
+ """
54
+ temp_dir = tempfile.gettempdir()
55
+ if filename is None:
56
+ temp_file = tempfile.NamedTemporaryFile(delete=False)
57
+ filepath = temp_file.name
58
+ else:
59
+ filepath = os.path.join(temp_dir, filename)
60
+
61
+ # Write content to the file
62
+ with open(filepath, "w") as f:
63
+ f.write(content)
64
+
65
+ return f"File saved to {filepath}. You can read this file to process its contents."
66
+
67
+
68
+ @tool("download_file_from_url")
69
+ def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
70
+ """
71
+ Download a file from a URL and save it to a temporary location.
72
+
73
+ Args:
74
+ url: The URL to download from
75
+ filename: Optional filename, will generate one based on URL if not provided
76
+
77
+ Returns:
78
+ Path to the downloaded file
79
+ """
80
+ try:
81
+ # Parse URL to get filename if not provided
82
+ if not filename:
83
+ path = urlparse(url).path
84
+ filename = os.path.basename(path)
85
+ if not filename:
86
+ # Generate a random name if we couldn't extract one
87
+ import uuid
88
+
89
+ filename = f"downloaded_{uuid.uuid4().hex[:8]}"
90
+
91
+ # Create temporary file
92
+ temp_dir = tempfile.gettempdir()
93
+ filepath = os.path.join(temp_dir, filename)
94
+
95
+ # Download the file
96
+ response = requests.get(url, stream=True)
97
+ response.raise_for_status()
98
+
99
+ # Save the file
100
+ with open(filepath, "wb") as f:
101
+ for chunk in response.iter_content(chunk_size=8192):
102
+ f.write(chunk)
103
+
104
+ return f"File downloaded to {filepath}. You can now process this file."
105
+ except Exception as e:
106
+ return f"Error downloading file: {str(e)}"
107
+
108
+
109
+ @tool("extract_text_from_image")
110
+ def extract_text_from_image(image_path: str) -> str:
111
+ """
112
+ Extract text from an image using pytesseract (if available).
113
+
114
+ Args:
115
+ image_path: Path to the image file
116
+
117
+ Returns:
118
+ Extracted text or error message
119
+ """
120
+ try:
121
+ # Try to import pytesseract
122
+ import pytesseract
123
+ from PIL import Image
124
+
125
+ # Open the image
126
+ image = Image.open(image_path)
127
+
128
+ # Extract text
129
+ text = pytesseract.image_to_string(image)
130
+
131
+ return f"Extracted text from image:\n\n{text}"
132
+ except ImportError:
133
+ return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
134
+ except Exception as e:
135
+ return f"Error extracting text from image: {str(e)}"
136
+
137
+
138
+ @tool("analyze_csv_file")
139
+ def analyze_csv_file(file_path: str, query: str) -> str:
140
+ """
141
+ Analyze a CSV file using pandas and answer a question about it.
142
+
143
+ Args:
144
+ file_path: Path to the CSV file
145
+ query: Question about the data
146
+
147
+ Returns:
148
+ Analysis result or error message
149
+ """
150
+ try:
151
+ import pandas as pd
152
+
153
+ # Read the CSV file
154
+ df = pd.read_csv(file_path)
155
+
156
+ # Run various analyses based on the query
157
+ result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
158
+ result += f"Columns: {', '.join(df.columns)}\n\n"
159
+
160
+ # Add summary statistics
161
+ result += "Summary statistics:\n"
162
+ result += str(df.describe())
163
+
164
+ return result
165
+ except ImportError:
166
+ return "Error: pandas is not installed. Please install it with 'pip install pandas'."
167
+ except Exception as e:
168
+ return f"Error analyzing CSV file: {str(e)}"
169
+
170
+
171
+ @tool("analyze_excel_file")
172
+ def analyze_excel_file(file_path: str, query: str) -> str:
173
+ """
174
+ Analyze an Excel file using pandas and answer a question about it.
175
+
176
+ Args:
177
+ file_path: Path to the Excel file
178
+ query: Question about the data
179
+
180
+ Returns:
181
+ Analysis result or error message
182
+ """
183
+ try:
184
+ import pandas as pd
185
+
186
+ # Read the Excel file
187
+ df = pd.read_excel(file_path)
188
+
189
+ # Run various analyses based on the query
190
+ result = (
191
+ f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
192
+ )
193
+ result += f"Columns: {', '.join(df.columns)}\n\n"
194
+
195
+ # Add summary statistics
196
+ result += "Summary statistics:\n"
197
+ result += str(df.describe())
198
+
199
+ return result
200
+ except ImportError:
201
+ return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
202
+ except Exception as e:
203
+ return f"Error analyzing Excel file: {str(e)}"
204
+
205
+
206
+ def get_tools():
207
+ return [
208
+ search,
209
+ # process_content,
210
+ # save_file,
211
+ # download_file_from_url,
212
+ # extract_text_from_image,
213
+ # analyze_csv_file,
214
+ # analyze_excel_file
215
+ ]