Spaces:
Sleeping
Sleeping
gabrielaltay
commited on
Commit
•
b25bfc6
1
Parent(s):
662d960
agent update
Browse files- app.py +118 -138
- custom_tools.py +98 -0
- requirements.txt +5 -3
app.py
CHANGED
@@ -1,9 +1,13 @@
|
|
1 |
from collections import defaultdict
|
2 |
import json
|
3 |
-
from operator import itemgetter
|
4 |
import os
|
5 |
import re
|
6 |
|
|
|
|
|
|
|
|
|
|
|
7 |
from langchain_core.documents import Document
|
8 |
from langchain_core.prompts import PromptTemplate
|
9 |
from langchain_core.prompts import ChatPromptTemplate
|
@@ -14,6 +18,7 @@ from langchain_core.runnables import RunnableParallel
|
|
14 |
from langchain_core.runnables import RunnablePassthrough
|
15 |
from langchain_core.output_parsers import StrOutputParser
|
16 |
from langchain_community.callbacks import get_openai_callback
|
|
|
17 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
18 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
19 |
from langchain_openai import ChatOpenAI
|
@@ -28,6 +33,8 @@ st.set_page_config(layout="wide", page_title="LegisQA")
|
|
28 |
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
|
29 |
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
30 |
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
|
|
|
|
|
31 |
|
32 |
SS = st.session_state
|
33 |
SEED = 292764
|
@@ -54,34 +61,6 @@ ANTHROPIC_CHAT_MODELS = [
|
|
54 |
]
|
55 |
CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS
|
56 |
|
57 |
-
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
|
58 |
-
PROMPT_TEMPLATES = {
|
59 |
-
"v1": PREAMBLE
|
60 |
-
+ """ If you don't know how to respond, just tell the user.
|
61 |
-
|
62 |
-
{context}
|
63 |
-
|
64 |
-
Question: {query}""",
|
65 |
-
"v2": PREAMBLE
|
66 |
-
+ """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should cite particular snippets using legis_id and title. If you don't know how to respond, just tell the user.
|
67 |
-
|
68 |
-
{context}
|
69 |
-
|
70 |
-
Question: {query}""",
|
71 |
-
"v3": PREAMBLE
|
72 |
-
+ """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should cite the legis_id and title. If you don't know how to respond, just tell the user.
|
73 |
-
|
74 |
-
{context}
|
75 |
-
|
76 |
-
Question: {query}""",
|
77 |
-
"v4": PREAMBLE
|
78 |
-
+ """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "title" and "legis_id" in the response. If you don't know how to respond, just tell the user.
|
79 |
-
|
80 |
-
{context}
|
81 |
-
|
82 |
-
Query: {query}""",
|
83 |
-
}
|
84 |
-
|
85 |
|
86 |
def get_sponsor_url(bioguide_id: str) -> str:
|
87 |
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
@@ -92,10 +71,6 @@ def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) ->
|
|
92 |
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
93 |
|
94 |
|
95 |
-
def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
96 |
-
return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}"
|
97 |
-
|
98 |
-
|
99 |
def load_bge_embeddings():
|
100 |
model_name = "BAAI/bge-small-en-v1.5"
|
101 |
model_kwargs = {"device": "cpu"}
|
@@ -156,58 +131,7 @@ def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
|
156 |
return doc_grps
|
157 |
|
158 |
|
159 |
-
def
|
160 |
-
"""Simple double new line join"""
|
161 |
-
return "\n\n".join([doc.page_content for doc in docs])
|
162 |
-
|
163 |
-
|
164 |
-
def format_docs_v2(docs):
|
165 |
-
"""Format with snippet_num, legis_id, and title"""
|
166 |
-
|
167 |
-
def format_doc(idoc, doc):
|
168 |
-
return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format(
|
169 |
-
idoc,
|
170 |
-
doc.metadata["legis_id"],
|
171 |
-
doc.metadata["title"],
|
172 |
-
doc.page_content,
|
173 |
-
)
|
174 |
-
|
175 |
-
snips = []
|
176 |
-
for idoc, doc in enumerate(docs):
|
177 |
-
txt = format_doc(idoc, doc)
|
178 |
-
snips.append(txt)
|
179 |
-
|
180 |
-
return "\n===\n".join(snips)
|
181 |
-
|
182 |
-
|
183 |
-
def format_docs_v3(docs):
|
184 |
-
|
185 |
-
def format_header(doc):
|
186 |
-
return "legis_id: {}\ntitle: {}".format(
|
187 |
-
doc.metadata["legis_id"],
|
188 |
-
doc.metadata["title"],
|
189 |
-
)
|
190 |
-
|
191 |
-
def format_content(doc):
|
192 |
-
return "... {} ...\n".format(
|
193 |
-
doc.page_content,
|
194 |
-
)
|
195 |
-
|
196 |
-
snips = []
|
197 |
-
doc_grps = group_docs(docs)
|
198 |
-
for legis_id, doc_grp in doc_grps:
|
199 |
-
first_doc = doc_grp[0]
|
200 |
-
head = format_header(first_doc)
|
201 |
-
contents = []
|
202 |
-
for idoc, doc in enumerate(doc_grp):
|
203 |
-
txt = format_content(doc)
|
204 |
-
contents.append(txt)
|
205 |
-
snips.append("{}\n\n{}".format(head, "\n".join(contents)))
|
206 |
-
|
207 |
-
return "\n===\n".join(snips)
|
208 |
-
|
209 |
-
|
210 |
-
def format_docs_v4(docs):
|
211 |
"""JSON grouped"""
|
212 |
|
213 |
doc_grps = group_docs(docs)
|
@@ -216,20 +140,13 @@ def format_docs_v4(docs):
|
|
216 |
dd = {
|
217 |
"legis_id": doc_grp[0].metadata["legis_id"],
|
218 |
"title": doc_grp[0].metadata["title"],
|
|
|
219 |
"snippets": [doc.page_content for doc in doc_grp],
|
220 |
}
|
221 |
out.append(dd)
|
222 |
return json.dumps(out, indent=4)
|
223 |
|
224 |
|
225 |
-
DOC_FORMATTERS = {
|
226 |
-
"v1": format_docs_v1,
|
227 |
-
"v2": format_docs_v2,
|
228 |
-
"v3": format_docs_v3,
|
229 |
-
"v4": format_docs_v4,
|
230 |
-
}
|
231 |
-
|
232 |
-
|
233 |
def escape_markdown(text):
|
234 |
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
235 |
for char in MD_SPECIAL_CHARS:
|
@@ -258,12 +175,6 @@ def render_doc_grp(legis_id: str, doc_grp: list[Document]):
|
|
258 |
)
|
259 |
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
260 |
|
261 |
-
gov_track_url = get_govtrack_url(
|
262 |
-
first_doc.metadata["congress_num"],
|
263 |
-
first_doc.metadata["legis_type"],
|
264 |
-
first_doc.metadata["legis_num"],
|
265 |
-
)
|
266 |
-
gov_track_link = f"[govtrack.us]({gov_track_url})"
|
267 |
|
268 |
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
269 |
len(doc_grp),
|
@@ -392,26 +303,31 @@ def render_sidebar():
|
|
392 |
key="filter_sponsor_parties",
|
393 |
)
|
394 |
|
395 |
-
with st.expander("Prompt Config"):
|
396 |
-
st.selectbox(
|
397 |
-
label="prompt version",
|
398 |
-
options=PROMPT_TEMPLATES.keys(),
|
399 |
-
index=3,
|
400 |
-
key="prompt_version",
|
401 |
-
)
|
402 |
-
st.text_area(
|
403 |
-
"prompt template",
|
404 |
-
PROMPT_TEMPLATES[SS["prompt_version"]],
|
405 |
-
height=300,
|
406 |
-
key="prompt_template",
|
407 |
-
)
|
408 |
-
|
409 |
|
410 |
-
def
|
411 |
|
412 |
render_example_queries()
|
413 |
|
414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
st.text_area("Enter query:", key="query")
|
416 |
query_submitted = st.form_submit_button("Submit")
|
417 |
|
@@ -421,7 +337,7 @@ def render_query_tab():
|
|
421 |
retriever = vectorstore.as_retriever(
|
422 |
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
423 |
)
|
424 |
-
|
425 |
rag_chain = (
|
426 |
RunnableParallel(
|
427 |
{
|
@@ -430,7 +346,7 @@ def render_query_tab():
|
|
430 |
}
|
431 |
)
|
432 |
.assign(context=(lambda x: format_docs(x["docs"])))
|
433 |
-
.assign(
|
434 |
)
|
435 |
|
436 |
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
@@ -442,7 +358,7 @@ def render_query_tab():
|
|
442 |
|
443 |
if "out" in SS:
|
444 |
|
445 |
-
out_display = SS["out"]["
|
446 |
if SS["response_escape_markdown"]:
|
447 |
out_display = escape_markdown(out_display)
|
448 |
if SS["response_add_legis_urls"]:
|
@@ -451,7 +367,7 @@ def render_query_tab():
|
|
451 |
st.write("Response")
|
452 |
st.info(out_display)
|
453 |
|
454 |
-
if
|
455 |
with st.container(border=True):
|
456 |
st.write("API Usage")
|
457 |
st.warning(SS["cb"])
|
@@ -468,24 +384,82 @@ def render_query_tab():
|
|
468 |
st.write(SS["out"])
|
469 |
|
470 |
|
471 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
472 |
|
473 |
-
|
474 |
-
|
475 |
-
|
476 |
-
|
477 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
478 |
[
|
479 |
-
("system",
|
480 |
-
|
481 |
-
("
|
482 |
]
|
483 |
)
|
484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
485 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
486 |
st.write("Coming Soon")
|
487 |
|
488 |
-
|
|
|
489 |
|
490 |
|
491 |
##################
|
@@ -521,16 +495,22 @@ else:
|
|
521 |
|
522 |
|
523 |
vectorstore = load_pinecone_vectorstore()
|
524 |
-
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
525 |
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
|
527 |
-
|
|
|
528 |
|
529 |
-
with
|
530 |
-
|
531 |
|
532 |
-
with
|
533 |
-
|
534 |
|
535 |
-
with
|
536 |
-
|
|
|
1 |
from collections import defaultdict
|
2 |
import json
|
|
|
3 |
import os
|
4 |
import re
|
5 |
|
6 |
+
from langchain.tools.retriever import create_retriever_tool
|
7 |
+
from langchain.agents import AgentExecutor
|
8 |
+
from langchain.agents import create_openai_tools_agent
|
9 |
+
from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
|
10 |
+
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
11 |
from langchain_core.documents import Document
|
12 |
from langchain_core.prompts import PromptTemplate
|
13 |
from langchain_core.prompts import ChatPromptTemplate
|
|
|
18 |
from langchain_core.runnables import RunnablePassthrough
|
19 |
from langchain_core.output_parsers import StrOutputParser
|
20 |
from langchain_community.callbacks import get_openai_callback
|
21 |
+
from langchain_community.callbacks import StreamlitCallbackHandler
|
22 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
23 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
24 |
from langchain_openai import ChatOpenAI
|
|
|
33 |
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
|
34 |
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
35 |
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
|
36 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
37 |
+
|
38 |
|
39 |
SS = st.session_state
|
40 |
SEED = 292764
|
|
|
61 |
]
|
62 |
CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS
|
63 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def get_sponsor_url(bioguide_id: str) -> str:
|
66 |
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
|
|
71 |
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
72 |
|
73 |
|
|
|
|
|
|
|
|
|
74 |
def load_bge_embeddings():
|
75 |
model_name = "BAAI/bge-small-en-v1.5"
|
76 |
model_kwargs = {"device": "cpu"}
|
|
|
131 |
return doc_grps
|
132 |
|
133 |
|
134 |
+
def format_docs(docs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
"""JSON grouped"""
|
136 |
|
137 |
doc_grps = group_docs(docs)
|
|
|
140 |
dd = {
|
141 |
"legis_id": doc_grp[0].metadata["legis_id"],
|
142 |
"title": doc_grp[0].metadata["title"],
|
143 |
+
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
144 |
"snippets": [doc.page_content for doc in doc_grp],
|
145 |
}
|
146 |
out.append(dd)
|
147 |
return json.dumps(out, indent=4)
|
148 |
|
149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
150 |
def escape_markdown(text):
|
151 |
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
152 |
for char in MD_SPECIAL_CHARS:
|
|
|
175 |
)
|
176 |
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
177 |
|
|
|
|
|
|
|
|
|
|
|
|
|
178 |
|
179 |
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
180 |
len(doc_grp),
|
|
|
303 |
key="filter_sponsor_parties",
|
304 |
)
|
305 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
306 |
|
307 |
+
def render_query_rag_tab():
|
308 |
|
309 |
render_example_queries()
|
310 |
|
311 |
+
QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "title", "legis_id", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
312 |
+
|
313 |
+
---
|
314 |
+
|
315 |
+
Congressional Legislation Excerpts:
|
316 |
+
|
317 |
+
{context}
|
318 |
+
|
319 |
+
---
|
320 |
+
|
321 |
+
Query: {query}"""
|
322 |
+
|
323 |
+
prompt = ChatPromptTemplate.from_messages(
|
324 |
+
[
|
325 |
+
("system", "You are an expert legislative analyst."),
|
326 |
+
("human", QUERY_TEMPLATE),
|
327 |
+
]
|
328 |
+
)
|
329 |
+
|
330 |
+
with st.form("query_form"):
|
331 |
st.text_area("Enter query:", key="query")
|
332 |
query_submitted = st.form_submit_button("Submit")
|
333 |
|
|
|
337 |
retriever = vectorstore.as_retriever(
|
338 |
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
339 |
)
|
340 |
+
|
341 |
rag_chain = (
|
342 |
RunnableParallel(
|
343 |
{
|
|
|
346 |
}
|
347 |
)
|
348 |
.assign(context=(lambda x: format_docs(x["docs"])))
|
349 |
+
.assign(output=prompt | llm | StrOutputParser())
|
350 |
)
|
351 |
|
352 |
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
|
|
358 |
|
359 |
if "out" in SS:
|
360 |
|
361 |
+
out_display = SS["out"]["output"]
|
362 |
if SS["response_escape_markdown"]:
|
363 |
out_display = escape_markdown(out_display)
|
364 |
if SS["response_add_legis_urls"]:
|
|
|
367 |
st.write("Response")
|
368 |
st.info(out_display)
|
369 |
|
370 |
+
if "cb" in SS:
|
371 |
with st.container(border=True):
|
372 |
st.write("API Usage")
|
373 |
st.warning(SS["cb"])
|
|
|
384 |
st.write(SS["out"])
|
385 |
|
386 |
|
387 |
+
def render_query_agent_tab():
|
388 |
+
|
389 |
+
from custom_tools import get_retriever_tool
|
390 |
+
|
391 |
+
from langchain_community.tools import WikipediaQueryRun
|
392 |
+
from langchain_community.utilities import WikipediaAPIWrapper
|
393 |
+
from langchain.agents import load_tools
|
394 |
+
from langchain.agents import create_react_agent
|
395 |
+
from langchain import hub
|
396 |
|
397 |
+
if SS["model_name"] not in OPENAI_CHAT_MODELS:
|
398 |
+
st.write("only supported with OpenAI for now")
|
399 |
+
return
|
400 |
+
|
401 |
+
vs_filter = get_vectorstore_filter()
|
402 |
+
retriever = vectorstore.as_retriever(
|
403 |
+
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
404 |
+
)
|
405 |
+
legis_retrieval_tool = get_retriever_tool(
|
406 |
+
retriever,
|
407 |
+
"search_legislation",
|
408 |
+
"Searches and returns excerpts from congressional legislation. Always call this tool first.",
|
409 |
+
format_docs,
|
410 |
+
)
|
411 |
+
|
412 |
+
api_wrapper = WikipediaAPIWrapper(top_k_results=4, doc_content_chars_max=800)
|
413 |
+
wiki_search_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
|
414 |
+
|
415 |
+
ddg_tool = load_tools(["ddg-search"])[0]
|
416 |
+
|
417 |
+
avatars = {"human": "user", "ai": "assistant"}
|
418 |
+
tools = [legis_retrieval_tool, wiki_search_tool, ddg_tool]
|
419 |
+
llm_with_tools = llm.bind_tools(tools)
|
420 |
+
|
421 |
+
agent_prompt = ChatPromptTemplate.from_messages(
|
422 |
[
|
423 |
+
("system", "You are a helpful assistant."),
|
424 |
+
("human", "{input}"),
|
425 |
+
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
426 |
]
|
427 |
)
|
428 |
+
agent = (
|
429 |
+
{
|
430 |
+
"input": lambda x: x["input"],
|
431 |
+
"agent_scratchpad": lambda x: format_to_openai_tool_messages(
|
432 |
+
x["intermediate_steps"]
|
433 |
+
),
|
434 |
+
}
|
435 |
+
| agent_prompt
|
436 |
+
| llm_with_tools
|
437 |
+
| OpenAIToolsAgentOutputParser()
|
438 |
+
)
|
439 |
|
440 |
+
prompt = hub.pull("hwchase17/react")
|
441 |
+
agent = create_react_agent(llm, tools, prompt)
|
442 |
+
agent_executor = AgentExecutor(
|
443 |
+
agent=agent,
|
444 |
+
tools=tools,
|
445 |
+
return_intermediate_steps=True,
|
446 |
+
handle_parsing_errors=True,
|
447 |
+
verbose=True,
|
448 |
+
)
|
449 |
+
|
450 |
+
if user_input := st.chat_input(key="single_query_agent_input"):
|
451 |
+
st.chat_message("user").write(user_input)
|
452 |
+
with st.chat_message("assistant"):
|
453 |
+
st_callback = StreamlitCallbackHandler(st.container())
|
454 |
+
response = agent_executor.invoke({"input": user_input}, {"callbacks": [st_callback]})
|
455 |
+
st.write(response["output"])
|
456 |
+
|
457 |
+
|
458 |
+
def render_chat_agent_tab():
|
459 |
st.write("Coming Soon")
|
460 |
|
461 |
+
|
462 |
+
|
463 |
|
464 |
|
465 |
##################
|
|
|
495 |
|
496 |
|
497 |
vectorstore = load_pinecone_vectorstore()
|
|
|
498 |
|
499 |
+
query_rag_tab, query_agent_tab, chat_agent_tab, guide_tab = st.tabs([
|
500 |
+
"query_rag",
|
501 |
+
"query_agent",
|
502 |
+
"chat_agent",
|
503 |
+
"guide",
|
504 |
+
])
|
505 |
|
506 |
+
with query_rag_tab:
|
507 |
+
render_query_rag_tab()
|
508 |
|
509 |
+
with query_agent_tab:
|
510 |
+
render_query_agent_tab()
|
511 |
|
512 |
+
with chat_agent_tab:
|
513 |
+
render_chat_agent_tab()
|
514 |
|
515 |
+
with guide_tab:
|
516 |
+
render_guide()
|
custom_tools.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
TODO clean all this up
|
3 |
+
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
|
4 |
+
"""
|
5 |
+
|
6 |
+
from functools import partial
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
from langchain_core.callbacks.manager import Callbacks
|
10 |
+
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
11 |
+
from langchain_core.pydantic_v1 import BaseModel, Field
|
12 |
+
from langchain_core.retrievers import BaseRetriever
|
13 |
+
from langchain.tools import Tool
|
14 |
+
|
15 |
+
|
16 |
+
def get_retriever_tool(
|
17 |
+
retriever,
|
18 |
+
name,
|
19 |
+
description,
|
20 |
+
format_docs,
|
21 |
+
*,
|
22 |
+
document_prompt: Optional[BasePromptTemplate] = None,
|
23 |
+
document_separator: str = "\n\n",
|
24 |
+
):
|
25 |
+
|
26 |
+
class RetrieverInput(BaseModel):
|
27 |
+
"""Input to the retriever."""
|
28 |
+
|
29 |
+
query: str = Field(description="query to look up in retriever")
|
30 |
+
|
31 |
+
|
32 |
+
def _get_relevant_documents(
|
33 |
+
query: str,
|
34 |
+
retriever: BaseRetriever,
|
35 |
+
document_prompt: BasePromptTemplate,
|
36 |
+
document_separator: str,
|
37 |
+
callbacks: Callbacks = None,
|
38 |
+
) -> str:
|
39 |
+
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
40 |
+
return format_docs(docs)
|
41 |
+
|
42 |
+
async def _aget_relevant_documents(
|
43 |
+
query: str,
|
44 |
+
retriever: BaseRetriever,
|
45 |
+
document_prompt: BasePromptTemplate,
|
46 |
+
document_separator: str,
|
47 |
+
callbacks: Callbacks = None,
|
48 |
+
) -> str:
|
49 |
+
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
50 |
+
return format_docs(docs)
|
51 |
+
|
52 |
+
def create_retriever_tool(
|
53 |
+
retriever: BaseRetriever,
|
54 |
+
name: str,
|
55 |
+
description: str,
|
56 |
+
*,
|
57 |
+
document_prompt: Optional[BasePromptTemplate] = None,
|
58 |
+
document_separator: str = "\n\n",
|
59 |
+
) -> Tool:
|
60 |
+
"""Create a tool to do retrieval of documents.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
retriever: The retriever to use for the retrieval
|
64 |
+
name: The name for the tool. This will be passed to the language model,
|
65 |
+
so should be unique and somewhat descriptive.
|
66 |
+
description: The description for the tool. This will be passed to the language
|
67 |
+
model, so should be descriptive.
|
68 |
+
|
69 |
+
Returns:
|
70 |
+
Tool class to pass to an agent
|
71 |
+
"""
|
72 |
+
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
73 |
+
func = partial(
|
74 |
+
_get_relevant_documents,
|
75 |
+
retriever=retriever,
|
76 |
+
document_prompt=document_prompt,
|
77 |
+
document_separator=document_separator,
|
78 |
+
)
|
79 |
+
afunc = partial(
|
80 |
+
_aget_relevant_documents,
|
81 |
+
retriever=retriever,
|
82 |
+
document_prompt=document_prompt,
|
83 |
+
document_separator=document_separator,
|
84 |
+
)
|
85 |
+
return Tool(
|
86 |
+
name=name,
|
87 |
+
description=description,
|
88 |
+
func=func,
|
89 |
+
coroutine=afunc,
|
90 |
+
args_schema=RetrieverInput,
|
91 |
+
)
|
92 |
+
|
93 |
+
|
94 |
+
return create_retriever_tool(
|
95 |
+
retriever,
|
96 |
+
name,
|
97 |
+
description,
|
98 |
+
)
|
requirements.txt
CHANGED
@@ -37,12 +37,14 @@ jsonpatch==1.33
|
|
37 |
jsonpointer==2.4
|
38 |
jsonschema==4.21.1
|
39 |
jsonschema-specifications==2023.12.1
|
|
|
40 |
langchain-anthropic==0.1.1
|
41 |
-
langchain-community==0.0.
|
42 |
-
langchain-core==0.1.
|
43 |
langchain-openai==0.0.7
|
44 |
langchain-pinecone==0.0.3
|
45 |
-
|
|
|
46 |
markdown-it-py==3.0.0
|
47 |
MarkupSafe==2.1.5
|
48 |
marshmallow==3.20.2
|
|
|
37 |
jsonpointer==2.4
|
38 |
jsonschema==4.21.1
|
39 |
jsonschema-specifications==2023.12.1
|
40 |
+
langchain==0.1.13
|
41 |
langchain-anthropic==0.1.1
|
42 |
+
langchain-community==0.0.29
|
43 |
+
langchain-core==0.1.36
|
44 |
langchain-openai==0.0.7
|
45 |
langchain-pinecone==0.0.3
|
46 |
+
langchain-text-splitters==0.0.1
|
47 |
+
langsmith==0.1.38
|
48 |
markdown-it-py==3.0.0
|
49 |
MarkupSafe==2.1.5
|
50 |
marshmallow==3.20.2
|