Spaces:
Running
Running
File size: 3,237 Bytes
54a110c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
"""Create and return a model."""
import os
import re
from platform import node
from get_gemini_keys import get_gemini_keys
from loguru import logger
from smolagents import HfApiModel, LiteLLMRouterModel
def get_model(cat: str = "hf", provider=None, model_id=None):
"""
Create and return a model.
Args:
cat: category
provider: for HfApiModel (cat='hf')
model_id: model name
if no gemini_api_keys, return HfApiModel()
"""
if cat.lower() in ["gemini"]:
# get gemini_api_keys
# dedup
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
gemini_api_keys = dict.fromkeys(get_gemini_keys() + _)
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
if not gemini_api_keys:
logger.warning(
"cat='gemini' but no GEMINI_API_KEYS found, "
" returning HfApiModel()..."
" Set env var GEMINI_API_KEYS and/or .env-gemini "
" with free space gemini-api-keys if you want to try 'gemini' "
)
return HfApiModel()
# setup proxy for gemini and for golay (local)
if "golay" in node():
os.environ.update(
HTTPS_PROXY="http://localhost:8081",
HTTP_PROXY="http://localhost:8081",
ALL_PROXY="http://localhost:8081",
NO_PROXY="localhost,127.0.0.1,oracle",
)
if model_id is None:
model_id = "gemini-2.5-flash-preview-04-17"
# model_id = "gemini-2.5-flash-preview-04-17"
llm_loadbalancer_model_list_gemini = []
for api_key in gemini_api_keys:
llm_loadbalancer_model_list_gemini.append(
{
"model_name": "model-group-1",
"litellm_params": {
"model": f"gemini/{model_id}",
"api_key": api_key,
},
},
)
model_id = "deepseek-ai/DeepSeek-V3"
llm_loadbalancer_model_list_siliconflow = [
{
"model_name": "model-group-2",
"litellm_params": {
"model": f"openai/{model_id}",
"api_key": os.getenv("SILICONFLOW_API_KEY"),
"api_base": "https://api.siliconflow.cn/v1",
},
}
]
fallbacks = []
model_list = llm_loadbalancer_model_list_gemini
if os.getenv("SILICONFLOW_API_KEY"):
fallbacks = [{"model-group-1": "model-group-2"}]
model_list += llm_loadbalancer_model_list_siliconflow
model = LiteLLMRouterModel(
model_id="model-group-1",
model_list=model_list,
client_kwargs={
"routing_strategy": "simple-shuffle",
"num_retries": 3,
# "retry_after": 130, # waits min s before retrying request
"fallbacks": fallbacks,
},
)
return model
# if cat.lower() in ["hf"]: default
return HfApiModel(provider=provider, model_id=model_id)
|