Neda1 commited on
Commit
e47210b
Β·
verified Β·
1 Parent(s): 2f69fab

Update agent.py

Browse files
Files changed (1) hide show
  1. agent.py +320 -97
agent.py CHANGED
@@ -1,3 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """LangGraph Agent"""
2
  import os
3
  from dotenv import load_dotenv
@@ -15,10 +292,8 @@ from langchain_core.messages import SystemMessage, HumanMessage
15
  from langchain_core.tools import tool
16
  from langchain.tools.retriever import create_retriever_tool
17
  from supabase.client import Client, create_client
18
- from langchain_core.documents import Document
19
- #load_dotenv()
20
 
21
- load_dotenv(".env")
22
 
23
  @tool
24
  def multiply(a: int, b: int) -> int:
@@ -124,32 +399,15 @@ sys_msg = SystemMessage(content=system_prompt)
124
 
125
  # build a retriever
126
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
127
- # supabase: Client = create_client(
128
- # os.environ.get("SUPABASE_URL"),
129
- # os.environ.get("SUPABASE_SERVICE_KEY"))
130
- supabase_url = os.getenv("SUPABASE_URL")
131
- supabase_key = os.getenv("SUPABASE_KEY")
132
-
133
- if not supabase_url or not supabase_key:
134
- raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set in environment variables.")
135
-
136
- supabase: Client = create_client(supabase_url, supabase_key)
137
- docs = [Document(page_content="This is a test about AI.")]
138
  vector_store = SupabaseVectorStore(
139
- client=supabase, # should be your `supabase` client instance
140
- embedding=embeddings,
141
  table_name="documents",
142
  query_name="match_documents_langchain",
143
  )
144
-
145
- # Add documents
146
- vector_store.add_documents(docs)
147
-
148
- print("πŸ” Testing similarity_search with: 'What is AI?'")
149
- results = vector_store.similarity_search("What is AI?")
150
- print(f"βœ… Got {len(results)} results.")
151
- if results:
152
- print("First result content:\n", results[0].page_content)
153
  create_retriever_tool = create_retriever_tool(
154
  retriever=vector_store.as_retriever(),
155
  name="Question Search",
@@ -170,7 +428,7 @@ tools = [
170
  ]
171
 
172
  # Build graph function
173
- def build_graph(provider: str = "groq"):
174
  """Build the graph"""
175
  # Load environment variables from .env file
176
  if provider == "google":
@@ -192,86 +450,51 @@ def build_graph(provider: str = "groq"):
192
  # Bind tools to LLM
193
  llm_with_tools = llm.bind_tools(tools)
194
 
 
195
  def assistant(state: MessagesState):
196
  """Assistant node"""
197
- print("\n🧠 Final prompt to model:")
198
- for m in state["messages"]:
199
- print(f"{m.type.upper()}: {m.content[:300]}...\n") # truncate for readability
200
-
201
- response = llm_with_tools.invoke(state["messages"])
202
-
203
- print("πŸ’¬ Model response:", response.content[:500], "\n")
204
- return {"messages": [response]}
205
-
206
- # Node
207
- # def assistant(state: MessagesState):
208
- # """Assistant node"""
209
- # return {"messages": [llm_with_tools.invoke(state["messages"])]}
210
-
211
-
212
 
213
  # def retriever(state: MessagesState):
214
- # """Retriever node"""
215
- # similar_question = vector_store.similarity_search(state["messages"][0].content)
216
- # example_msg = HumanMessage(
217
- # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
218
- # )
219
- # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
220
-
221
- def retriever(state: MessagesState):
222
- """Retriever node"""
223
- messages = state.get("messages", [])
224
- if not messages:
225
- print("⚠️ No messages received in retriever node.")
226
- return {"messages": []}
227
-
228
- query = messages[0].content
229
- print(f"\nπŸ” Query to vector store: {query}")
230
-
231
- try:
232
- similar_question = vector_store.similarity_search(query)
233
- except Exception as e:
234
- print(f"❌ similarity_search failed: {e}")
235
- return {"messages": messages}
236
-
237
- if not similar_question:
238
- print("⚠️ No similar questions found.")
239
- return {"messages": messages}
240
-
241
- print(f"βœ… Found {len(similar_question)} similar question(s).")
242
- print("πŸ“„ First retrieved doc:\n", similar_question[0].page_content)
243
-
244
- example_msg = HumanMessage(
245
- content=f"Here I provide a similar question and answer for reference:\n\n{similar_question[0].page_content}"
246
- )
247
- return {"messages": [sys_msg] + messages + [example_msg]}
248
-
249
 
 
250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
252
  builder = StateGraph(MessagesState)
253
  builder.add_node("retriever", retriever)
254
- builder.add_node("assistant", assistant)
255
- builder.add_node("tools", ToolNode(tools))
256
- builder.add_edge(START, "retriever")
257
- builder.add_edge("retriever", "assistant")
258
- builder.add_conditional_edges(
259
- "assistant",
260
- tools_condition,
261
- )
262
- builder.add_edge("tools", "assistant")
263
 
264
  # Compile graph
265
  return builder.compile()
266
-
267
- # test
268
- if __name__ == "__main__":
269
- question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
270
- # Build the graph
271
- graph = build_graph(provider="groq")
272
- # Run the graph
273
- messages = [HumanMessage(content=question)]
274
- messages = graph.invoke({"messages": messages})
275
- for m in messages["messages"]:
276
- m.pretty_print()
277
-
 
1
+ # """LangGraph Agent"""
2
+ # import os
3
+ # from dotenv import load_dotenv
4
+ # from langgraph.graph import START, StateGraph, MessagesState
5
+ # from langgraph.prebuilt import tools_condition
6
+ # from langgraph.prebuilt import ToolNode
7
+ # from langchain_google_genai import ChatGoogleGenerativeAI
8
+ # from langchain_groq import ChatGroq
9
+ # from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
10
+ # from langchain_community.tools.tavily_search import TavilySearchResults
11
+ # from langchain_community.document_loaders import WikipediaLoader
12
+ # from langchain_community.document_loaders import ArxivLoader
13
+ # from langchain_community.vectorstores import SupabaseVectorStore
14
+ # from langchain_core.messages import SystemMessage, HumanMessage
15
+ # from langchain_core.tools import tool
16
+ # from langchain.tools.retriever import create_retriever_tool
17
+ # from supabase.client import Client, create_client
18
+ # from langchain_core.documents import Document
19
+ # #load_dotenv()
20
+
21
+ # load_dotenv(".env")
22
+
23
+ # @tool
24
+ # def multiply(a: int, b: int) -> int:
25
+ # """Multiply two numbers.
26
+ # Args:
27
+ # a: first int
28
+ # b: second int
29
+ # """
30
+ # return a * b
31
+
32
+ # @tool
33
+ # def add(a: int, b: int) -> int:
34
+ # """Add two numbers.
35
+
36
+ # Args:
37
+ # a: first int
38
+ # b: second int
39
+ # """
40
+ # return a + b
41
+
42
+ # @tool
43
+ # def subtract(a: int, b: int) -> int:
44
+ # """Subtract two numbers.
45
+
46
+ # Args:
47
+ # a: first int
48
+ # b: second int
49
+ # """
50
+ # return a - b
51
+
52
+ # @tool
53
+ # def divide(a: int, b: int) -> int:
54
+ # """Divide two numbers.
55
+
56
+ # Args:
57
+ # a: first int
58
+ # b: second int
59
+ # """
60
+ # if b == 0:
61
+ # raise ValueError("Cannot divide by zero.")
62
+ # return a / b
63
+
64
+ # @tool
65
+ # def modulus(a: int, b: int) -> int:
66
+ # """Get the modulus of two numbers.
67
+
68
+ # Args:
69
+ # a: first int
70
+ # b: second int
71
+ # """
72
+ # return a % b
73
+
74
+ # @tool
75
+ # def wiki_search(query: str) -> str:
76
+ # """Search Wikipedia for a query and return maximum 2 results.
77
+
78
+ # Args:
79
+ # query: The search query."""
80
+ # search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
81
+ # formatted_search_docs = "\n\n---\n\n".join(
82
+ # [
83
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
84
+ # for doc in search_docs
85
+ # ])
86
+ # return {"wiki_results": formatted_search_docs}
87
+
88
+ # @tool
89
+ # def web_search(query: str) -> str:
90
+ # """Search Tavily for a query and return maximum 3 results.
91
+
92
+ # Args:
93
+ # query: The search query."""
94
+ # search_docs = TavilySearchResults(max_results=3).invoke(query=query)
95
+ # formatted_search_docs = "\n\n---\n\n".join(
96
+ # [
97
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
98
+ # for doc in search_docs
99
+ # ])
100
+ # return {"web_results": formatted_search_docs}
101
+
102
+ # @tool
103
+ # def arvix_search(query: str) -> str:
104
+ # """Search Arxiv for a query and return maximum 3 result.
105
+
106
+ # Args:
107
+ # query: The search query."""
108
+ # search_docs = ArxivLoader(query=query, load_max_docs=3).load()
109
+ # formatted_search_docs = "\n\n---\n\n".join(
110
+ # [
111
+ # f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
112
+ # for doc in search_docs
113
+ # ])
114
+ # return {"arvix_results": formatted_search_docs}
115
+
116
+
117
+
118
+ # # load the system prompt from the file
119
+ # with open("system_prompt.txt", "r", encoding="utf-8") as f:
120
+ # system_prompt = f.read()
121
+
122
+ # # System message
123
+ # sys_msg = SystemMessage(content=system_prompt)
124
+
125
+ # # build a retriever
126
+ # embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
127
+ # # supabase: Client = create_client(
128
+ # # os.environ.get("SUPABASE_URL"),
129
+ # # os.environ.get("SUPABASE_SERVICE_KEY"))
130
+ # supabase_url = os.getenv("SUPABASE_URL")
131
+ # supabase_key = os.getenv("SUPABASE_KEY")
132
+
133
+ # if not supabase_url or not supabase_key:
134
+ # raise ValueError("SUPABASE_URL and SUPABASE_KEY must be set in environment variables.")
135
+
136
+ # supabase: Client = create_client(supabase_url, supabase_key)
137
+ # docs = [Document(page_content="This is a test about AI.")]
138
+ # vector_store = SupabaseVectorStore(
139
+ # client=supabase, # should be your `supabase` client instance
140
+ # embedding=embeddings,
141
+ # table_name="documents",
142
+ # query_name="match_documents_langchain",
143
+ # )
144
+
145
+ # # Add documents
146
+ # vector_store.add_documents(docs)
147
+
148
+ # print("πŸ” Testing similarity_search with: 'What is AI?'")
149
+ # results = vector_store.similarity_search("What is AI?")
150
+ # print(f"βœ… Got {len(results)} results.")
151
+ # if results:
152
+ # print("First result content:\n", results[0].page_content)
153
+ # create_retriever_tool = create_retriever_tool(
154
+ # retriever=vector_store.as_retriever(),
155
+ # name="Question Search",
156
+ # description="A tool to retrieve similar questions from a vector store.",
157
+ # )
158
+
159
+
160
+
161
+ # tools = [
162
+ # multiply,
163
+ # add,
164
+ # subtract,
165
+ # divide,
166
+ # modulus,
167
+ # wiki_search,
168
+ # web_search,
169
+ # arvix_search,
170
+ # ]
171
+
172
+ # # Build graph function
173
+ # def build_graph(provider: str = "groq"):
174
+ # """Build the graph"""
175
+ # # Load environment variables from .env file
176
+ # if provider == "google":
177
+ # # Google Gemini
178
+ # llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
179
+ # elif provider == "groq":
180
+ # # Groq https://console.groq.com/docs/models
181
+ # llm = ChatGroq(model="qwen-qwq-32b", temperature=0) # optional : qwen-qwq-32b gemma2-9b-it
182
+ # elif provider == "huggingface":
183
+ # # TODO: Add huggingface endpoint
184
+ # llm = ChatHuggingFace(
185
+ # llm=HuggingFaceEndpoint(
186
+ # url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
187
+ # temperature=0,
188
+ # ),
189
+ # )
190
+ # else:
191
+ # raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
192
+ # # Bind tools to LLM
193
+ # llm_with_tools = llm.bind_tools(tools)
194
+
195
+ # def assistant(state: MessagesState):
196
+ # """Assistant node"""
197
+ # print("\n🧠 Final prompt to model:")
198
+ # for m in state["messages"]:
199
+ # print(f"{m.type.upper()}: {m.content[:300]}...\n") # truncate for readability
200
+
201
+ # response = llm_with_tools.invoke(state["messages"])
202
+
203
+ # print("πŸ’¬ Model response:", response.content[:500], "\n")
204
+ # return {"messages": [response]}
205
+
206
+ # # Node
207
+ # # def assistant(state: MessagesState):
208
+ # # """Assistant node"""
209
+ # # return {"messages": [llm_with_tools.invoke(state["messages"])]}
210
+
211
+
212
+
213
+ # # def retriever(state: MessagesState):
214
+ # # """Retriever node"""
215
+ # # similar_question = vector_store.similarity_search(state["messages"][0].content)
216
+ # # example_msg = HumanMessage(
217
+ # # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
218
+ # # )
219
+ # # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
220
+
221
+ # def retriever(state: MessagesState):
222
+ # """Retriever node"""
223
+ # messages = state.get("messages", [])
224
+ # if not messages:
225
+ # print("⚠️ No messages received in retriever node.")
226
+ # return {"messages": []}
227
+
228
+ # query = messages[0].content
229
+ # print(f"\nπŸ” Query to vector store: {query}")
230
+
231
+ # try:
232
+ # similar_question = vector_store.similarity_search(query)
233
+ # except Exception as e:
234
+ # print(f"❌ similarity_search failed: {e}")
235
+ # return {"messages": messages}
236
+
237
+ # if not similar_question:
238
+ # print("⚠️ No similar questions found.")
239
+ # return {"messages": messages}
240
+
241
+ # print(f"βœ… Found {len(similar_question)} similar question(s).")
242
+ # print("πŸ“„ First retrieved doc:\n", similar_question[0].page_content)
243
+
244
+ # example_msg = HumanMessage(
245
+ # content=f"Here I provide a similar question and answer for reference:\n\n{similar_question[0].page_content}"
246
+ # )
247
+ # return {"messages": [sys_msg] + messages + [example_msg]}
248
+
249
+
250
+
251
+
252
+ # builder = StateGraph(MessagesState)
253
+ # builder.add_node("retriever", retriever)
254
+ # builder.add_node("assistant", assistant)
255
+ # builder.add_node("tools", ToolNode(tools))
256
+ # builder.add_edge(START, "retriever")
257
+ # builder.add_edge("retriever", "assistant")
258
+ # builder.add_conditional_edges(
259
+ # "assistant",
260
+ # tools_condition,
261
+ # )
262
+ # builder.add_edge("tools", "assistant")
263
+
264
+ # # Compile graph
265
+ # return builder.compile()
266
+
267
+ # # test
268
+ # if __name__ == "__main__":
269
+ # question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
270
+ # # Build the graph
271
+ # graph = build_graph(provider="groq")
272
+ # # Run the graph
273
+ # messages = [HumanMessage(content=question)]
274
+ # messages = graph.invoke({"messages": messages})
275
+ # for m in messages["messages"]:
276
+ # m.pretty_print()
277
+
278
  """LangGraph Agent"""
279
  import os
280
  from dotenv import load_dotenv
 
292
  from langchain_core.tools import tool
293
  from langchain.tools.retriever import create_retriever_tool
294
  from supabase.client import Client, create_client
 
 
295
 
296
+ load_dotenv()
297
 
298
  @tool
299
  def multiply(a: int, b: int) -> int:
 
399
 
400
  # build a retriever
401
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
402
+ supabase: Client = create_client(
403
+ os.environ.get("SUPABASE_URL"),
404
+ os.environ.get("SUPABASE_SERVICE_KEY"))
 
 
 
 
 
 
 
 
405
  vector_store = SupabaseVectorStore(
406
+ client=supabase,
407
+ embedding= embeddings,
408
  table_name="documents",
409
  query_name="match_documents_langchain",
410
  )
 
 
 
 
 
 
 
 
 
411
  create_retriever_tool = create_retriever_tool(
412
  retriever=vector_store.as_retriever(),
413
  name="Question Search",
 
428
  ]
429
 
430
  # Build graph function
431
+ def build_graph(provider: str = "google"):
432
  """Build the graph"""
433
  # Load environment variables from .env file
434
  if provider == "google":
 
450
  # Bind tools to LLM
451
  llm_with_tools = llm.bind_tools(tools)
452
 
453
+ # Node
454
  def assistant(state: MessagesState):
455
  """Assistant node"""
456
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
457
 
458
  # def retriever(state: MessagesState):
459
+ # """Retriever node"""
460
+ # similar_question = vector_store.similarity_search(state["messages"][0].content)
461
+ #example_msg = HumanMessage(
462
+ # content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
463
+ # )
464
+ # return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
465
 
466
+ from langchain_core.messages import AIMessage
467
 
468
+ def retriever(state: MessagesState):
469
+ query = state["messages"][-1].content
470
+ similar_doc = vector_store.similarity_search(query, k=1)[0]
471
+
472
+ content = similar_doc.page_content
473
+ if "Final answer :" in content:
474
+ answer = content.split("Final answer :")[-1].strip()
475
+ else:
476
+ answer = content.strip()
477
+
478
+ return {"messages": [AIMessage(content=answer)]}
479
+
480
+ # builder = StateGraph(MessagesState)
481
+ #builder.add_node("retriever", retriever)
482
+ #builder.add_node("assistant", assistant)
483
+ #builder.add_node("tools", ToolNode(tools))
484
+ #builder.add_edge(START, "retriever")
485
+ #builder.add_edge("retriever", "assistant")
486
+ #builder.add_conditional_edges(
487
+ # "assistant",
488
+ # tools_condition,
489
+ #)
490
+ #builder.add_edge("tools", "assistant")
491
 
492
  builder = StateGraph(MessagesState)
493
  builder.add_node("retriever", retriever)
494
+
495
+ # Retriever ist Start und Endpunkt
496
+ builder.set_entry_point("retriever")
497
+ builder.set_finish_point("retriever")
 
 
 
 
 
498
 
499
  # Compile graph
500
  return builder.compile()