Spaces:
Sleeping
Sleeping
gabrielaltay
commited on
Commit
•
80275c5
1
Parent(s):
7632f66
lets go
Browse files- app.py +307 -8
- requirements.txt +1 -0
app.py
CHANGED
@@ -1,13 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
2 |
-
from langchain_community.vectorstores import Pinecone
|
3 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
4 |
from langchain_openai import ChatOpenAI
|
5 |
-
from
|
|
|
6 |
import streamlit as st
|
7 |
|
8 |
-
st.set_page_config(layout="wide", page_title="LegisQA")
|
9 |
|
|
|
|
|
10 |
|
|
|
11 |
CONGRESS_GOV_TYPE_MAP = {
|
12 |
"hconres": "house-concurrent-resolution",
|
13 |
"hjres": "house-joint-resolution",
|
@@ -25,6 +35,48 @@ OPENAI_CHAT_MODELS = [
|
|
25 |
]
|
26 |
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
def load_bge_embeddings():
|
29 |
model_name = "BAAI/bge-small-en-v1.5"
|
30 |
model_kwargs = {"device": "cpu"}
|
@@ -40,9 +92,9 @@ def load_bge_embeddings():
|
|
40 |
|
41 |
def load_pinecone_vectorstore():
|
42 |
emb_fn = load_bge_embeddings()
|
43 |
-
pc =
|
44 |
index = pc.Index(st.secrets["pinecone_index_name"])
|
45 |
-
vectorstore =
|
46 |
index=index,
|
47 |
embedding=emb_fn,
|
48 |
text_key="text",
|
@@ -51,8 +103,255 @@ def load_pinecone_vectorstore():
|
|
51 |
return vectorstore
|
52 |
|
53 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
vectorstore = load_pinecone_vectorstore()
|
56 |
-
|
57 |
-
|
58 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import json
|
3 |
+
|
4 |
+
from langchain_core.documents import Document
|
5 |
+
from langchain_core.prompts import PromptTemplate
|
6 |
+
from langchain_core.runnables import RunnableParallel
|
7 |
+
from langchain_core.runnables import RunnablePassthrough
|
8 |
+
from langchain_core.output_parsers import StrOutputParser
|
9 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
|
|
10 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
11 |
from langchain_openai import ChatOpenAI
|
12 |
+
from langchain_pinecone import PineconeVectorStore
|
13 |
+
from pinecone import Pinecone
|
14 |
import streamlit as st
|
15 |
|
|
|
16 |
|
17 |
+
st.set_page_config(layout="wide", page_title="LegisQA")
|
18 |
+
SS = st.session_state
|
19 |
|
20 |
+
SEED = 292764
|
21 |
CONGRESS_GOV_TYPE_MAP = {
|
22 |
"hconres": "house-concurrent-resolution",
|
23 |
"hjres": "house-joint-resolution",
|
|
|
35 |
]
|
36 |
|
37 |
|
38 |
+
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query."
|
39 |
+
PROMPT_TEMPLATES = {
|
40 |
+
"v1": PREAMBLE
|
41 |
+
+ """ If you don't know how to respond, just tell the user.
|
42 |
+
|
43 |
+
{context}
|
44 |
+
|
45 |
+
Question: {question}""",
|
46 |
+
"v2": PREAMBLE
|
47 |
+
+ """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should reference particular snippets using legis_id and title. If you don't know how to respond, just tell the user.
|
48 |
+
|
49 |
+
{context}
|
50 |
+
|
51 |
+
Question: {question}""",
|
52 |
+
"v3": PREAMBLE
|
53 |
+
+ """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should mention the legis_id and title. If you don't know how to respond, just tell the user.
|
54 |
+
|
55 |
+
{context}
|
56 |
+
|
57 |
+
Question: {question}""",
|
58 |
+
"v4": PREAMBLE
|
59 |
+
+ """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then mention the "title" and "legis_id" inline as you write. If you don't know how to respond, just tell the user.
|
60 |
+
|
61 |
+
{context}
|
62 |
+
|
63 |
+
Query: {question}""",
|
64 |
+
}
|
65 |
+
|
66 |
+
|
67 |
+
def get_sponsor_url(bioguide_id: str) -> str:
|
68 |
+
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
|
69 |
+
|
70 |
+
|
71 |
+
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
72 |
+
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
|
73 |
+
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
|
74 |
+
|
75 |
+
|
76 |
+
def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str:
|
77 |
+
return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}"
|
78 |
+
|
79 |
+
|
80 |
def load_bge_embeddings():
|
81 |
model_name = "BAAI/bge-small-en-v1.5"
|
82 |
model_kwargs = {"device": "cpu"}
|
|
|
92 |
|
93 |
def load_pinecone_vectorstore():
|
94 |
emb_fn = load_bge_embeddings()
|
95 |
+
pc = Pinecone(api_key=st.secrets["pinecone_api_key"])
|
96 |
index = pc.Index(st.secrets["pinecone_index_name"])
|
97 |
+
vectorstore = PineconeVectorStore(
|
98 |
index=index,
|
99 |
embedding=emb_fn,
|
100 |
text_key="text",
|
|
|
103 |
return vectorstore
|
104 |
|
105 |
|
106 |
+
def write_outreach_links():
|
107 |
+
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
|
108 |
+
nomic_map_name = "us-congressional-legislation-s1024o256nomic"
|
109 |
+
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
|
110 |
+
hf_url = "https://huggingface.co/hyperdemocracy"
|
111 |
+
st.subheader(":brain: Learn about [hyperdemocracy](https://hyperdemocracy.us)")
|
112 |
+
st.subheader(f":world_map: Visualize with [nomic atlas]({nomic_url})")
|
113 |
+
st.subheader(f":hugging_face: Explore the [huggingface datasets](hf_url)")
|
114 |
+
|
115 |
+
|
116 |
+
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
117 |
+
doc_grps = defaultdict(list)
|
118 |
+
|
119 |
+
# create legis_id groups
|
120 |
+
for doc in docs:
|
121 |
+
doc_grps[doc.metadata["legis_id"]].append(doc)
|
122 |
+
|
123 |
+
# sort docs in each group by start index
|
124 |
+
for legis_id in doc_grps.keys():
|
125 |
+
doc_grps[legis_id] = sorted(
|
126 |
+
doc_grps[legis_id],
|
127 |
+
key=lambda x: x.metadata["start_index"],
|
128 |
+
)
|
129 |
+
|
130 |
+
# sort groups by number of docs
|
131 |
+
doc_grps = sorted(
|
132 |
+
tuple(doc_grps.items()),
|
133 |
+
key=lambda x: -len(x[1]),
|
134 |
+
)
|
135 |
+
|
136 |
+
return doc_grps
|
137 |
+
|
138 |
+
|
139 |
+
def format_docs_v1(docs):
|
140 |
+
"""Simple double new line join"""
|
141 |
+
return "\n\n".join([doc.page_content for doc in docs])
|
142 |
+
|
143 |
+
|
144 |
+
def format_docs_v2(docs):
|
145 |
+
"""Format with snippet_num, legis_id, and title"""
|
146 |
+
|
147 |
+
def format_doc(idoc, doc):
|
148 |
+
return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format(
|
149 |
+
idoc,
|
150 |
+
doc.metadata["legis_id"],
|
151 |
+
doc.metadata["title"],
|
152 |
+
doc.page_content,
|
153 |
+
)
|
154 |
+
|
155 |
+
snips = []
|
156 |
+
for idoc, doc in enumerate(docs):
|
157 |
+
txt = format_doc(idoc, doc)
|
158 |
+
snips.append(txt)
|
159 |
+
|
160 |
+
return "\n===\n".join(snips)
|
161 |
+
|
162 |
+
|
163 |
+
def format_docs_v3(docs):
|
164 |
+
|
165 |
+
def format_header(doc):
|
166 |
+
return "legis_id: {}\ntitle: {}".format(
|
167 |
+
doc.metadata["legis_id"],
|
168 |
+
doc.metadata["title"],
|
169 |
+
)
|
170 |
+
|
171 |
+
def format_content(doc):
|
172 |
+
return "... {} ...\n".format(
|
173 |
+
doc.page_content,
|
174 |
+
)
|
175 |
+
|
176 |
+
snips = []
|
177 |
+
doc_grps = group_docs(docs)
|
178 |
+
for legis_id, doc_grp in doc_grps:
|
179 |
+
first_doc = doc_grp[0]
|
180 |
+
head = format_header(first_doc)
|
181 |
+
contents = []
|
182 |
+
for idoc, doc in enumerate(doc_grp):
|
183 |
+
txt = format_content(doc)
|
184 |
+
contents.append(txt)
|
185 |
+
snips.append("{}\n\n{}".format(head, "\n".join(contents)))
|
186 |
+
|
187 |
+
return "\n===\n".join(snips)
|
188 |
+
|
189 |
+
|
190 |
+
def format_docs_v4(docs):
|
191 |
+
"""JSON grouped"""
|
192 |
+
|
193 |
+
doc_grps = group_docs(docs)
|
194 |
+
out = []
|
195 |
+
for legis_id, doc_grp in doc_grps:
|
196 |
+
dd = {
|
197 |
+
"legis_id": doc_grp[0].metadata["legis_id"],
|
198 |
+
"title": doc_grp[0].metadata["title"],
|
199 |
+
"snippets": [doc.page_content for doc in doc_grp],
|
200 |
+
}
|
201 |
+
out.append(dd)
|
202 |
+
return json.dumps(out, indent=4)
|
203 |
+
|
204 |
+
|
205 |
+
DOC_FORMATTERS = {
|
206 |
+
"v1": format_docs_v1,
|
207 |
+
"v2": format_docs_v2,
|
208 |
+
"v3": format_docs_v3,
|
209 |
+
"v4": format_docs_v4,
|
210 |
+
}
|
211 |
+
|
212 |
+
|
213 |
+
def escape_markdown(text):
|
214 |
+
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
|
215 |
+
for char in MD_SPECIAL_CHARS:
|
216 |
+
text = text.replace(char, "\\" + char)
|
217 |
+
return text
|
218 |
+
|
219 |
+
|
220 |
+
with st.sidebar:
|
221 |
+
|
222 |
+
with st.container(border=True):
|
223 |
+
write_outreach_links()
|
224 |
+
|
225 |
+
st.checkbox("escape markdown in answer", key="response_escape_markdown")
|
226 |
+
|
227 |
+
with st.expander("Generative Config"):
|
228 |
+
st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
|
229 |
+
st.slider(
|
230 |
+
"temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature"
|
231 |
+
)
|
232 |
+
st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
|
233 |
+
|
234 |
+
with st.expander("Retrieval Config"):
|
235 |
+
st.slider(
|
236 |
+
"Number of chunks to retrieve",
|
237 |
+
min_value=1,
|
238 |
+
max_value=40,
|
239 |
+
value=10,
|
240 |
+
key="n_ret_docs",
|
241 |
+
)
|
242 |
+
st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
|
243 |
+
st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
|
244 |
+
st.text_input("Congress (e.g. 118)", key="filter_congress_num")
|
245 |
+
|
246 |
+
with st.expander("Prompt Config"):
|
247 |
+
st.selectbox(
|
248 |
+
label="prompt version",
|
249 |
+
options=PROMPT_TEMPLATES.keys(),
|
250 |
+
index=3,
|
251 |
+
key="prompt_version",
|
252 |
+
)
|
253 |
+
st.text_area(
|
254 |
+
"prompt template",
|
255 |
+
PROMPT_TEMPLATES[SS["prompt_version"]],
|
256 |
+
height=300,
|
257 |
+
key="prompt_template",
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
llm = ChatOpenAI(
|
262 |
+
model_name=SS["model_name"],
|
263 |
+
temperature=SS["temperature"],
|
264 |
+
openai_api_key=st.secrets["openai_api_key"],
|
265 |
+
model_kwargs={"top_p": SS["top_p"], "seed": SEED},
|
266 |
+
)
|
267 |
|
268 |
vectorstore = load_pinecone_vectorstore()
|
269 |
+
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
270 |
+
|
271 |
+
with st.form("my_form"):
|
272 |
+
st.text_area("Enter question:", key="query")
|
273 |
+
query_submitted = st.form_submit_button("Submit")
|
274 |
+
|
275 |
+
|
276 |
+
def get_vectorstore_filter():
|
277 |
+
vs_filter = {}
|
278 |
+
if SS["filter_legis_id"] != "":
|
279 |
+
vs_filter["legis_id"] = SS["filter_legis_id"]
|
280 |
+
if SS["filter_bioguide_id"] != "":
|
281 |
+
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
|
282 |
+
if SS["filter_congress_num"] != "":
|
283 |
+
vs_filter["congress_num"] = int(SS["filter_congress_num"])
|
284 |
+
return vs_filter
|
285 |
+
|
286 |
+
|
287 |
+
if query_submitted:
|
288 |
+
|
289 |
+
vs_filter = get_vectorstore_filter()
|
290 |
+
retriever = vectorstore.as_retriever(
|
291 |
+
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
292 |
+
)
|
293 |
+
prompt = PromptTemplate.from_template(SS["prompt_template"])
|
294 |
+
rag_chain_from_docs = (
|
295 |
+
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
|
296 |
+
| prompt
|
297 |
+
| llm
|
298 |
+
| StrOutputParser()
|
299 |
+
)
|
300 |
+
rag_chain_with_source = RunnableParallel(
|
301 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
302 |
+
).assign(answer=rag_chain_from_docs)
|
303 |
+
out = rag_chain_with_source.invoke(SS["query"])
|
304 |
+
SS["out"] = out
|
305 |
+
|
306 |
+
|
307 |
+
def write_doc_grp(legis_id: str, doc_grp: list[Document]):
|
308 |
+
first_doc = doc_grp[0]
|
309 |
+
|
310 |
+
congress_gov_url = get_congress_gov_url(
|
311 |
+
first_doc.metadata["congress_num"],
|
312 |
+
first_doc.metadata["legis_type"],
|
313 |
+
first_doc.metadata["legis_num"],
|
314 |
+
)
|
315 |
+
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
316 |
+
|
317 |
+
gov_track_url = get_govtrack_url(
|
318 |
+
first_doc.metadata["congress_num"],
|
319 |
+
first_doc.metadata["legis_type"],
|
320 |
+
first_doc.metadata["legis_num"],
|
321 |
+
)
|
322 |
+
gov_track_link = f"[govtrack.us]({gov_track_url})"
|
323 |
+
|
324 |
+
ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format(
|
325 |
+
len(doc_grp),
|
326 |
+
first_doc.metadata["legis_id"],
|
327 |
+
first_doc.metadata["title"],
|
328 |
+
congress_gov_link,
|
329 |
+
gov_track_link,
|
330 |
+
first_doc.metadata["sponsor_full_name"],
|
331 |
+
first_doc.metadata["sponsor_bioguide_id"],
|
332 |
+
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
333 |
+
)
|
334 |
+
doc_contents = [
|
335 |
+
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
336 |
+
for doc in doc_grp
|
337 |
+
]
|
338 |
+
with st.expander(ref):
|
339 |
+
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
|
340 |
+
|
341 |
+
|
342 |
+
out = SS.get("out")
|
343 |
+
if out:
|
344 |
+
|
345 |
+
if SS["response_escape_markdown"]:
|
346 |
+
st.info(escape_markdown(out["answer"]))
|
347 |
+
else:
|
348 |
+
st.info(out["answer"])
|
349 |
+
|
350 |
+
doc_grps = group_docs(out["context"])
|
351 |
+
for legis_id, doc_grp in doc_grps:
|
352 |
+
write_doc_grp(legis_id, doc_grp)
|
353 |
+
|
354 |
+
with st.expander("Debug doc format"):
|
355 |
+
|
356 |
+
st.text_area("formatted docs", value=format_docs(out["context"]), height=600)
|
357 |
+
# st.write(json.loads(format_docs(out["context"])))
|
requirements.txt
CHANGED
@@ -33,6 +33,7 @@ jsonschema-specifications==2023.12.1
|
|
33 |
langchain-community==0.0.24
|
34 |
langchain-core==0.1.26
|
35 |
langchain-openai==0.0.7
|
|
|
36 |
langsmith==0.1.7
|
37 |
markdown-it-py==3.0.0
|
38 |
MarkupSafe==2.1.5
|
|
|
33 |
langchain-community==0.0.24
|
34 |
langchain-core==0.1.26
|
35 |
langchain-openai==0.0.7
|
36 |
+
langchain-pinecone==0.0.3
|
37 |
langsmith==0.1.7
|
38 |
markdown-it-py==3.0.0
|
39 |
MarkupSafe==2.1.5
|