gabrielaltay commited on
Commit
42554ac
1 Parent(s): 2b72dfd
Files changed (2) hide show
  1. app.py +254 -314
  2. usage.py +72 -0
app.py CHANGED
@@ -6,36 +6,22 @@ import json
6
  import os
7
  import re
8
 
9
- from langchain.tools.retriever import create_retriever_tool
10
- from langchain.agents import AgentExecutor
11
- from langchain.agents import create_openai_tools_agent
12
- from langchain.agents.format_scratchpad.openai_tools import (
13
- format_to_openai_tool_messages,
14
- )
15
- from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
16
  from langchain_core.documents import Document
17
- from langchain_core.prompts import PromptTemplate
18
  from langchain_core.prompts import ChatPromptTemplate
19
- from langchain_core.prompts import MessagesPlaceholder
20
- from langchain_core.messages import AIMessage
21
- from langchain_core.messages import HumanMessage
22
  from langchain_core.runnables import RunnableParallel
23
  from langchain_core.runnables import RunnablePassthrough
24
- from langchain_core.output_parsers import StrOutputParser
25
- from langchain_community.callbacks import get_openai_callback
26
- from langchain_community.callbacks import StreamlitCallbackHandler
27
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
28
  from langchain_community.vectorstores.utils import DistanceStrategy
29
  from langchain_openai import ChatOpenAI
30
  from langchain_anthropic import ChatAnthropic
31
  from langchain_together import ChatTogether
32
  from langchain_pinecone import PineconeVectorStore
33
- from pinecone import Pinecone
34
  import streamlit as st
35
 
 
36
 
37
- st.set_page_config(layout="wide", page_title="LegisQA")
38
 
 
39
  os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
40
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
41
  os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
@@ -70,7 +56,9 @@ TOGETHER_CHAT_MODELS = {
70
  "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
71
  "cost": {"pmi": 0.88, "pmo": 0.88}
72
  },
73
- "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"cost": {"pmi": 5.00, "pmo": 5.00}},
 
 
74
  }
75
 
76
  PROVIDER_MODELS = {
@@ -128,6 +116,12 @@ def render_outreach_links():
128
  st.subheader(f":pancakes: Inference [together.ai]({together_url})")
129
 
130
 
 
 
 
 
 
 
131
  def group_docs(docs) -> list[tuple[str, list[Document]]]:
132
  doc_grps = defaultdict(list)
133
 
@@ -151,7 +145,7 @@ def group_docs(docs) -> list[tuple[str, list[Document]]]:
151
  return doc_grps
152
 
153
 
154
- def format_docs(docs):
155
  """JSON grouped"""
156
 
157
  doc_grps = group_docs(docs)
@@ -168,26 +162,26 @@ def format_docs(docs):
168
  return json.dumps(out, indent=4)
169
 
170
 
171
- def escape_markdown(text):
172
  MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
173
  for char in MD_SPECIAL_CHARS:
174
  text = text.replace(char, "\\" + char)
175
  return text
176
 
177
 
178
- def get_vectorstore_filter(key_prefix: str):
179
  vs_filter = {}
180
- if SS[f"{key_prefix}|filter_legis_id"] != "":
181
- vs_filter["legis_id"] = SS[f"{key_prefix}|filter_legis_id"]
182
- if SS[f"{key_prefix}|filter_bioguide_id"] != "":
183
- vs_filter["sponsor_bioguide_id"] = SS[f"{key_prefix}|filter_bioguide_id"]
184
  vs_filter = {
185
  **vs_filter,
186
- "congress_num": {"$in": SS[f"{key_prefix}|filter_congress_nums"]},
187
  }
188
  vs_filter = {
189
  **vs_filter,
190
- "sponsor_party": {"$in": SS[f"{key_prefix}|filter_sponsor_parties"]},
191
  }
192
  return vs_filter
193
 
@@ -288,163 +282,137 @@ Suggest reforms that would benefit the Medicaid program.
288
  )
289
 
290
 
291
- def render_generative_config(key_prefix: str):
292
- st.selectbox(
293
- label="provider", options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|provider"
 
 
 
294
  )
295
- st.selectbox(
296
- label="model name",
297
- options=PROVIDER_MODELS[SS[f"{key_prefix}|provider"]],
298
- key=f"{key_prefix}|model_name",
 
 
299
  )
300
- st.slider(
301
- "temperature",
 
 
302
  min_value=0.0,
303
  max_value=2.0,
304
- value=0.01,
305
- key=f"{key_prefix}|temperature",
306
  )
307
- st.slider(
308
- "max_output_tokens",
 
 
309
  min_value=1024,
310
  max_value=2048,
311
- key=f"{key_prefix}|max_output_tokens",
312
  )
313
- st.slider(
314
- "top_p", min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|top_p"
 
 
315
  )
316
- st.checkbox(
317
- "escape markdown in answer", key=f"{key_prefix}|response_escape_markdown"
 
 
 
 
318
  )
319
- st.checkbox(
320
- "add legis urls in answer",
 
 
321
  value=True,
322
- key=f"{key_prefix}|response_add_legis_urls",
323
  )
324
 
 
 
 
 
 
325
 
326
- def render_retrieval_config(key_prefix: str):
327
- st.slider(
328
  "Number of chunks to retrieve",
329
  min_value=1,
330
  max_value=32,
331
  value=8,
332
- key=f"{key_prefix}|n_ret_docs",
333
  )
334
- st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|filter_legis_id")
335
- st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|filter_bioguide_id")
336
- st.multiselect(
 
 
 
 
 
 
337
  "Congress Numbers",
338
  CONGRESS_NUMBERS,
339
  default=CONGRESS_NUMBERS,
340
- key=f"{key_prefix}|filter_congress_nums",
341
  )
342
- st.multiselect(
 
 
343
  "Sponsor Party",
344
  SPONSOR_PARTIES,
345
  default=SPONSOR_PARTIES,
346
- key=f"{key_prefix}|filter_sponsor_parties",
347
  )
348
 
 
349
 
350
- def get_llm(key_prefix: str):
351
-
352
- if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
353
- llm = ChatOpenAI(
354
- model=SS[f"{key_prefix}|model_name"],
355
- temperature=SS[f"{key_prefix}|temperature"],
356
- api_key=st.secrets["openai_api_key"],
357
- top_p=SS[f"{key_prefix}|top_p"],
358
- seed=SEED,
359
- max_tokens=SS[f"{key_prefix}|max_output_tokens"],
360
- )
361
- elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
362
- llm = ChatAnthropic(
363
- model_name=SS[f"{key_prefix}|model_name"],
364
- temperature=SS[f"{key_prefix}|temperature"],
365
- api_key=st.secrets["anthropic_api_key"],
366
- top_p=SS[f"{key_prefix}|top_p"],
367
- max_tokens_to_sample=SS[f"{key_prefix}|max_output_tokens"],
368
- )
369
- elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
370
- llm = ChatTogether(
371
- model=SS[f"{key_prefix}|model_name"],
372
- temperature=SS[f"{key_prefix}|temperature"],
373
- max_tokens=SS[f"{key_prefix}|max_output_tokens"],
374
- top_p=SS[f"{key_prefix}|top_p"],
375
- seed=SEED,
376
- api_key=st.secrets["together_api_key"],
377
- )
378
- else:
379
- raise ValueError()
380
 
381
- return llm
382
 
 
383
 
384
- def get_token_usage(key_prefix: str, metadata: dict):
385
- if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
386
- model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
387
- return get_openai_token_usage(metadata, model_info)
388
- elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
389
- model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
390
- return get_anthropic_token_usage(metadata, model_info)
391
- elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
392
- model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
393
- return get_together_token_usage(metadata, model_info)
394
- else:
395
- raise ValueError()
396
-
397
-
398
- def get_openai_token_usage(metadata: dict, model_info: dict):
399
- input_tokens = metadata["token_usage"]["prompt_tokens"]
400
- output_tokens = metadata["token_usage"]["completion_tokens"]
401
- cost = (
402
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
403
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
404
- )
405
- return {
406
- "input_tokens": input_tokens,
407
- "output_tokens": output_tokens,
408
- "cost": cost,
409
- }
410
-
411
-
412
- def get_anthropic_token_usage(metadata: dict, model_info: dict):
413
- input_tokens = metadata["usage"]["input_tokens"]
414
- output_tokens = metadata["usage"]["output_tokens"]
415
- cost = (
416
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
417
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
418
- )
419
- return {
420
- "input_tokens": input_tokens,
421
- "output_tokens": output_tokens,
422
- "cost": cost,
423
- }
424
-
425
-
426
- def get_together_token_usage(metadata: dict, model_info: dict):
427
- input_tokens = metadata["token_usage"]["prompt_tokens"]
428
- output_tokens = metadata["token_usage"]["completion_tokens"]
429
- cost = (
430
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
431
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
432
- )
433
- return {
434
- "input_tokens": input_tokens,
435
- "output_tokens": output_tokens,
436
- "cost": cost,
437
- }
438
 
 
 
 
 
 
 
 
 
439
 
440
- def render_sidebar():
 
 
 
 
 
 
 
 
441
 
442
- with st.container(border=True):
443
- render_outreach_links()
444
 
 
445
 
446
- def render_query_rag_tab():
447
 
 
448
  QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
449
 
450
  ---
@@ -463,219 +431,191 @@ Query: {query}"""
463
  ]
464
  )
465
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
466
  key_prefix = "query_rag"
467
  render_example_queries()
468
 
469
  with st.form(f"{key_prefix}|query_form"):
470
- st.text_area(
471
- "Enter a query that can be answered with congressional legislation:",
472
- key=f"{key_prefix}|query",
473
  )
474
- query_submitted = st.form_submit_button("Submit")
 
 
 
 
475
 
476
  col1, col2 = st.columns(2)
477
  with col1:
478
  with st.expander("Generative Config"):
479
- render_generative_config(key_prefix)
480
  with col2:
481
  with st.expander("Retrieval Config"):
482
- render_retrieval_config(key_prefix)
483
 
 
484
  if query_submitted:
485
-
486
- llm = get_llm(key_prefix)
487
- vs_filter = get_vectorstore_filter(key_prefix)
488
- retriever = vectorstore.as_retriever(
489
- search_kwargs={"k": SS[f"{key_prefix}|n_ret_docs"], "filter": vs_filter},
 
 
 
 
 
 
 
490
  )
491
 
492
- rag_chain = (
493
- RunnableParallel(
494
- {
495
- "docs": retriever, # list of docs
496
- "query": RunnablePassthrough(), # str
497
- }
498
- )
499
- .assign(context=(lambda x: format_docs(x["docs"])))
500
- .assign(output=prompt | llm)
501
- )
502
-
503
- SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{key_prefix}|query"])
504
-
505
- if f"{key_prefix}|out" in SS:
506
-
507
- out_display = SS[f"{key_prefix}|out"]["output"].content
508
- if SS[f"{key_prefix}|response_escape_markdown"]:
509
- out_display = escape_markdown(out_display)
510
- if SS[f"{key_prefix}|response_add_legis_urls"]:
511
- out_display = replace_legis_ids_with_urls(out_display)
512
- with st.container(border=True):
513
- st.write("Response")
514
- st.info(out_display)
515
-
516
- with st.container(border=True):
517
- st.write("API Usage")
518
- token_usage = get_token_usage(
519
- key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
520
- )
521
- col1, col2, col3 = st.columns(3)
522
- with col1:
523
- st.metric("Input Tokens", token_usage["input_tokens"])
524
- with col2:
525
- st.metric("Output Tokens", token_usage["output_tokens"])
526
- with col3:
527
- st.metric("Cost", f"${token_usage['cost']:.4f}")
528
- with st.expander("Response Metadata"):
529
- st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
530
-
531
- with st.container(border=True):
532
- doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
533
- st.write(
534
- "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
535
- )
536
- for legis_id, doc_grp in doc_grps:
537
- render_doc_grp(legis_id, doc_grp)
538
-
539
  with st.expander("Debug"):
540
- st.write(SS[f"{key_prefix}|out"])
541
 
542
 
543
  def render_query_rag_sbs_tab():
544
-
545
- QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
546
-
547
- ---
548
-
549
- Congressional Legislation Excerpts:
550
-
551
- {context}
552
-
553
- ---
554
-
555
- Query: {query}"""
556
-
557
  base_key_prefix = "query_rag_sbs"
558
 
559
- prompt = ChatPromptTemplate.from_messages(
560
- [
561
- ("human", QUERY_RAG_TEMPLATE),
562
- ]
563
- )
564
-
565
  with st.form(f"{base_key_prefix}|query_form"):
566
- st.text_area(
567
- "Enter a query that can be answered with congressional legislation:",
568
- key=f"{base_key_prefix}|query",
569
  )
570
- query_submitted = st.form_submit_button("Submit")
 
 
 
 
571
 
572
  grp1a, grp2a = st.columns(2)
573
 
 
 
574
  with grp1a:
575
  st.header("Group 1")
576
  key_prefix = f"{base_key_prefix}|grp1"
577
  with st.expander("Generative Config"):
578
- render_generative_config(key_prefix)
579
  with st.expander("Retrieval Config"):
580
- render_retrieval_config(key_prefix)
581
 
582
  with grp2a:
583
  st.header("Group 2")
584
  key_prefix = f"{base_key_prefix}|grp2"
585
  with st.expander("Generative Config"):
586
- render_generative_config(key_prefix)
587
  with st.expander("Retrieval Config"):
588
- render_retrieval_config(key_prefix)
589
 
590
  grp1b, grp2b = st.columns(2)
591
  sbs_cols = {"grp1": grp1b, "grp2": grp2b}
 
592
 
593
  for post_key_prefix in ["grp1", "grp2"]:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
594
 
595
- key_prefix = f"{base_key_prefix}|{post_key_prefix}"
596
 
597
- if query_submitted:
598
- llm = get_llm(key_prefix)
599
- vs_filter = get_vectorstore_filter(key_prefix)
600
- retriever = vectorstore.as_retriever(
601
- search_kwargs={
602
- "k": SS[f"{key_prefix}|n_ret_docs"],
603
- "filter": vs_filter,
604
- },
605
- )
606
- rag_chain = (
607
- RunnableParallel(
608
- {
609
- "docs": retriever, # list of docs
610
- "query": RunnablePassthrough(), # str
611
- }
612
- )
613
- .assign(context=(lambda x: format_docs(x["docs"])))
614
- .assign(output=prompt | llm)
615
- )
616
- SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{base_key_prefix}|query"])
617
-
618
- if f"{key_prefix}|out" in SS:
619
- with sbs_cols[post_key_prefix]:
620
- out_display = SS[f"{key_prefix}|out"]["output"].content
621
- if SS[f"{key_prefix}|response_escape_markdown"]:
622
- out_display = escape_markdown(out_display)
623
- if SS[f"{key_prefix}|response_add_legis_urls"]:
624
- out_display = replace_legis_ids_with_urls(out_display)
625
- with st.container(border=True):
626
- st.write("Response")
627
- st.info(out_display)
628
-
629
- with st.container(border=True):
630
- st.write("API Usage")
631
- token_usage = get_token_usage(
632
- key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
633
- )
634
- col1, col2, col3 = st.columns(3)
635
- with col1:
636
- st.metric("Input Tokens", token_usage["input_tokens"])
637
- with col2:
638
- st.metric("Output Tokens", token_usage["output_tokens"])
639
- with col3:
640
- st.metric("Cost", f"${token_usage['cost']:.4f}")
641
- with st.expander("Response Metadata"):
642
- st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
643
-
644
- with st.container(border=True):
645
- doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
646
- st.write(
647
- "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
648
- )
649
- for legis_id, doc_grp in doc_grps:
650
- render_doc_grp(legis_id, doc_grp)
651
-
652
-
653
- ##################
654
-
655
-
656
- st.title(":classical_building: LegisQA :classical_building:")
657
- st.header("Chat With Congressional Bills")
658
-
659
-
660
- with st.sidebar:
661
- render_sidebar()
662
-
663
-
664
- vectorstore = load_pinecone_vectorstore()
665
-
666
- query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
667
- [
668
- "RAG",
669
- "RAG (side-by-side)",
670
- "Guide",
671
- ]
672
- )
673
 
674
- with query_rag_tab:
675
- render_query_rag_tab()
676
 
677
- with query_rag_sbs_tab:
678
- render_query_rag_sbs_tab()
679
 
680
- with guide_tab:
681
- render_guide()
 
6
  import os
7
  import re
8
 
 
 
 
 
 
 
 
9
  from langchain_core.documents import Document
 
10
  from langchain_core.prompts import ChatPromptTemplate
 
 
 
11
  from langchain_core.runnables import RunnableParallel
12
  from langchain_core.runnables import RunnablePassthrough
 
 
 
13
  from langchain_community.embeddings import HuggingFaceBgeEmbeddings
14
  from langchain_community.vectorstores.utils import DistanceStrategy
15
  from langchain_openai import ChatOpenAI
16
  from langchain_anthropic import ChatAnthropic
17
  from langchain_together import ChatTogether
18
  from langchain_pinecone import PineconeVectorStore
 
19
  import streamlit as st
20
 
21
+ import usage
22
 
 
23
 
24
+ st.set_page_config(layout="wide", page_title="LegisQA")
25
  os.environ["LANGCHAIN_API_KEY"] = st.secrets["langchain_api_key"]
26
  os.environ["LANGCHAIN_TRACING_V2"] = "true"
27
  os.environ["LANGCHAIN_PROJECT"] = st.secrets["langchain_project"]
 
56
  "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
57
  "cost": {"pmi": 0.88, "pmo": 0.88}
58
  },
59
+ "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {
60
+ "cost": {"pmi": 5.00, "pmo": 5.00}
61
+ },
62
  }
63
 
64
  PROVIDER_MODELS = {
 
116
  st.subheader(f":pancakes: Inference [together.ai]({together_url})")
117
 
118
 
119
+ def render_sidebar():
120
+
121
+ with st.container(border=True):
122
+ render_outreach_links()
123
+
124
+
125
  def group_docs(docs) -> list[tuple[str, list[Document]]]:
126
  doc_grps = defaultdict(list)
127
 
 
145
  return doc_grps
146
 
147
 
148
+ def format_docs(docs: list[Document]) -> str:
149
  """JSON grouped"""
150
 
151
  doc_grps = group_docs(docs)
 
162
  return json.dumps(out, indent=4)
163
 
164
 
165
+ def escape_markdown(text: str) -> str:
166
  MD_SPECIAL_CHARS = r"\`*_{}[]()#+-.!$"
167
  for char in MD_SPECIAL_CHARS:
168
  text = text.replace(char, "\\" + char)
169
  return text
170
 
171
 
172
+ def get_vectorstore_filter(ret_config: dict) -> dict:
173
  vs_filter = {}
174
+ if ret_config["filter_legis_id"] != "":
175
+ vs_filter["legis_id"] = ret_config["filter_legis_id"]
176
+ if ret_config["filter_bioguide_id"] != "":
177
+ vs_filter["sponsor_bioguide_id"] = ret_config["filter_bioguide_id"]
178
  vs_filter = {
179
  **vs_filter,
180
+ "congress_num": {"$in": ret_config["filter_congress_nums"]},
181
  }
182
  vs_filter = {
183
  **vs_filter,
184
+ "sponsor_party": {"$in": ret_config["filter_sponsor_parties"]},
185
  }
186
  return vs_filter
187
 
 
282
  )
283
 
284
 
285
+ def get_generative_config(key_prefix: str) -> dict:
286
+ output = {}
287
+
288
+ key = "provider"
289
+ output[key] = st.selectbox(
290
+ label=key, options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|{key}"
291
  )
292
+
293
+ key = "model_name"
294
+ output[key] = st.selectbox(
295
+ label=key,
296
+ options=PROVIDER_MODELS[output["provider"]],
297
+ key=f"{key_prefix}|{key}",
298
  )
299
+
300
+ key = "temperature"
301
+ output[key] = st.slider(
302
+ key,
303
  min_value=0.0,
304
  max_value=2.0,
305
+ value=0.0,
306
+ key=f"{key_prefix}|{key}",
307
  )
308
+
309
+ key = "max_output_tokens"
310
+ output[key] = st.slider(
311
+ key,
312
  min_value=1024,
313
  max_value=2048,
314
+ key=f"{key_prefix}|{key}",
315
  )
316
+
317
+ key = "top_p"
318
+ output[key] = st.slider(
319
+ key, min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|{key}"
320
  )
321
+
322
+ key = "should_escape_markdown"
323
+ output[key] = st.checkbox(
324
+ key,
325
+ value=False,
326
+ key=f"{key_prefix}|{key}",
327
  )
328
+
329
+ key = "should_add_legis_urls"
330
+ output[key] = st.checkbox(
331
+ key,
332
  value=True,
333
+ key=f"{key_prefix}|{key}",
334
  )
335
 
336
+ return output
337
+
338
+
339
+ def get_retrieval_config(key_prefix: str) -> dict:
340
+ output = {}
341
 
342
+ key = "n_ret_docs"
343
+ output[key] = st.slider(
344
  "Number of chunks to retrieve",
345
  min_value=1,
346
  max_value=32,
347
  value=8,
348
+ key=f"{key_prefix}|{key}",
349
  )
350
+
351
+ key = "filter_legis_id"
352
+ output[key] = st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|{key}")
353
+
354
+ key = "filter_bioguide_id"
355
+ output[key] = st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|{key}")
356
+
357
+ key = "filter_congress_nums"
358
+ output[key] = st.multiselect(
359
  "Congress Numbers",
360
  CONGRESS_NUMBERS,
361
  default=CONGRESS_NUMBERS,
362
+ key=f"{key_prefix}|{key}",
363
  )
364
+
365
+ key = "filter_sponsor_parties"
366
+ output[key] = st.multiselect(
367
  "Sponsor Party",
368
  SPONSOR_PARTIES,
369
  default=SPONSOR_PARTIES,
370
+ key=f"{key_prefix}|{key}",
371
  )
372
 
373
+ return output
374
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
375
 
376
+ def get_llm(gen_config: dict):
377
 
378
+ match gen_config["provider"]:
379
 
380
+ case "OpenAI":
381
+ llm = ChatOpenAI(
382
+ model=gen_config["model_name"],
383
+ temperature=gen_config["temperature"],
384
+ api_key=st.secrets["openai_api_key"],
385
+ top_p=gen_config["top_p"],
386
+ seed=SEED,
387
+ max_tokens=gen_config["max_output_tokens"],
388
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
 
390
+ case "Anthropic":
391
+ llm = ChatAnthropic(
392
+ model_name=gen_config["model_name"],
393
+ temperature=gen_config["temperature"],
394
+ api_key=st.secrets["anthropic_api_key"],
395
+ top_p=gen_config["top_p"],
396
+ max_tokens_to_sample=gen_config["max_output_tokens"],
397
+ )
398
 
399
+ case "Together":
400
+ llm = ChatTogether(
401
+ model=gen_config["model_name"],
402
+ temperature=gen_config["temperature"],
403
+ max_tokens=gen_config["max_output_tokens"],
404
+ top_p=gen_config["top_p"],
405
+ seed=SEED,
406
+ api_key=st.secrets["together_api_key"],
407
+ )
408
 
409
+ case _:
410
+ raise ValueError()
411
 
412
+ return llm
413
 
 
414
 
415
+ def create_rag_chain(llm, retriever):
416
  QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
417
 
418
  ---
 
431
  ]
432
  )
433
 
434
+ rag_chain = (
435
+ RunnableParallel(
436
+ {
437
+ "docs": retriever,
438
+ "query": RunnablePassthrough(),
439
+ }
440
+ )
441
+ .assign(context=lambda x: format_docs(x["docs"]))
442
+ .assign(aimessage=prompt | llm)
443
+ )
444
+
445
+ return rag_chain
446
+
447
+
448
+ def process_query(gen_config: dict, ret_config: dict, query: str):
449
+ vectorstore = load_pinecone_vectorstore()
450
+ llm = get_llm(gen_config)
451
+ vs_filter = get_vectorstore_filter(ret_config)
452
+ retriever = vectorstore.as_retriever(
453
+ search_kwargs={"k": ret_config["n_ret_docs"], "filter": vs_filter},
454
+ )
455
+ rag_chain = create_rag_chain(llm, retriever)
456
+ response = rag_chain.invoke(query)
457
+ return response
458
+
459
+
460
+ def display_retrieved_chunks(response):
461
+ with st.container(border=True):
462
+ doc_grps = group_docs(response["docs"])
463
+ st.write(
464
+ "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
465
+ )
466
+ for legis_id, doc_grp in doc_grps:
467
+ render_doc_grp(legis_id, doc_grp)
468
+
469
+
470
+ def display_response(
471
+ response, model_info, provider, should_escape_markdown, should_add_legis_urls
472
+ ):
473
+ out_display = response["aimessage"].content
474
+ if should_escape_markdown:
475
+ out_display = escape_markdown(out_display)
476
+ if should_add_legis_urls:
477
+ out_display = replace_legis_ids_with_urls(out_display)
478
+
479
+ with st.container(border=True):
480
+ st.write("Response")
481
+ st.info(out_display)
482
+
483
+ usage.display_api_usage(response, model_info, provider)
484
+ display_retrieved_chunks(response)
485
+
486
+
487
+ def render_query_rag_tab():
488
  key_prefix = "query_rag"
489
  render_example_queries()
490
 
491
  with st.form(f"{key_prefix}|query_form"):
492
+ query = st.text_area(
493
+ "Enter a query that can be answered with congressional legislation:"
 
494
  )
495
+ cols = st.columns(2)
496
+ with cols[0]:
497
+ query_submitted = st.form_submit_button("Submit")
498
+ with cols[1]:
499
+ status_placeholder = st.empty()
500
 
501
  col1, col2 = st.columns(2)
502
  with col1:
503
  with st.expander("Generative Config"):
504
+ gen_config = get_generative_config(key_prefix)
505
  with col2:
506
  with st.expander("Retrieval Config"):
507
+ ret_config = get_retrieval_config(key_prefix)
508
 
509
+ rkey = f"{key_prefix}|response"
510
  if query_submitted:
511
+ with status_placeholder:
512
+ with st.spinner("generating response"):
513
+ SS[rkey] = process_query(gen_config, ret_config, query)
514
+
515
+ if response := SS.get(rkey):
516
+ model_info = PROVIDER_MODELS[gen_config["provider"]][gen_config["model_name"]]
517
+ display_response(
518
+ response,
519
+ model_info,
520
+ gen_config["provider"],
521
+ gen_config["should_escape_markdown"],
522
+ gen_config["should_add_legis_urls"],
523
  )
524
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
525
  with st.expander("Debug"):
526
+ st.write(response)
527
 
528
 
529
  def render_query_rag_sbs_tab():
 
 
 
 
 
 
 
 
 
 
 
 
 
530
  base_key_prefix = "query_rag_sbs"
531
 
 
 
 
 
 
 
532
  with st.form(f"{base_key_prefix}|query_form"):
533
+ query = st.text_area(
534
+ "Enter a query that can be answered with congressional legislation:"
 
535
  )
536
+ cols = st.columns(2)
537
+ with cols[0]:
538
+ query_submitted = st.form_submit_button("Submit")
539
+ with cols[1]:
540
+ status_placeholder = st.empty()
541
 
542
  grp1a, grp2a = st.columns(2)
543
 
544
+ gen_configs = {}
545
+ ret_configs = {}
546
  with grp1a:
547
  st.header("Group 1")
548
  key_prefix = f"{base_key_prefix}|grp1"
549
  with st.expander("Generative Config"):
550
+ gen_configs["grp1"] = get_generative_config(key_prefix)
551
  with st.expander("Retrieval Config"):
552
+ ret_configs["grp1"] = get_retrieval_config(key_prefix)
553
 
554
  with grp2a:
555
  st.header("Group 2")
556
  key_prefix = f"{base_key_prefix}|grp2"
557
  with st.expander("Generative Config"):
558
+ gen_configs["grp2"] = get_generative_config(key_prefix)
559
  with st.expander("Retrieval Config"):
560
+ ret_configs["grp2"] = get_retrieval_config(key_prefix)
561
 
562
  grp1b, grp2b = st.columns(2)
563
  sbs_cols = {"grp1": grp1b, "grp2": grp2b}
564
+ grp_names = {"grp1": "Group 1", "grp2": "Group 2"}
565
 
566
  for post_key_prefix in ["grp1", "grp2"]:
567
+ with sbs_cols[post_key_prefix]:
568
+ key_prefix = f"{base_key_prefix}|{post_key_prefix}"
569
+ rkey = f"{key_prefix}|response"
570
+ if query_submitted:
571
+ with status_placeholder:
572
+ with st.spinner(
573
+ "generating response for {}".format(grp_names[post_key_prefix])
574
+ ):
575
+ SS[rkey] = process_query(
576
+ gen_configs[post_key_prefix],
577
+ ret_configs[post_key_prefix],
578
+ query,
579
+ )
580
+
581
+ if response := SS.get(rkey):
582
+ model_info = PROVIDER_MODELS[gen_configs[post_key_prefix]["provider"]][
583
+ gen_configs[post_key_prefix]["model_name"]
584
+ ]
585
+ display_response(
586
+ response,
587
+ model_info,
588
+ gen_configs[post_key_prefix]["provider"],
589
+ gen_configs[post_key_prefix]["should_escape_markdown"],
590
+ gen_configs[post_key_prefix]["should_add_legis_urls"],
591
+ )
592
 
 
593
 
594
+ def main():
595
+
596
+ st.title(":classical_building: LegisQA :classical_building:")
597
+ st.header("Query Congressional Bills")
598
+
599
+ with st.sidebar:
600
+ render_sidebar()
601
+
602
+ query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
603
+ [
604
+ "RAG",
605
+ "RAG (side-by-side)",
606
+ "Guide",
607
+ ]
608
+ )
609
+
610
+ with query_rag_tab:
611
+ render_query_rag_tab()
612
+
613
+ with query_rag_sbs_tab:
614
+ render_query_rag_sbs_tab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
615
 
616
+ with guide_tab:
617
+ render_guide()
618
 
 
 
619
 
620
+ if __name__ == "__main__":
621
+ main()
usage.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+
4
+ def get_openai_token_usage(metadata: dict, model_info: dict):
5
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
6
+ output_tokens = metadata["token_usage"]["completion_tokens"]
7
+ cost = (
8
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
9
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
10
+ )
11
+ return {
12
+ "input_tokens": input_tokens,
13
+ "output_tokens": output_tokens,
14
+ "cost": cost,
15
+ }
16
+
17
+
18
+ def get_anthropic_token_usage(metadata: dict, model_info: dict):
19
+ input_tokens = metadata["usage"]["input_tokens"]
20
+ output_tokens = metadata["usage"]["output_tokens"]
21
+ cost = (
22
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
23
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
24
+ )
25
+ return {
26
+ "input_tokens": input_tokens,
27
+ "output_tokens": output_tokens,
28
+ "cost": cost,
29
+ }
30
+
31
+
32
+ def get_together_token_usage(metadata: dict, model_info: dict):
33
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
34
+ output_tokens = metadata["token_usage"]["completion_tokens"]
35
+ cost = (
36
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
37
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
38
+ )
39
+ return {
40
+ "input_tokens": input_tokens,
41
+ "output_tokens": output_tokens,
42
+ "cost": cost,
43
+ }
44
+
45
+
46
+ def get_token_usage(metadata: dict, model_info: dict, provider: str):
47
+ match provider:
48
+ case "OpenAI":
49
+ return get_openai_token_usage(metadata, model_info)
50
+ case "Anthropic":
51
+ return get_anthropic_token_usage(metadata, model_info)
52
+ case "Together":
53
+ return get_together_token_usage(metadata, model_info)
54
+ case _:
55
+ raise ValueError()
56
+
57
+
58
+ def display_api_usage(response, model_info, provider: str):
59
+ with st.container(border=True):
60
+ st.write("API Usage")
61
+ token_usage = get_token_usage(
62
+ response["aimessage"].response_metadata, model_info, provider
63
+ )
64
+ col1, col2, col3 = st.columns(3)
65
+ with col1:
66
+ st.metric("Input Tokens", token_usage["input_tokens"])
67
+ with col2:
68
+ st.metric("Output Tokens", token_usage["output_tokens"])
69
+ with col3:
70
+ st.metric("Cost", f"${token_usage['cost']:.4f}")
71
+ with st.expander("Response Metadata"):
72
+ st.warning(response["aimessage"].response_metadata)