orbulat commited on
Commit
9206efd
·
verified ·
1 Parent(s): ff4f92a

Create agent.py

Browse files
Files changed (1) hide show
  1. agent.py +69 -0
agent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from langgraph.graph import START, StateGraph, MessagesState
3
+ from langgraph.prebuilt import ToolNode, tools_condition
4
+ from langchain_openai import ChatOpenAI
5
+ from langchain_core.messages import SystemMessage, HumanMessage
6
+ from langchain_core.tools import tool
7
+ from langchain_community.tools.tavily_search import TavilySearchResults
8
+ from langchain_community.document_loaders import WikipediaLoader
9
+
10
+ # Load system prompt from system_prompt.txt
11
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
12
+ SYSTEM_PROMPT = f.read()
13
+
14
+ # Define tools
15
+ @tool
16
+ def wiki_search(query: str) -> str:
17
+ """Search Wikipedia for a query and return 2 results."""
18
+ docs = WikipediaLoader(query=query, load_max_docs=2).load()
19
+ return "\n\n---\n\n".join([doc.page_content for doc in docs])
20
+
21
+ @tool
22
+ def web_search(query: str) -> str:
23
+ """Search the web using Tavily API and return 3 results."""
24
+ docs = TavilySearchResults(max_results=3).invoke(query)
25
+ return "\n\n---\n\n".join([doc.page_content for doc in docs])
26
+
27
+ tools = [wiki_search, web_search]
28
+
29
+ # Build LangGraph agent
30
+ def build_graph():
31
+ llm = ChatOpenAI(
32
+ model="gpt-4o",
33
+ temperature=0,
34
+ api_key=os.getenv("OPENAI_API_KEY")
35
+ )
36
+ llm_with_tools = llm.bind_tools(tools)
37
+
38
+ def system_node(state: MessagesState):
39
+ return {"messages": [SystemMessage(content=SYSTEM_PROMPT)] + state["messages"]}
40
+
41
+ def assistant_node(state: MessagesState):
42
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
43
+
44
+ graph = StateGraph(MessagesState)
45
+ graph.add_node("system", system_node)
46
+ graph.add_node("assistant", assistant_node)
47
+ graph.add_node("tools", ToolNode(tools))
48
+
49
+ graph.add_edge(START, "system")
50
+ graph.add_edge("system", "assistant")
51
+ graph.add_conditional_edges("assistant", tools_condition)
52
+ graph.add_edge("tools", "assistant")
53
+
54
+ return graph.compile()
55
+
56
+ # Final class for submission agent
57
+ class BasicAgent:
58
+ def __init__(self):
59
+ print("Initializing LangGraph GAIA Agent...")
60
+ self.graph = build_graph()
61
+
62
+ def __call__(self, question: str) -> str:
63
+ try:
64
+ messages = [HumanMessage(content=question)]
65
+ result = self.graph.invoke({"messages": messages})
66
+ final_msg = result["messages"][-1]
67
+ return final_msg.content.strip()
68
+ except Exception as e:
69
+ return f"Agent error: {str(e)}"