Spaces:
Sleeping
Sleeping
gabrielaltay
commited on
Commit
•
69c42d0
1
Parent(s):
6ef5143
llama 3.1
Browse files
app.py
CHANGED
@@ -1,6 +1,4 @@
|
|
1 |
"""
|
2 |
-
TODO: checkout langgraph
|
3 |
-
TODO: clear screen between agent calls (see here https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/clear_results.py)
|
4 |
"""
|
5 |
|
6 |
from collections import defaultdict
|
@@ -11,7 +9,9 @@ import re
|
|
11 |
from langchain.tools.retriever import create_retriever_tool
|
12 |
from langchain.agents import AgentExecutor
|
13 |
from langchain.agents import create_openai_tools_agent
|
14 |
-
from langchain.agents.format_scratchpad.openai_tools import
|
|
|
|
|
15 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
16 |
from langchain_core.documents import Document
|
17 |
from langchain_core.prompts import PromptTemplate
|
@@ -56,23 +56,22 @@ CONGRESS_GOV_TYPE_MAP = {
|
|
56 |
"sjres": "senate-joint-resolution",
|
57 |
"sres": "senate-resolution",
|
58 |
}
|
59 |
-
OPENAI_CHAT_MODELS =
|
60 |
-
"gpt-4o-mini",
|
61 |
-
"gpt-4o",
|
62 |
-
|
63 |
-
ANTHROPIC_CHAT_MODELS =
|
64 |
-
"claude-3-haiku-20240307",
|
65 |
-
"claude-3-5-sonnet-20240620",
|
66 |
-
"claude-3-opus-20240229",
|
67 |
-
|
68 |
-
TOGETHER_CHAT_MODELS =
|
69 |
-
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
|
70 |
-
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS + TOGETHER_CHAT_MODELS
|
76 |
|
77 |
PROVIDER_MODELS = {
|
78 |
"OpenAI": OPENAI_CHAT_MODELS,
|
@@ -174,14 +173,20 @@ def escape_markdown(text):
|
|
174 |
return text
|
175 |
|
176 |
|
177 |
-
def get_vectorstore_filter():
|
178 |
vs_filter = {}
|
179 |
-
if SS["filter_legis_id"] != "":
|
180 |
-
vs_filter["legis_id"] = SS["filter_legis_id"]
|
181 |
-
if SS["filter_bioguide_id"] != "":
|
182 |
-
vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
|
183 |
-
vs_filter = {
|
184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
185 |
return vs_filter
|
186 |
|
187 |
|
@@ -195,7 +200,6 @@ def render_doc_grp(legis_id: str, doc_grp: list[Document]):
|
|
195 |
)
|
196 |
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
197 |
|
198 |
-
|
199 |
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
200 |
len(doc_grp),
|
201 |
first_doc.metadata["legis_id"],
|
@@ -282,53 +286,118 @@ Suggest reforms that would benefit the Medicaid program.
|
|
282 |
)
|
283 |
|
284 |
|
285 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
286 |
|
287 |
-
with st.container(border=True):
|
288 |
-
render_outreach_links()
|
289 |
|
290 |
-
|
291 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
292 |
|
293 |
-
|
294 |
-
|
295 |
-
|
296 |
-
|
297 |
-
"
|
298 |
-
|
299 |
-
|
300 |
-
"
|
301 |
-
|
302 |
-
|
303 |
-
|
304 |
-
with st.expander("Retrieval Config"):
|
305 |
-
st.slider(
|
306 |
-
"Number of chunks to retrieve",
|
307 |
-
min_value=1,
|
308 |
-
max_value=32,
|
309 |
-
value=8,
|
310 |
-
key="n_ret_docs",
|
311 |
)
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
"
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
)
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
|
|
|
|
|
|
325 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
|
328 |
def render_query_rag_tab():
|
329 |
|
|
|
330 |
render_example_queries()
|
331 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
332 |
QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
333 |
|
334 |
---
|
@@ -348,14 +417,18 @@ Query: {query}"""
|
|
348 |
)
|
349 |
|
350 |
with st.form("query_form"):
|
351 |
-
st.text_area(
|
|
|
|
|
|
|
352 |
query_submitted = st.form_submit_button("Submit")
|
353 |
|
354 |
if query_submitted:
|
355 |
|
356 |
-
|
|
|
357 |
retriever = vectorstore.as_retriever(
|
358 |
-
search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
|
359 |
)
|
360 |
|
361 |
rag_chain = (
|
@@ -364,37 +437,41 @@ Query: {query}"""
|
|
364 |
"docs": retriever, # list of docs
|
365 |
"query": RunnablePassthrough(), # str
|
366 |
}
|
367 |
-
)
|
368 |
-
.assign(
|
369 |
-
.assign(output=prompt | llm
|
370 |
)
|
371 |
|
372 |
-
|
373 |
-
with get_openai_callback() as cb:
|
374 |
-
SS["out"] = rag_chain.invoke(SS["query"])
|
375 |
-
SS["cb"] = cb
|
376 |
-
else:
|
377 |
-
SS.pop("cb", None)
|
378 |
-
SS["out"] = rag_chain.invoke(SS["query"])
|
379 |
|
380 |
-
if "out" in SS:
|
381 |
|
382 |
-
out_display = SS["out"]["output"]
|
383 |
-
if SS["response_escape_markdown"]:
|
384 |
out_display = escape_markdown(out_display)
|
385 |
-
if SS["response_add_legis_urls"]:
|
386 |
out_display = replace_legis_ids_with_urls(out_display)
|
387 |
with st.container(border=True):
|
388 |
st.write("Response")
|
389 |
st.info(out_display)
|
390 |
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
395 |
|
396 |
with st.container(border=True):
|
397 |
-
doc_grps = group_docs(SS["out"]["docs"])
|
398 |
st.write(
|
399 |
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
400 |
)
|
@@ -402,86 +479,68 @@ Query: {query}"""
|
|
402 |
render_doc_grp(legis_id, doc_grp)
|
403 |
|
404 |
with st.expander("Debug"):
|
405 |
-
st.write(SS["out"])
|
406 |
-
|
407 |
-
|
408 |
-
def
|
409 |
-
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
|
|
|
|
426 |
)
|
427 |
-
|
428 |
-
|
429 |
-
"
|
430 |
-
"
|
431 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
)
|
433 |
-
|
434 |
-
|
435 |
-
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
[
|
445 |
-
|
446 |
-
("human", "{input}"),
|
447 |
-
MessagesPlaceholder(variable_name="agent_scratchpad"),
|
448 |
-
]
|
449 |
)
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
),
|
456 |
-
}
|
457 |
-
| agent_prompt
|
458 |
-
| llm_with_tools
|
459 |
-
| OpenAIToolsAgentOutputParser()
|
460 |
-
)
|
461 |
-
|
462 |
-
prompt = hub.pull("hwchase17/react")
|
463 |
-
agent = create_react_agent(llm, tools, prompt)
|
464 |
-
agent_executor = AgentExecutor(
|
465 |
-
agent=agent,
|
466 |
-
tools=tools,
|
467 |
-
return_intermediate_steps=True,
|
468 |
-
handle_parsing_errors=True,
|
469 |
-
verbose=True,
|
470 |
-
)
|
471 |
-
|
472 |
-
if user_input := st.chat_input(key="single_query_agent_input"):
|
473 |
-
st.chat_message("user").write(user_input)
|
474 |
-
with st.chat_message("assistant"):
|
475 |
-
st_callback = StreamlitCallbackHandler(st.container())
|
476 |
-
response = agent_executor.invoke({"input": user_input}, {"callbacks": [st_callback]})
|
477 |
-
st.write(response["output"])
|
478 |
-
|
479 |
-
|
480 |
-
def render_chat_agent_tab():
|
481 |
-
st.write("Coming Soon")
|
482 |
|
483 |
|
|
|
484 |
|
|
|
485 |
|
486 |
|
487 |
##################
|
@@ -495,53 +554,21 @@ with st.sidebar:
|
|
495 |
render_sidebar()
|
496 |
|
497 |
|
498 |
-
if SS["model_name"] in OPENAI_CHAT_MODELS:
|
499 |
-
llm = ChatOpenAI(
|
500 |
-
model=SS["model_name"],
|
501 |
-
temperature=SS["temperature"],
|
502 |
-
api_key=st.secrets["openai_api_key"],
|
503 |
-
top_p=SS["top_p"],
|
504 |
-
seed=SEED,
|
505 |
-
max_tokens=SS["max_output_tokens"],
|
506 |
-
)
|
507 |
-
elif SS["model_name"] in ANTHROPIC_CHAT_MODELS:
|
508 |
-
llm = ChatAnthropic(
|
509 |
-
model_name=SS["model_name"],
|
510 |
-
temperature=SS["temperature"],
|
511 |
-
api_key=st.secrets["anthropic_api_key"],
|
512 |
-
top_p=SS["top_p"],
|
513 |
-
max_tokens_to_sample=SS["max_output_tokens"],
|
514 |
-
)
|
515 |
-
elif SS["model_name"] in TOGETHER_CHAT_MODELS:
|
516 |
-
llm = ChatTogether(
|
517 |
-
model=SS["model_name"],
|
518 |
-
temperature=SS["temperature"],
|
519 |
-
max_tokens=SS["max_output_tokens"],
|
520 |
-
top_p=SS["top_p"],
|
521 |
-
seed=SEED,
|
522 |
-
api_key=st.secrets["together_api_key"],
|
523 |
-
)
|
524 |
-
else:
|
525 |
-
raise ValueError()
|
526 |
-
|
527 |
-
|
528 |
vectorstore = load_pinecone_vectorstore()
|
529 |
|
530 |
-
query_rag_tab,
|
531 |
-
|
532 |
-
|
533 |
-
|
534 |
-
|
535 |
-
]
|
|
|
536 |
|
537 |
with query_rag_tab:
|
538 |
render_query_rag_tab()
|
539 |
|
540 |
-
with
|
541 |
-
|
542 |
-
|
543 |
-
with chat_agent_tab:
|
544 |
-
render_chat_agent_tab()
|
545 |
|
546 |
with guide_tab:
|
547 |
render_guide()
|
|
|
1 |
"""
|
|
|
|
|
2 |
"""
|
3 |
|
4 |
from collections import defaultdict
|
|
|
9 |
from langchain.tools.retriever import create_retriever_tool
|
10 |
from langchain.agents import AgentExecutor
|
11 |
from langchain.agents import create_openai_tools_agent
|
12 |
+
from langchain.agents.format_scratchpad.openai_tools import (
|
13 |
+
format_to_openai_tool_messages,
|
14 |
+
)
|
15 |
from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
|
16 |
from langchain_core.documents import Document
|
17 |
from langchain_core.prompts import PromptTemplate
|
|
|
56 |
"sjres": "senate-joint-resolution",
|
57 |
"sres": "senate-resolution",
|
58 |
}
|
59 |
+
OPENAI_CHAT_MODELS = {
|
60 |
+
"gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
|
61 |
+
# "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
|
62 |
+
}
|
63 |
+
ANTHROPIC_CHAT_MODELS = {
|
64 |
+
"claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}},
|
65 |
+
# "claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}},
|
66 |
+
# "claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}},
|
67 |
+
}
|
68 |
+
TOGETHER_CHAT_MODELS = {
|
69 |
+
"meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}},
|
70 |
+
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
|
71 |
+
"cost": {"pmi": 0.88, "pmo": 0.88}
|
72 |
+
},
|
73 |
+
# "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"cost": {"pmi": 5.00, "pmo": 5.00}},
|
74 |
+
}
|
|
|
75 |
|
76 |
PROVIDER_MODELS = {
|
77 |
"OpenAI": OPENAI_CHAT_MODELS,
|
|
|
173 |
return text
|
174 |
|
175 |
|
176 |
+
def get_vectorstore_filter(key_prefix: str):
|
177 |
vs_filter = {}
|
178 |
+
if SS[f"{key_prefix}|filter_legis_id"] != "":
|
179 |
+
vs_filter["legis_id"] = SS[f"{key_prefix}|filter_legis_id"]
|
180 |
+
if SS[f"{key_prefix}|filter_bioguide_id"] != "":
|
181 |
+
vs_filter["sponsor_bioguide_id"] = SS[f"{key_prefix}|filter_bioguide_id"]
|
182 |
+
vs_filter = {
|
183 |
+
**vs_filter,
|
184 |
+
"congress_num": {"$in": SS[f"{key_prefix}|filter_congress_nums"]},
|
185 |
+
}
|
186 |
+
vs_filter = {
|
187 |
+
**vs_filter,
|
188 |
+
"sponsor_party": {"$in": SS[f"{key_prefix}|filter_sponsor_parties"]},
|
189 |
+
}
|
190 |
return vs_filter
|
191 |
|
192 |
|
|
|
200 |
)
|
201 |
congress_gov_link = f"[congress.gov]({congress_gov_url})"
|
202 |
|
|
|
203 |
ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
|
204 |
len(doc_grp),
|
205 |
first_doc.metadata["legis_id"],
|
|
|
286 |
)
|
287 |
|
288 |
|
289 |
+
def render_generative_config(key_prefix: str):
|
290 |
+
st.selectbox(
|
291 |
+
label="provider", options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|provider"
|
292 |
+
)
|
293 |
+
st.selectbox(
|
294 |
+
label="model name",
|
295 |
+
options=PROVIDER_MODELS[SS[f"{key_prefix}|provider"]],
|
296 |
+
key=f"{key_prefix}|model_name",
|
297 |
+
)
|
298 |
+
st.slider(
|
299 |
+
"temperature",
|
300 |
+
min_value=0.0,
|
301 |
+
max_value=2.0,
|
302 |
+
value=0.01,
|
303 |
+
key=f"{key_prefix}|temperature",
|
304 |
+
)
|
305 |
+
st.slider(
|
306 |
+
"max_output_tokens",
|
307 |
+
min_value=512,
|
308 |
+
max_value=1024,
|
309 |
+
key=f"{key_prefix}|max_output_tokens",
|
310 |
+
)
|
311 |
+
st.slider(
|
312 |
+
"top_p", min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|top_p"
|
313 |
+
)
|
314 |
+
st.checkbox(
|
315 |
+
"escape markdown in answer", key=f"{key_prefix}|response_escape_markdown"
|
316 |
+
)
|
317 |
+
st.checkbox(
|
318 |
+
"add legis urls in answer",
|
319 |
+
value=True,
|
320 |
+
key=f"{key_prefix}|response_add_legis_urls",
|
321 |
+
)
|
322 |
|
|
|
|
|
323 |
|
324 |
+
def render_retrieval_config(key_prefix: str):
|
325 |
+
st.slider(
|
326 |
+
"Number of chunks to retrieve",
|
327 |
+
min_value=1,
|
328 |
+
max_value=32,
|
329 |
+
value=8,
|
330 |
+
key=f"{key_prefix}|n_ret_docs",
|
331 |
+
)
|
332 |
+
st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|filter_legis_id")
|
333 |
+
st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|filter_bioguide_id")
|
334 |
+
st.multiselect(
|
335 |
+
"Congress Numbers",
|
336 |
+
CONGRESS_NUMBERS,
|
337 |
+
default=CONGRESS_NUMBERS,
|
338 |
+
key=f"{key_prefix}|filter_congress_nums",
|
339 |
+
)
|
340 |
+
st.multiselect(
|
341 |
+
"Sponsor Party",
|
342 |
+
SPONSOR_PARTIES,
|
343 |
+
default=SPONSOR_PARTIES,
|
344 |
+
key=f"{key_prefix}|filter_sponsor_parties",
|
345 |
+
)
|
346 |
+
|
347 |
|
348 |
+
def get_llm(key_prefix: str):
|
349 |
+
|
350 |
+
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
|
351 |
+
llm = ChatOpenAI(
|
352 |
+
model=SS[f"{key_prefix}|model_name"],
|
353 |
+
temperature=SS[f"{key_prefix}|temperature"],
|
354 |
+
api_key=st.secrets["openai_api_key"],
|
355 |
+
top_p=SS[f"{key_prefix}|top_p"],
|
356 |
+
seed=SEED,
|
357 |
+
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
)
|
359 |
+
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
|
360 |
+
llm = ChatAnthropic(
|
361 |
+
model_name=SS[f"{key_prefix}|model_name"],
|
362 |
+
temperature=SS[f"{key_prefix}|temperature"],
|
363 |
+
api_key=st.secrets["anthropic_api_key"],
|
364 |
+
top_p=SS[f"{key_prefix}|top_p"],
|
365 |
+
max_tokens_to_sample=SS[f"{key_prefix}|max_output_tokens"],
|
366 |
)
|
367 |
+
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
|
368 |
+
llm = ChatTogether(
|
369 |
+
model=SS[f"{key_prefix}|model_name"],
|
370 |
+
temperature=SS[f"{key_prefix}|temperature"],
|
371 |
+
max_tokens=SS[f"{key_prefix}|max_output_tokens"],
|
372 |
+
top_p=SS[f"{key_prefix}|top_p"],
|
373 |
+
seed=SEED,
|
374 |
+
api_key=st.secrets["together_api_key"],
|
375 |
)
|
376 |
+
else:
|
377 |
+
raise ValueError()
|
378 |
+
|
379 |
+
return llm
|
380 |
+
|
381 |
+
|
382 |
+
def render_sidebar():
|
383 |
+
|
384 |
+
with st.container(border=True):
|
385 |
+
render_outreach_links()
|
386 |
|
387 |
|
388 |
def render_query_rag_tab():
|
389 |
|
390 |
+
key_prefix = "query_rag"
|
391 |
render_example_queries()
|
392 |
|
393 |
+
col1, col2 = st.columns(2)
|
394 |
+
with col1:
|
395 |
+
with st.expander("Generative Config"):
|
396 |
+
render_generative_config(key_prefix)
|
397 |
+
with col2:
|
398 |
+
with st.expander("Retrieval Config"):
|
399 |
+
render_retrieval_config(key_prefix)
|
400 |
+
|
401 |
QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
|
402 |
|
403 |
---
|
|
|
417 |
)
|
418 |
|
419 |
with st.form("query_form"):
|
420 |
+
st.text_area(
|
421 |
+
"Enter a query that can be answered with congressional legislation:",
|
422 |
+
key=f"{key_prefix}|query",
|
423 |
+
)
|
424 |
query_submitted = st.form_submit_button("Submit")
|
425 |
|
426 |
if query_submitted:
|
427 |
|
428 |
+
llm = get_llm(key_prefix)
|
429 |
+
vs_filter = get_vectorstore_filter(key_prefix)
|
430 |
retriever = vectorstore.as_retriever(
|
431 |
+
search_kwargs={"k": SS[f"{key_prefix}|n_ret_docs"], "filter": vs_filter},
|
432 |
)
|
433 |
|
434 |
rag_chain = (
|
|
|
437 |
"docs": retriever, # list of docs
|
438 |
"query": RunnablePassthrough(), # str
|
439 |
}
|
440 |
+
).assign(context=(lambda x: format_docs(x["docs"])))
|
441 |
+
# .assign(output=prompt | llm | StrOutputParser())
|
442 |
+
.assign(output=prompt | llm)
|
443 |
)
|
444 |
|
445 |
+
SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{key_prefix}|query"])
|
|
|
|
|
|
|
|
|
|
|
|
|
446 |
|
447 |
+
if f"{key_prefix}|out" in SS:
|
448 |
|
449 |
+
out_display = SS[f"{key_prefix}|out"]["output"].content
|
450 |
+
if SS[f"{key_prefix}|response_escape_markdown"]:
|
451 |
out_display = escape_markdown(out_display)
|
452 |
+
if SS[f"{key_prefix}|response_add_legis_urls"]:
|
453 |
out_display = replace_legis_ids_with_urls(out_display)
|
454 |
with st.container(border=True):
|
455 |
st.write("Response")
|
456 |
st.info(out_display)
|
457 |
|
458 |
+
with st.container(border=True):
|
459 |
+
st.write("API Usage")
|
460 |
+
token_usage = get_token_usage(
|
461 |
+
key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
|
462 |
+
)
|
463 |
+
col1, col2, col3 = st.columns(3)
|
464 |
+
with col1:
|
465 |
+
st.metric("Input Tokens", token_usage["input_tokens"])
|
466 |
+
with col2:
|
467 |
+
st.metric("Output Tokens", token_usage["output_tokens"])
|
468 |
+
with col3:
|
469 |
+
st.metric("Cost", f"${token_usage['cost']:.4f}")
|
470 |
+
with st.expander("Response Metadata"):
|
471 |
+
st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
|
472 |
|
473 |
with st.container(border=True):
|
474 |
+
doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
|
475 |
st.write(
|
476 |
"Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
|
477 |
)
|
|
|
479 |
render_doc_grp(legis_id, doc_grp)
|
480 |
|
481 |
with st.expander("Debug"):
|
482 |
+
st.write(SS[f"{key_prefix}|out"])
|
483 |
+
|
484 |
+
|
485 |
+
def get_token_usage(key_prefix: str, metadata: dict):
|
486 |
+
if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
|
487 |
+
model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
|
488 |
+
return get_openai_token_usage(metadata, model_info)
|
489 |
+
elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
|
490 |
+
model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
|
491 |
+
return get_anthropic_token_usage(metadata, model_info)
|
492 |
+
elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
|
493 |
+
model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
|
494 |
+
return get_together_token_usage(metadata, model_info)
|
495 |
+
else:
|
496 |
+
raise ValueError()
|
497 |
+
|
498 |
+
|
499 |
+
def get_openai_token_usage(metadata: dict, model_info: dict):
|
500 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
501 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
502 |
+
cost = (
|
503 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
504 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
505 |
)
|
506 |
+
return {
|
507 |
+
"input_tokens": input_tokens,
|
508 |
+
"output_tokens": output_tokens,
|
509 |
+
"cost": cost,
|
510 |
+
}
|
511 |
+
|
512 |
+
|
513 |
+
def get_anthropic_token_usage(metadata: dict, model_info: dict):
|
514 |
+
input_tokens = metadata["usage"]["input_tokens"]
|
515 |
+
output_tokens = metadata["usage"]["output_tokens"]
|
516 |
+
cost = (
|
517 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
518 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
519 |
)
|
520 |
+
return {
|
521 |
+
"input_tokens": input_tokens,
|
522 |
+
"output_tokens": output_tokens,
|
523 |
+
"cost": cost,
|
524 |
+
}
|
525 |
+
|
526 |
+
|
527 |
+
def get_together_token_usage(metadata: dict, model_info: dict):
|
528 |
+
input_tokens = metadata["token_usage"]["prompt_tokens"]
|
529 |
+
output_tokens = metadata["token_usage"]["completion_tokens"]
|
530 |
+
cost = (
|
531 |
+
input_tokens * 1e-6 * model_info["cost"]["pmi"]
|
532 |
+
+ output_tokens * 1e-6 * model_info["cost"]["pmo"]
|
|
|
|
|
|
|
533 |
)
|
534 |
+
return {
|
535 |
+
"input_tokens": input_tokens,
|
536 |
+
"output_tokens": output_tokens,
|
537 |
+
"cost": cost,
|
538 |
+
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
539 |
|
540 |
|
541 |
+
def render_query_rag_sbs_tab():
|
542 |
|
543 |
+
return
|
544 |
|
545 |
|
546 |
##################
|
|
|
554 |
render_sidebar()
|
555 |
|
556 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
557 |
vectorstore = load_pinecone_vectorstore()
|
558 |
|
559 |
+
query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
|
560 |
+
[
|
561 |
+
"query_rag",
|
562 |
+
"query_rag_sbs",
|
563 |
+
"guide",
|
564 |
+
]
|
565 |
+
)
|
566 |
|
567 |
with query_rag_tab:
|
568 |
render_query_rag_tab()
|
569 |
|
570 |
+
with query_rag_sbs_tab:
|
571 |
+
render_query_rag_sbs_tab()
|
|
|
|
|
|
|
572 |
|
573 |
with guide_tab:
|
574 |
render_guide()
|