Spaces:
Sleeping
Sleeping
gabrielaltay
commited on
Commit
•
723ae91
1
Parent(s):
2029299
update
Browse files
app.py
CHANGED
@@ -1,11 +1,13 @@
|
|
1 |
from collections import defaultdict
|
2 |
import json
|
|
|
3 |
|
4 |
from langchain_core.documents import Document
|
5 |
from langchain_core.prompts import PromptTemplate
|
6 |
from langchain_core.runnables import RunnableParallel
|
7 |
from langchain_core.runnables import RunnablePassthrough
|
8 |
from langchain_core.output_parsers import StrOutputParser
|
|
|
9 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
10 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
11 |
from langchain_openai import ChatOpenAI
|
@@ -19,6 +21,7 @@ SS = st.session_state
|
|
19 |
|
20 |
SEED = 292764
|
21 |
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
|
|
|
22 |
CONGRESS_GOV_TYPE_MAP = {
|
23 |
"hconres": "house-concurrent-resolution",
|
24 |
"hjres": "house-joint-resolution",
|
@@ -29,7 +32,6 @@ CONGRESS_GOV_TYPE_MAP = {
|
|
29 |
"sjres": "senate-joint-resolution",
|
30 |
"sres": "senate-resolution",
|
31 |
}
|
32 |
-
|
33 |
OPENAI_CHAT_MODELS = [
|
34 |
"gpt-3.5-turbo-0125",
|
35 |
"gpt-4-0125-preview",
|
@@ -115,6 +117,7 @@ def write_outreach_links():
|
|
115 |
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
|
116 |
st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
|
117 |
|
|
|
118 |
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
119 |
doc_grps = defaultdict(list)
|
120 |
|
@@ -219,15 +222,96 @@ def escape_markdown(text):
|
|
219 |
return text
|
220 |
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
227 |
|
228 |
-
st.
|
|
|
|
|
229 |
|
230 |
-
st.write("""
|
231 |
```
|
232 |
What are the themes around artificial intelligence?
|
233 |
```
|
@@ -239,8 +323,15 @@ Write a well cited 3 paragraph essay on food insecurity.
|
|
239 |
```
|
240 |
Create a table summarizing the major climate change ideas with columns legis_id, title, idea.
|
241 |
```
|
242 |
-
|
243 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
244 |
|
245 |
|
246 |
with st.sidebar:
|
@@ -249,6 +340,7 @@ with st.sidebar:
|
|
249 |
write_outreach_links()
|
250 |
|
251 |
st.checkbox("escape markdown in answer", key="response_escape_markdown")
|
|
|
252 |
|
253 |
with st.expander("Generative Config"):
|
254 |
st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
|
@@ -261,20 +353,24 @@ with st.sidebar:
|
|
261 |
st.slider(
|
262 |
"Number of chunks to retrieve",
|
263 |
min_value=1,
|
264 |
-
max_value=
|
265 |
-
value=
|
266 |
key="n_ret_docs",
|
267 |
)
|
268 |
st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
|
269 |
st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
|
270 |
-
# st.text_input("Congress (e.g. 118)", key="filter_congress_num")
|
271 |
st.multiselect(
|
272 |
"Congress Numbers",
|
273 |
CONGRESS_NUMBERS,
|
274 |
default=CONGRESS_NUMBERS,
|
275 |
key="filter_congress_nums",
|
276 |
)
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
with st.expander("Prompt Config"):
|
280 |
st.selectbox(
|
@@ -297,97 +393,65 @@ llm = ChatOpenAI(
|
|
297 |
openai_api_key=st.secrets["openai_api_key"],
|
298 |
model_kwargs={"top_p": SS["top_p"], "seed": SEED},
|
299 |
)
|
300 |
-
|
301 |
vectorstore = load_pinecone_vectorstore()
|
302 |
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
|
|
303 |
|
304 |
-
|
305 |
-
st.text_area("Enter query:", key="query")
|
306 |
-
query_submitted = st.form_submit_button("Submit")
|
307 |
|
308 |
-
|
309 |
-
|
310 |
-
vs_filter = {}
|
311 |
-
if SS["filter_legis_id"] != "":
|
312 |
-
vs_filter["legis_id"] = SS["filter_legis_id"]
|
313 |
-
if SS["filter_bioguide_id"] != "":
|
314 |
-
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
|
315 |
-
# if SS["filter_congress_num"] != "":
|
316 |
-
# vs_filter["congress_num"] = int(SS["filter_congress_num"])
|
317 |
-
vs_filter = {"congress_num": {"$in": SS["filter_congress_nums"]}}
|
318 |
-
return vs_filter
|
319 |
|
320 |
|
321 |
-
|
322 |
|
323 |
-
|
324 |
-
with st.sidebar:
|
325 |
-
with st.expander("Debug vs_filter"):
|
326 |
-
st.write(vs_filter)
|
327 |
-
retriever = vectorstore.as_retriever(
|
328 |
-
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
329 |
-
)
|
330 |
-
prompt = PromptTemplate.from_template(SS["prompt_template"])
|
331 |
-
rag_chain_from_docs = (
|
332 |
-
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
|
333 |
-
| prompt
|
334 |
-
| llm
|
335 |
-
| StrOutputParser()
|
336 |
-
)
|
337 |
-
rag_chain_with_source = RunnableParallel(
|
338 |
-
{"context": retriever, "question": RunnablePassthrough()}
|
339 |
-
).assign(answer=rag_chain_from_docs)
|
340 |
-
out = rag_chain_with_source.invoke(SS["query"])
|
341 |
-
SS["out"] = out
|
342 |
|
|
|
|
|
|
|
343 |
|
344 |
-
|
345 |
-
first_doc = doc_grp[0]
|
346 |
-
|
347 |
-
congress_gov_url = get_congress_gov_url(
|
348 |
-
first_doc.metadata["congress_num"],
|
349 |
-
first_doc.metadata["legis_type"],
|
350 |
-
first_doc.metadata["legis_num"],
|
351 |
-
)
|
352 |
-
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
353 |
-
|
354 |
-
gov_track_url = get_govtrack_url(
|
355 |
-
first_doc.metadata["congress_num"],
|
356 |
-
first_doc.metadata["legis_type"],
|
357 |
-
first_doc.metadata["legis_num"],
|
358 |
-
)
|
359 |
-
gov_track_link = f"[govtrack.us]({gov_track_url})"
|
360 |
-
|
361 |
-
ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format(
|
362 |
-
len(doc_grp),
|
363 |
-
first_doc.metadata["legis_id"],
|
364 |
-
first_doc.metadata["title"],
|
365 |
-
congress_gov_link,
|
366 |
-
gov_track_link,
|
367 |
-
first_doc.metadata["sponsor_full_name"],
|
368 |
-
first_doc.metadata["sponsor_bioguide_id"],
|
369 |
-
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
370 |
-
)
|
371 |
-
doc_contents = [
|
372 |
-
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
373 |
-
for doc in doc_grp
|
374 |
-
]
|
375 |
-
with st.expander(ref):
|
376 |
-
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
|
377 |
-
|
378 |
-
|
379 |
-
out = SS.get("out")
|
380 |
-
if out:
|
381 |
-
|
382 |
-
if SS["response_escape_markdown"]:
|
383 |
-
st.info(escape_markdown(out["answer"]))
|
384 |
-
else:
|
385 |
-
st.info(out["answer"])
|
386 |
-
|
387 |
-
doc_grps = group_docs(out["context"])
|
388 |
-
for legis_id, doc_grp in doc_grps:
|
389 |
-
write_doc_grp(legis_id, doc_grp)
|
390 |
-
|
391 |
-
with st.expander("Debug doc format"):
|
392 |
-
st.text_area("formatted docs", value=format_docs(out["context"]), height=600)
|
393 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from collections import defaultdict
|
2 |
import json
|
3 |
+
import re
|
4 |
|
5 |
from langchain_core.documents import Document
|
6 |
from langchain_core.prompts import PromptTemplate
|
7 |
from langchain_core.runnables import RunnableParallel
|
8 |
from langchain_core.runnables import RunnablePassthrough
|
9 |
from langchain_core.output_parsers import StrOutputParser
|
10 |
+
from langchain_community.callbacks import get_openai_callback
|
11 |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
|
12 |
from langchain_community.vectorstores.utils import DistanceStrategy
|
13 |
from langchain_openai import ChatOpenAI
|
|
|
21 |
|
22 |
SEED = 292764
|
23 |
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
|
24 |
+
SPONSOR_PARTIES = ["D", "R", "L", "I"]
|
25 |
CONGRESS_GOV_TYPE_MAP = {
|
26 |
"hconres": "house-concurrent-resolution",
|
27 |
"hjres": "house-joint-resolution",
|
|
|
32 |
"sjres": "senate-joint-resolution",
|
33 |
"sres": "senate-resolution",
|
34 |
}
|
|
|
35 |
OPENAI_CHAT_MODELS = [
|
36 |
"gpt-3.5-turbo-0125",
|
37 |
"gpt-4-0125-preview",
|
|
|
117 |
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
|
118 |
st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
|
119 |
|
120 |
+
|
121 |
def group_docs(docs) -> list[tuple[str, list[Document]]]:
|
122 |
doc_grps = defaultdict(list)
|
123 |
|
|
|
222 |
return text
|
223 |
|
224 |
|
225 |
+
def get_vectorstore_filter():
|
226 |
+
vs_filter = {}
|
227 |
+
if SS["filter_legis_id"] != "":
|
228 |
+
vs_filter["legis_id"] = SS["filter_legis_id"]
|
229 |
+
if SS["filter_bioguide_id"] != "":
|
230 |
+
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
|
231 |
+
vs_filter = {**vs_filter, "congress_num": {"$in": SS["filter_congress_nums"]}}
|
232 |
+
vs_filter = {**vs_filter, "sponsor_party": {"$in": SS["filter_sponsor_parties"]}}
|
233 |
+
return vs_filter
|
234 |
+
|
235 |
+
|
236 |
+
def write_doc_grp(legis_id: str, doc_grp: list[Document]):
|
237 |
+
first_doc = doc_grp[0]
|
238 |
+
|
239 |
+
congress_gov_url = get_congress_gov_url(
|
240 |
+
first_doc.metadata["congress_num"],
|
241 |
+
first_doc.metadata["legis_type"],
|
242 |
+
first_doc.metadata["legis_num"],
|
243 |
+
)
|
244 |
+
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
245 |
+
|
246 |
+
gov_track_url = get_govtrack_url(
|
247 |
+
first_doc.metadata["congress_num"],
|
248 |
+
first_doc.metadata["legis_type"],
|
249 |
+
first_doc.metadata["legis_num"],
|
250 |
+
)
|
251 |
+
gov_track_link = f"[govtrack.us]({gov_track_url})"
|
252 |
+
|
253 |
+
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
254 |
+
len(doc_grp),
|
255 |
+
first_doc.metadata["legis_id"],
|
256 |
+
first_doc.metadata["title"],
|
257 |
+
congress_gov_link,
|
258 |
+
first_doc.metadata["sponsor_full_name"],
|
259 |
+
first_doc.metadata["sponsor_bioguide_id"],
|
260 |
+
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
|
261 |
+
)
|
262 |
+
doc_contents = [
|
263 |
+
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
|
264 |
+
for doc in doc_grp
|
265 |
+
]
|
266 |
+
with st.expander(ref):
|
267 |
+
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
|
268 |
+
|
269 |
+
|
270 |
+
def legis_id_to_link(legis_id: str) -> str:
|
271 |
+
congress_num, legis_type, legis_num = legis_id.split("-")
|
272 |
+
return get_congress_gov_url(congress_num, legis_type, legis_num)
|
273 |
+
|
274 |
+
|
275 |
+
def legis_id_match_to_link(matchobj):
|
276 |
+
mstring = matchobj.string[matchobj.start() : matchobj.end()]
|
277 |
+
url = legis_id_to_link(mstring)
|
278 |
+
link = f"[{mstring}]({url})"
|
279 |
+
return link
|
280 |
+
|
281 |
+
|
282 |
+
def replace_legis_ids_with_urls(text):
|
283 |
+
pattern = "11[345678]-[a-z]+-\d{1,5}"
|
284 |
+
rtext = re.sub(pattern, legis_id_match_to_link, text)
|
285 |
+
return rtext
|
286 |
+
|
287 |
+
|
288 |
+
def write_guide():
|
289 |
+
|
290 |
+
st.write(
|
291 |
+
"""
|
292 |
+
When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
|
293 |
+
|
294 |
+
|
295 |
+
## Disclaimer
|
296 |
+
|
297 |
+
This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
|
298 |
+
|
299 |
+
|
300 |
+
## Sidebar Config
|
301 |
+
|
302 |
+
Use the `Generative Config` to change LLM parameters.
|
303 |
+
Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
|
304 |
+
|
305 |
+
"""
|
306 |
+
)
|
307 |
+
|
308 |
+
|
309 |
+
def write_example_queries():
|
310 |
|
311 |
+
with st.expander("Example Queries"):
|
312 |
+
st.write(
|
313 |
+
"""
|
314 |
|
|
|
315 |
```
|
316 |
What are the themes around artificial intelligence?
|
317 |
```
|
|
|
323 |
```
|
324 |
Create a table summarizing the major climate change ideas with columns legis_id, title, idea.
|
325 |
```
|
326 |
+
|
327 |
+
"""
|
328 |
+
)
|
329 |
+
|
330 |
+
|
331 |
+
##################
|
332 |
+
|
333 |
+
|
334 |
+
st.title(":classical_building: LegisQA :classical_building:")
|
335 |
|
336 |
|
337 |
with st.sidebar:
|
|
|
340 |
write_outreach_links()
|
341 |
|
342 |
st.checkbox("escape markdown in answer", key="response_escape_markdown")
|
343 |
+
st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
|
344 |
|
345 |
with st.expander("Generative Config"):
|
346 |
st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name")
|
|
|
353 |
st.slider(
|
354 |
"Number of chunks to retrieve",
|
355 |
min_value=1,
|
356 |
+
max_value=32,
|
357 |
+
value=8,
|
358 |
key="n_ret_docs",
|
359 |
)
|
360 |
st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
|
361 |
st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
|
|
|
362 |
st.multiselect(
|
363 |
"Congress Numbers",
|
364 |
CONGRESS_NUMBERS,
|
365 |
default=CONGRESS_NUMBERS,
|
366 |
key="filter_congress_nums",
|
367 |
)
|
368 |
+
st.multiselect(
|
369 |
+
"Sponsor Party",
|
370 |
+
SPONSOR_PARTIES,
|
371 |
+
default=SPONSOR_PARTIES,
|
372 |
+
key="filter_sponsor_parties",
|
373 |
+
)
|
374 |
|
375 |
with st.expander("Prompt Config"):
|
376 |
st.selectbox(
|
|
|
393 |
openai_api_key=st.secrets["openai_api_key"],
|
394 |
model_kwargs={"top_p": SS["top_p"], "seed": SEED},
|
395 |
)
|
|
|
396 |
vectorstore = load_pinecone_vectorstore()
|
397 |
format_docs = DOC_FORMATTERS[SS["prompt_version"]]
|
398 |
+
vs_filter = get_vectorstore_filter()
|
399 |
|
400 |
+
query_tab, guide_tab = st.tabs(["query", "guide"])
|
|
|
|
|
401 |
|
402 |
+
with guide_tab:
|
403 |
+
write_guide()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
|
405 |
|
406 |
+
with query_tab:
|
407 |
|
408 |
+
write_example_queries()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
409 |
|
410 |
+
with st.form("my_form"):
|
411 |
+
st.text_area("Enter query:", key="query")
|
412 |
+
query_submitted = st.form_submit_button("Submit")
|
413 |
|
414 |
+
if query_submitted:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
|
416 |
+
retriever = vectorstore.as_retriever(
|
417 |
+
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
418 |
+
)
|
419 |
+
prompt = PromptTemplate.from_template(SS["prompt_template"])
|
420 |
+
rag_chain_from_docs = (
|
421 |
+
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"])))
|
422 |
+
| prompt
|
423 |
+
| llm
|
424 |
+
| StrOutputParser()
|
425 |
+
)
|
426 |
+
rag_chain_with_source = RunnableParallel(
|
427 |
+
{"context": retriever, "question": RunnablePassthrough()}
|
428 |
+
).assign(answer=rag_chain_from_docs)
|
429 |
+
|
430 |
+
with get_openai_callback() as cb:
|
431 |
+
SS["out"] = rag_chain_with_source.invoke(SS["query"])
|
432 |
+
SS["cb"] = cb
|
433 |
+
|
434 |
+
if "out" in SS:
|
435 |
+
|
436 |
+
out_display = SS["out"]["answer"]
|
437 |
+
if SS["response_escape_markdown"]:
|
438 |
+
out_display = escape_markdown(out_display)
|
439 |
+
if SS["response_add_legis_urls"]:
|
440 |
+
out_display = replace_legis_ids_with_urls(out_display)
|
441 |
+
with st.container(border=True):
|
442 |
+
st.write("Response")
|
443 |
+
st.info(out_display)
|
444 |
+
with st.container(border=True):
|
445 |
+
st.write("API Usage")
|
446 |
+
st.warning(SS["cb"])
|
447 |
+
|
448 |
+
with st.container(border=True):
|
449 |
+
doc_grps = group_docs(SS["out"]["context"])
|
450 |
+
st.write(
|
451 |
+
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
452 |
+
)
|
453 |
+
for legis_id, doc_grp in doc_grps:
|
454 |
+
write_doc_grp(legis_id, doc_grp)
|
455 |
+
|
456 |
+
# with st.expander("Debug doc format"):
|
457 |
+
# st.text_area("formatted docs", value=format_docs(SS["out"]["context"]), height=600)
|