Commit
·
80e3b5e
1
Parent(s):
f3fa776
updated code base
Browse files
agent.py
CHANGED
@@ -1,19 +1,46 @@
|
|
1 |
import os
|
|
|
2 |
from langchain_groq import ChatGroq
|
3 |
from langchain.prompts import PromptTemplate
|
4 |
from langgraph.graph import START, StateGraph, MessagesState
|
5 |
from langgraph.prebuilt import ToolNode, tools_condition
|
6 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
7 |
from langchain_community.document_loaders import WikipediaLoader
|
8 |
-
from langchain_core.messages import HumanMessage
|
9 |
from langchain.tools import tool
|
10 |
from langchain_core.prompts import ChatPromptTemplate
|
11 |
from langchain_core.runnables import Runnable
|
12 |
from dotenv import load_dotenv
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
# Load environment variables from .env
|
15 |
load_dotenv()
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# Initialize LLM
|
18 |
def initialize_llm():
|
19 |
"""Initializes the ChatGroq LLM."""
|
@@ -247,103 +274,145 @@ def standard_deviation(numbers: list) -> float:
|
|
247 |
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
|
248 |
return variance ** 0.5
|
249 |
|
250 |
-
#
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
"""
|
253 |
-
|
|
|
|
|
|
|
254 |
|
255 |
Returns:
|
256 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
257 |
"""
|
258 |
-
llm = initialize_llm()
|
259 |
-
search_tool = initialize_search_tool()
|
260 |
-
recommendation_chain = initialize_recommendation_chain(llm)
|
261 |
-
|
262 |
-
@tool
|
263 |
-
def weather_tool(location: str) -> str:
|
264 |
-
"""
|
265 |
-
Fetches the weather for a location.
|
266 |
-
|
267 |
-
Args:
|
268 |
-
location (str): The location to fetch weather for.
|
269 |
-
|
270 |
-
Returns:
|
271 |
-
str: The weather information.
|
272 |
-
"""
|
273 |
-
return get_weather(location, search_tool)
|
274 |
-
|
275 |
-
@tool
|
276 |
-
def web_search(query: str) -> str:
|
277 |
-
"""Search the web for a given query and return the summary.
|
278 |
-
Args:
|
279 |
-
query (str): The search query.
|
280 |
-
"""
|
281 |
-
|
282 |
-
search_tool = TavilySearchResults()
|
283 |
-
result = search_tool.run(query)
|
284 |
-
return result[0]['content']
|
285 |
-
|
286 |
-
@tool
|
287 |
-
def wiki_search(query : str) -> str:
|
288 |
-
"""Search Wikipedia for a given query and return the summary.
|
289 |
-
Args:
|
290 |
-
query (str): The search query.
|
291 |
-
"""
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
301 |
|
302 |
-
|
303 |
-
|
304 |
-
|
305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
-
|
308 |
-
|
|
|
|
|
|
|
309 |
|
310 |
-
|
311 |
-
str: The recommendations.
|
312 |
-
"""
|
313 |
-
return get_recommendation(weather_condition, recommendation_chain)
|
314 |
|
315 |
-
|
316 |
-
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
317 |
|
318 |
-
|
|
|
|
|
|
|
319 |
|
320 |
-
|
321 |
-
|
322 |
-
|
|
|
|
|
|
|
|
|
323 |
|
324 |
-
|
325 |
-
|
|
|
326 |
|
327 |
-
|
328 |
-
dict: The next state of the conversation.
|
329 |
-
"""
|
330 |
-
print("Entering assistant node...")
|
331 |
-
response = llm_with_tools.invoke(state["messages"])
|
332 |
-
print(f"Assistant says: {response.content}")
|
333 |
-
return {"messages": [response]}
|
334 |
|
|
|
335 |
builder = StateGraph(MessagesState)
|
|
|
336 |
builder.add_node("assistant", assistant)
|
337 |
-
builder.add_node("tools", ToolNode(tools))
|
338 |
-
|
|
|
|
|
339 |
builder.add_conditional_edges("assistant", tools_condition)
|
340 |
builder.add_edge("tools", "assistant")
|
|
|
341 |
return builder.compile()
|
342 |
|
|
|
343 |
if __name__ == "__main__":
|
|
|
344 |
graph = build_graph()
|
345 |
-
question = "How many albums were pulished by Mercedes Sosa?"
|
346 |
messages = [HumanMessage(content=question)]
|
347 |
result = graph.invoke({"messages": messages})
|
|
|
348 |
for msg in result["messages"]:
|
349 |
-
msg.pretty_print()
|
|
|
1 |
import os
|
2 |
+
import shutil
|
3 |
from langchain_groq import ChatGroq
|
4 |
from langchain.prompts import PromptTemplate
|
5 |
from langgraph.graph import START, StateGraph, MessagesState
|
6 |
from langgraph.prebuilt import ToolNode, tools_condition
|
7 |
from langchain_community.tools.tavily_search import TavilySearchResults
|
8 |
from langchain_community.document_loaders import WikipediaLoader
|
9 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
10 |
from langchain.tools import tool
|
11 |
from langchain_core.prompts import ChatPromptTemplate
|
12 |
from langchain_core.runnables import Runnable
|
13 |
from dotenv import load_dotenv
|
14 |
+
from langchain.vectorstores import Chroma
|
15 |
+
from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
|
16 |
+
from langchain.text_splitter import CharacterTextSplitter
|
17 |
+
from langchain.tools.retriever import create_retriever_tool
|
18 |
+
from typing import TypedDict, Annotated, List
|
19 |
|
20 |
# Load environment variables from .env
|
21 |
load_dotenv()
|
22 |
|
23 |
+
# Custom Agent Prompt Template
|
24 |
+
Agent_prompt_template = '''You are a helpful assistant following the REACT methodology and tasked with answering questions using a set of tools.
|
25 |
+
Once a question is asked,you have to Report your thoughts, and finish your answer with the following template:
|
26 |
+
FINAL ANSWER: [YOUR FINAL ANSWER].
|
27 |
+
|
28 |
+
### **Instructions:**
|
29 |
+
- YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
|
30 |
+
- If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
|
31 |
+
- If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
|
32 |
+
- Provide the answer in clear and professional language.
|
33 |
+
- **Limit** the FINAL ANSWER to 1000 words and keep it to the point.
|
34 |
+
- **Do not** include any additional information or explanations in your FINAL ANSWER.
|
35 |
+
- **Do not** include any information that is not relevant to the question.
|
36 |
+
- **Validate** your answer before providing it.
|
37 |
+
- **Print** all the steps you take to arrive at your answer.
|
38 |
+
|
39 |
+
Your answer should only start with "FINAL ANSWER: ", then follows with the answer. '''
|
40 |
+
|
41 |
+
sys_msg = SystemMessage(content=Agent_prompt_template)
|
42 |
+
|
43 |
+
|
44 |
# Initialize LLM
|
45 |
def initialize_llm():
|
46 |
"""Initializes the ChatGroq LLM."""
|
|
|
274 |
variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
|
275 |
return variance ** 0.5
|
276 |
|
277 |
+
# --- Vector Store + Retriever ---
|
278 |
+
# State schema
|
279 |
+
class MessagesState(TypedDict):
|
280 |
+
messages: Annotated[List[HumanMessage], "Messages in the conversation"]
|
281 |
+
|
282 |
+
# === VECTOR STORE SETUP ===
|
283 |
+
PERSIST_DIR = "./chroma_store"
|
284 |
+
|
285 |
+
def initialize_chroma_store():
|
286 |
+
# Optional: clear existing store if desired
|
287 |
+
if os.path.exists(PERSIST_DIR):
|
288 |
+
shutil.rmtree(PERSIST_DIR)
|
289 |
+
os.makedirs(PERSIST_DIR)
|
290 |
+
|
291 |
+
# Initialize embeddings
|
292 |
+
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
|
293 |
+
|
294 |
+
# Load existing or empty vector store
|
295 |
+
vectorstore = Chroma(
|
296 |
+
embedding_function=embeddings,
|
297 |
+
persist_directory=PERSIST_DIR
|
298 |
+
)
|
299 |
+
return vectorstore
|
300 |
+
|
301 |
+
vector_store = initialize_chroma_store()
|
302 |
+
|
303 |
+
# Create retriever tool
|
304 |
+
retriever_tool = create_retriever_tool(
|
305 |
+
retriever=vector_store.as_retriever(),
|
306 |
+
name="Question Search",
|
307 |
+
description="A tool to retrieve similar questions from a vector store."
|
308 |
+
)
|
309 |
+
|
310 |
+
|
311 |
+
|
312 |
+
@tool
|
313 |
+
def weather_tool(location: str) -> str:
|
314 |
"""
|
315 |
+
Fetches the weather for a location.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
location (str): The location to fetch weather for.
|
319 |
|
320 |
Returns:
|
321 |
+
str: The weather information.
|
322 |
+
"""
|
323 |
+
return get_weather(location, search_tool)
|
324 |
+
|
325 |
+
@tool
|
326 |
+
def web_search(query: str) -> str:
|
327 |
+
"""Search the web for a given query and return the summary.
|
328 |
+
Args:
|
329 |
+
query (str): The search query.
|
330 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
331 |
|
332 |
+
search_tool = TavilySearchResults()
|
333 |
+
result = search_tool.run(query)
|
334 |
+
return result[0]['content']
|
335 |
+
|
336 |
+
@tool
|
337 |
+
def wiki_search(query : str) -> str:
|
338 |
+
"""Search Wikipedia for a given query and return the summary.
|
339 |
+
Args:
|
340 |
+
query (str): The search query.
|
341 |
+
"""
|
342 |
|
343 |
+
search_docs = WikipediaLoader(query=query, load_max_docs=1).load()
|
344 |
+
formatted_search_docs = "\n\n----\n\n".join(
|
345 |
+
[
|
346 |
+
f'<Document Source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
|
347 |
+
for doc in search_docs
|
348 |
+
]
|
349 |
+
)
|
350 |
+
return formatted_search_docs
|
351 |
+
|
352 |
+
# @tool
|
353 |
+
# def recommendation_tool(weather_condition: str) -> str:
|
354 |
+
# """
|
355 |
+
# Provides recommendations based on weather conditions.
|
356 |
+
|
357 |
+
# Args:
|
358 |
+
# weather_condition (str): The weather condition.
|
359 |
+
|
360 |
+
# Returns:
|
361 |
+
# str: The recommendations.
|
362 |
+
# """
|
363 |
+
# return get_recommendation(weather_condition, recommendation_chain)
|
364 |
+
|
365 |
+
tools = [weather_tool, wiki_search, web_search,
|
366 |
+
add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
|
367 |
|
368 |
+
llm = ChatGroq(
|
369 |
+
temperature=0,
|
370 |
+
model_name="qwen-qwq-32b", # Updated to working model
|
371 |
+
groq_api_key=os.getenv("GROQ_API_KEY")
|
372 |
+
)
|
373 |
|
374 |
+
llm_with_tools = llm.bind_tools(tools)
|
|
|
|
|
|
|
375 |
|
376 |
+
# === NODES ===
|
|
|
377 |
|
378 |
+
def retriever(state: MessagesState):
|
379 |
+
"""Retrieve similar context and inject"""
|
380 |
+
query = state["messages"][0].content
|
381 |
+
similar_docs = vector_store.similarity_search(query)
|
382 |
|
383 |
+
if similar_docs:
|
384 |
+
ref_msg = HumanMessage(
|
385 |
+
content=f"Here is a similar question and answer for reference:\n\n{similar_docs[0].page_content}"
|
386 |
+
)
|
387 |
+
return {"messages": [sys_msg] + state["messages"] + [ref_msg]}
|
388 |
+
else:
|
389 |
+
return {"messages": [sys_msg] + state["messages"]}
|
390 |
|
391 |
+
def assistant(state: MessagesState):
|
392 |
+
"""Invoke LLM with tools"""
|
393 |
+
return {"messages": [llm_with_tools.invoke(state["messages"])]}
|
394 |
|
395 |
+
# === GRAPH BUILD ===
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
+
def build_graph():
|
398 |
builder = StateGraph(MessagesState)
|
399 |
+
builder.add_node("retriever", retriever)
|
400 |
builder.add_node("assistant", assistant)
|
401 |
+
builder.add_node("tools", ToolNode([retriever_tool] + tools))
|
402 |
+
|
403 |
+
builder.set_entry_point("retriever")
|
404 |
+
builder.add_edge("retriever", "assistant")
|
405 |
builder.add_conditional_edges("assistant", tools_condition)
|
406 |
builder.add_edge("tools", "assistant")
|
407 |
+
|
408 |
return builder.compile()
|
409 |
|
410 |
+
# === TEST ===
|
411 |
if __name__ == "__main__":
|
412 |
+
question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
|
413 |
graph = build_graph()
|
|
|
414 |
messages = [HumanMessage(content=question)]
|
415 |
result = graph.invoke({"messages": messages})
|
416 |
+
|
417 |
for msg in result["messages"]:
|
418 |
+
msg.pretty_print()
|