File size: 3,056 Bytes
42554ac
 
 
76cbdff
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
76cbdff
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
76cbdff
 
 
42554ac
 
 
 
 
 
 
 
 
 
 
3ead889
 
 
 
 
 
 
 
 
 
 
 
 
 
76cbdff
42554ac
 
76cbdff
42554ac
76cbdff
42554ac
76cbdff
3ead889
 
42554ac
 
 
 
76cbdff
 
 
42554ac
793d0f2
 
 
 
76cbdff
42554ac
 
 
 
 
 
 
 
76cbdff
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import streamlit as st


def get_openai_token_usage(response_metadata: dict, model_info: dict):
    input_tokens = response_metadata["token_usage"]["prompt_tokens"]
    output_tokens = response_metadata["token_usage"]["completion_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_anthropic_token_usage(response_metadata: dict, model_info: dict):
    input_tokens = response_metadata["usage"]["input_tokens"]
    output_tokens = response_metadata["usage"]["output_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_together_token_usage(response_metadata: dict, model_info: dict):
    input_tokens = response_metadata["token_usage"]["prompt_tokens"]
    output_tokens = response_metadata["token_usage"]["completion_tokens"]
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_google_token_usage(response_metadata: dict, model_info: dict):
    input_tokens = 0
    output_tokens = 0
    cost = (
        input_tokens * 1e-6 * model_info["cost"]["pmi"]
        + output_tokens * 1e-6 * model_info["cost"]["pmo"]
    )
    return {
        "input_tokens": input_tokens,
        "output_tokens": output_tokens,
        "cost": cost,
    }


def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
    match provider:
        case "OpenAI":
            return get_openai_token_usage(response_metadata, model_info)
        case "Anthropic":
            return get_anthropic_token_usage(response_metadata, model_info)
        case "Together":
            return get_together_token_usage(response_metadata, model_info)
        case "Google":
            return get_google_token_usage(response_metadata, model_info)
        case _:
            raise ValueError()


def display_api_usage(
    response_metadata: dict, model_info: dict, provider: str, tag: str | None = None
):
    with st.container(border=True):
        if tag is None:
            st.write("API Usage")
        else:
            st.write(f"API Usage ({tag})")
        token_usage = get_token_usage(response_metadata, model_info, provider)
        col1, col2, col3 = st.columns(3)
        with col1:
            st.metric("Input Tokens", token_usage["input_tokens"])
        with col2:
            st.metric("Output Tokens", token_usage["output_tokens"])
        with col3:
            st.metric("Cost", f"${token_usage['cost']:.4f}")
        with st.expander("Response Metadata"):
            st.warning(response_metadata)