Spaces:
Running
Running
from collections import defaultdict | |
import json | |
from langchain_core.documents import Document | |
from langchain_core.prompts import PromptTemplate | |
from langchain_core.runnables import RunnableParallel | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_community.embeddings import HuggingFaceBgeEmbeddings | |
from langchain_community.vectorstores.utils import DistanceStrategy | |
from langchain_openai import ChatOpenAI | |
from langchain_pinecone import PineconeVectorStore | |
from pinecone import Pinecone | |
import streamlit as st | |
st.set_page_config(layout="wide", page_title="LegisQA") | |
SS = st.session_state | |
SEED = 292764 | |
CONGRESS_GOV_TYPE_MAP = { | |
"hconres": "house-concurrent-resolution", | |
"hjres": "house-joint-resolution", | |
"hr": "house-bill", | |
"hres": "house-resolution", | |
"s": "senate-bill", | |
"sconres": "senate-concurrent-resolution", | |
"sjres": "senate-joint-resolution", | |
"sres": "senate-resolution", | |
} | |
OPENAI_CHAT_MODELS = [ | |
"gpt-3.5-turbo-0125", | |
"gpt-4-0125-preview", | |
] | |
PREAMBLE = "You are an expert analyst. Use the following excerpts from US congressional legislation to respond to the user's query." | |
PROMPT_TEMPLATES = { | |
"v1": PREAMBLE | |
+ """ If you don't know how to respond, just tell the user. | |
{context} | |
Question: {question}""", | |
"v2": PREAMBLE | |
+ """ Each snippet starts with a header that includes a unique snippet number (snippet_num), a legis_id, and a title. Your response should reference particular snippets using legis_id and title. If you don't know how to respond, just tell the user. | |
{context} | |
Question: {question}""", | |
"v3": PREAMBLE | |
+ """ Each excerpt starts with a header that includes a legis_id, and a title followed by one or more text snippets. When using text snippets in your response, you should mention the legis_id and title. If you don't know how to respond, just tell the user. | |
{context} | |
Question: {question}""", | |
"v4": PREAMBLE | |
+ """ The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", and "snippets" keys. If a snippet is useful in writing part of your response, then mention the "title" and "legis_id" inline as you write. If you don't know how to respond, just tell the user. | |
{context} | |
Query: {question}""", | |
} | |
def get_sponsor_url(bioguide_id: str) -> str: | |
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}" | |
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str: | |
lt = CONGRESS_GOV_TYPE_MAP[legis_type] | |
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}" | |
def get_govtrack_url(congress_num: int, legis_type: str, legis_num: int) -> str: | |
return f"https://www.govtrack.us/congress/bills/{int(congress_num)}/{legis_type}{int(legis_num)}" | |
def load_bge_embeddings(): | |
model_name = "BAAI/bge-small-en-v1.5" | |
model_kwargs = {"device": "cpu"} | |
encode_kwargs = {"normalize_embeddings": True} | |
emb_fn = HuggingFaceBgeEmbeddings( | |
model_name=model_name, | |
model_kwargs=model_kwargs, | |
encode_kwargs=encode_kwargs, | |
query_instruction="Represent this question for searching relevant passages: ", | |
) | |
return emb_fn | |
def load_pinecone_vectorstore(): | |
emb_fn = load_bge_embeddings() | |
pc = Pinecone(api_key=st.secrets["pinecone_api_key"]) | |
index = pc.Index(st.secrets["pinecone_index_name"]) | |
vectorstore = PineconeVectorStore( | |
index=index, | |
embedding=emb_fn, | |
text_key="text", | |
distance_strategy=DistanceStrategy.COSINE, | |
) | |
return vectorstore | |
def write_outreach_links(): | |
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy" | |
nomic_map_name = "us-congressional-legislation-s1024o256nomic" | |
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map" | |
hf_url = "https://huggingface.co/hyperdemocracy" | |
st.subheader(":brain: Learn about [hyperdemocracy](https://hyperdemocracy.us)") | |
st.subheader(f":world_map: Visualize with [nomic atlas]({nomic_url})") | |
st.subheader(f":hugging_face: Explore the [huggingface datasets](hf_url)") | |
def group_docs(docs) -> list[tuple[str, list[Document]]]: | |
doc_grps = defaultdict(list) | |
# create legis_id groups | |
for doc in docs: | |
doc_grps[doc.metadata["legis_id"]].append(doc) | |
# sort docs in each group by start index | |
for legis_id in doc_grps.keys(): | |
doc_grps[legis_id] = sorted( | |
doc_grps[legis_id], | |
key=lambda x: x.metadata["start_index"], | |
) | |
# sort groups by number of docs | |
doc_grps = sorted( | |
tuple(doc_grps.items()), | |
key=lambda x: -len(x[1]), | |
) | |
return doc_grps | |
def format_docs_v1(docs): | |
"""Simple double new line join""" | |
return "\n\n".join([doc.page_content for doc in docs]) | |
def format_docs_v2(docs): | |
"""Format with snippet_num, legis_id, and title""" | |
def format_doc(idoc, doc): | |
return "snippet_num: {}\nlegis_id: {}\ntitle: {}\n... {} ...\n".format( | |
idoc, | |
doc.metadata["legis_id"], | |
doc.metadata["title"], | |
doc.page_content, | |
) | |
snips = [] | |
for idoc, doc in enumerate(docs): | |
txt = format_doc(idoc, doc) | |
snips.append(txt) | |
return "\n===\n".join(snips) | |
def format_docs_v3(docs): | |
def format_header(doc): | |
return "legis_id: {}\ntitle: {}".format( | |
doc.metadata["legis_id"], | |
doc.metadata["title"], | |
) | |
def format_content(doc): | |
return "... {} ...\n".format( | |
doc.page_content, | |
) | |
snips = [] | |
doc_grps = group_docs(docs) | |
for legis_id, doc_grp in doc_grps: | |
first_doc = doc_grp[0] | |
head = format_header(first_doc) | |
contents = [] | |
for idoc, doc in enumerate(doc_grp): | |
txt = format_content(doc) | |
contents.append(txt) | |
snips.append("{}\n\n{}".format(head, "\n".join(contents))) | |
return "\n===\n".join(snips) | |
def format_docs_v4(docs): | |
"""JSON grouped""" | |
doc_grps = group_docs(docs) | |
out = [] | |
for legis_id, doc_grp in doc_grps: | |
dd = { | |
"legis_id": doc_grp[0].metadata["legis_id"], | |
"title": doc_grp[0].metadata["title"], | |
"snippets": [doc.page_content for doc in doc_grp], | |
} | |
out.append(dd) | |
return json.dumps(out, indent=4) | |
DOC_FORMATTERS = { | |
"v1": format_docs_v1, | |
"v2": format_docs_v2, | |
"v3": format_docs_v3, | |
"v4": format_docs_v4, | |
} | |
def escape_markdown(text): | |
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$" | |
for char in MD_SPECIAL_CHARS: | |
text = text.replace(char, "\\" + char) | |
return text | |
with st.sidebar: | |
with st.container(border=True): | |
write_outreach_links() | |
st.checkbox("escape markdown in answer", key="response_escape_markdown") | |
with st.expander("Generative Config"): | |
st.selectbox(label="model name", options=OPENAI_CHAT_MODELS, key="model_name") | |
st.slider( | |
"temperature", min_value=0.0, max_value=2.0, value=0.0, key="temperature" | |
) | |
st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p") | |
with st.expander("Retrieval Config"): | |
st.slider( | |
"Number of chunks to retrieve", | |
min_value=1, | |
max_value=40, | |
value=10, | |
key="n_ret_docs", | |
) | |
st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id") | |
st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id") | |
st.text_input("Congress (e.g. 118)", key="filter_congress_num") | |
with st.expander("Prompt Config"): | |
st.selectbox( | |
label="prompt version", | |
options=PROMPT_TEMPLATES.keys(), | |
index=3, | |
key="prompt_version", | |
) | |
st.text_area( | |
"prompt template", | |
PROMPT_TEMPLATES[SS["prompt_version"]], | |
height=300, | |
key="prompt_template", | |
) | |
llm = ChatOpenAI( | |
model_name=SS["model_name"], | |
temperature=SS["temperature"], | |
openai_api_key=st.secrets["openai_api_key"], | |
model_kwargs={"top_p": SS["top_p"], "seed": SEED}, | |
) | |
vectorstore = load_pinecone_vectorstore() | |
format_docs = DOC_FORMATTERS[SS["prompt_version"]] | |
with st.form("my_form"): | |
st.text_area("Enter question:", key="query") | |
query_submitted = st.form_submit_button("Submit") | |
def get_vectorstore_filter(): | |
vs_filter = {} | |
if SS["filter_legis_id"] != "": | |
vs_filter["legis_id"] = SS["filter_legis_id"] | |
if SS["filter_bioguide_id"] != "": | |
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"] | |
if SS["filter_congress_num"] != "": | |
vs_filter["congress_num"] = int(SS["filter_congress_num"]) | |
return vs_filter | |
if query_submitted: | |
vs_filter = get_vectorstore_filter() | |
retriever = vectorstore.as_retriever( | |
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter}, | |
) | |
prompt = PromptTemplate.from_template(SS["prompt_template"]) | |
rag_chain_from_docs = ( | |
RunnablePassthrough.assign(context=(lambda x: format_docs(x["context"]))) | |
| prompt | |
| llm | |
| StrOutputParser() | |
) | |
rag_chain_with_source = RunnableParallel( | |
{"context": retriever, "question": RunnablePassthrough()} | |
).assign(answer=rag_chain_from_docs) | |
out = rag_chain_with_source.invoke(SS["query"]) | |
SS["out"] = out | |
def write_doc_grp(legis_id: str, doc_grp: list[Document]): | |
first_doc = doc_grp[0] | |
congress_gov_url = get_congress_gov_url( | |
first_doc.metadata["congress_num"], | |
first_doc.metadata["legis_type"], | |
first_doc.metadata["legis_num"], | |
) | |
congress_gov_link = f"[congress.gov]({congress_gov_url})" | |
gov_track_url = get_govtrack_url( | |
first_doc.metadata["congress_num"], | |
first_doc.metadata["legis_type"], | |
first_doc.metadata["legis_num"], | |
) | |
gov_track_link = f"[govtrack.us]({gov_track_url})" | |
ref = "{} chunks from {}\n\n{}\n\n{} | {}\n\n[{} ({}) ]({})".format( | |
len(doc_grp), | |
first_doc.metadata["legis_id"], | |
first_doc.metadata["title"], | |
congress_gov_link, | |
gov_track_link, | |
first_doc.metadata["sponsor_full_name"], | |
first_doc.metadata["sponsor_bioguide_id"], | |
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]), | |
) | |
doc_contents = [ | |
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content | |
for doc in doc_grp | |
] | |
with st.expander(ref): | |
st.write(escape_markdown("\n\n...\n\n".join(doc_contents))) | |
out = SS.get("out") | |
if out: | |
if SS["response_escape_markdown"]: | |
st.info(escape_markdown(out["answer"])) | |
else: | |
st.info(out["answer"]) | |
doc_grps = group_docs(out["context"]) | |
for legis_id, doc_grp in doc_grps: | |
write_doc_grp(legis_id, doc_grp) | |
with st.expander("Debug doc format"): | |
st.text_area("formatted docs", value=format_docs(out["context"]), height=600) | |
# st.write(json.loads(format_docs(out["context"]))) | |