final-assignment / get_model.py
ffreemt
Fix agent.run, fallbacks somehow dont work
48ec86e
raw
history blame
4.07 kB
"""Create and return a model."""
import os
import re
from platform import node
from get_gemini_keys import get_gemini_keys
from loguru import logger
from smolagents import HfApiModel, LiteLLMRouterModel
def get_model(cat: str = "hf", provider=None, model_id=None):
"""
Create and return a model.
Args:
cat: category
provider: for HfApiModel (cat='hf')
model_id: model name
if no gemini_api_keys, return HfApiModel()
"""
if cat.lower() in ["gemini"]:
# get gemini_api_keys
# dedup
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)]
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
if not gemini_api_keys:
logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ")
logger.info(" set gemini but return HfApiModel()")
return HfApiModel()
# setup proxy for gemini and for golay (local)
if "golay" in node():
os.environ.update(
HTTPS_PROXY="http://localhost:8081",
HTTP_PROXY="http://localhost:8081",
ALL_PROXY="http://localhost:8081",
NO_PROXY="localhost,127.0.0.1,oracle",
)
if model_id is None:
model_id = "gemini-2.5-flash-preview-04-17"
# model_id = "gemini-2.5-flash-preview-04-17"
llm_loadbalancer_model_list_gemini = []
for api_key in gemini_api_keys:
llm_loadbalancer_model_list_gemini.append(
{
"model_name": "model-group-1",
"litellm_params": {
"model": f"gemini/{model_id}",
"api_key": api_key,
},
},
)
model_id = "deepseek-ai/DeepSeek-V3"
llm_loadbalancer_model_list_siliconflow = [
{
"model_name": "model-group-2",
"litellm_params": {
"model": f"openai/{model_id}",
"api_key": os.getenv("SILICONFLOW_API_KEY"),
"api_base": "https://api.siliconflow.cn/v1",
},
},
]
# gemma-3-27b-it
llm_loadbalancer_model_list_gemma = [
{
"model_name": "model-group-3",
"litellm_params": {
"model": f"gemini/gemma-3-27b-it",
"api_key": os.getenv("GEMINI_API_KEY") },
},
]
fallbacks = []
model_list = llm_loadbalancer_model_list_gemini
if os.getenv("SILICONFLOW_API_KEY"):
fallbacks = [{"model-group-1": "model-group-2"}]
model_list += llm_loadbalancer_model_list_siliconflow
model_list += llm_loadbalancer_model_list_gemma
fallbacks13 = [{"model-group-1": "model-group-3"}]
fallbacks31 = [{"model-group-3": "model-group-1"}]
model = LiteLLMRouterModel(
model_id="model-group-1",
model_list=model_list,
client_kwargs={
"routing_strategy": "simple-shuffle",
"num_retries": 3,
"retry_after": 180, # waits min s before retrying request
"fallbacks": fallbacks13, # falllacks dont seem to work
},
)
if os.getenv("SILICONFLOW_API_KEY"):
logger.info(" set gemini, return LiteLLMRouterModel + fallbacks")
else:
logger.info(" set gemini, return LiteLLMRouterModel")
return model
logger.info(" default return default HfApiModel(provider=None, model_id=None)")
# if cat.lower() in ["hf"]: default
return HfApiModel(provider=provider, model_id=model_id)