gabrielaltay commited on
Commit
69c42d0
1 Parent(s): 6ef5143
Files changed (1) hide show
  1. app.py +227 -200
app.py CHANGED
@@ -1,6 +1,4 @@
1
  """
2
- TODO: checkout langgraph
3
- TODO: clear screen between agent calls (see here https://github.com/langchain-ai/streamlit-agent/blob/main/streamlit_agent/clear_results.py)
4
  """
5
 
6
  from collections import defaultdict
@@ -11,7 +9,9 @@ import re
11
  from langchain.tools.retriever import create_retriever_tool
12
  from langchain.agents import AgentExecutor
13
  from langchain.agents import create_openai_tools_agent
14
- from langchain.agents.format_scratchpad.openai_tools import format_to_openai_tool_messages
 
 
15
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
16
  from langchain_core.documents import Document
17
  from langchain_core.prompts import PromptTemplate
@@ -56,23 +56,22 @@ CONGRESS_GOV_TYPE_MAP = {
56
  "sjres": "senate-joint-resolution",
57
  "sres": "senate-resolution",
58
  }
59
- OPENAI_CHAT_MODELS = [
60
- "gpt-4o-mini",
61
- "gpt-4o",
62
- ]
63
- ANTHROPIC_CHAT_MODELS = [
64
- "claude-3-haiku-20240307",
65
- "claude-3-5-sonnet-20240620",
66
- "claude-3-opus-20240229",
67
- ]
68
- TOGETHER_CHAT_MODELS = [
69
- "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo",
70
- "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo",
71
- "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo",
72
- ]
73
-
74
-
75
- CHAT_MODELS = OPENAI_CHAT_MODELS + ANTHROPIC_CHAT_MODELS + TOGETHER_CHAT_MODELS
76
 
77
  PROVIDER_MODELS = {
78
  "OpenAI": OPENAI_CHAT_MODELS,
@@ -174,14 +173,20 @@ def escape_markdown(text):
174
  return text
175
 
176
 
177
- def get_vectorstore_filter():
178
  vs_filter = {}
179
- if SS["filter_legis_id"] != "":
180
- vs_filter["legis_id"] = SS["filter_legis_id"]
181
- if SS["filter_bioguide_id"] != "":
182
- vs_filter["sponsor_bioguide_id"] = SS["filter_bioguide_id"]
183
- vs_filter = {**vs_filter, "congress_num": {"$in": SS["filter_congress_nums"]}}
184
- vs_filter = {**vs_filter, "sponsor_party": {"$in": SS["filter_sponsor_parties"]}}
 
 
 
 
 
 
185
  return vs_filter
186
 
187
 
@@ -195,7 +200,6 @@ def render_doc_grp(legis_id: str, doc_grp: list[Document]):
195
  )
196
  congress_gov_link = f"[congress.gov]({congress_gov_url})"
197
 
198
-
199
  ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
200
  len(doc_grp),
201
  first_doc.metadata["legis_id"],
@@ -282,53 +286,118 @@ Suggest reforms that would benefit the Medicaid program.
282
  )
283
 
284
 
285
- def render_sidebar():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- with st.container(border=True):
288
- render_outreach_links()
289
 
290
- st.checkbox("escape markdown in answer", key="response_escape_markdown")
291
- st.checkbox("add legis urls in answer", value=True, key="response_add_legis_urls")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
- with st.expander("Generative Config"):
294
- st.selectbox(label="provider", options=PROVIDER_MODELS.keys(), key="provider")
295
- st.selectbox(label="model name", options=PROVIDER_MODELS[SS["provider"]], key="model_name")
296
- st.slider(
297
- "temperature", min_value=0.0, max_value=2.0, value=0.01, key="temperature"
298
- )
299
- st.slider(
300
- "max_output_tokens", min_value=512, max_value=1024, key="max_output_tokens"
301
- )
302
- st.slider("top_p", min_value=0.0, max_value=1.0, value=1.0, key="top_p")
303
-
304
- with st.expander("Retrieval Config"):
305
- st.slider(
306
- "Number of chunks to retrieve",
307
- min_value=1,
308
- max_value=32,
309
- value=8,
310
- key="n_ret_docs",
311
  )
312
- st.text_input("Bill ID (e.g. 118-s-2293)", key="filter_legis_id")
313
- st.text_input("Bioguide ID (e.g. R000595)", key="filter_bioguide_id")
314
- st.multiselect(
315
- "Congress Numbers",
316
- CONGRESS_NUMBERS,
317
- default=CONGRESS_NUMBERS,
318
- key="filter_congress_nums",
319
  )
320
- st.multiselect(
321
- "Sponsor Party",
322
- SPONSOR_PARTIES,
323
- default=SPONSOR_PARTIES,
324
- key="filter_sponsor_parties",
 
 
 
325
  )
 
 
 
 
 
 
 
 
 
 
326
 
327
 
328
  def render_query_rag_tab():
329
 
 
330
  render_example_queries()
331
 
 
 
 
 
 
 
 
 
332
  QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
333
 
334
  ---
@@ -348,14 +417,18 @@ Query: {query}"""
348
  )
349
 
350
  with st.form("query_form"):
351
- st.text_area("Enter a query that can be answered with congressional legislation:", key="query")
 
 
 
352
  query_submitted = st.form_submit_button("Submit")
353
 
354
  if query_submitted:
355
 
356
- vs_filter = get_vectorstore_filter()
 
357
  retriever = vectorstore.as_retriever(
358
- search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
359
  )
360
 
361
  rag_chain = (
@@ -364,37 +437,41 @@ Query: {query}"""
364
  "docs": retriever, # list of docs
365
  "query": RunnablePassthrough(), # str
366
  }
367
- )
368
- .assign(context=(lambda x: format_docs(x["docs"])))
369
- .assign(output=prompt | llm | StrOutputParser())
370
  )
371
 
372
- if SS["model_name"] in OPENAI_CHAT_MODELS:
373
- with get_openai_callback() as cb:
374
- SS["out"] = rag_chain.invoke(SS["query"])
375
- SS["cb"] = cb
376
- else:
377
- SS.pop("cb", None)
378
- SS["out"] = rag_chain.invoke(SS["query"])
379
 
380
- if "out" in SS:
381
 
382
- out_display = SS["out"]["output"]
383
- if SS["response_escape_markdown"]:
384
  out_display = escape_markdown(out_display)
385
- if SS["response_add_legis_urls"]:
386
  out_display = replace_legis_ids_with_urls(out_display)
387
  with st.container(border=True):
388
  st.write("Response")
389
  st.info(out_display)
390
 
391
- if "cb" in SS:
392
- with st.container(border=True):
393
- st.write("API Usage")
394
- st.warning(SS["cb"])
 
 
 
 
 
 
 
 
 
 
395
 
396
  with st.container(border=True):
397
- doc_grps = group_docs(SS["out"]["docs"])
398
  st.write(
399
  "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
400
  )
@@ -402,86 +479,68 @@ Query: {query}"""
402
  render_doc_grp(legis_id, doc_grp)
403
 
404
  with st.expander("Debug"):
405
- st.write(SS["out"])
406
-
407
-
408
- def render_query_agent_tab():
409
-
410
- from retriever_tools import get_retriever_tool
411
-
412
- from langchain_community.tools import WikipediaQueryRun
413
- from langchain_community.utilities import WikipediaAPIWrapper
414
- # from langchain.agents import load_tools
415
- from langchain_community.agent_toolkits.load_tools import load_tools
416
- from langchain.agents import create_react_agent
417
- from langchain import hub
418
-
419
- if SS["model_name"] not in OPENAI_CHAT_MODELS:
420
- st.write("only supported with OpenAI for now")
421
- return
422
-
423
- vs_filter = get_vectorstore_filter()
424
- retriever = vectorstore.as_retriever(
425
- search_kwargs={"k": SS["n_ret_docs"], "filter": vs_filter},
 
 
426
  )
427
- legis_retrieval_tool = get_retriever_tool(
428
- retriever,
429
- "search_legislation",
430
- "Searches and returns excerpts from congressional legislation. Always call this tool first.",
431
- format_docs,
 
 
 
 
 
 
 
 
432
  )
433
-
434
- api_wrapper = WikipediaAPIWrapper(top_k_results=4, doc_content_chars_max=800)
435
- wiki_search_tool = WikipediaQueryRun(api_wrapper=api_wrapper)
436
-
437
- ddg_tool = load_tools(["ddg-search"])[0]
438
-
439
- avatars = {"human": "user", "ai": "assistant"}
440
- tools = [legis_retrieval_tool, wiki_search_tool, ddg_tool]
441
- llm_with_tools = llm.bind_tools(tools)
442
-
443
- agent_prompt = ChatPromptTemplate.from_messages(
444
- [
445
- ("system", "You are a helpful assistant."),
446
- ("human", "{input}"),
447
- MessagesPlaceholder(variable_name="agent_scratchpad"),
448
- ]
449
  )
450
- agent = (
451
- {
452
- "input": lambda x: x["input"],
453
- "agent_scratchpad": lambda x: format_to_openai_tool_messages(
454
- x["intermediate_steps"]
455
- ),
456
- }
457
- | agent_prompt
458
- | llm_with_tools
459
- | OpenAIToolsAgentOutputParser()
460
- )
461
-
462
- prompt = hub.pull("hwchase17/react")
463
- agent = create_react_agent(llm, tools, prompt)
464
- agent_executor = AgentExecutor(
465
- agent=agent,
466
- tools=tools,
467
- return_intermediate_steps=True,
468
- handle_parsing_errors=True,
469
- verbose=True,
470
- )
471
-
472
- if user_input := st.chat_input(key="single_query_agent_input"):
473
- st.chat_message("user").write(user_input)
474
- with st.chat_message("assistant"):
475
- st_callback = StreamlitCallbackHandler(st.container())
476
- response = agent_executor.invoke({"input": user_input}, {"callbacks": [st_callback]})
477
- st.write(response["output"])
478
-
479
-
480
- def render_chat_agent_tab():
481
- st.write("Coming Soon")
482
 
483
 
 
484
 
 
485
 
486
 
487
  ##################
@@ -495,53 +554,21 @@ with st.sidebar:
495
  render_sidebar()
496
 
497
 
498
- if SS["model_name"] in OPENAI_CHAT_MODELS:
499
- llm = ChatOpenAI(
500
- model=SS["model_name"],
501
- temperature=SS["temperature"],
502
- api_key=st.secrets["openai_api_key"],
503
- top_p=SS["top_p"],
504
- seed=SEED,
505
- max_tokens=SS["max_output_tokens"],
506
- )
507
- elif SS["model_name"] in ANTHROPIC_CHAT_MODELS:
508
- llm = ChatAnthropic(
509
- model_name=SS["model_name"],
510
- temperature=SS["temperature"],
511
- api_key=st.secrets["anthropic_api_key"],
512
- top_p=SS["top_p"],
513
- max_tokens_to_sample=SS["max_output_tokens"],
514
- )
515
- elif SS["model_name"] in TOGETHER_CHAT_MODELS:
516
- llm = ChatTogether(
517
- model=SS["model_name"],
518
- temperature=SS["temperature"],
519
- max_tokens=SS["max_output_tokens"],
520
- top_p=SS["top_p"],
521
- seed=SEED,
522
- api_key=st.secrets["together_api_key"],
523
- )
524
- else:
525
- raise ValueError()
526
-
527
-
528
  vectorstore = load_pinecone_vectorstore()
529
 
530
- query_rag_tab, query_agent_tab, chat_agent_tab, guide_tab = st.tabs([
531
- "query_rag",
532
- "query_agent",
533
- "chat_agent",
534
- "guide",
535
- ])
 
536
 
537
  with query_rag_tab:
538
  render_query_rag_tab()
539
 
540
- with query_agent_tab:
541
- render_query_agent_tab()
542
-
543
- with chat_agent_tab:
544
- render_chat_agent_tab()
545
 
546
  with guide_tab:
547
  render_guide()
 
1
  """
 
 
2
  """
3
 
4
  from collections import defaultdict
 
9
  from langchain.tools.retriever import create_retriever_tool
10
  from langchain.agents import AgentExecutor
11
  from langchain.agents import create_openai_tools_agent
12
+ from langchain.agents.format_scratchpad.openai_tools import (
13
+ format_to_openai_tool_messages,
14
+ )
15
  from langchain.agents.output_parsers.openai_tools import OpenAIToolsAgentOutputParser
16
  from langchain_core.documents import Document
17
  from langchain_core.prompts import PromptTemplate
 
56
  "sjres": "senate-joint-resolution",
57
  "sres": "senate-resolution",
58
  }
59
+ OPENAI_CHAT_MODELS = {
60
+ "gpt-4o-mini": {"cost": {"pmi": 0.15, "pmo": 0.60}},
61
+ # "gpt-4o": {"cost": {"pmi": 5.00, "pmo": 15.0}},
62
+ }
63
+ ANTHROPIC_CHAT_MODELS = {
64
+ "claude-3-haiku-20240307": {"cost": {"pmi": 0.25, "pmo": 1.25}},
65
+ # "claude-3-5-sonnet-20240620": {"cost": {"pmi": 3.00, "pmo": 15.0}},
66
+ # "claude-3-opus-20240229": {"cost": {"pmi": 15.0, "pmo": 75.0}},
67
+ }
68
+ TOGETHER_CHAT_MODELS = {
69
+ "meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {"cost": {"pmi": 0.18, "pmo": 0.18}},
70
+ "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
71
+ "cost": {"pmi": 0.88, "pmo": 0.88}
72
+ },
73
+ # "meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": {"cost": {"pmi": 5.00, "pmo": 5.00}},
74
+ }
 
75
 
76
  PROVIDER_MODELS = {
77
  "OpenAI": OPENAI_CHAT_MODELS,
 
173
  return text
174
 
175
 
176
+ def get_vectorstore_filter(key_prefix: str):
177
  vs_filter = {}
178
+ if SS[f"{key_prefix}|filter_legis_id"] != "":
179
+ vs_filter["legis_id"] = SS[f"{key_prefix}|filter_legis_id"]
180
+ if SS[f"{key_prefix}|filter_bioguide_id"] != "":
181
+ vs_filter["sponsor_bioguide_id"] = SS[f"{key_prefix}|filter_bioguide_id"]
182
+ vs_filter = {
183
+ **vs_filter,
184
+ "congress_num": {"$in": SS[f"{key_prefix}|filter_congress_nums"]},
185
+ }
186
+ vs_filter = {
187
+ **vs_filter,
188
+ "sponsor_party": {"$in": SS[f"{key_prefix}|filter_sponsor_parties"]},
189
+ }
190
  return vs_filter
191
 
192
 
 
200
  )
201
  congress_gov_link = f"[congress.gov]({congress_gov_url})"
202
 
 
203
  ref = "{} chunks from {}\n\n{}\n\n{}\n\n[{} ({}) ]({})".format(
204
  len(doc_grp),
205
  first_doc.metadata["legis_id"],
 
286
  )
287
 
288
 
289
+ def render_generative_config(key_prefix: str):
290
+ st.selectbox(
291
+ label="provider", options=PROVIDER_MODELS.keys(), key=f"{key_prefix}|provider"
292
+ )
293
+ st.selectbox(
294
+ label="model name",
295
+ options=PROVIDER_MODELS[SS[f"{key_prefix}|provider"]],
296
+ key=f"{key_prefix}|model_name",
297
+ )
298
+ st.slider(
299
+ "temperature",
300
+ min_value=0.0,
301
+ max_value=2.0,
302
+ value=0.01,
303
+ key=f"{key_prefix}|temperature",
304
+ )
305
+ st.slider(
306
+ "max_output_tokens",
307
+ min_value=512,
308
+ max_value=1024,
309
+ key=f"{key_prefix}|max_output_tokens",
310
+ )
311
+ st.slider(
312
+ "top_p", min_value=0.0, max_value=1.0, value=0.9, key=f"{key_prefix}|top_p"
313
+ )
314
+ st.checkbox(
315
+ "escape markdown in answer", key=f"{key_prefix}|response_escape_markdown"
316
+ )
317
+ st.checkbox(
318
+ "add legis urls in answer",
319
+ value=True,
320
+ key=f"{key_prefix}|response_add_legis_urls",
321
+ )
322
 
 
 
323
 
324
+ def render_retrieval_config(key_prefix: str):
325
+ st.slider(
326
+ "Number of chunks to retrieve",
327
+ min_value=1,
328
+ max_value=32,
329
+ value=8,
330
+ key=f"{key_prefix}|n_ret_docs",
331
+ )
332
+ st.text_input("Bill ID (e.g. 118-s-2293)", key=f"{key_prefix}|filter_legis_id")
333
+ st.text_input("Bioguide ID (e.g. R000595)", key=f"{key_prefix}|filter_bioguide_id")
334
+ st.multiselect(
335
+ "Congress Numbers",
336
+ CONGRESS_NUMBERS,
337
+ default=CONGRESS_NUMBERS,
338
+ key=f"{key_prefix}|filter_congress_nums",
339
+ )
340
+ st.multiselect(
341
+ "Sponsor Party",
342
+ SPONSOR_PARTIES,
343
+ default=SPONSOR_PARTIES,
344
+ key=f"{key_prefix}|filter_sponsor_parties",
345
+ )
346
+
347
 
348
+ def get_llm(key_prefix: str):
349
+
350
+ if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
351
+ llm = ChatOpenAI(
352
+ model=SS[f"{key_prefix}|model_name"],
353
+ temperature=SS[f"{key_prefix}|temperature"],
354
+ api_key=st.secrets["openai_api_key"],
355
+ top_p=SS[f"{key_prefix}|top_p"],
356
+ seed=SEED,
357
+ max_tokens=SS[f"{key_prefix}|max_output_tokens"],
 
 
 
 
 
 
 
 
358
  )
359
+ elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
360
+ llm = ChatAnthropic(
361
+ model_name=SS[f"{key_prefix}|model_name"],
362
+ temperature=SS[f"{key_prefix}|temperature"],
363
+ api_key=st.secrets["anthropic_api_key"],
364
+ top_p=SS[f"{key_prefix}|top_p"],
365
+ max_tokens_to_sample=SS[f"{key_prefix}|max_output_tokens"],
366
  )
367
+ elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
368
+ llm = ChatTogether(
369
+ model=SS[f"{key_prefix}|model_name"],
370
+ temperature=SS[f"{key_prefix}|temperature"],
371
+ max_tokens=SS[f"{key_prefix}|max_output_tokens"],
372
+ top_p=SS[f"{key_prefix}|top_p"],
373
+ seed=SEED,
374
+ api_key=st.secrets["together_api_key"],
375
  )
376
+ else:
377
+ raise ValueError()
378
+
379
+ return llm
380
+
381
+
382
+ def render_sidebar():
383
+
384
+ with st.container(border=True):
385
+ render_outreach_links()
386
 
387
 
388
  def render_query_rag_tab():
389
 
390
+ key_prefix = "query_rag"
391
  render_example_queries()
392
 
393
+ col1, col2 = st.columns(2)
394
+ with col1:
395
+ with st.expander("Generative Config"):
396
+ render_generative_config(key_prefix)
397
+ with col2:
398
+ with st.expander("Retrieval Config"):
399
+ render_retrieval_config(key_prefix)
400
+
401
  QUERY_TEMPLATE = """You are an expert legislative analyst. Use the following excerpts from US congressional legislation to respond to the user's query. The excerpts are formatted as a JSON list. Each JSON object has "legis_id", "title", "introduced_date", "sponsor", and "snippets" keys. If a snippet is useful in writing part of your response, then cite the "legis_id", "title", "introduced_date", and "sponsor" in the response. If you don't know how to respond, just tell the user.
402
 
403
  ---
 
417
  )
418
 
419
  with st.form("query_form"):
420
+ st.text_area(
421
+ "Enter a query that can be answered with congressional legislation:",
422
+ key=f"{key_prefix}|query",
423
+ )
424
  query_submitted = st.form_submit_button("Submit")
425
 
426
  if query_submitted:
427
 
428
+ llm = get_llm(key_prefix)
429
+ vs_filter = get_vectorstore_filter(key_prefix)
430
  retriever = vectorstore.as_retriever(
431
+ search_kwargs={"k": SS[f"{key_prefix}|n_ret_docs"], "filter": vs_filter},
432
  )
433
 
434
  rag_chain = (
 
437
  "docs": retriever, # list of docs
438
  "query": RunnablePassthrough(), # str
439
  }
440
+ ).assign(context=(lambda x: format_docs(x["docs"])))
441
+ # .assign(output=prompt | llm | StrOutputParser())
442
+ .assign(output=prompt | llm)
443
  )
444
 
445
+ SS[f"{key_prefix}|out"] = rag_chain.invoke(SS[f"{key_prefix}|query"])
 
 
 
 
 
 
446
 
447
+ if f"{key_prefix}|out" in SS:
448
 
449
+ out_display = SS[f"{key_prefix}|out"]["output"].content
450
+ if SS[f"{key_prefix}|response_escape_markdown"]:
451
  out_display = escape_markdown(out_display)
452
+ if SS[f"{key_prefix}|response_add_legis_urls"]:
453
  out_display = replace_legis_ids_with_urls(out_display)
454
  with st.container(border=True):
455
  st.write("Response")
456
  st.info(out_display)
457
 
458
+ with st.container(border=True):
459
+ st.write("API Usage")
460
+ token_usage = get_token_usage(
461
+ key_prefix, SS[f"{key_prefix}|out"]["output"].response_metadata
462
+ )
463
+ col1, col2, col3 = st.columns(3)
464
+ with col1:
465
+ st.metric("Input Tokens", token_usage["input_tokens"])
466
+ with col2:
467
+ st.metric("Output Tokens", token_usage["output_tokens"])
468
+ with col3:
469
+ st.metric("Cost", f"${token_usage['cost']:.4f}")
470
+ with st.expander("Response Metadata"):
471
+ st.warning(SS[f"{key_prefix}|out"]["output"].response_metadata)
472
 
473
  with st.container(border=True):
474
+ doc_grps = group_docs(SS[f"{key_prefix}|out"]["docs"])
475
  st.write(
476
  "Retrieved Chunks (note that you may need to 'right click' on links in the expanders to follow them)"
477
  )
 
479
  render_doc_grp(legis_id, doc_grp)
480
 
481
  with st.expander("Debug"):
482
+ st.write(SS[f"{key_prefix}|out"])
483
+
484
+
485
+ def get_token_usage(key_prefix: str, metadata: dict):
486
+ if SS[f"{key_prefix}|model_name"] in OPENAI_CHAT_MODELS:
487
+ model_info = PROVIDER_MODELS["OpenAI"][SS[f"{key_prefix}|model_name"]]
488
+ return get_openai_token_usage(metadata, model_info)
489
+ elif SS[f"{key_prefix}|model_name"] in ANTHROPIC_CHAT_MODELS:
490
+ model_info = PROVIDER_MODELS["Anthropic"][SS[f"{key_prefix}|model_name"]]
491
+ return get_anthropic_token_usage(metadata, model_info)
492
+ elif SS[f"{key_prefix}|model_name"] in TOGETHER_CHAT_MODELS:
493
+ model_info = PROVIDER_MODELS["Together"][SS[f"{key_prefix}|model_name"]]
494
+ return get_together_token_usage(metadata, model_info)
495
+ else:
496
+ raise ValueError()
497
+
498
+
499
+ def get_openai_token_usage(metadata: dict, model_info: dict):
500
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
501
+ output_tokens = metadata["token_usage"]["completion_tokens"]
502
+ cost = (
503
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
504
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
505
  )
506
+ return {
507
+ "input_tokens": input_tokens,
508
+ "output_tokens": output_tokens,
509
+ "cost": cost,
510
+ }
511
+
512
+
513
+ def get_anthropic_token_usage(metadata: dict, model_info: dict):
514
+ input_tokens = metadata["usage"]["input_tokens"]
515
+ output_tokens = metadata["usage"]["output_tokens"]
516
+ cost = (
517
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
518
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
519
  )
520
+ return {
521
+ "input_tokens": input_tokens,
522
+ "output_tokens": output_tokens,
523
+ "cost": cost,
524
+ }
525
+
526
+
527
+ def get_together_token_usage(metadata: dict, model_info: dict):
528
+ input_tokens = metadata["token_usage"]["prompt_tokens"]
529
+ output_tokens = metadata["token_usage"]["completion_tokens"]
530
+ cost = (
531
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
532
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
 
 
 
533
  )
534
+ return {
535
+ "input_tokens": input_tokens,
536
+ "output_tokens": output_tokens,
537
+ "cost": cost,
538
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539
 
540
 
541
+ def render_query_rag_sbs_tab():
542
 
543
+ return
544
 
545
 
546
  ##################
 
554
  render_sidebar()
555
 
556
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
557
  vectorstore = load_pinecone_vectorstore()
558
 
559
+ query_rag_tab, query_rag_sbs_tab, guide_tab = st.tabs(
560
+ [
561
+ "query_rag",
562
+ "query_rag_sbs",
563
+ "guide",
564
+ ]
565
+ )
566
 
567
  with query_rag_tab:
568
  render_query_rag_tab()
569
 
570
+ with query_rag_sbs_tab:
571
+ render_query_rag_sbs_tab()
 
 
 
572
 
573
  with guide_tab:
574
  render_guide()