CyberAssassin commited on
Commit
5b7d586
·
verified ·
1 Parent(s): 55c4837

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +223 -214
agent.py CHANGED
@@ -1,214 +1,223 @@
1
- """LangGraph Agent"""
2
- import os
3
- from dotenv import load_dotenv
4
- from langgraph.graph import START, StateGraph, MessagesState
5
- from langgraph.prebuilt import tools_condition
6
- from langgraph.prebuilt import ToolNode
7
- from langchain_google_genai import ChatGoogleGenerativeAI
8
- from langchain_groq import ChatGroq
9
- from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
- from langchain_community.tools.tavily_search import TavilySearchResults
11
- from langchain_community.document_loaders import WikipediaLoader
12
- from langchain_community.document_loaders import ArxivLoader
13
- from langchain_community.vectorstores import SupabaseVectorStore
14
- from langchain_core.messages import SystemMessage, HumanMessage
15
- from langchain_core.tools import tool
16
- from langchain.tools.retriever import create_retriever_tool
17
- from supabase.client import Client, create_client
18
-
19
- load_dotenv()
20
-
21
- @tool
22
- def multiply(a: int, b: int) -> int:
23
- """Multiply two numbers.
24
-
25
- Args:
26
- a: first int
27
- b: second int
28
- """
29
- return a * b
30
-
31
- @tool
32
- def add(a: int, b: int) -> int:
33
- """Add two numbers.
34
-
35
- Args:
36
- a: first int
37
- b: second int
38
- """
39
- return a + b
40
-
41
- @tool
42
- def subtract(a: int, b: int) -> int:
43
- """Subtract two numbers.
44
-
45
- Args:
46
- a: first int
47
- b: second int
48
- """
49
- return a - b
50
-
51
- @tool
52
- def divide(a: int, b: int) -> int:
53
- """Divide two numbers.
54
-
55
- Args:
56
- a: first int
57
- b: second int
58
- """
59
- if b == 0:
60
- raise ValueError("Cannot divide by zero.")
61
- return a / b
62
-
63
- @tool
64
- def modulus(a: int, b: int) -> int:
65
- """Get the modulus of two numbers.
66
-
67
- Args:
68
- a: first int
69
- b: second int
70
- """
71
- return a % b
72
-
73
- @tool
74
- def wiki_search(query: str) -> str:
75
- """Search Wikipedia for a query and return maximum 2 results.
76
-
77
- Args:
78
- query: The search query."""
79
- search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
- formatted_search_docs = "\n\n---\n\n".join(
81
- [
82
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
- for doc in search_docs
84
- ])
85
- return {"wiki_results": formatted_search_docs}
86
-
87
- @tool
88
- def web_search(query: str) -> str:
89
- """Search Tavily for a query and return maximum 3 results.
90
-
91
- Args:
92
- query: The search query."""
93
- search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
- formatted_search_docs = "\n\n---\n\n".join(
95
- [
96
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
- for doc in search_docs
98
- ])
99
- return {"web_results": formatted_search_docs}
100
-
101
- @tool
102
- def arvix_search(query: str) -> str:
103
- """Search Arxiv for a query and return maximum 3 result.
104
-
105
- Args:
106
- query: The search query."""
107
- search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
- formatted_search_docs = "\n\n---\n\n".join(
109
- [
110
- f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
- for doc in search_docs
112
- ])
113
- return {"arvix_results": formatted_search_docs}
114
-
115
-
116
-
117
- # load the system prompt from the file
118
- with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
- system_prompt = f.read()
120
-
121
- # System message
122
- sys_msg = SystemMessage(content=system_prompt)
123
-
124
- # build a retriever
125
- embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
- supabase: Client = create_client(
127
- os.environ.get("SUPABASE_URL"),
128
- os.environ.get("SUPABASE_SERVICE_KEY"))
129
- vector_store = SupabaseVectorStore(
130
- client=supabase,
131
- embedding= embeddings,
132
- table_name="documents",
133
- query_name="match_documents_langchain",
134
- )
135
- create_retriever_tool = create_retriever_tool(
136
- retriever=vector_store.as_retriever(),
137
- name="Question Search",
138
- description="A tool to retrieve similar questions from a vector store.",
139
- )
140
-
141
-
142
-
143
- tools = [
144
- multiply,
145
- add,
146
- subtract,
147
- divide,
148
- modulus,
149
- wiki_search,
150
- web_search,
151
- arvix_search,
152
- ]
153
-
154
- # Build graph function
155
- def build_graph(provider: str = "groq"):
156
- """Build the graph"""
157
- # Load environment variables from .env file
158
- if provider == "google":
159
- # Google Gemini
160
- llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
161
- elif provider == "groq":
162
- # Groq https://console.groq.com/docs/models
163
- llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
164
- elif provider == "huggingface":
165
- # TODO: Add huggingface endpoint
166
- llm = ChatHuggingFace(
167
- llm=HuggingFaceEndpoint(
168
- url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
169
- temperature=0,
170
- ),
171
- )
172
- else:
173
- raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
174
- # Bind tools to LLM
175
- llm_with_tools = llm.bind_tools(tools)
176
-
177
- # Node
178
- def assistant(state: MessagesState):
179
- """Assistant node"""
180
- return {"messages": [llm_with_tools.invoke(state["messages"])]}
181
-
182
- def retriever(state: MessagesState):
183
- """Retriever node"""
184
- similar_question = vector_store.similarity_search(state["messages"][0].content)
185
- example_msg = HumanMessage(
186
- content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
187
- )
188
- return {"messages": [sys_msg] + state["messages"] + [example_msg]}
189
-
190
- builder = StateGraph(MessagesState)
191
- builder.add_node("retriever", retriever)
192
- builder.add_node("assistant", assistant)
193
- builder.add_node("tools", ToolNode(tools))
194
- builder.add_edge(START, "retriever")
195
- builder.add_edge("retriever", "assistant")
196
- builder.add_conditional_edges(
197
- "assistant",
198
- tools_condition,
199
- )
200
- builder.add_edge("tools", "assistant")
201
-
202
- # Compile graph
203
- return builder.compile()
204
-
205
- # test
206
- if __name__ == "__main__":
207
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
208
- # Build the graph
209
- graph = build_graph(provider="groq")
210
- # Run the graph
211
- messages = [HumanMessage(content=question)]
212
- messages = graph.invoke({"messages": messages})
213
- for m in messages["messages"]:
214
- m.pretty_print()
 
 
 
 
 
 
 
 
 
 
1
+ """LangGraph Agent"""
2
+ import os
3
+ from dotenv import load_dotenv
4
+ from langgraph.graph import START, StateGraph, MessagesState
5
+ from langgraph.prebuilt import tools_condition
6
+ from langgraph.prebuilt import ToolNode
7
+ from langchain_google_genai import ChatGoogleGenerativeAI
8
+ from langchain_groq import ChatGroq
9
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ from langchain_community.tools.tavily_search import TavilySearchResults
11
+ from langchain_community.document_loaders import WikipediaLoader
12
+ from langchain_community.document_loaders import ArxivLoader
13
+ from langchain_community.vectorstores import SupabaseVectorStore
14
+ from langchain_core.messages import SystemMessage, HumanMessage
15
+ from langchain_core.tools import tool
16
+ from langchain.tools.retriever import create_retriever_tool
17
+ from supabase.client import Client, create_client
18
+
19
+ load_dotenv()
20
+
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers.
24
+
25
+ Args:
26
+ a: first int
27
+ b: second int
28
+ """
29
+ return a * b
30
+
31
+ @tool
32
+ def add(a: int, b: int) -> int:
33
+ """Add two numbers.
34
+
35
+ Args:
36
+ a: first int
37
+ b: second int
38
+ """
39
+ return a + b
40
+
41
+ @tool
42
+ def subtract(a: int, b: int) -> int:
43
+ """Subtract two numbers.
44
+
45
+ Args:
46
+ a: first int
47
+ b: second int
48
+ """
49
+ return a - b
50
+
51
+ @tool
52
+ def divide(a: int, b: int) -> int:
53
+ """Divide two numbers.
54
+
55
+ Args:
56
+ a: first int
57
+ b: second int
58
+ """
59
+ if b == 0:
60
+ raise ValueError("Cannot divide by zero.")
61
+ return a / b
62
+
63
+ @tool
64
+ def modulus(a: int, b: int) -> int:
65
+ """Get the modulus of two numbers.
66
+
67
+ Args:
68
+ a: first int
69
+ b: second int
70
+ """
71
+ return a % b
72
+
73
+ @tool
74
+ def wiki_search(query: str) -> str:
75
+ """Search Wikipedia for a query and return maximum 2 results.
76
+
77
+ Args:
78
+ query: The search query."""
79
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
80
+ formatted_search_docs = "\n\n---\n\n".join(
81
+ [
82
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
83
+ for doc in search_docs
84
+ ])
85
+ return {"wiki_results": formatted_search_docs}
86
+
87
+ @tool
88
+ def web_search(query: str) -> str:
89
+ """Search Tavily for a query and return maximum 3 results.
90
+
91
+ Args:
92
+ query: The search query."""
93
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
94
+ formatted_search_docs = "\n\n---\n\n".join(
95
+ [
96
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
97
+ for doc in search_docs
98
+ ])
99
+ return {"web_results": formatted_search_docs}
100
+
101
+ @tool
102
+ def arvix_search(query: str) -> str:
103
+ """Search Arxiv for a query and return maximum 3 result.
104
+
105
+ Args:
106
+ query: The search query."""
107
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
108
+ formatted_search_docs = "\n\n---\n\n".join(
109
+ [
110
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
111
+ for doc in search_docs
112
+ ])
113
+ return {"arvix_results": formatted_search_docs}
114
+
115
+
116
+
117
+ # load the system prompt from the file
118
+ with open("system_prompt.txt", "r", encoding="utf-8") as f:
119
+ system_prompt = f.read()
120
+
121
+ # System message
122
+ sys_msg = SystemMessage(content=system_prompt)
123
+
124
+ # build a retriever
125
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
126
+ supabase: Client = create_client(
127
+ os.environ.get("SUPABASE_URL"),
128
+ os.environ.get("SUPABASE_SERVICE_KEY"))
129
+ vector_store = SupabaseVectorStore(
130
+ client=supabase,
131
+ embedding=embeddings,
132
+ table_name="documents",
133
+ query_name="match_documents", # Updated function name
134
+ search_kwargs={
135
+ "k": 3,
136
+ "filter": {}
137
+ }
138
+ )
139
+ retriever_tool = create_retriever_tool(
140
+ retriever=vector_store.as_retriever(
141
+ search_kwargs={
142
+ "k": 3,
143
+ "filter": {}
144
+ }
145
+ ),
146
+ name="document_retriever",
147
+ description="Searches for similar documents",
148
+ )
149
+
150
+
151
+
152
+ tools = [
153
+ multiply,
154
+ add,
155
+ subtract,
156
+ divide,
157
+ modulus,
158
+ wiki_search,
159
+ web_search,
160
+ arvix_search,
161
+ ]
162
+
163
+ # Build graph function
164
+ def build_graph(provider: str = "groq"):
165
+ """Build the graph"""
166
+ # Load environment variables from .env file
167
+ if provider == "google":
168
+ # Google Gemini
169
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
170
+ elif provider == "groq":
171
+ # Groq https://console.groq.com/docs/models
172
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
173
+ elif provider == "huggingface":
174
+ # TODO: Add huggingface endpoint
175
+ llm = ChatHuggingFace(
176
+ llm=HuggingFaceEndpoint(
177
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
178
+ temperature=0,
179
+ ),
180
+ )
181
+ else:
182
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
183
+ # Bind tools to LLM
184
+ llm_with_tools = llm.bind_tools(tools)
185
+
186
+ # Node
187
+ def assistant(state: MessagesState):
188
+ """Assistant node"""
189
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
190
+
191
+ def retriever(state: MessagesState):
192
+ """Retriever node"""
193
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
194
+ example_msg = HumanMessage(
195
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
196
+ )
197
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
198
+
199
+ builder = StateGraph(MessagesState)
200
+ builder.add_node("retriever", retriever)
201
+ builder.add_node("assistant", assistant)
202
+ builder.add_node("tools", ToolNode(tools))
203
+ builder.add_edge(START, "retriever")
204
+ builder.add_edge("retriever", "assistant")
205
+ builder.add_conditional_edges(
206
+ "assistant",
207
+ tools_condition,
208
+ )
209
+ builder.add_edge("tools", "assistant")
210
+
211
+ # Compile graph
212
+ return builder.compile()
213
+
214
+ # test
215
+ if __name__ == "__main__":
216
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
217
+ # Build the graph
218
+ graph = build_graph(provider="groq")
219
+ # Run the graph
220
+ messages = [HumanMessage(content=question)]
221
+ messages = graph.invoke({"messages": messages})
222
+ for m in messages["messages"]:
223
+ m.pretty_print()