Alfred828 commited on
Commit
ad1c6a3
·
verified ·
1 Parent(s): 8ff3d81

Create agents/common_agent.py

Browse files
Files changed (1) hide show
  1. agents/common_agent.py +114 -0
agents/common_agent.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Annotated, Any, Sequence, TypedDict
2
+
3
+ from langchain.tools import StructuredTool
4
+ from langchain_core.messages import AnyMessage, HumanMessage, SystemMessage
5
+ from langchain_core.messages.base import BaseMessage
6
+ from langchain_core.prompt_values import PromptValue
7
+ from langchain_core.runnables.base import Runnable
8
+ from langchain_openai import ChatOpenAI
9
+ from langgraph.graph import START, StateGraph
10
+ from langgraph.graph.message import add_messages
11
+ from langgraph.graph.state import CompiledStateGraph
12
+ from langgraph.prebuilt import ToolNode, tools_condition
13
+
14
+ from config import settings
15
+ from tools.encyclopedia import EncyclopediaRetriever
16
+ from tools.tool_collection_common import ToolsCollection
17
+
18
+ retriver = EncyclopediaRetriever(["gaia"], settings.PROJ_PATH)
19
+
20
+
21
+ class AgentState(TypedDict):
22
+ messages: Annotated[list[AnyMessage], add_messages]
23
+
24
+
25
+ class AgenticRAG:
26
+ def __init__(self):
27
+ chat = ChatOpenAI(model="gpt-4o", verbose=True)
28
+ self.tools: list[StructuredTool] = ToolsCollection.get_tools(
29
+ [
30
+ "search_tool",
31
+ "get_weather",
32
+ "hub_stats_tool",
33
+ "wikipedia_en_tool_agent",
34
+ "EncyclopediaRetriever",
35
+ ]
36
+ )
37
+ self.chat_with_tools: Runnable[
38
+ PromptValue
39
+ | str
40
+ | Sequence[
41
+ BaseMessage | list[str] | tuple[str, str] | str | dict[str, Any]
42
+ ],
43
+ BaseMessage,
44
+ ] = chat.bind_tools(self.tools)
45
+
46
+ self.agent = self.build_agent()
47
+
48
+ def retriver(self, state: AgentState):
49
+ question = state["messages"][-1].content
50
+ result = retriver.get_related_question(question)
51
+
52
+ return {
53
+ "messages": [
54
+ HumanMessage(
55
+ content=f"以下有多組可能的問題及最終答案,請檢查是否有該問題`{question}`的答案,含有該問題答案時,挑出答案並回傳"
56
+ + f"\n{result}"
57
+ )
58
+ ],
59
+ }
60
+
61
+ async def assistant(self, state: AgentState) -> dict[str, list[BaseMessage]]:
62
+ # print("\n=================", state["messages"], "=================\n")
63
+ result_message: BaseMessage = await self.chat_with_tools.ainvoke(
64
+ state["messages"]
65
+ )
66
+
67
+ return {
68
+ "messages": [result_message],
69
+ }
70
+
71
+ def build_agent(self) -> CompiledStateGraph:
72
+ builder = StateGraph(AgentState)
73
+
74
+ builder.add_node("retriver", self.retriver)
75
+ builder.add_node("assistant", self.assistant)
76
+ builder.add_node("tools", ToolNode(self.tools))
77
+
78
+ # Define edges: these determine how the control flow moves
79
+ builder.add_edge(START, "retriver")
80
+ builder.add_edge("retriver", "assistant")
81
+ builder.add_conditional_edges(
82
+ source="assistant",
83
+ path=tools_condition,
84
+ )
85
+ builder.add_edge("tools", "assistant")
86
+ agent: CompiledStateGraph = builder.compile()
87
+
88
+ return agent
89
+
90
+ async def ainvoke(self, message: str) -> dict[list[BaseMessage], str, Any]:
91
+ response = await self.agent.ainvoke(
92
+ {
93
+ "messages": [
94
+ SystemMessage(
95
+ content="""
96
+ 你是一個AI助理,專門回答問題,當根據現有資訊無法得出答案時,優先使用外部工具嘗試找出答案。
97
+
98
+ 當你使用外部工具時,以外部工具提供給你的答案為最準。
99
+ 雖然你有出眾的語言能力,回答時請精簡,不要解釋為什麼是這個答案,也不需要提供參考資訊,直接告訴我結果就好。
100
+ 以這個問題為例
101
+ 提問:How many studio albums were published by xxx between 1990 and 2009 (included)?
102
+ 原始回覆:Between 1990 and 2009 (included), xxx published n studio albums.
103
+ 希望的回覆:n
104
+ """
105
+ ),
106
+ HumanMessage(content=message),
107
+ ]
108
+ },
109
+ config={"callbacks": [settings.LANGFUSE_HANDLER]},
110
+ )
111
+
112
+ # print("🎩 Agent's Response:")
113
+ # print(response["messages"][-1].content)
114
+ return response["messages"][-1].content