"""Create and return a model.""" import os import re from platform import node from get_gemini_keys import get_gemini_keys from loguru import logger from smolagents import HfApiModel, LiteLLMRouterModel def get_model(cat: str = "hf", provider=None, model_id=None): """ Create and return a model. Args: cat: category provider: for HfApiModel (cat='hf') model_id: model name if no gemini_api_keys, return HfApiModel() """ if cat.lower() in ["gemini"]: # get gemini_api_keys # dedup _ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", "")) gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)] # assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again." if not gemini_api_keys: logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ") logger.info(" set gemini but return HfApiModel()") return HfApiModel() # setup proxy for gemini and for golay (local) if "golay" in node(): os.environ.update( HTTPS_PROXY="http://localhost:8081", HTTP_PROXY="http://localhost:8081", ALL_PROXY="http://localhost:8081", NO_PROXY="localhost,127.0.0.1,oracle", ) if model_id is None: model_id = "gemini-2.5-flash-preview-04-17" # model_id = "gemini-2.5-flash-preview-04-17" llm_loadbalancer_model_list_gemini = [] for api_key in gemini_api_keys: llm_loadbalancer_model_list_gemini.append( { "model_name": "model-group-1", "litellm_params": { "model": f"gemini/{model_id}", "api_key": api_key, }, }, ) model_id = "deepseek-ai/DeepSeek-V3" llm_loadbalancer_model_list_siliconflow = [ { "model_name": "model-group-2", "litellm_params": { "model": f"openai/{model_id}", "api_key": os.getenv("SILICONFLOW_API_KEY"), "api_base": "https://api.siliconflow.cn/v1", }, }, ] # gemma-3-27b-it llm_loadbalancer_model_list_gemma = [ { "model_name": "model-group-3", "litellm_params": { "model": f"gemini/gemma-3-27b-it", "api_key": os.getenv("GEMINI_API_KEY") }, }, ] fallbacks = [] model_list = llm_loadbalancer_model_list_gemini if os.getenv("SILICONFLOW_API_KEY"): fallbacks = [{"model-group-1": "model-group-2"}] model_list += llm_loadbalancer_model_list_siliconflow model_list += llm_loadbalancer_model_list_gemma fallbacks13 = [{"model-group-1": "model-group-3"}] fallbacks31 = [{"model-group-3": "model-group-1"}] model = LiteLLMRouterModel( model_id="model-group-1", model_list=model_list, client_kwargs={ "routing_strategy": "simple-shuffle", "num_retries": 3, "retry_after": 180, # waits min s before retrying request "fallbacks": fallbacks13, # falllacks dont seem to work }, ) if os.getenv("SILICONFLOW_API_KEY"): logger.info(" set gemini, return LiteLLMRouterModel + fallbacks") else: logger.info(" set gemini, return LiteLLMRouterModel") return model logger.info(" default return default HfApiModel(provider=None, model_id=None)") # if cat.lower() in ["hf"]: default return HfApiModel(provider=provider, model_id=model_id)