hf-legisqa / app.py
gabrielaltay's picture
add gemini
3ead889
raw
history blame
13 kB
"""
"""
from collections import defaultdict
import json
import os
import re
from langchain_core.documents import Document
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnableParallel
from langchain_core.runnables import RunnablePassthrough
from langchain_openai import ChatOpenAI
from langchain_anthropic import ChatAnthropic
from langchain_together import ChatTogether
from langchain_google_genai import ChatGoogleGenerativeAI
import streamlit as st
import utils_mod
import doc_format_mod
import guide_mod
import sidebar_mod
import usage_mod
import vectorstore_mod
st.set_page_config(layout="wide", page_title="LegisQA")
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
os.environ["LANGCHAIN_TRACING_V2"] = "true"
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
os.environ["TOKENIZERS_PARALLELISM"] = "false"
SS = st.session_state
SEED = 292764
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118]
SPONSOR_PARTIES = ["D", "R", "L", "I"]
OPENAI_CHAT_MODELS = {
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
"gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
}
ANTHROPIC_CHAT_MODELS = {
"claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}},
"claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}},
"claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}},
}
TOGETHER_CHAT_MODELS = {
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}},
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
"cost": {"pmi": 0.88, "pmo": 0.88}
},
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
"cost": {"pmi": 5.00, "pmo": 5.00}
},
}
GOOGLE_CHAT_MODELS = {
"gemini-1.5-flash": {"cost": {"pmi": 0.0, "pmo": 0.0}},
"gemini-1.5-pro": {"cost": {"pmi": 0.0, "pmo": 0.0}},
"gemini-1.5-pro-exp-0801": {"cost": {"pmi": 0.0, "pmo": 0.0}},
}
PROVIDER_MODELS = {
"OpenAI": OPENAI_CHAT_MODELS,
"Anthropic": ANTHROPIC_CHAT_MODELS,
"Together": TOGETHER_CHAT_MODELS,
"Google": GOOGLE_CHAT_MODELS,
}
def render_example_queries():
with st.expander("Example Queries"):
st.write(
"""
```
What are the themes around artificial intelligence?
```
```
Write a well cited 3 paragraph essay on food insecurity.
```
```
Create a table summarizing major climate change ideas with columns legis_id, title, idea.
```
```
Write an action plan to keep social security solvent.
```
```
Suggest reforms that would benefit the Medicaid program.
```
"""
)
def get_generative_config(key_prefix: str) -> dict:
output = {}
key = "provider"
output[key] = st.selectbox(
label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
)
key = "model_name"
output[key] = st.selectbox(
label=key,
options=PROVIDER_MODELS[output["provider"]],
key=f"{key_prefix}|{key}",
)
key = "temperature"
output[key] = st.slider(
key,
min_value=0.0,
max_value=2.0,
value=0.0,
key=f"{key_prefix}|{key}",
)
key = "max_output_tokens"
output[key] = st.slider(
key,
min_value=1024,
max_value=2048,
key=f"{key_prefix}|{key}",
)
key = "top_p"
output[key] = st.slider(
key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}"
)
key = "should_escape_markdown"
output[key] = st.checkbox(
key,
value=False,
key=f"{key_prefix}|{key}",
)
key = "should_add_legis_urls"
output[key] = st.checkbox(
key,
value=True,
key=f"{key_prefix}|{key}",
)
return output
def get_retrieval_config(key_prefix: str) -> dict:
output = {}
key = "n_ret_docs"
output[key] = st.slider(
"Number of chunks to retrieve",
min_value=1,
max_value=32,
value=8,
key=f"{key_prefix}|{key}",
)
key = "filter_legis_id"
output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")
key = "filter_bioguide_id"
output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")
key = "filter_congress_nums"
output[key] = st.multiselect(
"Congress Numbers",
CONGRESS_NUMBERS,
default=CONGRESS_NUMBERS,
key=f"{key_prefix}|{key}",
)
key = "filter_sponsor_parties"
output[key] = st.multiselect(
"Sponsor Party",
SPONSOR_PARTIES,
default=SPONSOR_PARTIES,
key=f"{key_prefix}|{key}",
)
return output
def get_llm(gen_config: dict):
match gen_config["provider"]:
case "OpenAI":
llm = ChatOpenAI(
model=gen_config["model_name"],
temperature=gen_config["temperature"],
api_key=st.secrets["openai_api_key"],
top_p=gen_config["top_p"],
seed=SEED,
max_tokens=gen_config["max_output_tokens"],
)
case "Anthropic":
llm = ChatAnthropic(
model_name=gen_config["model_name"],
temperature=gen_config["temperature"],
api_key=st.secrets["anthropic_api_key"],
top_p=gen_config["top_p"],
max_tokens_to_sample=gen_config["max_output_tokens"],
)
case "Together":
llm = ChatTogether(
model=gen_config["model_name"],
temperature=gen_config["temperature"],
max_tokens=gen_config["max_output_tokens"],
top_p=gen_config["top_p"],
seed=SEED,
api_key=st.secrets["together_api_key"],
)
case "Google":
llm = ChatGoogleGenerativeAI(
model=gen_config["model_name"],
temperature=gen_config["temperature"],
api_key=st.secrets["google_api_key"],
max_output_tokens=gen_config["max_output_tokens"],
top_p=gen_config["top_p"],
)
case _:
raise ValueError()
return llm
def create_rag_chain(llm, retriever):
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user.
---
Congressional Legislation Excerpts:
{context}
---
Query: {query}"""
prompt = ChatPromptTemplate.from_messages(
[
("human", QUERY_RAG_TEMPLATE),
]
)
rag_chain = (
RunnableParallel(
{
"docs": retriever,
"query": RunnablePassthrough(),
}
)
.assign(context=lambda x: doc_format_mod.format_docs(x["docs"]))
.assign(aimessage=prompt | llm)
)
return rag_chain
def process_query(gen_config: dict, ret_config: dict, query: str):
vectorstore = vectorstore_mod.load_pinecone_vectorstore()
llm = get_llm(gen_config)
vs_filter = vectorstore_mod.get_vectorstore_filter(ret_config)
retriever = vectorstore.as_retriever(
search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
)
rag_chain = create_rag_chain(llm, retriever)
response = rag_chain.invoke(query)
return response
def render_response(
response: dict,
model_info: dict,
provider: str,
should_escape_markdown: bool,
should_add_legis_urls: bool,
tag: str | None = None,
):
response_text = response["aimessage"].content
if should_escape_markdown:
response_text = utils_mod.escape_markdown(response_text)
if should_add_legis_urls:
response_text = utils_mod.replace_legis_ids_with_urls(response_text)
with st.container(border=True):
if tag is None:
st.write("Response")
else:
st.write(f"Response ({tag})")
st.info(response_text)
usage_mod.display_api_usage(
response["aimessage"].response_metadata, model_info, provider, tag=tag
)
doc_format_mod.render_retrieved_chunks(response["docs"], tag=tag)
def render_query_rag_tab():
key_prefix = "query_rag"
render_example_queries()
with st.form(f"{key_prefix}|query_form"):
query = st.text_area(
"Enter a query that can be answered with congressional legislation:"
)
cols = st.columns(2)
with cols[0]:
query_submitted = st.form_submit_button("Submit")
with cols[1]:
status_placeholder = st.empty()
col1, col2 = st.columns(2)
with col1:
with st.expander("Generative Config"):
gen_config = get_generative_config(key_prefix)
with col2:
with st.expander("Retrieval Config"):
ret_config = get_retrieval_config(key_prefix)
rkey = f"{key_prefix}|response"
if query_submitted:
with status_placeholder:
with st.spinner("generating response"):
SS[rkey] = process_query(gen_config, ret_config, query)
if response := SS.get(rkey):
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
render_response(
response,
model_info,
gen_config["provider"],
gen_config["should_escape_markdown"],
gen_config["should_add_legis_urls"],
)
with st.expander("Debug"):
st.write(response)
def render_query_rag_sbs_tab():
base_key_prefix = "query_rag_sbs"
with st.form(f"{base_key_prefix}|query_form"):
query = st.text_area(
"Enter a query that can be answered with congressional legislation:"
)
cols = st.columns(2)
with cols[0]:
query_submitted = st.form_submit_button("Submit")
with cols[1]:
status_placeholder = st.empty()
grp1a, grp2a = st.columns(2)
gen_configs = {}
ret_configs = {}
with grp1a:
st.header("Group 1")
key_prefix = f"{base_key_prefix}|grp1"
with st.expander("Generative Config"):
gen_configs["grp1"] = get_generative_config(key_prefix)
with st.expander("Retrieval Config"):
ret_configs["grp1"] = get_retrieval_config(key_prefix)
with grp2a:
st.header("Group 2")
key_prefix = f"{base_key_prefix}|grp2"
with st.expander("Generative Config"):
gen_configs["grp2"] = get_generative_config(key_prefix)
with st.expander("Retrieval Config"):
ret_configs["grp2"] = get_retrieval_config(key_prefix)
grp1b, grp2b = st.columns(2)
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
grp_names = {"grp1": "Group 1", "grp2": "Group 2"}
for post_key_prefix in ["grp1", "grp2"]:
with sbs_cols[post_key_prefix]:
key_prefix = f"{base_key_prefix}|{post_key_prefix}"
rkey = f"{key_prefix}|response"
if query_submitted:
with status_placeholder:
with st.spinner(
"generating response for {}".format(grp_names[post_key_prefix])
):
SS[rkey] = process_query(
gen_configs[post_key_prefix],
ret_configs[post_key_prefix],
query,
)
if response := SS.get(rkey):
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
gen_configs[post_key_prefix]["model_name"]
]
render_response(
response,
model_info,
gen_configs[post_key_prefix]["provider"],
gen_configs[post_key_prefix]["should_escape_markdown"],
gen_configs[post_key_prefix]["should_add_legis_urls"],
tag=grp_names[post_key_prefix],
)
def main():
st.title(":classical_building: LegisQA :classical_building:")
st.header("Query Congressional Bills")
with st.sidebar:
sidebar_mod.render_sidebar()
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
[
"RAG",
"RAG (side-by-side)",
"Guide",
]
)
with query_rag_tab:
render_query_rag_tab()
with query_rag_sbs_tab:
render_query_rag_sbs_tab()
with guide_tab:
guide_mod.render_guide()
if __name__ == "__main__":
main()