kylea commited on
Commit
afb4047
·
1 Parent(s): c42f7a4

added tools for downloading, files, wikipedia search

Browse files
app.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import gradio as gr
3
  import requests
4
  import inspect
@@ -21,7 +22,7 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
21
  """
22
  # --- Determine HF Space Runtime URL and Repo URL ---
23
  space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
24
-
25
  if profile:
26
  username= f"{profile.username}"
27
  print(f"User logged in: {username}")
@@ -71,20 +72,40 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
71
  for item in questions_data:
72
  task_id = item.get("task_id")
73
  question_text = item.get("question")
 
74
  if not task_id or question_text is None:
75
  print(f"Skipping item with missing task_id or question: {item}")
76
  continue
77
  try:
78
  print(f"Running agent on task {task_id}: {question_text}")
79
- submitted_answer = agent.graph.invoke({"messages": [HumanMessage(content=question_text)]})
80
- print(submitted_answer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
82
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
83
  except Exception as e:
84
  print(f"Error running agent on task {task_id}: {e}")
85
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
86
 
87
- break
88
 
89
  if not answers_payload:
90
  print("Agent did not produce any answers to submit.")
@@ -94,52 +115,52 @@ def run_and_submit_all( profile: gr.OAuthProfile | None):
94
  print(f"Results log: {results_log}")
95
 
96
  # 4. Prepare Submission
97
- # submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
98
- # status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
99
- # print(status_update)
100
-
101
- # # 5. Submit
102
- # print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
103
- # try:
104
- # response = requests.post(submit_url, json=submission_data, timeout=60)
105
- # response.raise_for_status()
106
- # result_data = response.json()
107
- # final_status = (
108
- # f"Submission Successful!\n"
109
- # f"User: {result_data.get('username')}\n"
110
- # f"Overall Score: {result_data.get('score', 'N/A')}% "
111
- # f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
112
- # f"Message: {result_data.get('message', 'No message received.')}"
113
- # )
114
- # print("Submission successful.")
115
- # results_df = pd.DataFrame(results_log)
116
- # return final_status, results_df
117
- # except requests.exceptions.HTTPError as e:
118
- # error_detail = f"Server responded with status {e.response.status_code}."
119
- # try:
120
- # error_json = e.response.json()
121
- # error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
122
- # except requests.exceptions.JSONDecodeError:
123
- # error_detail += f" Response: {e.response.text[:500]}"
124
- # status_message = f"Submission Failed: {error_detail}"
125
- # print(status_message)
126
- # results_df = pd.DataFrame(results_log)
127
- # return status_message, results_df
128
- # except requests.exceptions.Timeout:
129
- # status_message = "Submission Failed: The request timed out."
130
- # print(status_message)
131
- # results_df = pd.DataFrame(results_log)
132
- # return status_message, results_df
133
- # except requests.exceptions.RequestException as e:
134
- # status_message = f"Submission Failed: Network error - {e}"
135
- # print(status_message)
136
- # results_df = pd.DataFrame(results_log)
137
- # return status_message, results_df
138
- # except Exception as e:
139
- # status_message = f"An unexpected error occurred during submission: {e}"
140
- # print(status_message)
141
- # results_df = pd.DataFrame(results_log)
142
- # return status_message, results_df
143
 
144
 
145
  # --- Build Gradio Interface using Blocks ---
 
1
  import os
2
+ import time
3
  import gradio as gr
4
  import requests
5
  import inspect
 
22
  """
23
  # --- Determine HF Space Runtime URL and Repo URL ---
24
  space_id = os.getenv("SPACE_ID") # Get the SPACE_ID for sending link to the code
25
+ space_id = space_id or "kylea/GAIA_Agent_Space"
26
  if profile:
27
  username= f"{profile.username}"
28
  print(f"User logged in: {username}")
 
72
  for item in questions_data:
73
  task_id = item.get("task_id")
74
  question_text = item.get("question")
75
+ file_name = item.get("file_name")
76
  if not task_id or question_text is None:
77
  print(f"Skipping item with missing task_id or question: {item}")
78
  continue
79
  try:
80
  print(f"Running agent on task {task_id}: {question_text}")
81
+ # Retry 3 times with 45 seconds delay between attempts
82
+ for attempt in range(3):
83
+ try:
84
+ # Call the agent's graph with the question text
85
+ submitted_answer = agent.graph.invoke(
86
+ {
87
+ "messages": [HumanMessage(content=question_text)],
88
+ "task_id": task_id,
89
+ "file_name": file_name,
90
+ }
91
+ )
92
+ break # Break if successful
93
+ except Exception as e:
94
+ print(f"Attempt {attempt + 1} failed: {e}")
95
+ if attempt < 2:
96
+ print("Retrying...")
97
+ time.sleep(45) # Exponential backoff
98
+ submitted_answer = submitted_answer['messages'][-1].content
99
+ if "FINAL ANSWER:" in submitted_answer:
100
+ submitted_answer = submitted_answer.split("FINAL ANSWER:")[-1].strip()
101
+ print(f"Agent submitted answer for task {task_id}: {submitted_answer}")
102
  answers_payload.append({"task_id": task_id, "submitted_answer": submitted_answer})
103
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": submitted_answer})
104
  except Exception as e:
105
  print(f"Error running agent on task {task_id}: {e}")
106
  results_log.append({"Task ID": task_id, "Question": question_text, "Submitted Answer": f"AGENT ERROR: {e}"})
107
 
108
+ # break
109
 
110
  if not answers_payload:
111
  print("Agent did not produce any answers to submit.")
 
115
  print(f"Results log: {results_log}")
116
 
117
  # 4. Prepare Submission
118
+ submission_data = {"username": username.strip(), "agent_code": agent_code, "answers": answers_payload}
119
+ status_update = f"Agent finished. Submitting {len(answers_payload)} answers for user '{username}'..."
120
+ print(status_update)
121
+
122
+ # 5. Submit
123
+ print(f"Submitting {len(answers_payload)} answers to: {submit_url}")
124
+ try:
125
+ response = requests.post(submit_url, json=submission_data, timeout=60)
126
+ response.raise_for_status()
127
+ result_data = response.json()
128
+ final_status = (
129
+ f"Submission Successful!\n"
130
+ f"User: {result_data.get('username')}\n"
131
+ f"Overall Score: {result_data.get('score', 'N/A')}% "
132
+ f"({result_data.get('correct_count', '?')}/{result_data.get('total_attempted', '?')} correct)\n"
133
+ f"Message: {result_data.get('message', 'No message received.')}"
134
+ )
135
+ print("Submission successful.")
136
+ results_df = pd.DataFrame(results_log)
137
+ return final_status, results_df
138
+ except requests.exceptions.HTTPError as e:
139
+ error_detail = f"Server responded with status {e.response.status_code}."
140
+ try:
141
+ error_json = e.response.json()
142
+ error_detail += f" Detail: {error_json.get('detail', e.response.text)}"
143
+ except requests.exceptions.JSONDecodeError:
144
+ error_detail += f" Response: {e.response.text[:500]}"
145
+ status_message = f"Submission Failed: {error_detail}"
146
+ print(status_message)
147
+ results_df = pd.DataFrame(results_log)
148
+ return status_message, results_df
149
+ except requests.exceptions.Timeout:
150
+ status_message = "Submission Failed: The request timed out."
151
+ print(status_message)
152
+ results_df = pd.DataFrame(results_log)
153
+ return status_message, results_df
154
+ except requests.exceptions.RequestException as e:
155
+ status_message = f"Submission Failed: Network error - {e}"
156
+ print(status_message)
157
+ results_df = pd.DataFrame(results_log)
158
+ return status_message, results_df
159
+ except Exception as e:
160
+ status_message = f"An unexpected error occurred during submission: {e}"
161
+ print(status_message)
162
+ results_df = pd.DataFrame(results_log)
163
+ return status_message, results_df
164
 
165
 
166
  # --- Build Gradio Interface using Blocks ---
downloads/.gitkeep ADDED
File without changes
requirements.txt CHANGED
@@ -1,7 +1,9 @@
1
  gradio
2
  requests
3
  python-dotenv
 
4
  langchain
5
  langchain-google-genai
6
  langchain-tavily
7
- langgraph
 
 
1
  gradio
2
  requests
3
  python-dotenv
4
+ pandas
5
  langchain
6
  langchain-google-genai
7
  langchain-tavily
8
+ langgraph
9
+ openai-whisper
src/custom_tools/__init__.py ADDED
File without changes
src/custom_tools/downloads.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import Annotated
3
+
4
+ from langchain_core.tools import tool
5
+ from langgraph.prebuilt import InjectedState
6
+
7
+ from src.state import State
8
+
9
+ DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
+
11
+ @tool
12
+ def download_file(
13
+ task_id: str,
14
+ state: Annotated[State, InjectedState]
15
+ ) -> str:
16
+ """Download a file specified by using the task id."""
17
+ file_name = state.file_name
18
+ task_id = state.task_id
19
+
20
+ if not file_name:
21
+ return "No file name in input, unable to download."
22
+ if not task_id:
23
+ return "No task id in input, unable to download."
24
+
25
+ base_url = DEFAULT_API_URL + "/files"
26
+ url = f"{base_url}/{task_id}" if task_id else None
27
+ if not url:
28
+ return "No URL provided."
29
+
30
+ try:
31
+ response = requests.get(url, stream=True)
32
+ response.raise_for_status() # Raise an error for bad responses
33
+
34
+ local_file_path = f"downloads/{file_name}"
35
+ with open(local_file_path, "wb") as f:
36
+ for chunk in response.iter_content(chunk_size=8192):
37
+ f.write(chunk)
38
+
39
+ return f"File downloaded successfully: {local_file_path}"
40
+ except requests.exceptions.RequestException as e:
41
+ return f"Error downloading file: {e}"
src/custom_tools/files.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from langchain.tools import tool
4
+
5
+ @tool
6
+ def read_python(file_path):
7
+ """
8
+ Reads a Python file and returns its content as a string.
9
+
10
+ Args:
11
+ file_path (str): The path to the Python file.
12
+
13
+ Returns:
14
+ str: The content of the Python file.
15
+ """
16
+ try:
17
+ if not os.path.exists(file_path):
18
+ return f"Error: File not found at {file_path}"
19
+ with open(file_path, "r", encoding="utf-8") as file:
20
+ content = file.read()
21
+ return content
22
+ except Exception as e:
23
+ return f"Error reading Python file: {str(e)}"
24
+
25
+ @tool
26
+ def read_excel(file_path):
27
+ """
28
+ Reads an Excel file and returns its content as a string.
29
+
30
+ Args:
31
+ file_path (str): The path to the Excel file.
32
+
33
+ Returns:
34
+ str: The content of the Excel file.
35
+ """
36
+ if not file_path.endswith(('.xls', '.xlsx')):
37
+ return "Error: File is not an Excel file."
38
+ try:
39
+ if not os.path.exists(file_path):
40
+ return f"Error: File not found at {file_path}"
41
+ import pandas as pd
42
+ df = pd.read_excel(file_path)
43
+ return df.to_string()
44
+ except Exception as e:
45
+ return f"Error reading Excel file: {str(e)}"
46
+
47
+
48
+ @tool
49
+ def transcribe_audio(file_path):
50
+ """
51
+ Transcribes an audio file and returns its content as a string.
52
+
53
+ Args:
54
+ file_path (str): The path to the audio file.
55
+
56
+ Returns:
57
+ str: The transcribed text from the audio file.
58
+ """
59
+ if not file_path.endswith(('.wav', '.mp3', '.m4a')):
60
+ return "Error: File is not an audio file."
61
+ try:
62
+ if not os.path.exists(file_path):
63
+ return f"Error: File not found at {file_path}"
64
+
65
+ import whisper
66
+ model = whisper.load_model("base")
67
+ result = model.transcribe(file_path, language="en")
68
+ text = result["text"]
69
+ return text
70
+ except Exception as e:
71
+ return f"Error transcribing audio file: {str(e)}"
src/custom_tools/wikipedia.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing_extensions import Any, Annotated
3
+
4
+ from langchain_core.tools import tool
5
+ from langchain_core.messages import ToolMessage
6
+ from langgraph.types import Command
7
+ from langgraph.prebuilt import InjectedState
8
+ from langchain_core.tools.base import InjectedToolCallId
9
+
10
+ from src.state import State
11
+
12
+ @tool
13
+ def get_wiki_page_sections(
14
+ page_title: str,
15
+ tool_call_id: Annotated[str, InjectedToolCallId]
16
+ ) -> Command:
17
+ """Get sections of a Wikipedia page.
18
+
19
+ This function retrieves the sections of a Wikipedia page.
20
+ It requires the page title as an input parameter.
21
+ """
22
+ page_title = page_title.replace(" ", "_")
23
+ payload = {
24
+ "action": "parse",
25
+ "page": page_title,
26
+ "prop": "sections",
27
+ "format": "json",
28
+ }
29
+
30
+ response = requests.get(
31
+ "https://en.wikipedia.org/w/api.php",
32
+ params=payload
33
+ )
34
+
35
+ if not response.status_code == 200:
36
+ return (f"Error fetching sections for {page_title}: {response.test}")
37
+
38
+ data = response.json()
39
+ sections = data.get("parse", {}).get("sections", [])
40
+ sections_map = {}
41
+ for section in sections:
42
+ section_title = section.get("anchor").lower()
43
+ section_number = section.get("index")
44
+ if section_title and section_number:
45
+ sections_map[section_title] = section_number
46
+
47
+ sections_text = "The sections of the page are:\n"
48
+ for title in sections_map.keys():
49
+ sections_text += f"{title}\n"
50
+
51
+ return Command(
52
+ update={
53
+ # update the state keys
54
+ "wiki_sections": sections_map,
55
+ # update the message history
56
+ "messages": [
57
+ ToolMessage(
58
+ sections_text, tool_call_id=tool_call_id
59
+ )
60
+ ],
61
+ }
62
+ )
63
+
64
+ @tool
65
+ def get_wiki_page_by_section(
66
+ page_title: str,
67
+ section: str,
68
+ state: Annotated[State, InjectedState]
69
+ ) -> str:
70
+ """Get sections of a Wikipedia page.
71
+
72
+ This function retrieves the content of a specific section from a Wikipedia page.
73
+ It requires the page title and the section name as input parameters.
74
+ """
75
+ wiki_sections = state.wiki_sections
76
+ if not wiki_sections:
77
+ return (f"Error: No sections found for {page_title}. Please run get_page_sections first.")
78
+
79
+ page_title = page_title.replace(" ", "_")
80
+ section = section.replace(" ", "_").lower()
81
+
82
+ if section not in wiki_sections:
83
+ return (f"Error: Section '{section}' not found in {page_title}. Please run get_page_sections first.")
84
+
85
+ payload = {
86
+ "action": "parse",
87
+ "page": page_title,
88
+ "prop": "wikitext",
89
+ "section": wiki_sections[section],
90
+ "format": "json",
91
+ }
92
+
93
+ response = requests.get(
94
+ "https://en.wikipedia.org/w/api.php",
95
+ params=payload
96
+ )
97
+
98
+ if not response.status_code == 200:
99
+ return (f"Error fetching sections for {page_title}: {response.test}")
100
+
101
+ data = response.json()
102
+
103
+ return data.get("parse", {}).get("wikitext", "No content found.")
104
+
105
+
106
+
src/gaia_agent.py CHANGED
@@ -31,7 +31,7 @@ class GaiaAgent:
31
  )
32
  builder.add_edge("tools", "call_model")
33
 
34
- graph = builder.compile(name="GAIA Agent", debug=True)
35
 
36
  return graph
37
 
@@ -59,15 +59,25 @@ class GaiaAgent:
59
  # Format the system prompt. Customize this to change the agent's behavior.
60
  system_message = configuration.system_prompt
61
 
 
 
 
 
 
 
 
 
62
  # Get the model's response
63
  response = cast(
64
  AIMessage,
65
  model.llm.invoke(
66
- [{"role": "system", "content": system_message}, *state.messages]
 
 
 
 
67
  ),
68
  )
69
-
70
- print(response.tool_calls)
71
 
72
  # Handle the case when it's the last step and the model still wants to use a tool
73
  if state.is_last_step and response.tool_calls:
 
31
  )
32
  builder.add_edge("tools", "call_model")
33
 
34
+ graph = builder.compile(name="GAIA Agent", debug=False)
35
 
36
  return graph
37
 
 
59
  # Format the system prompt. Customize this to change the agent's behavior.
60
  system_message = configuration.system_prompt
61
 
62
+ if state.file_name:
63
+ file_prompt = (
64
+ f"\n\nThe task id is {state.task_id}.\n"
65
+ f"Please use this to download the file."
66
+ )
67
+
68
+ system_message += file_prompt
69
+
70
  # Get the model's response
71
  response = cast(
72
  AIMessage,
73
  model.llm.invoke(
74
+ [
75
+ {"role": "system", "content": system_message},
76
+ *state.messages,
77
+
78
+ ]
79
  ),
80
  )
 
 
81
 
82
  # Handle the case when it's the last step and the model still wants to use a tool
83
  if state.is_last_step and response.tool_calls:
src/prompts.py CHANGED
@@ -1,7 +1,23 @@
1
  SYSTEM_PROMPT = (
2
  "You are a helpful AI assistant.\n"
3
- "Please answer the question to the best of your ability.\n"
4
- "Use the tools provided to you to find the answer.\n"
5
- "Do not ask for permission to use the tools.\n"
6
- "If you think you should use a tool, do so.\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  )
 
1
  SYSTEM_PROMPT = (
2
  "You are a helpful AI assistant.\n"
3
+ "Please answer the question to the best of your ability. "
4
+ "Use the tools provided to you to find the answer. "
5
+ "Do not ask for permission to use the tools. "
6
+ "If you think you should use a tool, do so. "
7
+ "If the user specifies a file to use, use the "
8
+ "download tool to download the file and then use it. "
9
+ "If you get a file not found error, please try to download the file. "
10
+ "Include adjectives in your answer if the user asks shows this in the example. "
11
+ "Finish your answer with the following template: "
12
+ "FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER "
13
+ "should be a number OR as few words as possible OR a "
14
+ "comma separated list of numbers and/or strings. If "
15
+ "you are asked for a number, don't use comma to write "
16
+ "your number neither use units such as $ or percent sign "
17
+ "unless specified otherwise. If you are asked for a string, "
18
+ "don't use articles, neither abbreviations (e.g. for cities), "
19
+ "and write the digits in plain text unless specified otherwise. "
20
+ "If you are asked for a comma separated list, apply the above rules "
21
+ "depending of whether the element to be put in the list is a number "
22
+ "or a string."
23
  )
src/state.py CHANGED
@@ -21,6 +21,8 @@ class InputState:
21
  messages: Annotated[Sequence[AnyMessage], add_messages] = field(
22
  default_factory=list
23
  )
 
 
24
  """
25
  Messages tracking the primary execution state of the agent.
26
 
@@ -46,6 +48,7 @@ class State(InputState):
46
  """
47
 
48
  is_last_step: IsLastStep = field(default=False)
 
49
  """
50
  Indicates whether the current step is the last one before the graph raises an error.
51
 
 
21
  messages: Annotated[Sequence[AnyMessage], add_messages] = field(
22
  default_factory=list
23
  )
24
+ task_id: str = field(default="")
25
+ file_name: str = field(default=None)
26
  """
27
  Messages tracking the primary execution state of the agent.
28
 
 
48
  """
49
 
50
  is_last_step: IsLastStep = field(default=False)
51
+ wiki_sections: dict[str, int] = field(default_factory=dict)
52
  """
53
  Indicates whether the current step is the last one before the graph raises an error.
54
 
src/tools.py CHANGED
@@ -3,6 +3,7 @@ from typing import Any, Callable, List, Optional, cast
3
  from langchain_tavily import TavilySearch # type: ignore[import-not-found]
4
 
5
  from src.config import Configuration
 
6
 
7
  def search(query: str) -> Optional[dict[str, Any]]:
8
  """Search for general web results.
@@ -15,5 +16,12 @@ def search(query: str) -> Optional[dict[str, Any]]:
15
  wrapped = TavilySearch(max_results=configuration.max_search_results)
16
  return cast(dict[str, Any], wrapped.invoke({"query": query}))
17
 
18
-
19
- TOOLS: List[Callable[..., Any]] = [search]
 
 
 
 
 
 
 
 
3
  from langchain_tavily import TavilySearch # type: ignore[import-not-found]
4
 
5
  from src.config import Configuration
6
+ from src.custom_tools import wikipedia, files, downloads
7
 
8
  def search(query: str) -> Optional[dict[str, Any]]:
9
  """Search for general web results.
 
16
  wrapped = TavilySearch(max_results=configuration.max_search_results)
17
  return cast(dict[str, Any], wrapped.invoke({"query": query}))
18
 
19
+ TOOLS: List[Callable[..., Any]] = [
20
+ search,
21
+ wikipedia.get_wiki_page_sections,
22
+ wikipedia.get_wiki_page_by_section,
23
+ downloads.download_file,
24
+ files.read_python,
25
+ files.read_excel,
26
+ files.transcribe_audio,
27
+ ]