Files changed (3) hide show
  1. agent.py +194 -0
  2. metadata.jsonl +0 -0
  3. system_prompt.txt +5 -0
agent.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+ from langgraph.graph import START, StateGraph, MessagesState
4
+ from langgraph.prebuilt import tools_condition
5
+ from langgraph.prebuilt import ToolNode
6
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint, HuggingFaceEmbeddings
7
+ from langchain_community.tools import DuckDuckGoSearchResults
8
+ from langchain_community.document_loaders import WikipediaLoader
9
+ from langchain_community.document_loaders import ArxivLoader
10
+ from langchain_community.vectorstores import SupabaseVectorStore
11
+ from langchain_core.messages import SystemMessage, HumanMessage
12
+ from langchain_core.tools import tool
13
+ from langchain.tools.retriever import create_retriever_tool
14
+ from supabase.client import Client, create_client
15
+
16
+ load_dotenv()
17
+
18
+ @tool
19
+ def multiply(a: int, b: int) -> int:
20
+ """Multiply two numbers.
21
+
22
+ Args:
23
+ a: first int
24
+ b: second int
25
+ """
26
+ return a * b
27
+
28
+ @tool
29
+ def add(a: int, b: int) -> int:
30
+ """Add two numbers.
31
+
32
+ Args:
33
+ a: first int
34
+ b: second int
35
+ """
36
+ return a+b
37
+
38
+ @tool
39
+ def subtract(a: int, b:int) -> int:
40
+ """Subtract two numbers.
41
+
42
+ Args:
43
+ a: first int
44
+ b: second int
45
+ """
46
+ return a-b
47
+
48
+ @tool
49
+ def divide(a: int, b: int) -> int:
50
+ """Divide two numbers.
51
+
52
+ Args:
53
+ a: first int
54
+ b: second int
55
+ """
56
+ if b == 0:
57
+ raise ValueError("Cannot divide by zero.")
58
+ return a / b
59
+
60
+ @tool
61
+ def modulus(a: int, b:int) -> int:
62
+ """Get the modulus of two numbers.
63
+
64
+ Args:
65
+ a: first int
66
+ b: second int
67
+ """
68
+ return a%b
69
+
70
+ @tool
71
+ def wiki_search(query: str) -> str:
72
+ """Search Wikipedia for a query and return maximum 2 results.
73
+
74
+ Args:
75
+ query: The search query.
76
+ """
77
+ search_docs = WikipediaLoader(query=query, load_max_docs=3).load()
78
+ formatted_search_docs = "\n\n---\n\n".join(
79
+ [
80
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
81
+ for doc in search_docs
82
+ ])
83
+ return {"wiki_results": formatted_search_docs}
84
+
85
+
86
+ @tool
87
+ def web_search(query: str) -> str:
88
+ """Search Duck2DuckGo for a query and return maximum 3 results.
89
+
90
+ Args:
91
+ query: The search query."""
92
+ search_docs = DuckDuckGoSearchResults(max_results=4).invoke(query=query)
93
+ formatted_search_docs = "\n\n---\n\n".join(
94
+ [
95
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
96
+ for doc in search_docs
97
+ ])
98
+ return {"web_results": formatted_search_docs}
99
+
100
+ @tool
101
+ def arvix_search(query: str) -> str:
102
+ """Search Arxiv for a query and return maximum 3 result.
103
+
104
+ Args:
105
+ query: The search query."""
106
+ search_docs = ArxivLoader(query=query, load_max_docs=2).load()
107
+ formatted_search_docs = "\n\n---\n\n".join(
108
+ [
109
+ f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
110
+ for doc in search_docs
111
+ ])
112
+ return {"arvix_results": formatted_search_docs}
113
+
114
+
115
+ with open("system_prompt.txt","r",encoding="utf-8") as f:
116
+ system_prompt = f.read()
117
+
118
+ # System message
119
+ sys_msg = SystemMessage(content=system_prompt)
120
+
121
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2") # dim=768
122
+ supabase: Client = create_client(
123
+ os.environ.get("SUPABASE_URL"),
124
+ os.environ.get("SUPABASE_SERVICE_KEY"))
125
+ vector_store = SupabaseVectorStore(
126
+ client=supabase,
127
+ embedding= embeddings,
128
+ table_name="documents",
129
+ query_name="match_documents_langchain",
130
+ )
131
+ create_retriever_tool = create_retriever_tool(
132
+ retriever=vector_store.as_retriever(),
133
+ name="Question Search",
134
+ description="A tool to retrieve similar questions from a vector store.",
135
+ )
136
+
137
+ tools = [
138
+ multiply,
139
+ add,
140
+ subtract,
141
+ divide,
142
+ modulus,
143
+ wiki_search,
144
+ web_search,
145
+ arvix_search,
146
+ ]
147
+
148
+ def build_graph():
149
+
150
+ llm = ChatHuggingFace(
151
+ llm=HuggingFaceEndpoint(
152
+ repo_id="meta-llama/Llama-2-7b-chat-hf",
153
+ temperature=0,
154
+ )
155
+ )
156
+ llm_with_tools = llm.bind_tools(tools)
157
+
158
+ def assistant(state: MessagesState):
159
+ """Assistant node"""
160
+ return{"messages":[llm_with_tools.invoke(state["messages"])]}
161
+
162
+ def retriever(state: MessagesState):
163
+ """Retriever node"""
164
+ similar_question = vector_store.similarity_search(state["messages"][0].content)
165
+ example_msg = HumanMessage(
166
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_question[0].page_content}",
167
+ )
168
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
169
+
170
+ builder = StateGraph(MessagesState)
171
+ builder.add_node("retriever", retriever)
172
+ builder.add_node("assistant", assistant)
173
+ builder.add_node("tools", ToolNode(tools))
174
+
175
+ builder.add_edge(START,"retriever")
176
+ builder.add_edge("retriever","assistant")
177
+ builder.add_edge("retriever","assistant")
178
+ builder.add_conditional_edges(
179
+ "assistant",
180
+ tools_condition,
181
+ )
182
+ builder.add_edge("tools","assistant")
183
+
184
+ return builder.compile()
185
+
186
+ if __name__ == "__main__":
187
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
188
+
189
+ # Build the graph
190
+ graph = build_graph()
191
+ messages = [HumanMessage(content=question)]
192
+ messages = graph.invoke({"messages":messages})
193
+ for m in messages["messages"]:
194
+ m.preetty_print()
metadata.jsonl ADDED
The diff for this file is too large to render. See raw diff
 
system_prompt.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ You are a helpful assistant tasked with answering questions using a set of tools.
2
+ Now, I will ask you a question. Report your thoughts, and finish your answer with the following template:
3
+ FINAL ANSWER: [YOUR FINAL ANSWER].
4
+ YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, don't use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, don't use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string.
5
+ Your answer should only start with "FINAL ANSWER: ", then follows with the answer.