Spaces:
Running
Running
"""Create and return a model.""" | |
# ruff: noqa: F841 | |
import os | |
import re | |
from platform import node | |
from loguru import logger | |
from smolagents import InferenceClientModel as HfApiModel | |
from smolagents import LiteLLMRouterModel, OpenAIServerModel | |
# FutureWarning: HfApiModel was renamed to InferenceClientModel in version 1.14.0 and will be removed in 1.17.0. | |
from get_gemini_keys import get_gemini_keys | |
def get_model(cat: str = "hf", provider=None, model_id=None): | |
""" | |
Create and return a model. | |
Args: | |
cat: category, hf, gemin, llama (default and fallback: hf) | |
provider: for HfApiModel (cat='hf') | |
model_id: model name | |
if no gemini_api_keys, return HfApiModel() | |
""" | |
if cat.lower() in ["hf"]: | |
logger.info(" usiing HfApiModel, make sure you set HF_TOKEN") | |
return HfApiModel(provider=provider, model_id=model_id) | |
# setup proxy for gemini and for golay (local tetsin) | |
if "golay" in node() and cat.lower() in ["gemini", "llama"]: | |
os.environ.update( | |
HTTPS_PROXY="http://localhost:8081", | |
HTTP_PROXY="http://localhost:8081", | |
ALL_PROXY="http://localhost:8081", | |
NO_PROXY="localhost,127.0.0.1", | |
) | |
if cat.lower() in ["gemini"]: | |
# get gemini_api_keys | |
# dedup | |
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", "")) | |
gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)] | |
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again." | |
if not gemini_api_keys: | |
logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ") | |
logger.info(" set gemini but return HfApiModel()") | |
return HfApiModel() | |
if model_id is None: | |
model_id = "gemini-2.5-flash-preview-04-17" | |
# model_id = "gemini-2.5-flash-preview-04-17" | |
llm_loadbalancer_model_list_gemini = [] | |
for api_key in gemini_api_keys: | |
llm_loadbalancer_model_list_gemini.append( | |
{ | |
"model_name": "model-group-1", | |
"litellm_params": { | |
"model": f"gemini/{model_id}", | |
"api_key": api_key, | |
}, | |
}, | |
) | |
model_id = "deepseek-ai/DeepSeek-V3" | |
llm_loadbalancer_model_list_siliconflow = [ | |
{ | |
"model_name": "model-group-2", | |
"litellm_params": { | |
"model": f"openai/{model_id}", | |
"api_key": os.getenv("SILICONFLOW_API_KEY"), | |
"api_base": "https://api.siliconflow.cn/v1", | |
}, | |
}, | |
] | |
# gemma-3-27b-it | |
llm_loadbalancer_model_list_gemma = [ | |
{ | |
"model_name": "model-group-3", | |
"litellm_params": { | |
"model": "gemini/gemma-3-27b-it", | |
"api_key": os.getenv("GEMINI_API_KEY") }, | |
}, | |
] | |
fallbacks = [] | |
model_list = llm_loadbalancer_model_list_gemini | |
if os.getenv("SILICONFLOW_API_KEY"): | |
fallbacks = [{"model-group-1": "model-group-2"}] | |
model_list += llm_loadbalancer_model_list_siliconflow | |
model_list += llm_loadbalancer_model_list_gemma | |
fallbacks13 = [{"model-group-1": "model-group-3"}] | |
fallbacks31 = [{"model-group-3": "model-group-1"}] | |
model = LiteLLMRouterModel( | |
model_id="model-group-1", | |
model_list=model_list, | |
client_kwargs={ | |
"routing_strategy": "simple-shuffle", | |
"num_retries": 3, | |
"retry_after": 180, # waits min s before retrying request | |
"fallbacks": fallbacks13, # falllacks dont seem to work | |
}, | |
) | |
if os.getenv("SILICONFLOW_API_KEY"): | |
logger.info(" set gemini, return LiteLLMRouterModel + fallbacks") | |
else: | |
logger.info(" set gemini, return LiteLLMRouterModel") | |
return model | |
if cat.lower() in ["llama"]: | |
api_key = os.getenv("LLAMA_API_KEY") | |
if api_key is None: | |
logger.warning(" LLAMA_API_EY not set, using HfApiModel(), make sure you set HF_TOKEN") | |
return HfApiModel() | |
# default model_id | |
if model_id is None: | |
model_id = "Llama-4-Maverick-17B-128E-Instruct-FP8" | |
model_id = "Llama-4-Scout-17B-16E-Instruct-FP8" | |
model_llama = OpenAIServerModel( | |
model_id, | |
api_base="https://api.llama.com/compat/v1", | |
api_key=api_key, | |
# temperature=0., | |
) | |
return model_llama | |
logger.info(" default return default HfApiModel(provider=None, model_id=None)") | |
# if cat.lower() in ["hf"]: default | |
return HfApiModel(provider=provider, model_id=model_id) | |