gabrielaltay commited on
Commit
da0f003
1 Parent(s): 69c42d0

side by side

Browse files
Files changed (1) hide show
  1. app.py +161 -56
app.py CHANGED
@@ -304,8 +304,8 @@ def render_generative_config(key_prefix: str):
304
  )
305
  st.slider(
306
  "max_output_tokens",
307
- min_value=512,
308
- max_value=1024,
309
  key=f"{key_prefix}|max_output_tokens",
310
  )
311
  st.slider(
@@ -379,6 +379,62 @@ def get_llm(key_prefix: str):
379
  return llm
380
 
381
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
382
  def render_sidebar():
383
 
384
  with st.container(border=True):
@@ -398,7 +454,7 @@ def render_query_rag_tab():
398
  with st.expander("Retrieval Config"):
399
  render_retrieval_config(key_prefix)
400
 
401
- QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
402
 
403
  ---
404
 
@@ -412,11 +468,11 @@ Query: {query}"""
412
 
413
  prompt = ChatPromptTemplate.from_messages(
414
  [
415
- ("human", QUERY_TEMPLATE),
416
  ]
417
  )
418
 
419
- with st.form("query_form"):
420
  st.text_area(
421
  "Enter a query that can be answered with congressional legislation:",
422
  key=f"{key_prefix}|query",
@@ -437,8 +493,8 @@ Query: {query}"""
437
  "docs": retriever, # list of docs
438
  "query": RunnablePassthrough(), # str
439
  }
440
- ).assign(context=(lambda x: format_docs(x["docs"])))
441
- # .assign(output=prompt | llm | StrOutputParser())
442
  .assign(output=prompt | llm)
443
  )
444
 
@@ -482,65 +538,114 @@ Query: {query}"""
482
  st.write(SS[f"{key_prefix}|out"])
483
 
484
 
485
- def get_token_usage(key_prefix: str, metadata: dict):
486
- if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
487
- model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
488
- return get_openai_token_usage(metadata, model_info)
489
- elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
490
- model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
491
- return get_anthropic_token_usage(metadata, model_info)
492
- elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
493
- model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
494
- return get_together_token_usage(metadata, model_info)
495
- else:
496
- raise ValueError()
497
 
 
498
 
499
- def get_openai_token_usage(metadata: dict, model_info: dict):
500
- input_tokens = metadata["token_usage"]["prompt_tokens"]
501
- output_tokens = metadata["token_usage"]["completion_tokens"]
502
- cost = (
503
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
504
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
505
- )
506
- return {
507
- "input_tokens": input_tokens,
508
- "output_tokens": output_tokens,
509
- "cost": cost,
510
- }
511
 
 
512
 
513
- def get_anthropic_token_usage(metadata: dict, model_info: dict):
514
- input_tokens = metadata["usage"]["input_tokens"]
515
- output_tokens = metadata["usage"]["output_tokens"]
516
- cost = (
517
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
518
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
519
- )
520
- return {
521
- "input_tokens": input_tokens,
522
- "output_tokens": output_tokens,
523
- "cost": cost,
524
- }
525
 
 
526
 
527
- def get_together_token_usage(metadata: dict, model_info: dict):
528
- input_tokens = metadata["token_usage"]["prompt_tokens"]
529
- output_tokens = metadata["token_usage"]["completion_tokens"]
530
- cost = (
531
- input_tokens * 1e-6 * model_info["cost"]["pmi"]
532
- + output_tokens * 1e-6 * model_info["cost"]["pmo"]
 
 
533
  )
534
- return {
535
- "input_tokens": input_tokens,
536
- "output_tokens": output_tokens,
537
- "cost": cost,
538
- }
539
 
 
 
 
 
 
 
540
 
541
- def render_query_rag_sbs_tab():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
542
 
543
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
 
545
 
546
  ##################
 
304
  )
305
  st.slider(
306
  "max_output_tokens",
307
+ min_value=1024,
308
+ max_value=2048,
309
  key=f"{key_prefix}|max_output_tokens",
310
  )
311
  st.slider(
 
379
  return llm
380
 
381
 
382
+ def get_token_usage(key_prefix: str, metadata: dict):
383
+ if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
384
+ model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
385
+ return get_openai_token_usage(metadata, model_info)
386
+ elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
387
+ model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
388
+ return get_anthropic_token_usage(metadata, model_info)
389
+ elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
390
+ model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
391
+ return get_together_token_usage(metadata, model_info)
392
+ else:
393
+ raise ValueError()
394
+
395
+
396
+ def get_openai_token_usage(metadata: dict, model_info: dict):
397
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
398
+ output_tokens = metadata["token_usage"]["completion_tokens"]
399
+ cost = (
400
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
401
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
402
+ )
403
+ return {
404
+ "input_tokens": input_tokens,
405
+ "output_tokens": output_tokens,
406
+ "cost": cost,
407
+ }
408
+
409
+
410
+ def get_anthropic_token_usage(metadata: dict, model_info: dict):
411
+ input_tokens = metadata["usage"]["input_tokens"]
412
+ output_tokens = metadata["usage"]["output_tokens"]
413
+ cost = (
414
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
415
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
416
+ )
417
+ return {
418
+ "input_tokens": input_tokens,
419
+ "output_tokens": output_tokens,
420
+ "cost": cost,
421
+ }
422
+
423
+
424
+ def get_together_token_usage(metadata: dict, model_info: dict):
425
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
426
+ output_tokens = metadata["token_usage"]["completion_tokens"]
427
+ cost = (
428
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
429
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
430
+ )
431
+ return {
432
+ "input_tokens": input_tokens,
433
+ "output_tokens": output_tokens,
434
+ "cost": cost,
435
+ }
436
+
437
+
438
  def render_sidebar():
439
 
440
  with st.container(border=True):
 
454
  with st.expander("Retrieval Config"):
455
  render_retrieval_config(key_prefix)
456
 
457
+ QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
458
 
459
  ---
460
 
 
468
 
469
  prompt = ChatPromptTemplate.from_messages(
470
  [
471
+ ("human", QUERY_RAG_TEMPLATE),
472
  ]
473
  )
474
 
475
+ with st.form(f"{key_prefix}|query_form"):
476
  st.text_area(
477
  "Enter a query that can be answered with congressional legislation:",
478
  key=f"{key_prefix}|query",
 
493
  "docs": retriever, # list of docs
494
  "query": RunnablePassthrough(), # str
495
  }
496
+ )
497
+ .assign(context=(lambda x: format_docs(x["docs"])))
498
  .assign(output=prompt | llm)
499
  )
500
 
 
538
  st.write(SS[f"{key_prefix}|out"])
539
 
540
 
541
+ def render_query_rag_sbs_tab():
 
 
 
 
 
 
 
 
 
 
 
542
 
543
+ QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
544
 
545
+ ---
 
 
 
 
 
 
 
 
 
 
 
546
 
547
+ Congressional Legislation Excerpts:
548
 
549
+ {context}
 
 
 
 
 
 
 
 
 
 
 
550
 
551
+ ---
552
 
553
+ Query: {query}"""
554
+
555
+ base_key_prefix = "query_rag_sbs"
556
+
557
+ prompt = ChatPromptTemplate.from_messages(
558
+ [
559
+ ("human", QUERY_RAG_TEMPLATE),
560
+ ]
561
  )
 
 
 
 
 
562
 
563
+ with st.form(f"{base_key_prefix}|query_form"):
564
+ st.text_area(
565
+ "Enter a query that can be answered with congressional legislation:",
566
+ key=f"{base_key_prefix}|query",
567
+ )
568
+ query_submitted = st.form_submit_button("Submit")
569
 
570
+ grp1a, grp2a = st.columns(2)
571
+
572
+ with grp1a:
573
+ st.header("Group 1")
574
+ key_prefix = f"{base_key_prefix}|grp1"
575
+ with st.expander("Generative Config"):
576
+ render_generative_config(key_prefix)
577
+ with st.expander("Retrieval Config"):
578
+ render_retrieval_config(key_prefix)
579
+
580
+ with grp2a:
581
+ st.header("Group 2")
582
+ key_prefix = f"{base_key_prefix}|grp2"
583
+ with st.expander("Generative Config"):
584
+ render_generative_config(key_prefix)
585
+ with st.expander("Retrieval Config"):
586
+ render_retrieval_config(key_prefix)
587
+
588
+ grp1b, grp2b = st.columns(2)
589
+ sbs_cols = {"grp1": grp1b, "grp2": grp2b}
590
 
591
+ for post_key_prefix in ["grp1", "grp2"]:
592
+
593
+ key_prefix = f"{base_key_prefix}|{post_key_prefix}"
594
+
595
+ if query_submitted:
596
+ llm = get_llm(key_prefix)
597
+ vs_filter = get_vectorstore_filter(key_prefix)
598
+ retriever = vectorstore.as_retriever(
599
+ search_kwargs={
600
+ "k": SS[f"{key_prefix}|n_ret_docs"],
601
+ "filter": vs_filter,
602
+ },
603
+ )
604
+ rag_chain = (
605
+ RunnableParallel(
606
+ {
607
+ "docs": retriever, # list of docs
608
+ "query": RunnablePassthrough(), # str
609
+ }
610
+ )
611
+ .assign(context=(lambda x: format_docs(x["docs"])))
612
+ .assign(output=prompt | llm)
613
+ )
614
+ SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{base_key_prefix}|query"])
615
+
616
+ if f"{key_prefix}|out" in SS:
617
+ with sbs_cols[post_key_prefix]:
618
+ out_display = SS[f"{key_prefix}|out"]["output"].content
619
+ if SS[f"{key_prefix}|response_escape_markdown"]:
620
+ out_display = escape_markdown(out_display)
621
+ if SS[f"{key_prefix}|response_add_legis_urls"]:
622
+ out_display = replace_legis_ids_with_urls(out_display)
623
+ with st.container(border=True):
624
+ st.write("Response")
625
+ st.info(out_display)
626
+
627
+ with st.container(border=True):
628
+ st.write("API Usage")
629
+ token_usage = get_token_usage(
630
+ key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
631
+ )
632
+ col1, col2, col3 = st.columns(3)
633
+ with col1:
634
+ st.metric("Input Tokens", token_usage["input_tokens"])
635
+ with col2:
636
+ st.metric("Output Tokens", token_usage["output_tokens"])
637
+ with col3:
638
+ st.metric("Cost", f"${token_usage['cost']:.4f}")
639
+ with st.expander("Response Metadata"):
640
+ st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
641
+
642
+ with st.container(border=True):
643
+ doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
644
+ st.write(
645
+ "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
646
+ )
647
+ for legis_id, doc_grp in doc_grps:
648
+ render_doc_grp(legis_id, doc_grp)
649
 
650
 
651
  ##################