File size: 3,150 Bytes
42554ac
43b5ceb
42554ac
 
43b5ceb
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
43b5ceb
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
43b5ceb
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
43b5ceb
 
 
3ead889
 
 
 
 
 
 
 
 
 
 
43b5ceb
42554ac
 
43b5ceb
42554ac
43b5ceb
42554ac
43b5ceb
3ead889
43b5ceb
42554ac
 
 
 
76cbdff
43b5ceb
76cbdff
42554ac
793d0f2
 
 
 
43b5ceb
42554ac
 
 
 
 
 
 
43b5ceb
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import streamlit as st
from langchain_core.messages import AIMessage


def get_openai_token_usage(aimessage: AIMessage, model_info: dict):
    input_tokens = aimessage.usage_metadata["input_tokens"]
    output_tokens = aimessage.usage_metadata["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_anthropic_token_usage(aimessage: AIMessage, model_info: dict):
    input_tokens = aimessage.usage_metadata["input_tokens"]
    output_tokens = aimessage.usage_metadata["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_together_token_usage(aimessage: AIMessage, model_info: dict):
    input_tokens = aimessage.usage_metadata["input_tokens"]
    output_tokens = aimessage.usage_metadata["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_google_token_usage(aimessage: AIMessage, model_info: dict):
    input_tokens = aimessage.usage_metadata["input_tokens"]
    output_tokens = aimessage.usage_metadata["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_token_usage(aimessage: AIMessage, model_info: dict, provider: str):
    match provider:
        case "OpenAI":
            return get_openai_token_usage(aimessage, model_info)
        case "Anthropic":
            return get_anthropic_token_usage(aimessage, model_info)
        case "Together":
            return get_together_token_usage(aimessage, model_info)
        case "Google":
            return get_google_token_usage(aimessage, model_info)
        case _:
            raise ValueError()


def display_api_usage(
    aimessage: AIMessage, model_info: dict, provider: str, tag: str | None = None
):
    with st.container(border=True):
        if tag is None:
            st.write("API Usage")
        else:
            st.write(f"API Usage ({tag})")
        token_usage = get_token_usage(aimessage, model_info, provider)
        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Input Tokens", token_usage["input_tokens"])
        with col2:
            st.metric("Output Tokens", token_usage["output_tokens"])
        with col3:
            st.metric("Cost", f"${token_usage['cost']:.4f}")
        with st.expander("AIMessage Metadata"):
            dd = {key: val for key, val in aimessage.dict().items() if key != "content"}
            st.write(dd)