Spaces:
Sleeping
Sleeping
feat(agent): add agent to answer question
Browse files- agent.py +30 -0
- app.py +23 -3
- requirements.txt +2 -1
- system_prompt.txt +5 -0
- tools.py +215 -0
agent.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langgraph.prebuilt import create_react_agent
|
2 |
+
from langchain_openai import ChatOpenAI
|
3 |
+
from langchain_core.messages import HumanMessage
|
4 |
+
import os
|
5 |
+
import dotenv
|
6 |
+
from tools import get_tools
|
7 |
+
import json
|
8 |
+
import pandas as pd
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
dotenv.load_dotenv()
|
12 |
+
|
13 |
+
def post_process_answer(answer: str):
|
14 |
+
if "FINAL ANSWER:" not in answer:
|
15 |
+
raise ValueError("The answer does not contain 'FINAL ANSWER:' keyword")
|
16 |
+
key_answer_start_idx = answer.find("FINAL ANSWER:") + len("FINAL ANSWER:")
|
17 |
+
key_answer = answer[key_answer_start_idx:].strip()
|
18 |
+
return key_answer
|
19 |
+
|
20 |
+
llm = ChatOpenAI(
|
21 |
+
api_key=os.getenv("API_KEY"),
|
22 |
+
model=os.getenv("MODEL")
|
23 |
+
)
|
24 |
+
|
25 |
+
tools = get_tools()
|
26 |
+
|
27 |
+
with open("system_prompt.txt") as f:
|
28 |
+
system_prompt = f.read()
|
29 |
+
|
30 |
+
agent = create_react_agent(model=llm, tools=tools, prompt=system_prompt)
|
app.py
CHANGED
@@ -3,6 +3,11 @@ import gradio as gr
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
# (Keep Constants as is)
|
8 |
# --- Constants ---
|
@@ -13,11 +18,26 @@ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
|
|
13 |
class BasicAgent:
|
14 |
def __init__(self):
|
15 |
print("BasicAgent initialized.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
def __call__(self, question: str) -> str:
|
17 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
23 |
"""
|
|
|
3 |
import requests
|
4 |
import inspect
|
5 |
import pandas as pd
|
6 |
+
from agent import agent
|
7 |
+
from langchain_core.messages import HumanMessage
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
|
10 |
+
load_dotenv()
|
11 |
|
12 |
# (Keep Constants as is)
|
13 |
# --- Constants ---
|
|
|
18 |
class BasicAgent:
|
19 |
def __init__(self):
|
20 |
print("BasicAgent initialized.")
|
21 |
+
|
22 |
+
def post_process_answer(answer: str):
|
23 |
+
if "FINAL ANSWER:" not in answer:
|
24 |
+
raise ValueError("The answer does not contain 'FINAL ANSWER:' keyword")
|
25 |
+
key_answer_start_idx = answer.find("FINAL ANSWER:") + len("FINAL ANSWER:")
|
26 |
+
key_answer = answer[key_answer_start_idx:].strip()
|
27 |
+
return key_answer
|
28 |
+
|
29 |
def __call__(self, question: str) -> str:
|
30 |
print(f"Agent received question (first 50 chars): {question[:50]}...")
|
31 |
+
try:
|
32 |
+
result = agent.invoke({"messages": [HumanMessage(content=question)]})
|
33 |
+
answer = result["messages"][-1].content
|
34 |
+
key_answer = self.post_process_answer(answer)
|
35 |
+
print("Question:", question)
|
36 |
+
print("Answer:", key_answer)
|
37 |
+
except Exception as e:
|
38 |
+
print(e)
|
39 |
+
key_answer = str(e)
|
40 |
+
return key_answer
|
41 |
|
42 |
def run_and_submit_all( profile: gr.OAuthProfile | None):
|
43 |
"""
|
requirements.txt
CHANGED
@@ -1,2 +1,3 @@
|
|
1 |
gradio
|
2 |
-
requests
|
|
|
|
1 |
gradio
|
2 |
+
requests
|
3 |
+
gradio[oauth]
|
system_prompt.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
You are a helpful assistant tasked with answering questions using a set of tools.
|
2 |
+
Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
|
3 |
+
FINAL ANSWER: [YOUR FINAL ANSWER].
|
4 |
+
YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
5 |
+
Your answer should only start with "FINAL ANSWER: ", then follows with the answer.
|
tools.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from langchain.tools import tool
|
3 |
+
from duckduckgo_search import DDGS
|
4 |
+
from bs4 import BeautifulSoup
|
5 |
+
import tempfile
|
6 |
+
from typing import Optional
|
7 |
+
import os
|
8 |
+
from urllib.parse import urlparse
|
9 |
+
|
10 |
+
|
11 |
+
@tool("search", return_direct=False)
|
12 |
+
def search(query: str) -> str:
|
13 |
+
"""Searches the internet using DuckDuckGo
|
14 |
+
|
15 |
+
Args:
|
16 |
+
query (str): Search query
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
str: Search results
|
20 |
+
"""
|
21 |
+
with DDGS() as ddgs:
|
22 |
+
results = [r for r in ddgs.text(query, max_results=5)]
|
23 |
+
return results if results else "No results found."
|
24 |
+
|
25 |
+
|
26 |
+
@tool("process_content", return_direct=False)
|
27 |
+
def process_content(url: str) -> str:
|
28 |
+
"""Process content from a webpage
|
29 |
+
|
30 |
+
Args:
|
31 |
+
url (str): URL to get content
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
str: Content in the webpage
|
35 |
+
"""
|
36 |
+
response = requests.get(url)
|
37 |
+
soup = BeautifulSoup(response.content, "html.parser")
|
38 |
+
return soup.get_text()
|
39 |
+
|
40 |
+
|
41 |
+
@tool("save_file")
|
42 |
+
def save_file(content: str, filename: Optional[str] = None) -> str:
|
43 |
+
"""
|
44 |
+
Save content to a temporary file and return the path.
|
45 |
+
Useful for processing files from the GAIA API.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
content: The content to save to the file
|
49 |
+
filename: Optional filename, will generate a random name if not provided
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Path to the saved file
|
53 |
+
"""
|
54 |
+
temp_dir = tempfile.gettempdir()
|
55 |
+
if filename is None:
|
56 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False)
|
57 |
+
filepath = temp_file.name
|
58 |
+
else:
|
59 |
+
filepath = os.path.join(temp_dir, filename)
|
60 |
+
|
61 |
+
# Write content to the file
|
62 |
+
with open(filepath, "w") as f:
|
63 |
+
f.write(content)
|
64 |
+
|
65 |
+
return f"File saved to {filepath}. You can read this file to process its contents."
|
66 |
+
|
67 |
+
|
68 |
+
@tool("download_file_from_url")
|
69 |
+
def download_file_from_url(url: str, filename: Optional[str] = None) -> str:
|
70 |
+
"""
|
71 |
+
Download a file from a URL and save it to a temporary location.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
url: The URL to download from
|
75 |
+
filename: Optional filename, will generate one based on URL if not provided
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Path to the downloaded file
|
79 |
+
"""
|
80 |
+
try:
|
81 |
+
# Parse URL to get filename if not provided
|
82 |
+
if not filename:
|
83 |
+
path = urlparse(url).path
|
84 |
+
filename = os.path.basename(path)
|
85 |
+
if not filename:
|
86 |
+
# Generate a random name if we couldn't extract one
|
87 |
+
import uuid
|
88 |
+
|
89 |
+
filename = f"downloaded_{uuid.uuid4().hex[:8]}"
|
90 |
+
|
91 |
+
# Create temporary file
|
92 |
+
temp_dir = tempfile.gettempdir()
|
93 |
+
filepath = os.path.join(temp_dir, filename)
|
94 |
+
|
95 |
+
# Download the file
|
96 |
+
response = requests.get(url, stream=True)
|
97 |
+
response.raise_for_status()
|
98 |
+
|
99 |
+
# Save the file
|
100 |
+
with open(filepath, "wb") as f:
|
101 |
+
for chunk in response.iter_content(chunk_size=8192):
|
102 |
+
f.write(chunk)
|
103 |
+
|
104 |
+
return f"File downloaded to {filepath}. You can now process this file."
|
105 |
+
except Exception as e:
|
106 |
+
return f"Error downloading file: {str(e)}"
|
107 |
+
|
108 |
+
|
109 |
+
@tool("extract_text_from_image")
|
110 |
+
def extract_text_from_image(image_path: str) -> str:
|
111 |
+
"""
|
112 |
+
Extract text from an image using pytesseract (if available).
|
113 |
+
|
114 |
+
Args:
|
115 |
+
image_path: Path to the image file
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
Extracted text or error message
|
119 |
+
"""
|
120 |
+
try:
|
121 |
+
# Try to import pytesseract
|
122 |
+
import pytesseract
|
123 |
+
from PIL import Image
|
124 |
+
|
125 |
+
# Open the image
|
126 |
+
image = Image.open(image_path)
|
127 |
+
|
128 |
+
# Extract text
|
129 |
+
text = pytesseract.image_to_string(image)
|
130 |
+
|
131 |
+
return f"Extracted text from image:\n\n{text}"
|
132 |
+
except ImportError:
|
133 |
+
return "Error: pytesseract is not installed. Please install it with 'pip install pytesseract' and ensure Tesseract OCR is installed on your system."
|
134 |
+
except Exception as e:
|
135 |
+
return f"Error extracting text from image: {str(e)}"
|
136 |
+
|
137 |
+
|
138 |
+
@tool("analyze_csv_file")
|
139 |
+
def analyze_csv_file(file_path: str, query: str) -> str:
|
140 |
+
"""
|
141 |
+
Analyze a CSV file using pandas and answer a question about it.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
file_path: Path to the CSV file
|
145 |
+
query: Question about the data
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
Analysis result or error message
|
149 |
+
"""
|
150 |
+
try:
|
151 |
+
import pandas as pd
|
152 |
+
|
153 |
+
# Read the CSV file
|
154 |
+
df = pd.read_csv(file_path)
|
155 |
+
|
156 |
+
# Run various analyses based on the query
|
157 |
+
result = f"CSV file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
|
158 |
+
result += f"Columns: {', '.join(df.columns)}\n\n"
|
159 |
+
|
160 |
+
# Add summary statistics
|
161 |
+
result += "Summary statistics:\n"
|
162 |
+
result += str(df.describe())
|
163 |
+
|
164 |
+
return result
|
165 |
+
except ImportError:
|
166 |
+
return "Error: pandas is not installed. Please install it with 'pip install pandas'."
|
167 |
+
except Exception as e:
|
168 |
+
return f"Error analyzing CSV file: {str(e)}"
|
169 |
+
|
170 |
+
|
171 |
+
@tool("analyze_excel_file")
|
172 |
+
def analyze_excel_file(file_path: str, query: str) -> str:
|
173 |
+
"""
|
174 |
+
Analyze an Excel file using pandas and answer a question about it.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
file_path: Path to the Excel file
|
178 |
+
query: Question about the data
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
Analysis result or error message
|
182 |
+
"""
|
183 |
+
try:
|
184 |
+
import pandas as pd
|
185 |
+
|
186 |
+
# Read the Excel file
|
187 |
+
df = pd.read_excel(file_path)
|
188 |
+
|
189 |
+
# Run various analyses based on the query
|
190 |
+
result = (
|
191 |
+
f"Excel file loaded with {len(df)} rows and {len(df.columns)} columns.\n"
|
192 |
+
)
|
193 |
+
result += f"Columns: {', '.join(df.columns)}\n\n"
|
194 |
+
|
195 |
+
# Add summary statistics
|
196 |
+
result += "Summary statistics:\n"
|
197 |
+
result += str(df.describe())
|
198 |
+
|
199 |
+
return result
|
200 |
+
except ImportError:
|
201 |
+
return "Error: pandas and openpyxl are not installed. Please install them with 'pip install pandas openpyxl'."
|
202 |
+
except Exception as e:
|
203 |
+
return f"Error analyzing Excel file: {str(e)}"
|
204 |
+
|
205 |
+
|
206 |
+
def get_tools():
|
207 |
+
return [
|
208 |
+
search,
|
209 |
+
# process_content,
|
210 |
+
# save_file,
|
211 |
+
# download_file_from_url,
|
212 |
+
# extract_text_from_image,
|
213 |
+
# analyze_csv_file,
|
214 |
+
# analyze_excel_file
|
215 |
+
]
|