Alfred828 commited on
Commit
72db888
·
verified ·
1 Parent(s): ad1c6a3

Create agents/wiki_agent.py

Browse files
Files changed (1) hide show
  1. agents/wiki_agent.py +111 -0
agents/wiki_agent.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Any, Sequence, TypedDict
2
+
3
+ from langchain.tools import StructuredTool
4
+ from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
5
+ from langchain_core.messages.base import BaseMessage
6
+ from langchain_core.prompt_values import PromptValue
7
+ from langchain_core.runnables.base import Runnable
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.graph import END, START, StateGraph
10
+ from langgraph.graph.message import add_messages
11
+ from langgraph.graph.state import CompiledStateGraph
12
+ from langgraph.prebuilt import ToolNode, tools_condition
13
+ from pydantic import BaseModel, Field
14
+
15
+ from config import settings
16
+ from tools.tool_collection_wiki import ToolsCollection as WikiTool
17
+
18
+
19
+ class AgentState(TypedDict):
20
+ messages: Annotated[list[AnyMessage], add_messages]
21
+
22
+
23
+ class WikiAgent:
24
+ def __init__(self):
25
+ chat = ChatOpenAI(model="gpt-4o", verbose=True)
26
+ self.tools: list[StructuredTool] = WikiTool.get_tools(
27
+ [
28
+ "wikipedia_opensearch",
29
+ "get_page_title_excerpt_sections",
30
+ "get_page_section_content",
31
+ ]
32
+ )
33
+ self.chat_with_tools: Runnable[
34
+ PromptValue
35
+ | str
36
+ | Sequence[
37
+ BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
38
+ ],
39
+ BaseMessage,
40
+ ] = chat.bind_tools(self.tools)
41
+
42
+ self.agent = self.build_agent()
43
+
44
+ async def assistant(self, state: AgentState):
45
+ result_message: BaseMessage = await self.chat_with_tools.ainvoke(
46
+ state["messages"]
47
+ )
48
+
49
+ return {
50
+ "messages": [result_message],
51
+ }
52
+
53
+ def build_agent(self) -> CompiledStateGraph:
54
+ builder = StateGraph(AgentState)
55
+
56
+ builder.add_node("assistant", self.assistant)
57
+ builder.add_node("tools", ToolNode(self.tools))
58
+
59
+ # Define edges: these determine how the control flow moves
60
+ builder.add_edge(START, "assistant")
61
+
62
+ builder.add_conditional_edges(source="assistant", path=tools_condition)
63
+
64
+ builder.add_edge("tools", "assistant")
65
+
66
+ agent: CompiledStateGraph = builder.compile()
67
+
68
+ return agent
69
+
70
+ async def ainvoke(self, message: str) -> dict[list[BaseMessage], str, Any]:
71
+ response = await self.agent.ainvoke(
72
+ {
73
+ "messages": [
74
+ SystemMessage(
75
+ content="""
76
+ 你是一個專門搜尋wikipedia的AI Agent,
77
+ 步驟一:使用 wikipedia_opensearch 工具找出與問題相關的頁面
78
+ 步驟二:使用 get_page_title_excerpt_sections 工具找出頁面的 excerpt 和 sections
79
+ 步驟三:根據步驟二的 excerpt 和 sections 結合用戶問題,判斷哪些 section 會有需要的答案,呼叫 get_page_section_content 工具取得這些 section 的所有內容。
80
+ 步驟四:總和前述步驟找出答案。
81
+ """
82
+ ),
83
+ HumanMessage(content=message),
84
+ ]
85
+ },
86
+ config={"callbacks": [settings.LANGFUSE_HANDLER]},
87
+ )
88
+
89
+ # print("🎩 Agent's Response:")
90
+ # print(response["messages"][-1].content)
91
+ return response["messages"][-1].content
92
+
93
+
94
+ class WikipediaEnToolAgentInput(BaseModel):
95
+ question: str = Field(description="The user question in natural language.")
96
+
97
+
98
+ def wikipedia_en_tool_agent(question: str) -> str:
99
+ """
100
+ Invokes the WikiAgent asynchronously to answer a user-provided question using Wikipedia.
101
+
102
+ Args:
103
+ question (str): The user question in natural language.
104
+
105
+ Returns:
106
+ str: The answer or result generated by the WikiAgent.
107
+ """
108
+
109
+ import asyncio
110
+
111
+ return asyncio.run(WikiAgent().ainvoke(question))