Spaces:
Sleeping
Sleeping
gabrielaltay
commited on
Commit
•
da0f003
1
Parent(s):
69c42d0
side by side
Browse files
app.py
CHANGED
@@ -304,8 +304,8 @@ def render_generative_config(key_prefix: str):
|
|
304 |
)
|
305 |
st.slider(
|
306 |
"max_output_tokens",
|
307 |
-
min_value=
|
308 |
-
max_value=
|
309 |
key=f"{key_prefix}|max_output_tokens",
|
310 |
)
|
311 |
st.slider(
|
@@ -379,6 +379,62 @@ def get_llm(key_prefix: str):
|
|
379 |
return llm
|
380 |
|
381 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
382 |
def render_sidebar():
|
383 |
|
384 |
with st.container(border=True):
|
@@ -398,7 +454,7 @@ def render_query_rag_tab():
|
|
398 |
with st.expander("Retrieval Config"):
|
399 |
render_retrieval_config(key_prefix)
|
400 |
|
401 |
-
|
402 |
|
403 |
---
|
404 |
|
@@ -412,11 +468,11 @@ Query: {query}"""
|
|
412 |
|
413 |
prompt = ChatPromptTemplate.from_messages(
|
414 |
[
|
415 |
-
("human",
|
416 |
]
|
417 |
)
|
418 |
|
419 |
-
with st.form("query_form"):
|
420 |
st.text_area(
|
421 |
"Enter a query that can be answered with congressional legislation:",
|
422 |
key=f"{key_prefix}|query",
|
@@ -437,8 +493,8 @@ Query: {query}"""
|
|
437 |
"docs": retriever, # list of docs
|
438 |
"query": RunnablePassthrough(), # str
|
439 |
}
|
440 |
-
)
|
441 |
-
|
442 |
.assign(output=prompt | llm)
|
443 |
)
|
444 |
|
@@ -482,65 +538,114 @@ Query: {query}"""
|
|
482 |
st.write(SS[f"{key_prefix}|out"])
|
483 |
|
484 |
|
485 |
-
def
|
486 |
-
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
|
487 |
-
model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
|
488 |
-
return get_openai_token_usage(metadata, model_info)
|
489 |
-
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
|
490 |
-
model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
|
491 |
-
return get_anthropic_token_usage(metadata, model_info)
|
492 |
-
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
|
493 |
-
model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
|
494 |
-
return get_together_token_usage(metadata, model_info)
|
495 |
-
else:
|
496 |
-
raise ValueError()
|
497 |
|
|
|
498 |
|
499 |
-
|
500 |
-
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
501 |
-
output_tokens = metadata["token_usage"]["completion_tokens"]
|
502 |
-
cost = (
|
503 |
-
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
504 |
-
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
505 |
-
)
|
506 |
-
return {
|
507 |
-
"input_tokens": input_tokens,
|
508 |
-
"output_tokens": output_tokens,
|
509 |
-
"cost": cost,
|
510 |
-
}
|
511 |
|
|
|
512 |
|
513 |
-
|
514 |
-
input_tokens = metadata["usage"]["input_tokens"]
|
515 |
-
output_tokens = metadata["usage"]["output_tokens"]
|
516 |
-
cost = (
|
517 |
-
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
518 |
-
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
519 |
-
)
|
520 |
-
return {
|
521 |
-
"input_tokens": input_tokens,
|
522 |
-
"output_tokens": output_tokens,
|
523 |
-
"cost": cost,
|
524 |
-
}
|
525 |
|
|
|
526 |
|
527 |
-
|
528 |
-
|
529 |
-
|
530 |
-
|
531 |
-
|
532 |
-
|
|
|
|
|
533 |
)
|
534 |
-
return {
|
535 |
-
"input_tokens": input_tokens,
|
536 |
-
"output_tokens": output_tokens,
|
537 |
-
"cost": cost,
|
538 |
-
}
|
539 |
|
|
|
|
|
|
|
|
|
|
|
|
|
540 |
|
541 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
544 |
|
545 |
|
546 |
##################
|
|
|
304 |
)
|
305 |
st.slider(
|
306 |
"max_output_tokens",
|
307 |
+
min_value=1024,
|
308 |
+
max_value=2048,
|
309 |
key=f"{key_prefix}|max_output_tokens",
|
310 |
)
|
311 |
st.slider(
|
|
|
379 |
return llm
|
380 |
|
381 |
|
382 |
+
def get_token_usage(key_prefix: str, metadata: dict):
|
383 |
+
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
|
384 |
+
model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
|
385 |
+
return get_openai_token_usage(metadata, model_info)
|
386 |
+
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
|
387 |
+
model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
|
388 |
+
return get_anthropic_token_usage(metadata, model_info)
|
389 |
+
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
|
390 |
+
model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
|
391 |
+
return get_together_token_usage(metadata, model_info)
|
392 |
+
else:
|
393 |
+
raise ValueError()
|
394 |
+
|
395 |
+
|
396 |
+
def get_openai_token_usage(metadata: dict, model_info: dict):
|
397 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
398 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
399 |
+
cost = (
|
400 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
401 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
402 |
+
)
|
403 |
+
return {
|
404 |
+
"input_tokens": input_tokens,
|
405 |
+
"output_tokens": output_tokens,
|
406 |
+
"cost": cost,
|
407 |
+
}
|
408 |
+
|
409 |
+
|
410 |
+
def get_anthropic_token_usage(metadata: dict, model_info: dict):
|
411 |
+
input_tokens = metadata["usage"]["input_tokens"]
|
412 |
+
output_tokens = metadata["usage"]["output_tokens"]
|
413 |
+
cost = (
|
414 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
415 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
416 |
+
)
|
417 |
+
return {
|
418 |
+
"input_tokens": input_tokens,
|
419 |
+
"output_tokens": output_tokens,
|
420 |
+
"cost": cost,
|
421 |
+
}
|
422 |
+
|
423 |
+
|
424 |
+
def get_together_token_usage(metadata: dict, model_info: dict):
|
425 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
426 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
427 |
+
cost = (
|
428 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
429 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
430 |
+
)
|
431 |
+
return {
|
432 |
+
"input_tokens": input_tokens,
|
433 |
+
"output_tokens": output_tokens,
|
434 |
+
"cost": cost,
|
435 |
+
}
|
436 |
+
|
437 |
+
|
438 |
def render_sidebar():
|
439 |
|
440 |
with st.container(border=True):
|
|
|
454 |
with st.expander("Retrieval Config"):
|
455 |
render_retrieval_config(key_prefix)
|
456 |
|
457 |
+
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
458 |
|
459 |
---
|
460 |
|
|
|
468 |
|
469 |
prompt = ChatPromptTemplate.from_messages(
|
470 |
[
|
471 |
+
("human", QUERY_RAG_TEMPLATE),
|
472 |
]
|
473 |
)
|
474 |
|
475 |
+
with st.form(f"{key_prefix}|query_form"):
|
476 |
st.text_area(
|
477 |
"Enter a query that can be answered with congressional legislation:",
|
478 |
key=f"{key_prefix}|query",
|
|
|
493 |
"docs": retriever, # list of docs
|
494 |
"query": RunnablePassthrough(), # str
|
495 |
}
|
496 |
+
)
|
497 |
+
.assign(context=(lambda x: format_docs(x["docs"])))
|
498 |
.assign(output=prompt | llm)
|
499 |
)
|
500 |
|
|
|
538 |
st.write(SS[f"{key_prefix}|out"])
|
539 |
|
540 |
|
541 |
+
def render_query_rag_sbs_tab():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
542 |
|
543 |
+
QUERY_RAG_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
544 |
|
545 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
546 |
|
547 |
+
Congressional Legislation Excerpts:
|
548 |
|
549 |
+
{context}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
+
---
|
552 |
|
553 |
+
Query: {query}"""
|
554 |
+
|
555 |
+
base_key_prefix = "query_rag_sbs"
|
556 |
+
|
557 |
+
prompt = ChatPromptTemplate.from_messages(
|
558 |
+
[
|
559 |
+
("human", QUERY_RAG_TEMPLATE),
|
560 |
+
]
|
561 |
)
|
|
|
|
|
|
|
|
|
|
|
562 |
|
563 |
+
with st.form(f"{base_key_prefix}|query_form"):
|
564 |
+
st.text_area(
|
565 |
+
"Enter a query that can be answered with congressional legislation:",
|
566 |
+
key=f"{base_key_prefix}|query",
|
567 |
+
)
|
568 |
+
query_submitted = st.form_submit_button("Submit")
|
569 |
|
570 |
+
grp1a, grp2a = st.columns(2)
|
571 |
+
|
572 |
+
with grp1a:
|
573 |
+
st.header("Group 1")
|
574 |
+
key_prefix = f"{base_key_prefix}|grp1"
|
575 |
+
with st.expander("Generative Config"):
|
576 |
+
render_generative_config(key_prefix)
|
577 |
+
with st.expander("Retrieval Config"):
|
578 |
+
render_retrieval_config(key_prefix)
|
579 |
+
|
580 |
+
with grp2a:
|
581 |
+
st.header("Group 2")
|
582 |
+
key_prefix = f"{base_key_prefix}|grp2"
|
583 |
+
with st.expander("Generative Config"):
|
584 |
+
render_generative_config(key_prefix)
|
585 |
+
with st.expander("Retrieval Config"):
|
586 |
+
render_retrieval_config(key_prefix)
|
587 |
+
|
588 |
+
grp1b, grp2b = st.columns(2)
|
589 |
+
sbs_cols = {"grp1": grp1b, "grp2": grp2b}
|
590 |
|
591 |
+
for post_key_prefix in ["grp1", "grp2"]:
|
592 |
+
|
593 |
+
key_prefix = f"{base_key_prefix}|{post_key_prefix}"
|
594 |
+
|
595 |
+
if query_submitted:
|
596 |
+
llm = get_llm(key_prefix)
|
597 |
+
vs_filter = get_vectorstore_filter(key_prefix)
|
598 |
+
retriever = vectorstore.as_retriever(
|
599 |
+
search_kwargs={
|
600 |
+
"k": SS[f"{key_prefix}|n_ret_docs"],
|
601 |
+
"filter": vs_filter,
|
602 |
+
},
|
603 |
+
)
|
604 |
+
rag_chain = (
|
605 |
+
RunnableParallel(
|
606 |
+
{
|
607 |
+
"docs": retriever, # list of docs
|
608 |
+
"query": RunnablePassthrough(), # str
|
609 |
+
}
|
610 |
+
)
|
611 |
+
.assign(context=(lambda x: format_docs(x["docs"])))
|
612 |
+
.assign(output=prompt | llm)
|
613 |
+
)
|
614 |
+
SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{base_key_prefix}|query"])
|
615 |
+
|
616 |
+
if f"{key_prefix}|out" in SS:
|
617 |
+
with sbs_cols[post_key_prefix]:
|
618 |
+
out_display = SS[f"{key_prefix}|out"]["output"].content
|
619 |
+
if SS[f"{key_prefix}|response_escape_markdown"]:
|
620 |
+
out_display = escape_markdown(out_display)
|
621 |
+
if SS[f"{key_prefix}|response_add_legis_urls"]:
|
622 |
+
out_display = replace_legis_ids_with_urls(out_display)
|
623 |
+
with st.container(border=True):
|
624 |
+
st.write("Response")
|
625 |
+
st.info(out_display)
|
626 |
+
|
627 |
+
with st.container(border=True):
|
628 |
+
st.write("API Usage")
|
629 |
+
token_usage = get_token_usage(
|
630 |
+
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
|
631 |
+
)
|
632 |
+
col1, col2, col3 = st.columns(3)
|
633 |
+
with col1:
|
634 |
+
st.metric("Input Tokens", token_usage["input_tokens"])
|
635 |
+
with col2:
|
636 |
+
st.metric("Output Tokens", token_usage["output_tokens"])
|
637 |
+
with col3:
|
638 |
+
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
639 |
+
with st.expander("Response Metadata"):
|
640 |
+
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
|
641 |
+
|
642 |
+
with st.container(border=True):
|
643 |
+
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
|
644 |
+
st.write(
|
645 |
+
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
646 |
+
)
|
647 |
+
for legis_id, doc_grp in doc_grps:
|
648 |
+
render_doc_grp(legis_id, doc_grp)
|
649 |
|
650 |
|
651 |
##################
|