Spaces:
Running
Running
gabrielaltay
commited on
Commit
·
a01d550
1
Parent(s):
cb6c0bd
update
Browse files- app.py +12 -6
- custom_tools.py +0 -98
- retriever_tools.py +79 -0
app.py
CHANGED
@@ -1,3 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from collections import defaultdict
|
2 |
import json
|
3 |
import os
|
@@ -140,6 +145,7 @@ def format_docs(docs):
|
|
140 |
dd = {
|
141 |
"legis_id": doc_grp[0].metadata["legis_id"],
|
142 |
"title": doc_grp[0].metadata["title"],
|
|
|
143 |
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
144 |
"snippets": [doc.page_content for doc in doc_grp],
|
145 |
}
|
@@ -308,7 +314,7 @@ def render_query_rag_tab():
|
|
308 |
|
309 |
render_example_queries()
|
310 |
|
311 |
-
QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "title", "
|
312 |
|
313 |
---
|
314 |
|
@@ -328,7 +334,7 @@ Query: {query}"""
|
|
328 |
)
|
329 |
|
330 |
with st.form("query_form"):
|
331 |
-
st.text_area("Enter query:", key="query")
|
332 |
query_submitted = st.form_submit_button("Submit")
|
333 |
|
334 |
if query_submitted:
|
@@ -354,6 +360,7 @@ Query: {query}"""
|
|
354 |
SS["out"] = rag_chain.invoke(SS["query"])
|
355 |
SS["cb"] = cb
|
356 |
else:
|
|
|
357 |
SS["out"] = rag_chain.invoke(SS["query"])
|
358 |
|
359 |
if "out" in SS:
|
@@ -386,7 +393,7 @@ Query: {query}"""
|
|
386 |
|
387 |
def render_query_agent_tab():
|
388 |
|
389 |
-
from
|
390 |
|
391 |
from langchain_community.tools import WikipediaQueryRun
|
392 |
from langchain_community.utilities import WikipediaAPIWrapper
|
@@ -465,9 +472,8 @@ def render_chat_agent_tab():
|
|
465 |
##################
|
466 |
|
467 |
|
468 |
-
st.title(
|
469 |
-
|
470 |
-
)
|
471 |
|
472 |
|
473 |
with st.sidebar:
|
|
|
1 |
+
"""
|
2 |
+
TODO: checkout langgraph
|
3 |
+
TODO: clear screen between agent calls (see here https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/clear_results.py)
|
4 |
+
"""
|
5 |
+
|
6 |
from collections import defaultdict
|
7 |
import json
|
8 |
import os
|
|
|
145 |
dd = {
|
146 |
"legis_id": doc_grp[0].metadata["legis_id"],
|
147 |
"title": doc_grp[0].metadata["title"],
|
148 |
+
"introduced_date": doc_grp[0].metadata["introduced_date"],
|
149 |
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
|
150 |
"snippets": [doc.page_content for doc in doc_grp],
|
151 |
}
|
|
|
314 |
|
315 |
render_example_queries()
|
316 |
|
317 |
+
QUERY_TEMPLATE = """Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
318 |
|
319 |
---
|
320 |
|
|
|
334 |
)
|
335 |
|
336 |
with st.form("query_form"):
|
337 |
+
st.text_area("Enter a query that can be answered with congressional legislation:", key="query")
|
338 |
query_submitted = st.form_submit_button("Submit")
|
339 |
|
340 |
if query_submitted:
|
|
|
360 |
SS["out"] = rag_chain.invoke(SS["query"])
|
361 |
SS["cb"] = cb
|
362 |
else:
|
363 |
+
SS.pop("cb", None)
|
364 |
SS["out"] = rag_chain.invoke(SS["query"])
|
365 |
|
366 |
if "out" in SS:
|
|
|
393 |
|
394 |
def render_query_agent_tab():
|
395 |
|
396 |
+
from retriever_tools import get_retriever_tool
|
397 |
|
398 |
from langchain_community.tools import WikipediaQueryRun
|
399 |
from langchain_community.utilities import WikipediaAPIWrapper
|
|
|
472 |
##################
|
473 |
|
474 |
|
475 |
+
st.title(":classical_building: LegisQA :classical_building:")
|
476 |
+
st.header("Chat With Congressional Bills")
|
|
|
477 |
|
478 |
|
479 |
with st.sidebar:
|
custom_tools.py
DELETED
@@ -1,98 +0,0 @@
|
|
1 |
-
"""
|
2 |
-
TODO clean all this up
|
3 |
-
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
|
4 |
-
"""
|
5 |
-
|
6 |
-
from functools import partial
|
7 |
-
from typing import Optional
|
8 |
-
|
9 |
-
from langchain_core.callbacks.manager import Callbacks
|
10 |
-
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
|
11 |
-
from langchain_core.pydantic_v1 import BaseModel, Field
|
12 |
-
from langchain_core.retrievers import BaseRetriever
|
13 |
-
from langchain.tools import Tool
|
14 |
-
|
15 |
-
|
16 |
-
def get_retriever_tool(
|
17 |
-
retriever,
|
18 |
-
name,
|
19 |
-
description,
|
20 |
-
format_docs,
|
21 |
-
*,
|
22 |
-
document_prompt: Optional[BasePromptTemplate] = None,
|
23 |
-
document_separator: str = "\n\n",
|
24 |
-
):
|
25 |
-
|
26 |
-
class RetrieverInput(BaseModel):
|
27 |
-
"""Input to the retriever."""
|
28 |
-
|
29 |
-
query: str = Field(description="query to look up in retriever")
|
30 |
-
|
31 |
-
|
32 |
-
def _get_relevant_documents(
|
33 |
-
query: str,
|
34 |
-
retriever: BaseRetriever,
|
35 |
-
document_prompt: BasePromptTemplate,
|
36 |
-
document_separator: str,
|
37 |
-
callbacks: Callbacks = None,
|
38 |
-
) -> str:
|
39 |
-
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
40 |
-
return format_docs(docs)
|
41 |
-
|
42 |
-
async def _aget_relevant_documents(
|
43 |
-
query: str,
|
44 |
-
retriever: BaseRetriever,
|
45 |
-
document_prompt: BasePromptTemplate,
|
46 |
-
document_separator: str,
|
47 |
-
callbacks: Callbacks = None,
|
48 |
-
) -> str:
|
49 |
-
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
50 |
-
return format_docs(docs)
|
51 |
-
|
52 |
-
def create_retriever_tool(
|
53 |
-
retriever: BaseRetriever,
|
54 |
-
name: str,
|
55 |
-
description: str,
|
56 |
-
*,
|
57 |
-
document_prompt: Optional[BasePromptTemplate] = None,
|
58 |
-
document_separator: str = "\n\n",
|
59 |
-
) -> Tool:
|
60 |
-
"""Create a tool to do retrieval of documents.
|
61 |
-
|
62 |
-
Args:
|
63 |
-
retriever: The retriever to use for the retrieval
|
64 |
-
name: The name for the tool. This will be passed to the language model,
|
65 |
-
so should be unique and somewhat descriptive.
|
66 |
-
description: The description for the tool. This will be passed to the language
|
67 |
-
model, so should be descriptive.
|
68 |
-
|
69 |
-
Returns:
|
70 |
-
Tool class to pass to an agent
|
71 |
-
"""
|
72 |
-
document_prompt = document_prompt or PromptTemplate.from_template("{page_content}")
|
73 |
-
func = partial(
|
74 |
-
_get_relevant_documents,
|
75 |
-
retriever=retriever,
|
76 |
-
document_prompt=document_prompt,
|
77 |
-
document_separator=document_separator,
|
78 |
-
)
|
79 |
-
afunc = partial(
|
80 |
-
_aget_relevant_documents,
|
81 |
-
retriever=retriever,
|
82 |
-
document_prompt=document_prompt,
|
83 |
-
document_separator=document_separator,
|
84 |
-
)
|
85 |
-
return Tool(
|
86 |
-
name=name,
|
87 |
-
description=description,
|
88 |
-
func=func,
|
89 |
-
coroutine=afunc,
|
90 |
-
args_schema=RetrieverInput,
|
91 |
-
)
|
92 |
-
|
93 |
-
|
94 |
-
return create_retriever_tool(
|
95 |
-
retriever,
|
96 |
-
name,
|
97 |
-
description,
|
98 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
retriever_tools.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
modified from https://github.com/langchain-ai/langchain/blob/master/libs/langchain/langchain/tools/retriever.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
from functools import partial
|
6 |
+
from typing import Callable
|
7 |
+
from typing import Iterable
|
8 |
+
from typing import Optional
|
9 |
+
|
10 |
+
from langchain.schema import Document
|
11 |
+
from langchain.tools import Tool
|
12 |
+
from langchain_core.callbacks.manager import Callbacks
|
13 |
+
from langchain_core.pydantic_v1 import BaseModel
|
14 |
+
from langchain_core.pydantic_v1 import Field
|
15 |
+
from langchain_core.retrievers import BaseRetriever
|
16 |
+
|
17 |
+
|
18 |
+
class RetrieverInput(BaseModel):
|
19 |
+
"""Input to the retriever."""
|
20 |
+
query: str = Field(description="query to look up in retriever")
|
21 |
+
|
22 |
+
|
23 |
+
def _get_relevant_documents(
|
24 |
+
query: str,
|
25 |
+
retriever: BaseRetriever,
|
26 |
+
format_docs: Callable[[Iterable[Document]], str],
|
27 |
+
callbacks: Callbacks = None,
|
28 |
+
) -> str:
|
29 |
+
docs = retriever.get_relevant_documents(query, callbacks=callbacks)
|
30 |
+
return format_docs(docs)
|
31 |
+
|
32 |
+
|
33 |
+
async def _aget_relevant_documents(
|
34 |
+
query: str,
|
35 |
+
retriever: BaseRetriever,
|
36 |
+
format_docs: Callable[[Iterable[Document]], str],
|
37 |
+
callbacks: Callbacks = None,
|
38 |
+
) -> str:
|
39 |
+
docs = await retriever.aget_relevant_documents(query, callbacks=callbacks)
|
40 |
+
return format_docs(docs)
|
41 |
+
|
42 |
+
|
43 |
+
def get_retriever_tool(
|
44 |
+
retriever: BaseRetriever,
|
45 |
+
name: str,
|
46 |
+
description: str,
|
47 |
+
format_docs: Callable[[Iterable[Document]], str],
|
48 |
+
) -> Tool:
|
49 |
+
|
50 |
+
"""Create a tool to do retrieval of documents.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
retriever: The retriever to use for the retrieval
|
54 |
+
name: The name for the tool. This will be passed to the language model,
|
55 |
+
so should be unique and somewhat descriptive.
|
56 |
+
description: The description for the tool. This will be passed to the language
|
57 |
+
model, so should be descriptive.
|
58 |
+
format_docs: A function to turn an iterable of docs into a string.
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tool class to pass to an agent
|
62 |
+
"""
|
63 |
+
func = partial(
|
64 |
+
_get_relevant_documents,
|
65 |
+
retriever=retriever,
|
66 |
+
format_docs=format_docs,
|
67 |
+
)
|
68 |
+
afunc = partial(
|
69 |
+
_aget_relevant_documents,
|
70 |
+
retriever=retriever,
|
71 |
+
format_docs=format_docs,
|
72 |
+
)
|
73 |
+
return Tool(
|
74 |
+
name=name,
|
75 |
+
description=description,
|
76 |
+
func=func,
|
77 |
+
coroutine=afunc,
|
78 |
+
args_schema=RetrieverInput,
|
79 |
+
)
|