HarshitSundriyal commited on
Commit
caf960b
·
verified ·
1 Parent(s): b499e0b

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +19 -26
agent.py CHANGED
@@ -354,51 +354,44 @@ def wiki_search(query : str) -> str:
354
  tools = [weather_tool, wiki_search, web_search,
355
  add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
356
 
 
 
357
  llm = ChatGroq(
358
  temperature=0,
359
- model_name="qwen-qwq-32b", # Updated to working model
360
  groq_api_key=os.getenv("GROQ_API_KEY")
361
  )
362
 
363
- llm_with_tools = llm.bind_tools(tools)
 
 
364
 
365
- # === NODES ===
366
 
367
- def retriever(state: MessagesState):
368
- """Retrieve similar context and inject"""
369
- query = state["messages"][0].content
370
- similar_docs = vector_store.similarity_search(query)
371
 
372
- if similar_docs:
373
- ref_msg = HumanMessage(
374
- content=f"Here is a similar question and answer for reference:\n\n{similar_docs[0].page_content}"
375
- )
376
- return {"messages": [sys_msg] + state["messages"] + [ref_msg]}
377
- else:
378
- return {"messages": [sys_msg] + state["messages"]}
379
 
380
- def assistant(state: MessagesState):
381
- """Invoke LLM with tools"""
382
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
383
 
384
- # === GRAPH BUILD ===
385
 
386
  def build_graph():
387
- builder = StateGraph(MessagesState)
388
- builder.add_node("retriever", retriever)
389
  builder.add_node("assistant", assistant)
390
- builder.add_node("tools", ToolNode([retriever_tool] + tools))
391
 
392
- builder.set_entry_point("retriever")
393
- builder.add_edge("retriever", "assistant")
394
  builder.add_conditional_edges("assistant", tools_condition)
395
  builder.add_edge("tools", "assistant")
396
-
397
  return builder.compile()
398
 
399
- # === TEST ===
 
400
  if __name__ == "__main__":
401
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
402
  graph = build_graph()
403
  messages = [HumanMessage(content=question)]
404
  result = graph.invoke({"messages": messages})
 
354
  tools = [weather_tool, wiki_search, web_search,
355
  add, subtract, multiply, divide, square, cube, power, factorial, mean, standard_deviation]
356
 
357
+ # === LLM with Tools ===
358
+
359
  llm = ChatGroq(
360
  temperature=0,
361
+ model_name="qwen-qwq-32b",
362
  groq_api_key=os.getenv("GROQ_API_KEY")
363
  )
364
 
365
+ tools = [weather_tool, wiki_search, web_search,
366
+ add, subtract, multiply, divide, square, cube,
367
+ power, factorial, mean, standard_deviation]
368
 
369
+ llm_with_tools = llm.bind_tools(tools)
370
 
371
+ # === LangGraph State ===
 
 
 
372
 
373
+ class ToolAgentState(TypedDict):
374
+ messages: Annotated[List[HumanMessage], "Messages in the conversation"]
 
 
 
 
 
375
 
376
+ def assistant(state: ToolAgentState):
377
+ return {"messages": [llm_with_tools.invoke([sys_msg] + state["messages"])]}
 
378
 
379
+ # === Build Graph ===
380
 
381
  def build_graph():
382
+ builder = StateGraph(ToolAgentState)
 
383
  builder.add_node("assistant", assistant)
384
+ builder.add_node("tools", ToolNode(tools))
385
 
386
+ builder.set_entry_point("assistant")
 
387
  builder.add_conditional_edges("assistant", tools_condition)
388
  builder.add_edge("tools", "assistant")
 
389
  return builder.compile()
390
 
391
+ # === Run ===
392
+
393
  if __name__ == "__main__":
394
+ question = "When did India won a world cup in cricket before 2000?"
395
  graph = build_graph()
396
  messages = [HumanMessage(content=question)]
397
  result = graph.invoke({"messages": messages})