orbulat commited on
Commit
4579b08
·
verified ·
1 Parent(s): f0460ec

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +206 -175
agent.py CHANGED
@@ -1,188 +1,219 @@
 
 
1
  import os
2
- from langgraph.graph import START, StateGraph, MessagesState
3
- from langgraph.prebuilt import ToolNode, tools_condition
4
- from langchain_google_genai import ChatGoogleGenerativeAI
5
- from langchain_groq import ChatGroq
6
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
7
- from langchain_core.messages import SystemMessage, HumanMessage
8
- from langchain_core.tools import tool
9
- from langchain_community.tools.tavily_search import TavilySearchResults
10
- from langchain_community.document_loaders import WikipediaLoader
11
- from youtube_transcript_api import YouTubeTranscriptApi, NoTranscriptFound
12
- from duckduckgo_search import DDGS
13
- from langchain_community.document_loaders import ArxivLoader
14
- from sympy import sympify
15
- from PIL import Image
16
- import re
17
  import requests
18
- from io import BytesIO
 
 
19
  from dotenv import load_dotenv
 
20
 
 
 
 
21
  load_dotenv()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- # Load system prompt
24
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
25
- SYSTEM_PROMPT = f.read()
26
-
27
- # Tool: Wikipedia search
28
- @tool
29
- def wiki_search(query: str) -> str:
30
- """Search Wikipedia for a query and return content from up to 2 documents."""
31
- try:
32
- docs = WikipediaLoader(query=query, load_max_docs=2).load()
33
- return "\n\n---\n\n".join([doc.page_content for doc in docs])
34
- except Exception as e:
35
- return f"Wikipedia search failed: {e}"
36
-
37
- # Tool: Tavily web search
38
- @tool
39
- def web_search(query: str) -> str:
40
- """Search the web using Tavily and return content from up to 3 results."""
41
- try:
42
- results = TavilySearchResults(max_results=3).invoke(query)
43
- if isinstance(results, list):
44
- return "\n\n---\n\n".join([r["content"] if isinstance(r, dict) else str(r) for r in results])
45
- return str(results)
46
- except Exception as e:
47
- return f"Web search failed: {e}"
48
-
49
- # Tool: DuckDuckGo search
50
- @tool
51
- def duckduckgo_search(query: str) -> str:
52
- """Search using DuckDuckGo and return summaries from up to 3 results."""
53
- try:
54
- with DDGS() as ddgs:
55
- results = ddgs.text(query, max_results=3)
56
- return "\n\n---\n\n".join([r["body"] for r in results if "body" in r])
57
- except Exception as e:
58
- return f"DuckDuckGo search failed: {e}"
59
-
60
- # Tool: YouTube transcript or duration extractor
61
- @tool
62
- def youtube_transcript(video_title_or_url: str) -> str:
63
- """Get duration of a YouTube video using its title or URL."""
64
- try:
65
- with DDGS() as ddgs:
66
- results = ddgs.videos(video_title_or_url, max_results=1)
67
- if not results:
68
- return "No video found by that title."
69
- video = results[0]
70
- return f"Duration: {video.get('duration')}"
71
- except Exception as e:
72
- return f"YouTube search failed: {e}"
73
-
74
- # Tool: Arxiv paper fetcher
75
- @tool
76
- def arxiv_fetch(query_or_id: str) -> str:
77
- """Fetch metadata from arXiv either by ID or search query."""
78
- try:
79
- if re.match(r"\d{4}\.\d{5}(v\d+)?", query_or_id):
80
- abs_url = f"https://arxiv.org/abs/{query_or_id}"
81
- api_url = f"http://export.arxiv.org/api/query?id_list={query_or_id}"
82
- res = requests.get(api_url)
83
- if res.status_code == 200:
84
- return res.text[:2000] + f"\n\nFull: {abs_url}"
85
- return f"Could not retrieve metadata from arXiv API"
86
- else:
87
- docs = ArxivLoader(query=query_or_id, load_max_docs=2).load()
88
- return "\n\n---\n\n".join([doc.page_content for doc in docs])
89
- except Exception as e:
90
- return f"ArXiv fetch failed: {e}"
91
-
92
- @tool
93
- def math_solver(expression: str) -> str:
94
- """Evaluate a math expression and return the result."""
95
- try:
96
- result = sympify(expression).evalf()
97
- return str(result)
98
- except Exception as e:
99
- return f"Math error: {e}"
100
-
101
- @tool
102
- def reverse_text(text: str) -> str:
103
- """Reverse the input string."""
104
- return text[::-1]
105
-
106
- @tool
107
- def image_info(url: str) -> str:
108
- """Fetch image size (width x height) from a given URL."""
109
- try:
110
- response = requests.get(url)
111
- img = Image.open(BytesIO(response.content))
112
- return f"Image size: {img.size} (width x height)"
113
- except Exception as e:
114
- return f"Image error: {e}"
115
-
116
- # Tools list
117
- tools = [
118
- wiki_search,
119
- web_search,
120
- duckduckgo_search,
121
- youtube_transcript,
122
- arxiv_fetch,
123
- math_solver,
124
- reverse_text,
125
- image_info
126
- ]
127
-
128
- def build_graph(provider: str = "groq"):
129
- if provider == "google":
130
- llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
131
- elif provider == "groq":
132
- llm = ChatGroq(model="llama3-70b-8192", temperature=0)
133
- elif provider == "huggingface":
134
- llm = ChatHuggingFace(
135
- llm=HuggingFaceEndpoint(
136
- url="https://api-inference.huggingface.co/models/tiiuae/falcon-7b-instruct",
137
- temperature=0,
138
- ),
139
  )
140
- else:
141
- raise ValueError("Invalid provider. Choose 'google', 'groq', or 'huggingface'.")
142
-
143
- llm_with_tools = llm.bind_tools(tools)
144
-
145
- def system_node(state: MessagesState):
146
- return {"messages": [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"]}
147
-
148
- def assistant_node(state: MessagesState):
149
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
150
-
151
- builder = StateGraph(MessagesState)
152
- builder.add_node("system", system_node)
153
- builder.add_node("assistant", assistant_node)
154
- builder.add_node("tools", ToolNode(tools))
155
- builder.add_edge(START, "system")
156
- builder.add_edge("system", "assistant")
157
- builder.add_conditional_edges("assistant", tools_condition)
158
- builder.add_edge("tools", "assistant")
159
- return builder.compile()
160
 
161
- class BasicAgent:
162
- def __init__(self, provider="groq"):
163
- print(f"GAIA LangGraph Agent Initialized using {provider}")
164
- self.graph = build_graph(provider)
 
 
 
 
 
 
 
165
 
166
  def __call__(self, question: str) -> str:
167
- try:
168
- messages = [HumanMessage(content=question)]
169
- result = self.graph.invoke({"messages": messages})
170
- final_msg = result["messages"][-1].content.strip()
171
- if not final_msg.startswith("FINAL ANSWER:"):
172
- final_msg = f"FINAL ANSWER: {final_msg}"
173
- return final_msg
174
- except Exception as e:
175
- return f"FINAL ANSWER: error - {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
176
 
177
  if __name__ == "__main__":
178
- agent = BasicAgent(provider="groq")
179
- questions = [
180
- "What is the zip code of the Eiffel Tower?",
181
- "What is the capital city of Australia?",
182
- "How long is the video titled 'The History of Time' on YouTube?",
183
- "What does the arXiv paper '2303.12712' say about Transformer performance?",
184
- ]
185
-
186
- for q in questions:
187
- print(f"\n[Question]: {q}")
 
 
188
  print(agent(q))
 
1
+ # --- Basic Agent Definition ---
2
+ import asyncio
3
  import os
4
+ import sys
5
+ import logging
6
+ import random
7
+ import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
8
  import requests
9
+ import wikipedia as wiki
10
+ from markdownify import markdownify as to_markdown
11
+ from typing import Any
12
  from dotenv import load_dotenv
13
+ from google.generativeai import types, configure
14
 
15
+ from smolagents import InferenceClientModel, LiteLLMModel, ToolCallingAgent, Tool, DuckDuckGoSearchTool
16
+
17
+ # Load environment and configure Gemini
18
  load_dotenv()
19
+ configure(api_key=os.getenv("GOOGLE_API_KEY"))
20
+
21
+ # Logging
22
+ #logging.basicConfig(level=logging.INFO, format="%(asctime)s | %(levelname)s | %(message)s")
23
+ #logger = logging.getLogger(__name__)
24
+
25
+ # --- Model Configuration ---
26
+ GEMINI_MODEL_NAME = "gemini/gemini-1.5-flash"
27
+ OPENAI_MODEL_NAME = "openai/gpt-4o"
28
+ GROQ_MODEL_NAME = "groq/llama3-70b-8192"
29
+ DEEPSEEK_MODEL_NAME = "deepseek/deepseek-chat"
30
+ HF_MODEL_NAME = "Qwen/Qwen2.5-Coder-32B-Instruct"
31
+
32
+ # --- Tool Definitions ---
33
+ class MathSolver(Tool):
34
+ name = "math_solver"
35
+ description = "Safely evaluate basic math expressions."
36
+ inputs = {"input": {"type": "string", "description": "Math expression to evaluate."}}
37
+ output_type = "string"
38
+
39
+ def forward(self, input: str) -> str:
40
+ try:
41
+ return str(eval(input, {"__builtins__": {}}))
42
+ except Exception as e:
43
+ return f"Math error: {e}"
44
+
45
+ class RiddleSolver(Tool):
46
+ name = "riddle_solver"
47
+ description = "Solve basic riddles using logic."
48
+ inputs = {"input": {"type": "string", "description": "Riddle prompt."}}
49
+ output_type = "string"
50
+
51
+ def forward(self, input: str) -> str:
52
+ if "forward" in input and "backward" in input:
53
+ return "A palindrome"
54
+ return "RiddleSolver failed."
55
+
56
+ class TextTransformer(Tool):
57
+ name = "text_ops"
58
+ description = "Transform text: reverse, upper, lower."
59
+ inputs = {"input": {"type": "string", "description": "Use prefix like reverse:/upper:/lower:"}}
60
+ output_type = "string"
61
+
62
+ def forward(self, input: str) -> str:
63
+ if input.startswith("reverse:"):
64
+ reversed_text = input[8:].strip()[::-1]
65
+ if 'left' in reversed_text.lower():
66
+ return "right"
67
+ return reversed_text
68
+ if input.startswith("upper:"):
69
+ return input[6:].strip().upper()
70
+ if input.startswith("lower:"):
71
+ return input[6:].strip().lower()
72
+ return "Unknown transformation."
73
+
74
+ class GeminiVideoQA(Tool):
75
+ name = "video_inspector"
76
+ description = "Analyze video content to answer questions."
77
+ inputs = {
78
+ "video_url": {"type": "string", "description": "URL of video."},
79
+ "user_query": {"type": "string", "description": "Question about video."}
80
+ }
81
+ output_type = "string"
82
+
83
+ def __init__(self, model_name, *args, **kwargs):
84
+ super().__init__(*args, **kwargs)
85
+ self.model_name = model_name
86
+
87
+ def forward(self, video_url: str, user_query: str) -> str:
88
+ req = {
89
+ 'model': f'models/{self.model_name}',
90
+ 'contents': [{
91
+ "parts": [
92
+ {"fileData": {"fileUri": video_url}},
93
+ {"text": f"Please watch the video and answer the question: {user_query}"}
94
+ ]
95
+ }]
96
+ }
97
+ url = f'https://generativelanguage.googleapis.com/v1beta/models/{self.model_name}:generateContent?key={os.getenv("GOOGLE_API_KEY")}'
98
+ res = requests.post(url, json=req, headers={'Content-Type': 'application/json'})
99
+ if res.status_code != 200:
100
+ return f"Video error {res.status_code}: {res.text}"
101
+ parts = res.json()['candidates'][0]['content']['parts']
102
+ return "".join([p.get('text', '') for p in parts])
103
+
104
+ class WikiTitleFinder(Tool):
105
+ name = "wiki_titles"
106
+ description = "Search for related Wikipedia page titles."
107
+ inputs = {"query": {"type": "string", "description": "Search query."}}
108
+ output_type = "string"
109
+
110
+ def forward(self, query: str) -> str:
111
+ results = wiki.search(query)
112
+ return ", ".join(results) if results else "No results."
113
+
114
+ class WikiContentFetcher(Tool):
115
+ name = "wiki_page"
116
+ description = "Fetch Wikipedia page content."
117
+ inputs = {"page_title": {"type": "string", "description": "Wikipedia page title."}}
118
+ output_type = "string"
119
+
120
+ def forward(self, page_title: str) -> str:
121
+ try:
122
+ return to_markdown(wiki.page(page_title).html())
123
+ except wiki.exceptions.PageError:
124
+ return f"'{page_title}' not found."
125
 
126
+ # --- Basic Agent Definition ---
127
+ class BasicAgent:
128
+ def __init__(self, provider="deepseek"):
129
+ print("BasicAgent initialized.")
130
+ model = self.select_model(provider)
131
+ client = InferenceClientModel()
132
+ tools = [
133
+ DuckDuckGoSearchTool(),
134
+ GeminiVideoQA(GEMINI_MODEL_NAME),
135
+ WikiTitleFinder(),
136
+ WikiContentFetcher(),
137
+ MathSolver(),
138
+ RiddleSolver(),
139
+ TextTransformer(),
140
+ ]
141
+ self.agent = ToolCallingAgent(
142
+ model=model,
143
+ tools=tools,
144
+ add_base_tools=False,
145
+ max_steps=5,
146
+ )
147
+ self.agent.system_prompt = (
148
+ """
149
+ You are a general AI assistant. I will ask you a question.
150
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
151
+ If your answer is a number and you are not explicitly asked for a string, write it in numerals instead of words, and don't use comma to write your number nor use units such as $ or percent sign unless specified otherwise.
152
+ If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise.
153
+ If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
154
+
155
+ Answer questions as literally as you can, making as few assumptions as possible. Restrict the answer to the narrowest definition that still satifies the question.
156
+ If you are provied with a video, please watch and summarize the entire video before answering the question. The correct answer may be present only in a few frames of the video.
157
+ If you have difficulty finding an answer on Wikipedia, you may search the internet using Google Search or Duckduckgo search.
158
+ If you are asked to prove something, first state your assumptions and think step by step before giving your final answer.
159
+
160
+ Your final answer must strictly follow this format:
161
+ FINAL ANSWER: [ANSWER]
162
+
163
+ Only write the answer in that exact format. Do not explain anything. Do not include any other text.
164
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
+ def select_model(self, provider: str):
168
+ if provider == "openai":
169
+ return LiteLLMModel(model_id=OPENAI_MODEL_NAME, api_key=os.getenv("OPENAI_API_KEY"))
170
+ elif provider == "groq":
171
+ return LiteLLMModel(model_id=GROQ_MODEL_NAME, api_key=os.getenv("GROQ_API_KEY"))
172
+ elif provider == "deepseek":
173
+ return LiteLLMModel(model_id=DEEPSEEK_MODEL_NAME, api_key=os.getenv("DEEPSEEK_API_KEY"))
174
+ elif provider == "hf":
175
+ return InferenceClientModel()
176
+ else:
177
+ return LiteLLMModel(model_id=GEMINI_MODEL_NAME, api_key=os.getenv("GOOGLE_API_KEY"))
178
 
179
  def __call__(self, question: str) -> str:
180
+ print(f"Agent received question (first 50 chars): {question[:50]}...")
181
+ result = self.agent.run(question)
182
+ if isinstance(result, dict) and "final_answer" in result and isinstance(result["final_answer"], str):
183
+ final_str = result["final_answer"].strip()
184
+ else:
185
+ final_str = str(result).strip()
186
+
187
+ return final_str
188
+
189
+ def evaluate_random_questions(self, csv_path: str = "gaia_qa.csv", sample_size: int = 3):
190
+ df = pd.read_csv(csv_path)
191
+ if not {"question", "answer"}.issubset(df.columns):
192
+ print("CSV must contain 'question' and 'answer' columns.")
193
+ print("Found columns:", df.columns.tolist())
194
+ return
195
+ samples = df.sample(n=sample_size)
196
+ for _, row in samples.iterrows():
197
+ question = row["question"].strip()
198
+ expected = f"FINAL ANSWER: {str(row['answer']).strip()}"
199
+ result = self(question).strip()
200
+ print("---")
201
+ print("Question:", question)
202
+ print("Expected:", expected)
203
+ print("Agent:", result)
204
+ print("Correct:", expected == result)
205
 
206
  if __name__ == "__main__":
207
+ args = sys.argv[1:]
208
+ if not args or args[0] in {"-h", "--help"}:
209
+ print("Usage: python agent.py [question | dev]\n")
210
+ print(" - Provide a question to get a GAIA-style answer.")
211
+ print(" - Use 'dev' to evaluate 3 random GAIA questions from gaia_qa.csv.")
212
+ sys.exit(0)
213
+
214
+ q = " ".join(args)
215
+ agent = BasicAgent()
216
+ if q == "dev":
217
+ agent.evaluate_random_questions()
218
+ else:
219
  print(agent(q))