Spaces:
Running
Running
import streamlit as st | |
from langchain_core.messages import AIMessage | |
def get_openai_token_usage(aimessage: AIMessage, model_info: dict): | |
input_tokens = aimessage.usage_metadata["input_tokens"] | |
output_tokens = aimessage.usage_metadata["output_tokens"] | |
cost = ( | |
input_tokens * 1e-6 * model_info["cost"]["pmi"] | |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"] | |
) | |
return { | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"cost": cost, | |
} | |
def get_anthropic_token_usage(aimessage: AIMessage, model_info: dict): | |
input_tokens = aimessage.usage_metadata["input_tokens"] | |
output_tokens = aimessage.usage_metadata["output_tokens"] | |
cost = ( | |
input_tokens * 1e-6 * model_info["cost"]["pmi"] | |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"] | |
) | |
return { | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"cost": cost, | |
} | |
def get_together_token_usage(aimessage: AIMessage, model_info: dict): | |
input_tokens = aimessage.usage_metadata["input_tokens"] | |
output_tokens = aimessage.usage_metadata["output_tokens"] | |
cost = ( | |
input_tokens * 1e-6 * model_info["cost"]["pmi"] | |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"] | |
) | |
return { | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"cost": cost, | |
} | |
def get_google_token_usage(aimessage: AIMessage, model_info: dict): | |
input_tokens = aimessage.usage_metadata["input_tokens"] | |
output_tokens = aimessage.usage_metadata["output_tokens"] | |
cost = ( | |
input_tokens * 1e-6 * model_info["cost"]["pmi"] | |
+ output_tokens * 1e-6 * model_info["cost"]["pmo"] | |
) | |
return { | |
"input_tokens": input_tokens, | |
"output_tokens": output_tokens, | |
"cost": cost, | |
} | |
def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str): | |
match provider: | |
case "OpenAI": | |
return get_openai_token_usage(aimessage, model_info) | |
case "Anthropic": | |
return get_anthropic_token_usage(aimessage, model_info) | |
case "Together": | |
return get_together_token_usage(aimessage, model_info) | |
case "Google": | |
return get_google_token_usage(aimessage, model_info) | |
case _: | |
raise ValueError() | |
def display_api_usage( | |
aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None | |
): | |
with st.container(border=True): | |
if tag is None: | |
st.write("API Usage") | |
else: | |
st.write(f"API Usage ({tag})") | |
token_usage = get_token_usage(aimessage, model_info, provider) | |
col1, col2, col3 = st.columns(3) | |
with col1: | |
st.metric("Input Tokens", token_usage["input_tokens"]) | |
with col2: | |
st.metric("Output Tokens", token_usage["output_tokens"]) | |
with col3: | |
st.metric("Cost", f"${token_usage['cost']:.4f}") | |
with st.expander("AIMessage Metadata"): | |
dd = {key: val for key, val in aimessage.dict().items() if key != "content"} | |
st.write(dd) | |