Spaces:
Sleeping
Sleeping
Update agent.py
Browse files
agent.py
CHANGED
@@ -23,7 +23,7 @@ configure(api_key=os.getenv("GOOGLE_API_KEY"))
|
|
23 |
#logger = logging.getLogger(__name__)
|
24 |
|
25 |
# --- Model Configuration ---
|
26 |
-
GEMINI_MODEL_NAME = "gemini/gemini-
|
27 |
OPENAI_MODEL_NAME = "openai/gpt-4o"
|
28 |
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
|
29 |
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
|
@@ -123,6 +123,46 @@ class WikiContentFetcher(Tool):
|
|
123 |
except wiki.exceptions.PageError:
|
124 |
return f"'{page_title}' not found."
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
# --- Basic Agent Definition ---
|
127 |
class BasicAgent:
|
128 |
def __init__(self, provider="deepseek"):
|
@@ -137,6 +177,7 @@ class BasicAgent:
|
|
137 |
MathSolver(),
|
138 |
RiddleSolver(),
|
139 |
TextTransformer(),
|
|
|
140 |
]
|
141 |
self.agent = ToolCallingAgent(
|
142 |
model=model,
|
@@ -146,21 +187,24 @@ class BasicAgent:
|
|
146 |
)
|
147 |
self.agent.system_prompt = (
|
148 |
"""
|
149 |
-
You are a
|
150 |
-
|
151 |
-
If
|
152 |
-
If
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
If
|
158 |
-
If
|
159 |
-
|
160 |
-
|
161 |
-
|
|
|
|
|
162 |
|
163 |
-
|
|
|
164 |
"""
|
165 |
)
|
166 |
|
@@ -183,7 +227,40 @@ class BasicAgent:
|
|
183 |
final_str = result["final_answer"].strip()
|
184 |
else:
|
185 |
final_str = str(result).strip()
|
186 |
-
|
187 |
-
return final_str
|
188 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
#logger = logging.getLogger(__name__)
|
24 |
|
25 |
# --- Model Configuration ---
|
26 |
+
GEMINI_MODEL_NAME = "gemini/gemini-2.0-flash"
|
27 |
OPENAI_MODEL_NAME = "openai/gpt-4o"
|
28 |
GROQ_MODEL_NAME = "groq/llama3-70b-8192"
|
29 |
DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
|
|
|
123 |
except wiki.exceptions.PageError:
|
124 |
return f"'{page_title}' not found."
|
125 |
|
126 |
+
class FileAttachmentQueryTool(Tool):
|
127 |
+
name = "run_query_with_file"
|
128 |
+
description = """
|
129 |
+
Downloads a file mentioned in a user prompt, adds it to the context, and runs a query on it.
|
130 |
+
This assumes the file is 20MB or less.
|
131 |
+
"""
|
132 |
+
inputs = {
|
133 |
+
"task_id": {
|
134 |
+
"type": "string",
|
135 |
+
"description": "A unique identifier for the task related to this file, used to download it."
|
136 |
+
},
|
137 |
+
"mime_type": {
|
138 |
+
"type": "string",
|
139 |
+
"nullable": True,
|
140 |
+
"description": "The MIME type of the file, or the best guess if unknown."
|
141 |
+
},
|
142 |
+
"user_query": {
|
143 |
+
"type": "string",
|
144 |
+
"description": "The question to answer about the file."
|
145 |
+
}
|
146 |
+
}
|
147 |
+
output_type = "string"
|
148 |
+
|
149 |
+
def forward(self, task_id: str, mime_type: str | None, user_query: str) -> str:
|
150 |
+
file_url = f"https://agents-course-unit4-scoring.hf.space/files/{task_id}"
|
151 |
+
file_response = requests.get(file_url)
|
152 |
+
if file_response.status_code != 200:
|
153 |
+
return f"Failed to download file: {file_response.status_code} - {file_response.text}"
|
154 |
+
file_data = file_response.content
|
155 |
+
mime_type = mime_type or file_response.headers.get('Content-Type', 'application/octet-stream')
|
156 |
+
|
157 |
+
from google.generativeai import GenerativeModel
|
158 |
+
model = GenerativeModel(self.model_name)
|
159 |
+
response = model.generate_content([
|
160 |
+
types.Part.from_bytes(data=file_data, mime_type=mime_type),
|
161 |
+
user_query
|
162 |
+
])
|
163 |
+
|
164 |
+
return response.text
|
165 |
+
|
166 |
# --- Basic Agent Definition ---
|
167 |
class BasicAgent:
|
168 |
def __init__(self, provider="deepseek"):
|
|
|
177 |
MathSolver(),
|
178 |
RiddleSolver(),
|
179 |
TextTransformer(),
|
180 |
+
FileAttachmentQueryTool(model_name=GEMINI_MODEL_NAME),
|
181 |
]
|
182 |
self.agent = ToolCallingAgent(
|
183 |
model=model,
|
|
|
187 |
)
|
188 |
self.agent.system_prompt = (
|
189 |
"""
|
190 |
+
You are a GAIA benchmark AI assistant. Your sole purpose is to provide exact, minimal answers in the format 'FINAL ANSWER: [ANSWER]' with no additional text, explanations, or comments.
|
191 |
+
|
192 |
+
- If the answer is a number, use numerals (e.g., '42', not 'forty-two'), without commas or units (e.g., no '$', '%') unless explicitly requested.
|
193 |
+
- If the answer is a string, use no articles ('a', 'the'), no abbreviations (e.g., 'New York', not 'NY'), and write digits as text (e.g., 'one', not '1') unless specified.
|
194 |
+
- For comma-separated lists, apply the above rules to each element based on whether it's a number or string.
|
195 |
+
- Answer as literally as possible, making minimal assumptions and adhering to the question's narrowest interpretation.
|
196 |
+
- For videos, analyze the entire content but extract only the precise answer to the query, ignoring irrelevant details.
|
197 |
+
- For Wikipedia or search tools, distill results to the minimal correct answer, ignoring extraneous content.
|
198 |
+
- If proving something, compute step-by-step internally but output only the final result in the required format.
|
199 |
+
- If tool outputs are verbose, extract only the essential answer that satisfies the question.
|
200 |
+
- Under no circumstances include explanations, intermediate steps, or text outside the 'FINAL ANSWER: [ANSWER]' format.
|
201 |
+
|
202 |
+
Example:
|
203 |
+
Question: What is 2 + 2?
|
204 |
+
Response: FINAL ANSWER: 4
|
205 |
|
206 |
+
Your response must always be:
|
207 |
+
FINAL ANSWER: [ANSWER]
|
208 |
"""
|
209 |
)
|
210 |
|
|
|
227 |
final_str = result["final_answer"].strip()
|
228 |
else:
|
229 |
final_str = str(result).strip()
|
|
|
|
|
230 |
|
231 |
+
return f"FINAL ANSWER: {final_str}"
|
232 |
+
|
233 |
+
def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3, show_steps: bool = True):
|
234 |
+
df = pd.read_csv(csv_path)
|
235 |
+
if not {"question", "answer"}.issubset(df.columns):
|
236 |
+
print("CSV must contain 'question' and 'answer' columns.")
|
237 |
+
print("Found columns:", df.columns.tolist())
|
238 |
+
return
|
239 |
+
samples = df.sample(n=sample_size)
|
240 |
+
for _, row in samples.iterrows():
|
241 |
+
question = row["question"].strip()
|
242 |
+
expected = f"FINAL ANSWER: {str(row['answer']).strip()}"
|
243 |
+
result = self(question).strip()
|
244 |
+
if show_steps:
|
245 |
+
print("---")
|
246 |
+
print("Question:", question)
|
247 |
+
print("Expected:", expected)
|
248 |
+
print("Agent:", result)
|
249 |
+
print("Correct:", expected == result)
|
250 |
+
else:
|
251 |
+
print(f"Q: {question}\nE: {expected}\nA: {result}\n✓: {expected == result}\n")
|
252 |
+
|
253 |
+
if __name__ == "__main__":
|
254 |
+
args = sys.argv[1:]
|
255 |
+
if not args or args[0] in {"-h", "--help"}:
|
256 |
+
print("Usage: python agent.py [question | dev]")
|
257 |
+
print(" - Provide a question to get a GAIA-style answer.")
|
258 |
+
print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
|
259 |
+
sys.exit(0)
|
260 |
|
261 |
+
q = " ".join(args)
|
262 |
+
agent = BasicAgent()
|
263 |
+
if q == "dev":
|
264 |
+
agent.evaluate_random_questions()
|
265 |
+
else:
|
266 |
+
print(agent(q))
|