hsila commited on
Commit
4b2b389
·
1 Parent(s): 344bea7

add initial agent

Browse files
Files changed (4) hide show
  1. .gitignore +5 -1
  2. agent.py +197 -0
  3. app.py +7 -5
  4. requirements.txt +11 -1
.gitignore CHANGED
@@ -1 +1,5 @@
1
- venv/
 
 
 
 
 
1
+ venv/
2
+ .env
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
agent.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ from langchain_core.messages import HumanMessage
5
+ from langchain_core.tools import tool
6
+ from langchain_google_genai import ChatGoogleGenerativeAI
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_community.document_loaders import WikipediaLoader
9
+
10
+ from langgraph.graph import StateGraph, START, END, MessagesState
11
+ from langgraph.prebuilt import ToolNode, tools_condition
12
+
13
+ load_dotenv()
14
+
15
+
16
+ @tool
17
+ def add(a: float, b: float) -> float:
18
+ """Add two numbers together.
19
+
20
+ Args:
21
+ a: First number
22
+ b: Second number
23
+ """
24
+ return a + b
25
+
26
+
27
+ @tool
28
+ def subtract(a: float, b: float) -> float:
29
+ """Subtract b from a.
30
+
31
+ Args:
32
+ a: Number to subtract from
33
+ b: Number to subtract
34
+ """
35
+ return a - b
36
+
37
+
38
+ @tool
39
+ def multiply(a: float, b: float) -> float:
40
+ """Multiply two numbers together.
41
+
42
+ Args:
43
+ a: First number
44
+ b: Second number
45
+ """
46
+ return a * b
47
+
48
+
49
+ @tool
50
+ def divide(a: float, b: float) -> float:
51
+ """Divide a by b.
52
+
53
+ Args:
54
+ a: Dividend
55
+ b: Divisor
56
+ """
57
+ if b == 0:
58
+ return "Error: Division by zero"
59
+ return a / b
60
+
61
+
62
+ @tool
63
+ def modulo(a: float, b: float) -> float:
64
+ """Return the remainder of a divided by b.
65
+
66
+ Args:
67
+ a: Dividend
68
+ b: Divisor
69
+ """
70
+ if b == 0:
71
+ return "Error: Division by zero"
72
+ return a % b
73
+
74
+
75
+ @tool
76
+ def power(a: float, b: float) -> float:
77
+ """Raise a to the power of b.
78
+
79
+ Args:
80
+ a: Base number
81
+ b: Exponent
82
+ """
83
+ return a ** b
84
+
85
+
86
+ @tool
87
+ def square_root(a: float) -> float:
88
+ """Calculate the square root of a number.
89
+
90
+ Args:
91
+ a: Number to calculate square root of
92
+ """
93
+ if a < 0:
94
+ return "Error: Cannot calculate square root of negative number"
95
+ return a ** 0.5
96
+
97
+
98
+ @tool
99
+ def web_search(query: str) -> str:
100
+ """Search the web for current information and facts.
101
+
102
+ Args:
103
+ query: Search query string
104
+ """
105
+ try:
106
+ search_tool = TavilySearchResults(max_results=3)
107
+ results = search_tool.invoke(query)
108
+
109
+ if not results:
110
+ return "No search results found."
111
+
112
+ formatted_results = []
113
+ for i, result in enumerate(results, 1):
114
+ title = result.get('title', 'No title')
115
+ content = result.get('content', 'No content')
116
+ url = result.get('url', 'No URL')
117
+ formatted_results.append(f"{i}. {title}\n{content}\nSource: {url}")
118
+
119
+ return "\n\n ==== \n\n".join(formatted_results)
120
+ except Exception as e:
121
+ return f"Error performing search: {str(e)}"
122
+
123
+
124
+ @tool
125
+ def wikipedia_search(query: str) -> str:
126
+ """Search Wikipedia for factual information.
127
+
128
+ Args:
129
+ query: Wikipedia search query
130
+ """
131
+ try:
132
+ loader = WikipediaLoader(query=query, load_max_docs=2)
133
+ docs = loader.load()
134
+
135
+ if not docs:
136
+ return "No Wikipedia results found."
137
+
138
+ formatted_docs = []
139
+ for i, doc in enumerate(docs, 1):
140
+ title = doc.metadata.get('title', 'No title')
141
+ content = doc.page_content
142
+ formatted_docs.append(f"{i}. {title}\n{content}")
143
+
144
+ return "\n\n ==== \n\n".join(formatted_docs)
145
+ except Exception as e:
146
+ return f"Error searching Wikipedia: {str(e)}"
147
+
148
+
149
+ tools = [
150
+ add, subtract, multiply, divide, modulo, power, square_root,
151
+ web_search, wikipedia_search
152
+ ]
153
+
154
+
155
+ def get_llm():
156
+ """Initialize the llm"""
157
+ return ChatGoogleGenerativeAI(
158
+ model="gemini-2.5-pro",
159
+ temperature=0,
160
+ api_key=os.getenv("GEMINI_API_KEY")
161
+ )
162
+
163
+
164
+ def call_model(state: MessagesState):
165
+ """Call the LLM with the current state.
166
+
167
+ Args:
168
+ state: Current state containing messages
169
+ """
170
+ llm = get_llm()
171
+ llm_with_tools = llm.bind_tools(tools)
172
+
173
+ messages = state['messages']
174
+ response = llm_with_tools.invoke(messages)
175
+
176
+ return {"messages": [response]}
177
+
178
+
179
+ def build_graph():
180
+ """Build and return the LangGraph workflow."""
181
+ workflow = StateGraph(MessagesState)
182
+
183
+ workflow.add_node("agent", call_model)
184
+ workflow.add_node("tools", ToolNode(tools))
185
+
186
+ workflow.add_edge(START, "agent")
187
+ workflow.add_conditional_edges("agent", tools_condition)
188
+ workflow.add_edge("tools", "agent")
189
+
190
+ return workflow.compile()
191
+
192
+
193
+ if __name__ == "__main__":
194
+ graph = build_graph()
195
+ test_message = [HumanMessage(content="What is 15 + 27?")]
196
+ result = graph.invoke({"messages": test_message})
197
+ print(f"Test result: {result['messages'][-1].content}")
app.py CHANGED
@@ -3,21 +3,23 @@ import gradio as gr
3
  import requests
4
  import inspect
5
  import pandas as pd
 
6
 
7
  # (Keep Constants as is)
8
  # --- Constants ---
9
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
10
 
11
  # --- Basic Agent Definition ---
12
- # ----- THIS IS WERE YOU CAN BUILD WHAT YOU WANT ------
 
13
  class BasicAgent:
14
  def __init__(self):
15
- print("BasicAgent initialized.")
16
  def __call__(self, question: str) -> str:
17
  print(f"Agent received question (first 50 chars): {question[:50]}...")
18
- fixed_answer = "This is a default answer."
19
- print(f"Agent returning fixed answer: {fixed_answer}")
20
- return fixed_answer
21
 
22
  def run_and_submit_all( profile: gr.OAuthProfile | None):
23
  """
 
3
  import requests
4
  import inspect
5
  import pandas as pd
6
+ from langchain_core.messages import HumanMessage
7
 
8
  # (Keep Constants as is)
9
  # --- Constants ---
10
  DEFAULT_API_URL = "https://agents-course-unit4-scoring.hf.space"
11
 
12
  # --- Basic Agent Definition ---
13
+ from agent import build_graph
14
+
15
  class BasicAgent:
16
  def __init__(self):
17
+ self.graph = build_graph()
18
  def __call__(self, question: str) -> str:
19
  print(f"Agent received question (first 50 chars): {question[:50]}...")
20
+ messages = [HumanMessage(content=question)]
21
+ result = self.graph.invoke({"messages": messages})
22
+ return result['messages'][-1].content
23
 
24
  def run_and_submit_all( profile: gr.OAuthProfile | None):
25
  """
requirements.txt CHANGED
@@ -1,2 +1,12 @@
1
  gradio
2
- requests
 
 
 
 
 
 
 
 
 
 
 
1
  gradio
2
+ requests
3
+ langgraph
4
+ langchain
5
+ langchain-community
6
+ langchain-tavily
7
+ langchain-core
8
+ langchain-google-genai
9
+ wikipedia
10
+ google-genai
11
+ python-dotenv
12
+ pandas