gabrielaltay commited on
Commit
3ead889
·
1 Parent(s): 76cbdff

add gemini

Browse files
Files changed (2) hide show
  1. app.py +17 -0
  2. usage_mod.py +16 -0
app.py CHANGED
@@ -13,6 +13,7 @@ from langchain_core.runnables import RunnablePassthrough
13
  from langchain_openai import ChatOpenAI
14
  from langchain_anthropic import ChatAnthropic
15
  from langchain_together import ChatTogether
 
16
  import streamlit as st
17
 
18
  import utils_mod
@@ -53,11 +54,18 @@ TOGETHER_CHAT_MODELS = {
53
  "cost": {"pmi": 5.00, "pmo": 5.00}
54
  },
55
  }
 
 
 
 
 
 
56
 
57
  PROVIDER_MODELS = {
58
  "OpenAI": OPENAI_CHAT_MODELS,
59
  "Anthropic": ANTHROPIC_CHAT_MODELS,
60
  "Together": TOGETHER_CHAT_MODELS,
 
61
  }
62
 
63
 
@@ -215,6 +223,15 @@ def get_llm(gen_config: dict):
215
  api_key=st.secrets["together_api_key"],
216
  )
217
 
 
 
 
 
 
 
 
 
 
218
  case _:
219
  raise ValueError()
220
 
 
13
  from langchain_openai import ChatOpenAI
14
  from langchain_anthropic import ChatAnthropic
15
  from langchain_together import ChatTogether
16
+ from langchain_google_genai import ChatGoogleGenerativeAI
17
  import streamlit as st
18
 
19
  import utils_mod
 
54
  "cost": {"pmi": 5.00, "pmo": 5.00}
55
  },
56
  }
57
+ GOOGLE_CHAT_MODELS = {
58
+ "gemini-1.5-flash": {"cost": {"pmi": 0.0, "pmo": 0.0}},
59
+ "gemini-1.5-pro": {"cost": {"pmi": 0.0, "pmo": 0.0}},
60
+ "gemini-1.5-pro-exp-0801": {"cost": {"pmi": 0.0, "pmo": 0.0}},
61
+ }
62
+
63
 
64
  PROVIDER_MODELS = {
65
  "OpenAI": OPENAI_CHAT_MODELS,
66
  "Anthropic": ANTHROPIC_CHAT_MODELS,
67
  "Together": TOGETHER_CHAT_MODELS,
68
+ "Google": GOOGLE_CHAT_MODELS,
69
  }
70
 
71
 
 
223
  api_key=st.secrets["together_api_key"],
224
  )
225
 
226
+ case "Google":
227
+ llm = ChatGoogleGenerativeAI(
228
+ model=gen_config["model_name"],
229
+ temperature=gen_config["temperature"],
230
+ api_key=st.secrets["google_api_key"],
231
+ max_output_tokens=gen_config["max_output_tokens"],
232
+ top_p=gen_config["top_p"],
233
+ )
234
+
235
  case _:
236
  raise ValueError()
237
 
usage_mod.py CHANGED
@@ -43,6 +43,20 @@ def get_together_token_usage(response_metadata: dict, model_info: dict):
43
  }
44
 
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
47
  match provider:
48
  case "OpenAI":
@@ -51,6 +65,8 @@ def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
51
  return get_anthropic_token_usage(response_metadata, model_info)
52
  case "Together":
53
  return get_together_token_usage(response_metadata, model_info)
 
 
54
  case _:
55
  raise ValueError()
56
 
 
43
  }
44
 
45
 
46
+ def get_google_token_usage(response_metadata: dict, model_info: dict):
47
+ input_tokens = 0
48
+ output_tokens = 0
49
+ cost = (
50
+ input_tokens * 1e-6 * model_info["cost"]["pmi"]
51
+ + output_tokens * 1e-6 * model_info["cost"]["pmo"]
52
+ )
53
+ return {
54
+ "input_tokens": input_tokens,
55
+ "output_tokens": output_tokens,
56
+ "cost": cost,
57
+ }
58
+
59
+
60
  def get_token_usage(response_metadata: dict, model_info: dict, provider: str):
61
  match provider:
62
  case "OpenAI":
 
65
  return get_anthropic_token_usage(response_metadata, model_info)
66
  case "Together":
67
  return get_together_token_usage(response_metadata, model_info)
68
+ case "Google":
69
+ return get_google_token_usage(response_metadata, model_info)
70
  case _:
71
  raise ValueError()
72