Spaces:
Running
Running
""" | |
""" | |
from collections import defaultdict | |
import json | |
import os | |
import re | |
from langchain_core.documents import Document | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.runnables import RunnableParallel | |
from langchain_core.runnables import RunnablePassthrough | |
from langchain_openai import ChatOpenAI | |
from langchain_anthropic import ChatAnthropic | |
from langchain_together import ChatTogether | |
from langchain_google_genai import ChatGoogleGenerativeAI | |
import streamlit as st | |
import utils_mod | |
import doc_format_mod | |
import guide_mod | |
import sidebar_mod | |
import usage_mod | |
import vectorstore_mod | |
st.set_page_config(layout="wide", page_title="LegisQA") | |
os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"] | |
os.environ["LANGCHAIN_TRACING_V2"] = "true" | |
os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"] | |
os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
SS = st.session_state | |
SEED = 292764 | |
CONGRESS_NUMBERS = [113, 114, 115, 116, 117, 118] | |
SPONSOR_PARTIES = ["D", "R", "L", "I"] | |
OPENAI_CHAT_MODELS = { | |
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}}, | |
"gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}}, | |
} | |
ANTHROPIC_CHAT_MODELS = { | |
"claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}}, | |
"claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}}, | |
"claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}}, | |
} | |
TOGETHER_CHAT_MODELS = { | |
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}}, | |
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { | |
"cost": {"pmi": 0.88, "pmo": 0.88} | |
}, | |
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { | |
"cost": {"pmi": 5.00, "pmo": 5.00} | |
}, | |
} | |
GOOGLE_CHAT_MODELS = { | |
"gemini-1.5-flash": {"cost": {"pmi": 0.0, "pmo": 0.0}}, | |
"gemini-1.5-pro": {"cost": {"pmi": 0.0, "pmo": 0.0}}, | |
"gemini-1.5-pro-exp-0801": {"cost": {"pmi": 0.0, "pmo": 0.0}}, | |
} | |
PROVIDER_MODELS = { | |
"OpenAI": OPENAI_CHAT_MODELS, | |
"Anthropic": ANTHROPIC_CHAT_MODELS, | |
"Together": TOGETHER_CHAT_MODELS, | |
"Google": GOOGLE_CHAT_MODELS, | |
} | |
def render_example_queries(): | |
with st.expander("Example Queries"): | |
st.write( | |
""" | |
``` | |
What are the themes around artificial intelligence? | |
``` | |
``` | |
Write a well cited 3 paragraph essay on food insecurity. | |
``` | |
``` | |
Create a table summarizing major climate change ideas with columns legis_id, title, idea. | |
``` | |
``` | |
Write an action plan to keep social security solvent. | |
``` | |
``` | |
Suggest reforms that would benefit the Medicaid program. | |
``` | |
""" | |
) | |
def get_generative_config(key_prefix: str) -> dict: | |
output = {} | |
key = "provider" | |
output[key] = st.selectbox( | |
label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}" | |
) | |
key = "model_name" | |
output[key] = st.selectbox( | |
label=key, | |
options=PROVIDER_MODELS[output["provider"]], | |
key=f"{key_prefix}|{key}", | |
) | |
key = "temperature" | |
output[key] = st.slider( | |
key, | |
min_value=0.0, | |
max_value=2.0, | |
value=0.0, | |
key=f"{key_prefix}|{key}", | |
) | |
key = "max_output_tokens" | |
output[key] = st.slider( | |
key, | |
min_value=1024, | |
max_value=2048, | |
key=f"{key_prefix}|{key}", | |
) | |
key = "top_p" | |
output[key] = st.slider( | |
key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}" | |
) | |
key = "should_escape_markdown" | |
output[key] = st.checkbox( | |
key, | |
value=False, | |
key=f"{key_prefix}|{key}", | |
) | |
key = "should_add_legis_urls" | |
output[key] = st.checkbox( | |
key, | |
value=True, | |
key=f"{key_prefix}|{key}", | |
) | |
return output | |
def get_retrieval_config(key_prefix: str) -> dict: | |
output = {} | |
key = "n_ret_docs" | |
output[key] = st.slider( | |
"Number of chunks to retrieve", | |
min_value=1, | |
max_value=32, | |
value=8, | |
key=f"{key_prefix}|{key}", | |
) | |
key = "filter_legis_id" | |
output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}") | |
key = "filter_bioguide_id" | |
output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}") | |
key = "filter_congress_nums" | |
output[key] = st.multiselect( | |
"Congress Numbers", | |
CONGRESS_NUMBERS, | |
default=CONGRESS_NUMBERS, | |
key=f"{key_prefix}|{key}", | |
) | |
key = "filter_sponsor_parties" | |
output[key] = st.multiselect( | |
"Sponsor Party", | |
SPONSOR_PARTIES, | |
default=SPONSOR_PARTIES, | |
key=f"{key_prefix}|{key}", | |
) | |
return output | |
def get_llm(gen_config: dict): | |
match gen_config["provider"]: | |
case "OpenAI": | |
llm = ChatOpenAI( | |
model=gen_config["model_name"], | |
temperature=gen_config["temperature"], | |
api_key=st.secrets["openai_api_key"], | |
top_p=gen_config["top_p"], | |
seed=SEED, | |
max_tokens=gen_config["max_output_tokens"], | |
) | |
case "Anthropic": | |
llm = ChatAnthropic( | |
model_name=gen_config["model_name"], | |
temperature=gen_config["temperature"], | |
api_key=st.secrets["anthropic_api_key"], | |
top_p=gen_config["top_p"], | |
max_tokens_to_sample=gen_config["max_output_tokens"], | |
) | |
case "Together": | |
llm = ChatTogether( | |
model=gen_config["model_name"], | |
temperature=gen_config["temperature"], | |
max_tokens=gen_config["max_output_tokens"], | |
top_p=gen_config["top_p"], | |
seed=SEED, | |
api_key=st.secrets["together_api_key"], | |
) | |
case "Google": | |
llm = ChatGoogleGenerativeAI( | |
model=gen_config["model_name"], | |
temperature=gen_config["temperature"], | |
api_key=st.secrets["google_api_key"], | |
max_output_tokens=gen_config["max_output_tokens"], | |
top_p=gen_config["top_p"], | |
) | |
case _: | |
raise ValueError() | |
return llm | |
def create_rag_chain(llm, retriever): | |
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. When citing legis_id, use the same format as the excerpts (e.g. "116-hr-125"). If you don't know how to respond, just tell the user. | |
--- | |
Congressional Legislation Excerpts: | |
{context} | |
--- | |
Query: {query}""" | |
prompt = ChatPromptTemplate.from_messages( | |
[ | |
("human", QUERY_RAG_TEMPLATE), | |
] | |
) | |
rag_chain = ( | |
RunnableParallel( | |
{ | |
"docs": retriever, | |
"query": RunnablePassthrough(), | |
} | |
) | |
.assign(context=lambda x: doc_format_mod.format_docs(x["docs"])) | |
.assign(aimessage=prompt | llm) | |
) | |
return rag_chain | |
def process_query(gen_config: dict, ret_config: dict, query: str): | |
vectorstore = vectorstore_mod.load_pinecone_vectorstore() | |
llm = get_llm(gen_config) | |
vs_filter = vectorstore_mod.get_vectorstore_filter(ret_config) | |
retriever = vectorstore.as_retriever( | |
search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter}, | |
) | |
rag_chain = create_rag_chain(llm, retriever) | |
response = rag_chain.invoke(query) | |
return response | |
def render_response( | |
response: dict, | |
model_info: dict, | |
provider: str, | |
should_escape_markdown: bool, | |
should_add_legis_urls: bool, | |
tag: str | None = None, | |
): | |
response_text = response["aimessage"].content | |
if should_escape_markdown: | |
response_text = utils_mod.escape_markdown(response_text) | |
if should_add_legis_urls: | |
response_text = utils_mod.replace_legis_ids_with_urls(response_text) | |
with st.container(border=True): | |
if tag is None: | |
st.write("Response") | |
else: | |
st.write(f"Response ({tag})") | |
st.info(response_text) | |
usage_mod.display_api_usage(response["aimessage"], model_info, provider, tag=tag) | |
doc_format_mod.render_retrieved_chunks(response["docs"], tag=tag) | |
def render_query_rag_tab(): | |
key_prefix = "query_rag" | |
render_example_queries() | |
with st.form(f"{key_prefix}|query_form"): | |
query = st.text_area( | |
"Enter a query that can be answered with congressional legislation:" | |
) | |
cols = st.columns(2) | |
with cols[0]: | |
query_submitted = st.form_submit_button("Submit") | |
with cols[1]: | |
status_placeholder = st.empty() | |
col1, col2 = st.columns(2) | |
with col1: | |
with st.expander("Generative Config"): | |
gen_config = get_generative_config(key_prefix) | |
with col2: | |
with st.expander("Retrieval Config"): | |
ret_config = get_retrieval_config(key_prefix) | |
rkey = f"{key_prefix}|response" | |
if query_submitted: | |
with status_placeholder: | |
with st.spinner("generating response"): | |
SS[rkey] = process_query(gen_config, ret_config, query) | |
if response := SS.get(rkey): | |
model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]] | |
render_response( | |
response, | |
model_info, | |
gen_config["provider"], | |
gen_config["should_escape_markdown"], | |
gen_config["should_add_legis_urls"], | |
) | |
with st.expander("Debug"): | |
st.write(response) | |
def render_query_rag_sbs_tab(): | |
base_key_prefix = "query_rag_sbs" | |
with st.form(f"{base_key_prefix}|query_form"): | |
query = st.text_area( | |
"Enter a query that can be answered with congressional legislation:" | |
) | |
cols = st.columns(2) | |
with cols[0]: | |
query_submitted = st.form_submit_button("Submit") | |
with cols[1]: | |
status_placeholder = st.empty() | |
grp1a, grp2a = st.columns(2) | |
gen_configs = {} | |
ret_configs = {} | |
with grp1a: | |
st.header("Group 1") | |
key_prefix = f"{base_key_prefix}|grp1" | |
with st.expander("Generative Config"): | |
gen_configs["grp1"] = get_generative_config(key_prefix) | |
with st.expander("Retrieval Config"): | |
ret_configs["grp1"] = get_retrieval_config(key_prefix) | |
with grp2a: | |
st.header("Group 2") | |
key_prefix = f"{base_key_prefix}|grp2" | |
with st.expander("Generative Config"): | |
gen_configs["grp2"] = get_generative_config(key_prefix) | |
with st.expander("Retrieval Config"): | |
ret_configs["grp2"] = get_retrieval_config(key_prefix) | |
grp1b, grp2b = st.columns(2) | |
sbs_cols = {"grp1": grp1b, "grp2": grp2b} | |
grp_names = {"grp1": "Group 1", "grp2": "Group 2"} | |
for post_key_prefix in ["grp1", "grp2"]: | |
with sbs_cols[post_key_prefix]: | |
key_prefix = f"{base_key_prefix}|{post_key_prefix}" | |
rkey = f"{key_prefix}|response" | |
if query_submitted: | |
with status_placeholder: | |
with st.spinner( | |
"generating response for {}".format(grp_names[post_key_prefix]) | |
): | |
SS[rkey] = process_query( | |
gen_configs[post_key_prefix], | |
ret_configs[post_key_prefix], | |
query, | |
) | |
if response := SS.get(rkey): | |
model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][ | |
gen_configs[post_key_prefix]["model_name"] | |
] | |
render_response( | |
response, | |
model_info, | |
gen_configs[post_key_prefix]["provider"], | |
gen_configs[post_key_prefix]["should_escape_markdown"], | |
gen_configs[post_key_prefix]["should_add_legis_urls"], | |
tag=grp_names[post_key_prefix], | |
) | |
def main(): | |
st.title(":classical_building: LegisQA :classical_building:") | |
st.header("Query Congressional Bills") | |
with st.sidebar: | |
sidebar_mod.render_sidebar() | |
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs( | |
[ | |
"RAG", | |
"RAG (side-by-side)", | |
"Guide", | |
] | |
) | |
with query_rag_tab: | |
render_query_rag_tab() | |
with query_rag_sbs_tab: | |
render_query_rag_sbs_tab() | |
with guide_tab: | |
guide_mod.render_guide() | |
if __name__ == "__main__": | |
main() | |