HarshitSundriyal commited on
Commit
80e3b5e
·
1 Parent(s): f3fa776

updated code base

Browse files
Files changed (1) hide show
  1. agent.py +145 -76
agent.py CHANGED
@@ -1,19 +1,46 @@
1
  import os
 
2
  from langchain_groq import ChatGroq
3
  from langchain.prompts import PromptTemplate
4
  from langgraph.graph import START, StateGraph, MessagesState
5
  from langgraph.prebuilt import ToolNode, tools_condition
6
  from langchain_community.tools.tavily_search import TavilySearchResults
7
  from langchain_community.document_loaders import WikipediaLoader
8
- from langchain_core.messages import HumanMessage
9
  from langchain.tools import tool
10
  from langchain_core.prompts import ChatPromptTemplate
11
  from langchain_core.runnables import Runnable
12
  from dotenv import load_dotenv
 
 
 
 
 
13
 
14
  # Load environment variables from .env
15
  load_dotenv()
16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  # Initialize LLM
18
  def initialize_llm():
19
  """Initializes the ChatGroq LLM."""
@@ -247,103 +274,145 @@ def standard_deviation(numbers: list) -> float:
247
  variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
248
  return variance ** 0.5
249
 
250
- # Build the LangGraph
251
- def build_graph():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  """
253
- Builds the LangGraph with the defined tools and assistant node.
 
 
 
254
 
255
  Returns:
256
- RunnableGraph: The compiled LangGraph.
 
 
 
 
 
 
 
 
257
  """
258
- llm = initialize_llm()
259
- search_tool = initialize_search_tool()
260
- recommendation_chain = initialize_recommendation_chain(llm)
261
-
262
- @tool
263
- def weather_tool(location: str) -> str:
264
- """
265
- Fetches the weather for a location.
266
-
267
- Args:
268
- location (str): The location to fetch weather for.
269
-
270
- Returns:
271
- str: The weather information.
272
- """
273
- return get_weather(location, search_tool)
274
-
275
- @tool
276
- def web_search(query: str) -> str:
277
- """Search the web for a given query and return the summary.
278
- Args:
279
- query (str): The search query.
280
- """
281
-
282
- search_tool = TavilySearchResults()
283
- result = search_tool.run(query)
284
- return result[0]['content']
285
-
286
- @tool
287
- def wiki_search(query : str) -> str:
288
- """Search Wikipedia for a given query and return the summary.
289
- Args:
290
- query (str): The search query.
291
- """
292
 
293
- search_docs = WikipediaLoader(query=query, load_max_docs=1).load()
294
- formatted_search_docs = "\n\n----\n\n".join(
295
- [
296
- f'<Document Source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
297
- for doc in search_docs
298
- ]
299
- )
300
- return formatted_search_docs
 
 
301
 
302
- @tool
303
- def recommendation_tool(weather_condition: str) -> str:
304
- """
305
- Provides recommendations based on weather conditions.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
306
 
307
- Args:
308
- weather_condition (str): The weather condition.
 
 
 
309
 
310
- Returns:
311
- str: The recommendations.
312
- """
313
- return get_recommendation(weather_condition, recommendation_chain)
314
 
315
- tools = [weather_tool, recommendation_tool, wiki_search, web_search,
316
- add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
317
 
318
- llm_with_tools = llm.bind_tools(tools)
 
 
 
319
 
320
- def assistant(state: MessagesState):
321
- """
322
- Assistant node in the LangGraph.
 
 
 
 
323
 
324
- Args:
325
- state (MessagesState): The current state of the conversation.
 
326
 
327
- Returns:
328
- dict: The next state of the conversation.
329
- """
330
- print("Entering assistant node...")
331
- response = llm_with_tools.invoke(state["messages"])
332
- print(f"Assistant says: {response.content}")
333
- return {"messages": [response]}
334
 
 
335
  builder = StateGraph(MessagesState)
 
336
  builder.add_node("assistant", assistant)
337
- builder.add_node("tools", ToolNode(tools))
338
- builder.set_entry_point("assistant")
 
 
339
  builder.add_conditional_edges("assistant", tools_condition)
340
  builder.add_edge("tools", "assistant")
 
341
  return builder.compile()
342
 
 
343
  if __name__ == "__main__":
 
344
  graph = build_graph()
345
- question = "How many albums were pulished by Mercedes Sosa?"
346
  messages = [HumanMessage(content=question)]
347
  result = graph.invoke({"messages": messages})
 
348
  for msg in result["messages"]:
349
- msg.pretty_print()
 
1
  import os
2
+ import shutil
3
  from langchain_groq import ChatGroq
4
  from langchain.prompts import PromptTemplate
5
  from langgraph.graph import START, StateGraph, MessagesState
6
  from langgraph.prebuilt import ToolNode, tools_condition
7
  from langchain_community.tools.tavily_search import TavilySearchResults
8
  from langchain_community.document_loaders import WikipediaLoader
9
+ from langchain_core.messages import SystemMessage, HumanMessage
10
  from langchain.tools import tool
11
  from langchain_core.prompts import ChatPromptTemplate
12
  from langchain_core.runnables import Runnable
13
  from dotenv import load_dotenv
14
+ from langchain.vectorstores import Chroma
15
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
16
+ from langchain.text_splitter import CharacterTextSplitter
17
+ from langchain.tools.retriever import create_retriever_tool
18
+ from typing import TypedDict, Annotated, List
19
 
20
  # Load environment variables from .env
21
  load_dotenv()
22
 
23
+ # Custom Agent Prompt Template
24
+ Agent_prompt_template = '''You are a helpful assistant following the REACT methodology and tasked with answering questions using a set of tools.
25
+ Once a question is asked,you have to Report your thoughts, and finish your answer with the following template:
26
+ FINAL ANSWER: [YOUR FINAL ANSWER].
27
+
28
+ ### **Instructions:**
29
+ - YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings.
30
+ - If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise.
31
+ - If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
32
+ - Provide the answer in clear and professional language.
33
+ - **Limit** the FINAL ANSWER to 1000 words and keep it to the point.
34
+ - **Do not** include any additional information or explanations in your FINAL ANSWER.
35
+ - **Do not** include any information that is not relevant to the question.
36
+ - **Validate** your answer before providing it.
37
+ - **Print** all the steps you take to arrive at your answer.
38
+
39
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer. '''
40
+
41
+ sys_msg = SystemMessage(content=Agent_prompt_template)
42
+
43
+
44
  # Initialize LLM
45
  def initialize_llm():
46
  """Initializes the ChatGroq LLM."""
 
274
  variance = sum((x - mean_value) ** 2 for x in numbers) / len(numbers)
275
  return variance ** 0.5
276
 
277
+ # --- Vector Store + Retriever ---
278
+ # State schema
279
+ class MessagesState(TypedDict):
280
+ messages: Annotated[List[HumanMessage], "Messages in the conversation"]
281
+
282
+ # === VECTOR STORE SETUP ===
283
+ PERSIST_DIR = "./chroma_store"
284
+
285
+ def initialize_chroma_store():
286
+ # Optional: clear existing store if desired
287
+ if os.path.exists(PERSIST_DIR):
288
+ shutil.rmtree(PERSIST_DIR)
289
+ os.makedirs(PERSIST_DIR)
290
+
291
+ # Initialize embeddings
292
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
293
+
294
+ # Load existing or empty vector store
295
+ vectorstore = Chroma(
296
+ embedding_function=embeddings,
297
+ persist_directory=PERSIST_DIR
298
+ )
299
+ return vectorstore
300
+
301
+ vector_store = initialize_chroma_store()
302
+
303
+ # Create retriever tool
304
+ retriever_tool = create_retriever_tool(
305
+ retriever=vector_store.as_retriever(),
306
+ name="Question Search",
307
+ description="A tool to retrieve similar questions from a vector store."
308
+ )
309
+
310
+
311
+
312
+ @tool
313
+ def weather_tool(location: str) -> str:
314
  """
315
+ Fetches the weather for a location.
316
+
317
+ Args:
318
+ location (str): The location to fetch weather for.
319
 
320
  Returns:
321
+ str: The weather information.
322
+ """
323
+ return get_weather(location, search_tool)
324
+
325
+ @tool
326
+ def web_search(query: str) -> str:
327
+ """Search the web for a given query and return the summary.
328
+ Args:
329
+ query (str): The search query.
330
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ search_tool = TavilySearchResults()
333
+ result = search_tool.run(query)
334
+ return result[0]['content']
335
+
336
+ @tool
337
+ def wiki_search(query : str) -> str:
338
+ """Search Wikipedia for a given query and return the summary.
339
+ Args:
340
+ query (str): The search query.
341
+ """
342
 
343
+ search_docs = WikipediaLoader(query=query, load_max_docs=1).load()
344
+ formatted_search_docs = "\n\n----\n\n".join(
345
+ [
346
+ f'<Document Source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}">\n{doc.page_content}\n</Document>'
347
+ for doc in search_docs
348
+ ]
349
+ )
350
+ return formatted_search_docs
351
+
352
+ # @tool
353
+ # def recommendation_tool(weather_condition: str) -> str:
354
+ # """
355
+ # Provides recommendations based on weather conditions.
356
+
357
+ # Args:
358
+ # weather_condition (str): The weather condition.
359
+
360
+ # Returns:
361
+ # str: The recommendations.
362
+ # """
363
+ # return get_recommendation(weather_condition, recommendation_chain)
364
+
365
+ tools = [weather_tool, wiki_search, web_search,
366
+ add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
367
 
368
+ llm = ChatGroq(
369
+ temperature=0,
370
+ model_name="qwen-qwq-32b", # Updated to working model
371
+ groq_api_key=os.getenv("GROQ_API_KEY")
372
+ )
373
 
374
+ llm_with_tools = llm.bind_tools(tools)
 
 
 
375
 
376
+ # === NODES ===
 
377
 
378
+ def retriever(state: MessagesState):
379
+ """Retrieve similar context and inject"""
380
+ query = state["messages"][0].content
381
+ similar_docs = vector_store.similarity_search(query)
382
 
383
+ if similar_docs:
384
+ ref_msg = HumanMessage(
385
+ content=f"Here is a similar question and answer for reference:\n\n{similar_docs[0].page_content}"
386
+ )
387
+ return {"messages": [sys_msg] + state["messages"] + [ref_msg]}
388
+ else:
389
+ return {"messages": [sys_msg] + state["messages"]}
390
 
391
+ def assistant(state: MessagesState):
392
+ """Invoke LLM with tools"""
393
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
394
 
395
+ # === GRAPH BUILD ===
 
 
 
 
 
 
396
 
397
+ def build_graph():
398
  builder = StateGraph(MessagesState)
399
+ builder.add_node("retriever", retriever)
400
  builder.add_node("assistant", assistant)
401
+ builder.add_node("tools", ToolNode([retriever_tool] + tools))
402
+
403
+ builder.set_entry_point("retriever")
404
+ builder.add_edge("retriever", "assistant")
405
  builder.add_conditional_edges("assistant", tools_condition)
406
  builder.add_edge("tools", "assistant")
407
+
408
  return builder.compile()
409
 
410
+ # === TEST ===
411
  if __name__ == "__main__":
412
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
413
  graph = build_graph()
 
414
  messages = [HumanMessage(content=question)]
415
  result = graph.invoke({"messages": messages})
416
+
417
  for msg in result["messages"]:
418
+ msg.pretty_print()