hf-legisqa / app.py
gabrielaltay's picture
side by side
da0f003
raw
history blame
23.5 kB
"""
"""
from collections import defaultdict
import json
import os
import re
from langchain.tools.retriever import create_retriever_tool
from langchain.agents import AgentExecutor
from langchain.agents import create_openai_tools_agent
from langchain.agents.format_scratchpad.openai_tools import (
format_to_openai_tool_messages,
)
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
from langchain_core.documents import Document
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts import MessagesPlaceholder
from langchain_core.messages import AIMessage
from langchain_core.messages import HumanMessage
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnablePassthrough
from langchain_core.output_parsers import StrOutputParser
from langchain_community.callbacks import get_openai_callback
from langchain_community.callbacks import StreamlitCallbackHandler
from langchain_community.embeddings import HuggingFaceBgeEmbeddings
from langchain_community.vectorstores.utils import DistanceStrategy
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_together import ChatTogether
from langchain_pinecone import PineconeVectorStore
from pinecone import Pinecone
import streamlit as st
st.set_page_config(layout="wide", page_title="LegisQA")
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
SS = st.session_state
SEED = 292764
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
SPONSOR_PARTIES = ["D", "R", "L", "I"]
CONGRESS_GOV_TYPE_MAP = {
"hconres": "house-concurrent-resolution",
"hjres": "house-joint-resolution",
"hr": "house-bill",
"hres": "house-resolution",
"s": "senate-bill",
"sconres": "senate-concurrent-resolution",
"sjres": "senate-joint-resolution",
"sres": "senate-resolution",
}
OPENAI_CHAT_MODELS = {
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
# "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
}
ANTHROPIC_CHAT_MODELS = {
"claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}},
# "claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}},
# "claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}},
}
TOGETHER_CHAT_MODELS = {
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}},
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"cost": {"pmi": 0.88, "pmo": 0.88}
},
# "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"cost": {"pmi": 5.00, "pmo": 5.00}},
}
PROVIDER_MODELS = {
"OpenAI": OPENAI_CHAT_MODELS,
"Anthropic": ANTHROPIC_CHAT_MODELS,
"Together": TOGETHER_CHAT_MODELS,
}
def get_sponsor_url(bioguide_id: str) -> str:
return f"https://bioguide.congress.gov/search/bio/{bioguide_id}"
def get_congress_gov_url(congress_num: int, legis_type: str, legis_num: int) -> str:
lt = CONGRESS_GOV_TYPE_MAP[legis_type]
return f"https://www.congress.gov/bill/{int(congress_num)}th-congress/{lt}/{int(legis_num)}"
def load_bge_embeddings():
model_name = "BAAI/bge-small-en-v1.5"
model_kwargs = {"device": "cpu"}
encode_kwargs = {"normalize_embeddings": True}
emb_fn = HuggingFaceBgeEmbeddings(
model_name=model_name,
model_kwargs=model_kwargs,
encode_kwargs=encode_kwargs,
query_instruction="Represent this question for searching relevant passages: ",
)
return emb_fn
def load_pinecone_vectorstore():
emb_fn = load_bge_embeddings()
vectorstore = PineconeVectorStore(
embedding=emb_fn,
text_key="text",
distance_strategy=DistanceStrategy.COSINE,
pinecone_api_key=st.secrets["pinecone_api_key"],
index_name=st.secrets["pinecone_index_name"],
)
return vectorstore
def render_outreach_links():
nomic_base_url = "https://atlas.nomic.ai/data/gabrielhyperdemocracy"
nomic_map_name = "us-congressional-legislation-s1024o256nomic-1"
nomic_url = f"{nomic_base_url}/{nomic_map_name}/map"
hf_url = "https://huggingface.co/hyperdemocracy"
pc_url = "https://www.pinecone.io/blog/serverless"
st.subheader(":brain: About [hyperdemocracy](https://hyperdemocracy.us)")
st.subheader(f":world_map: Visualize [nomic atlas]({nomic_url})")
st.subheader(f":hugging_face: Raw [huggingface datasets]({hf_url})")
st.subheader(f":evergreen_tree: Index [pinecone serverless]({pc_url})")
def group_docs(docs) -> list[tuple[str, list[Document]]]:
doc_grps = defaultdict(list)
# create legis_id groups
for doc in docs:
doc_grps[doc.metadata["legis_id"]].append(doc)
# sort docs in each group by start index
for legis_id in doc_grps.keys():
doc_grps[legis_id] = sorted(
doc_grps[legis_id],
key=lambda x: x.metadata["start_index"],
)
# sort groups by number of docs
doc_grps = sorted(
tuple(doc_grps.items()),
key=lambda x: -len(x[1]),
)
return doc_grps
def format_docs(docs):
"""JSON grouped"""
doc_grps = group_docs(docs)
out = []
for legis_id, doc_grp in doc_grps:
dd = {
"legis_id": doc_grp[0].metadata["legis_id"],
"title": doc_grp[0].metadata["title"],
"introduced_date": doc_grp[0].metadata["introduced_date"],
"sponsor": doc_grp[0].metadata["sponsor_full_name"],
"snippets": [doc.page_content for doc in doc_grp],
}
out.append(dd)
return json.dumps(out, indent=4)
def escape_markdown(text):
MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
for char in MD_SPECIAL_CHARS:
text = text.replace(char, "\\" + char)
return text
def get_vectorstore_filter(key_prefix: str):
vs_filter = {}
if SS[f"{key_prefix}|filter_legis_id"] != "":
vs_filter["legis_id"] = SS[f"{key_prefix}|filter_legis_id"]
if SS[f"{key_prefix}|filter_bioguide_id"] != "":
vs_filter["sponsor_bioguide_id"] = SS[f"{key_prefix}|filter_bioguide_id"]
vs_filter = {
**vs_filter,
"congress_num": {"$in": SS[f"{key_prefix}|filter_congress_nums"]},
}
vs_filter = {
**vs_filter,
"sponsor_party": {"$in": SS[f"{key_prefix}|filter_sponsor_parties"]},
}
return vs_filter
def render_doc_grp(legis_id: str, doc_grp: list[Document]):
first_doc = doc_grp[0]
congress_gov_url = get_congress_gov_url(
first_doc.metadata["congress_num"],
first_doc.metadata["legis_type"],
first_doc.metadata["legis_num"],
)
congress_gov_link = f"[congress.gov]({congress_gov_url})"
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
len(doc_grp),
first_doc.metadata["legis_id"],
first_doc.metadata["title"],
congress_gov_link,
first_doc.metadata["sponsor_full_name"],
first_doc.metadata["sponsor_bioguide_id"],
get_sponsor_url(first_doc.metadata["sponsor_bioguide_id"]),
)
doc_contents = [
"[start_index={}] ".format(int(doc.metadata["start_index"])) + doc.page_content
for doc in doc_grp
]
with st.expander(ref):
st.write(escape_markdown("\n\n...\n\n".join(doc_contents)))
def legis_id_to_link(legis_id: str) -> str:
congress_num, legis_type, legis_num = legis_id.split("-")
return get_congress_gov_url(congress_num, legis_type, legis_num)
def legis_id_match_to_link(matchobj):
mstring = matchobj.string[matchobj.start() : matchobj.end()]
url = legis_id_to_link(mstring)
link = f"[{mstring}]({url})"
return link
def replace_legis_ids_with_urls(text):
pattern = "11[345678]-[a-z]+-\d{1,5}"
rtext = re.sub(pattern, legis_id_match_to_link, text)
return rtext
def render_guide():
st.write(
"""
When you send a query to LegisQA, it will attempt to retrieve relevant content from the past six congresses ([113th-118th](https://en.wikipedia.org/wiki/List_of_United_States_Congresses)) covering 2013 to the present, pass it to a [large language model (LLM)](https://en.wikipedia.org/wiki/Large_language_model), and generate a response. This technique is known as Retrieval Augmented Generation (RAG). You can read [an academic paper](https://proceedings.neurips.cc/paper/2020/hash/6b493230205f780e1bc26945df7481e5-Abstract.html) or [a high level summary](https://research.ibm.com/blog/retrieval-augmented-generation-RAG) to get more details. Once the response is generated, the retrieved content will be available for inspection with links to the bills and sponsors.
## Disclaimer
This is a research project. The RAG technique helps to ground the LLM response by providing context from a trusted source, but it does not guarantee a high quality response. We encourage you to play around, find questions that work and find questions that fail. There is a small monthly budget dedicated to the OpenAI endpoints. Once that is used up each month, queries will no longer work.
## Sidebar Config
Use the `Generative Config` to change LLM parameters.
Use the `Retrieval Config` to change the number of chunks retrieved from our congress corpus and to apply various filters to the content before it is retrieved (e.g. filter to a specific set of congresses). Use the `Prompt Config` to try out different document formatting and prompting strategies.
"""
)
def render_example_queries():
with st.expander("Example Queries"):
st.write(
"""
```
What are the themes around artificial intelligence?
```
```
Write a well cited 3 paragraph essay on food insecurity.
```
```
Create a table summarizing major climate change ideas with columns legis_id, title, idea.
```
```
Write an action plan to keep social security solvent.
```
```
Suggest reforms that would benefit the Medicaid program.
```
"""
)
def render_generative_config(key_prefix: str):
st.selectbox(
label="provider", options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|provider"
)
st.selectbox(
label="model name",
options=PROVIDER_MODELS[SS[f"{key_prefix}|provider"]],
key=f"{key_prefix}|model_name",
)
st.slider(
"temperature",
min_value=0.0,
max_value=2.0,
value=0.01,
key=f"{key_prefix}|temperature",
)
st.slider(
"max_output_tokens",
min_value=1024,
max_value=2048,
key=f"{key_prefix}|max_output_tokens",
)
st.slider(
"top_p", min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|top_p"
)
st.checkbox(
"escape markdown in answer", key=f"{key_prefix}|response_escape_markdown"
)
st.checkbox(
"add legis urls in answer",
value=True,
key=f"{key_prefix}|response_add_legis_urls",
)
def render_retrieval_config(key_prefix: str):
st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=32,
value=8,
key=f"{key_prefix}|n_ret_docs",
)
st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|filter_legis_id")
st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|filter_bioguide_id")
st.multiselect(
"Congress Numbers",
CONGRESS_NUMBERS,
default=CONGRESS_NUMBERS,
key=f"{key_prefix}|filter_congress_nums",
)
st.multiselect(
"Sponsor Party",
SPONSOR_PARTIES,
default=SPONSOR_PARTIES,
key=f"{key_prefix}|filter_sponsor_parties",
)
def get_llm(key_prefix: str):
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
llm = ChatOpenAI(
model=SS[f"{key_prefix}|model_name"],
temperature=SS[f"{key_prefix}|temperature"],
api_key=st.secrets["openai_api_key"],
top_p=SS[f"{key_prefix}|top_p"],
seed=SEED,
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
)
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
llm = ChatAnthropic(
model_name=SS[f"{key_prefix}|model_name"],
temperature=SS[f"{key_prefix}|temperature"],
api_key=st.secrets["anthropic_api_key"],
top_p=SS[f"{key_prefix}|top_p"],
max_tokens_to_sample=SS[f"{key_prefix}|max_output_tokens"],
)
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
llm = ChatTogether(
model=SS[f"{key_prefix}|model_name"],
temperature=SS[f"{key_prefix}|temperature"],
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
top_p=SS[f"{key_prefix}|top_p"],
seed=SEED,
api_key=st.secrets["together_api_key"],
)
else:
raise ValueError()
return llm
def get_token_usage(key_prefix: str, metadata: dict):
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
return get_openai_token_usage(metadata, model_info)
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
return get_anthropic_token_usage(metadata, model_info)
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
return get_together_token_usage(metadata, model_info)
else:
raise ValueError()
def get_openai_token_usage(metadata: dict, model_info: dict):
input_tokens = metadata["token_usage"]["prompt_tokens"]
output_tokens = metadata["token_usage"]["completion_tokens"]
cost = (
input_tokens * 1e-6 * model_info["cost"]["pmi"]
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": cost,
}
def get_anthropic_token_usage(metadata: dict, model_info: dict):
input_tokens = metadata["usage"]["input_tokens"]
output_tokens = metadata["usage"]["output_tokens"]
cost = (
input_tokens * 1e-6 * model_info["cost"]["pmi"]
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": cost,
}
def get_together_token_usage(metadata: dict, model_info: dict):
input_tokens = metadata["token_usage"]["prompt_tokens"]
output_tokens = metadata["token_usage"]["completion_tokens"]
cost = (
input_tokens * 1e-6 * model_info["cost"]["pmi"]
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
)
return {
"input_tokens": input_tokens,
"output_tokens": output_tokens,
"cost": cost,
}
def render_sidebar():
with st.container(border=True):
render_outreach_links()
def render_query_rag_tab():
key_prefix = "query_rag"
render_example_queries()
col1, col2 = st.columns(2)
with col1:
with st.expander("Generative Config"):
render_generative_config(key_prefix)
with col2:
with st.expander("Retrieval Config"):
render_retrieval_config(key_prefix)
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
---
Congressional Legislation Excerpts:
{context}
---
Query: {query}"""
prompt = ChatPromptTemplate.from_messages(
[
("human", QUERY_RAG_TEMPLATE),
]
)
with st.form(f"{key_prefix}|query_form"):
st.text_area(
"Enter a query that can be answered with congressional legislation:",
key=f"{key_prefix}|query",
)
query_submitted = st.form_submit_button("Submit")
if query_submitted:
llm = get_llm(key_prefix)
vs_filter = get_vectorstore_filter(key_prefix)
retriever = vectorstore.as_retriever(
search_kwargs={"k": SS[f"{key_prefix}|n_ret_docs"], "filter": vs_filter},
)
rag_chain = (
RunnableParallel(
{
"docs": retriever, # list of docs
"query": RunnablePassthrough(), # str
}
)
.assign(context=(lambda x: format_docs(x["docs"])))
.assign(output=prompt | llm)
)
SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{key_prefix}|query"])
if f"{key_prefix}|out" in SS:
out_display = SS[f"{key_prefix}|out"]["output"].content
if SS[f"{key_prefix}|response_escape_markdown"]:
out_display = escape_markdown(out_display)
if SS[f"{key_prefix}|response_add_legis_urls"]:
out_display = replace_legis_ids_with_urls(out_display)
with st.container(border=True):
st.write("Response")
st.info(out_display)
with st.container(border=True):
st.write("API Usage")
token_usage = get_token_usage(
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Input Tokens", token_usage["input_tokens"])
with col2:
st.metric("Output Tokens", token_usage["output_tokens"])
with col3:
st.metric("Cost", f"${token_usage['cost']:.4f}")
with st.expander("Response Metadata"):
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
with st.container(border=True):
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
st.write(
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
)
for legis_id, doc_grp in doc_grps:
render_doc_grp(legis_id, doc_grp)
with st.expander("Debug"):
st.write(SS[f"{key_prefix}|out"])
def render_query_rag_sbs_tab():
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
---
Congressional Legislation Excerpts:
{context}
---
Query: {query}"""
base_key_prefix = "query_rag_sbs"
prompt = ChatPromptTemplate.from_messages(
[
("human", QUERY_RAG_TEMPLATE),
]
)
with st.form(f"{base_key_prefix}|query_form"):
st.text_area(
"Enter a query that can be answered with congressional legislation:",
key=f"{base_key_prefix}|query",
)
query_submitted = st.form_submit_button("Submit")
grp1a, grp2a = st.columns(2)
with grp1a:
st.header("Group 1")
key_prefix = f"{base_key_prefix}|grp1"
with st.expander("Generative Config"):
render_generative_config(key_prefix)
with st.expander("Retrieval Config"):
render_retrieval_config(key_prefix)
with grp2a:
st.header("Group 2")
key_prefix = f"{base_key_prefix}|grp2"
with st.expander("Generative Config"):
render_generative_config(key_prefix)
with st.expander("Retrieval Config"):
render_retrieval_config(key_prefix)
grp1b, grp2b = st.columns(2)
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
for post_key_prefix in ["grp1", "grp2"]:
key_prefix = f"{base_key_prefix}|{post_key_prefix}"
if query_submitted:
llm = get_llm(key_prefix)
vs_filter = get_vectorstore_filter(key_prefix)
retriever = vectorstore.as_retriever(
search_kwargs={
"k": SS[f"{key_prefix}|n_ret_docs"],
"filter": vs_filter,
},
)
rag_chain = (
RunnableParallel(
{
"docs": retriever, # list of docs
"query": RunnablePassthrough(), # str
}
)
.assign(context=(lambda x: format_docs(x["docs"])))
.assign(output=prompt | llm)
)
SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{base_key_prefix}|query"])
if f"{key_prefix}|out" in SS:
with sbs_cols[post_key_prefix]:
out_display = SS[f"{key_prefix}|out"]["output"].content
if SS[f"{key_prefix}|response_escape_markdown"]:
out_display = escape_markdown(out_display)
if SS[f"{key_prefix}|response_add_legis_urls"]:
out_display = replace_legis_ids_with_urls(out_display)
with st.container(border=True):
st.write("Response")
st.info(out_display)
with st.container(border=True):
st.write("API Usage")
token_usage = get_token_usage(
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
)
col1, col2, col3 = st.columns(3)
with col1:
st.metric("Input Tokens", token_usage["input_tokens"])
with col2:
st.metric("Output Tokens", token_usage["output_tokens"])
with col3:
st.metric("Cost", f"${token_usage['cost']:.4f}")
with st.expander("Response Metadata"):
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
with st.container(border=True):
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
st.write(
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
)
for legis_id, doc_grp in doc_grps:
render_doc_grp(legis_id, doc_grp)
##################
st.title(":classical_building: LegisQA :classical_building:")
st.header("Chat With Congressional Bills")
with st.sidebar:
render_sidebar()
vectorstore = load_pinecone_vectorstore()
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
[
"query_rag",
"query_rag_sbs",
"guide",
]
)
with query_rag_tab:
render_query_rag_tab()
with query_rag_sbs_tab:
render_query_rag_sbs_tab()
with guide_tab:
render_guide()