hsila commited on
Commit
32c39ee
·
1 Parent(s): 4b2b389

add system message, ruff format, change pro to flash

Browse files
Files changed (1) hide show
  1. agent.py +28 -17
agent.py CHANGED
@@ -1,17 +1,19 @@
1
  import os
2
  from dotenv import load_dotenv
3
 
4
- from langchain_core.messages import HumanMessage
5
  from langchain_core.tools import tool
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
 
10
- from langgraph.graph import StateGraph, START, END, MessagesState
11
  from langgraph.prebuilt import ToolNode, tools_condition
12
 
13
  load_dotenv()
14
 
 
 
15
 
16
  @tool
17
  def add(a: float, b: float) -> float:
@@ -80,7 +82,7 @@ def power(a: float, b: float) -> float:
80
  a: Base number
81
  b: Exponent
82
  """
83
- return a ** b
84
 
85
 
86
  @tool
@@ -92,7 +94,7 @@ def square_root(a: float) -> float:
92
  """
93
  if a < 0:
94
  return "Error: Cannot calculate square root of negative number"
95
- return a ** 0.5
96
 
97
 
98
  @tool
@@ -111,9 +113,9 @@ def web_search(query: str) -> str:
111
 
112
  formatted_results = []
113
  for i, result in enumerate(results, 1):
114
- title = result.get('title', 'No title')
115
- content = result.get('content', 'No content')
116
- url = result.get('url', 'No URL')
117
  formatted_results.append(f"{i}. {title}\n{content}\nSource: {url}")
118
 
119
  return "\n\n ==== \n\n".join(formatted_results)
@@ -137,7 +139,7 @@ def wikipedia_search(query: str) -> str:
137
 
138
  formatted_docs = []
139
  for i, doc in enumerate(docs, 1):
140
- title = doc.metadata.get('title', 'No title')
141
  content = doc.page_content
142
  formatted_docs.append(f"{i}. {title}\n{content}")
143
 
@@ -147,32 +149,41 @@ def wikipedia_search(query: str) -> str:
147
 
148
 
149
  tools = [
150
- add, subtract, multiply, divide, modulo, power, square_root,
151
- web_search, wikipedia_search
 
 
 
 
 
 
 
152
  ]
153
 
154
 
155
  def get_llm():
156
  """Initialize the llm"""
157
  return ChatGoogleGenerativeAI(
158
- model="gemini-2.5-pro",
159
- temperature=0,
160
- api_key=os.getenv("GEMINI_API_KEY")
161
  )
162
 
163
 
164
  def call_model(state: MessagesState):
165
  """Call the LLM with the current state.
166
-
167
  Args:
168
  state: Current state containing messages
169
  """
170
  llm = get_llm()
171
  llm_with_tools = llm.bind_tools(tools)
172
-
173
- messages = state['messages']
 
 
 
 
174
  response = llm_with_tools.invoke(messages)
175
-
176
  return {"messages": [response]}
177
 
178
 
 
1
  import os
2
  from dotenv import load_dotenv
3
 
4
+ from langchain_core.messages import HumanMessage, SystemMessage
5
  from langchain_core.tools import tool
6
  from langchain_google_genai import ChatGoogleGenerativeAI
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
 
10
+ from langgraph.graph import StateGraph, START, MessagesState
11
  from langgraph.prebuilt import ToolNode, tools_condition
12
 
13
  load_dotenv()
14
 
15
+ SYSTEM_PROMPT = """You are a general AI assistant. I will ask you a question. Report your thoughts, and output only your final answer, no prefixes, suffixes, or extra text. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string."""
16
+
17
 
18
  @tool
19
  def add(a: float, b: float) -> float:
 
82
  a: Base number
83
  b: Exponent
84
  """
85
+ return a**b
86
 
87
 
88
  @tool
 
94
  """
95
  if a < 0:
96
  return "Error: Cannot calculate square root of negative number"
97
+ return a**0.5
98
 
99
 
100
  @tool
 
113
 
114
  formatted_results = []
115
  for i, result in enumerate(results, 1):
116
+ title = result.get("title", "No title")
117
+ content = result.get("content", "No content")
118
+ url = result.get("url", "No URL")
119
  formatted_results.append(f"{i}. {title}\n{content}\nSource: {url}")
120
 
121
  return "\n\n ==== \n\n".join(formatted_results)
 
139
 
140
  formatted_docs = []
141
  for i, doc in enumerate(docs, 1):
142
+ title = doc.metadata.get("title", "No title")
143
  content = doc.page_content
144
  formatted_docs.append(f"{i}. {title}\n{content}")
145
 
 
149
 
150
 
151
  tools = [
152
+ add,
153
+ subtract,
154
+ multiply,
155
+ divide,
156
+ modulo,
157
+ power,
158
+ square_root,
159
+ web_search,
160
+ wikipedia_search,
161
  ]
162
 
163
 
164
  def get_llm():
165
  """Initialize the llm"""
166
  return ChatGoogleGenerativeAI(
167
+ model="gemini-2.5-flash", temperature=0, api_key=os.getenv("GEMINI_API_KEY")
 
 
168
  )
169
 
170
 
171
  def call_model(state: MessagesState):
172
  """Call the LLM with the current state.
173
+
174
  Args:
175
  state: Current state containing messages
176
  """
177
  llm = get_llm()
178
  llm_with_tools = llm.bind_tools(tools)
179
+
180
+ messages = state["messages"]
181
+
182
+ if not messages or not isinstance(messages[0], SystemMessage):
183
+ messages = [SystemMessage(content=SYSTEM_PROMPT)] + messages
184
+
185
  response = llm_with_tools.invoke(messages)
186
+
187
  return {"messages": [response]}
188
 
189