Spaces:
Running
Running
"""Create and return a model.""" | |
import os | |
import re | |
from platform import node | |
from get_gemini_keys import get_gemini_keys | |
from loguru import logger | |
from smolagents import HfApiModel, LiteLLMRouterModel | |
def get_model(cat: str = "hf", provider=None, model_id=None): | |
""" | |
Create and return a model. | |
Args: | |
cat: category | |
provider: for HfApiModel (cat='hf') | |
model_id: model name | |
if no gemini_api_keys, return HfApiModel() | |
""" | |
if cat.lower() in ["gemini"]: | |
# get gemini_api_keys | |
# dedup | |
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", "")) | |
gemini_api_keys = dict.fromkeys(get_gemini_keys() + _) | |
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again." | |
if not gemini_api_keys: | |
logger.warning( | |
"cat='gemini' but no GEMINI_API_KEYS found, " | |
" returning HfApiModel()..." | |
" Set env var GEMINI_API_KEYS and/or .env-gemini " | |
" with free space gemini-api-keys if you want to try 'gemini' " | |
) | |
return HfApiModel() | |
# setup proxy for gemini and for golay (local) | |
if "golay" in node(): | |
os.environ.update( | |
HTTPS_PROXY="http://localhost:8081", | |
HTTP_PROXY="http://localhost:8081", | |
ALL_PROXY="http://localhost:8081", | |
NO_PROXY="localhost,127.0.0.1,oracle", | |
) | |
if model_id is None: | |
model_id = "gemini-2.5-flash-preview-04-17" | |
# model_id = "gemini-2.5-flash-preview-04-17" | |
llm_loadbalancer_model_list_gemini = [] | |
for api_key in gemini_api_keys: | |
llm_loadbalancer_model_list_gemini.append( | |
{ | |
"model_name": "model-group-1", | |
"litellm_params": { | |
"model": f"gemini/{model_id}", | |
"api_key": api_key, | |
}, | |
}, | |
) | |
model_id = "deepseek-ai/DeepSeek-V3" | |
llm_loadbalancer_model_list_siliconflow = [ | |
{ | |
"model_name": "model-group-2", | |
"litellm_params": { | |
"model": f"openai/{model_id}", | |
"api_key": os.getenv("SILICONFLOW_API_KEY"), | |
"api_base": "https://api.siliconflow.cn/v1", | |
}, | |
} | |
] | |
fallbacks = [] | |
model_list = llm_loadbalancer_model_list_gemini | |
if os.getenv("SILICONFLOW_API_KEY"): | |
fallbacks = [{"model-group-1": "model-group-2"}] | |
model_list += llm_loadbalancer_model_list_siliconflow | |
model = LiteLLMRouterModel( | |
model_id="model-group-1", | |
model_list=model_list, | |
client_kwargs={ | |
"routing_strategy": "simple-shuffle", | |
"num_retries": 3, | |
# "retry_after": 130, # waits min s before retrying request | |
"fallbacks": fallbacks, | |
}, | |
) | |
return model | |
# if cat.lower() in ["hf"]: default | |
return HfApiModel(provider=provider, model_id=model_id) | |