gabrielaltay commited on
Commit
793d0f2
·
1 Parent(s): 42554ac
Files changed (2) hide show
  1. app.py +23 -9
  2. usage.py +5 -2
app.py CHANGED
@@ -457,18 +457,28 @@ def process_query(gen_config: dict, ret_config: dict, query: str):
457
  return response
458
 
459
 
460
- def display_retrieved_chunks(response):
461
  with st.container(border=True):
462
- doc_grps = group_docs(response["docs"])
463
- st.write(
464
- "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
465
- )
 
 
 
 
 
466
  for legis_id, doc_grp in doc_grps:
467
  render_doc_grp(legis_id, doc_grp)
468
 
469
 
470
  def display_response(
471
- response, model_info, provider, should_escape_markdown, should_add_legis_urls
 
 
 
 
 
472
  ):
473
  out_display = response["aimessage"].content
474
  if should_escape_markdown:
@@ -477,11 +487,14 @@ def display_response(
477
  out_display = replace_legis_ids_with_urls(out_display)
478
 
479
  with st.container(border=True):
480
- st.write("Response")
 
 
 
481
  st.info(out_display)
482
 
483
- usage.display_api_usage(response, model_info, provider)
484
- display_retrieved_chunks(response)
485
 
486
 
487
  def render_query_rag_tab():
@@ -588,6 +601,7 @@ def render_query_rag_sbs_tab():
588
  gen_configs[post_key_prefix]["provider"],
589
  gen_configs[post_key_prefix]["should_escape_markdown"],
590
  gen_configs[post_key_prefix]["should_add_legis_urls"],
 
591
  )
592
 
593
 
 
457
  return response
458
 
459
 
460
+ def display_retrieved_chunks(docs: list[Document], tag: str|None=None):
461
  with st.container(border=True):
462
+ doc_grps = group_docs(docs)
463
+ if tag is None:
464
+ st.write(
465
+ "Retrieved Chunks\n\nleft click to expand, right click to follow links"
466
+ )
467
+ else:
468
+ st.write(
469
+ f"Retrieved Chunks ({tag})\n\nleft click to expand, right click to follow links"
470
+ )
471
  for legis_id, doc_grp in doc_grps:
472
  render_doc_grp(legis_id, doc_grp)
473
 
474
 
475
  def display_response(
476
+ response,
477
+ model_info: dict,
478
+ provider: str,
479
+ should_escape_markdown: bool,
480
+ should_add_legis_urls: bool,
481
+ tag: str|None=None
482
  ):
483
  out_display = response["aimessage"].content
484
  if should_escape_markdown:
 
487
  out_display = replace_legis_ids_with_urls(out_display)
488
 
489
  with st.container(border=True):
490
+ if tag is None:
491
+ st.write("Response")
492
+ else:
493
+ st.write(f"Response ({tag})")
494
  st.info(out_display)
495
 
496
+ usage.display_api_usage(response, model_info, provider, tag=tag)
497
+ display_retrieved_chunks(response["docs"], tag=tag)
498
 
499
 
500
  def render_query_rag_tab():
 
601
  gen_configs[post_key_prefix]["provider"],
602
  gen_configs[post_key_prefix]["should_escape_markdown"],
603
  gen_configs[post_key_prefix]["should_add_legis_urls"],
604
+ tag = grp_names[post_key_prefix],
605
  )
606
 
607
 
usage.py CHANGED
@@ -55,9 +55,12 @@ def get_token_usage(metadata: dict, model_info: dict, provider: str):
55
  raise ValueError()
56
 
57
 
58
- def display_api_usage(response, model_info, provider: str):
59
  with st.container(border=True):
60
- st.write("API Usage")
 
 
 
61
  token_usage = get_token_usage(
62
  response["aimessage"].response_metadata, model_info, provider
63
  )
 
55
  raise ValueError()
56
 
57
 
58
+ def display_api_usage(response, model_info, provider: str, tag: str|None=None):
59
  with st.container(border=True):
60
+ if tag is None:
61
+ st.write("API Usage")
62
+ else:
63
+ st.write(f"API Usage ({tag})")
64
  token_usage = get_token_usage(
65
  response["aimessage"].response_metadata, model_info, provider
66
  )