final-assignment / get_model.py
ffreemt
Update llama4 models
4d96293
raw
history blame
4.93 kB
"""Create and return a model."""
# ruff: noqa: F841
import os
import re
from platform import node
from loguru import logger
from smolagents import HfApiModel, LiteLLMRouterModel, OpenAIServerModel
from get_gemini_keys import get_gemini_keys
def get_model(cat: str = "hf", provider=None, model_id=None):
"""
Create and return a model.
Args:
cat: category, hf, gemin, llama (default and fallback: hf)
provider: for HfApiModel (cat='hf')
model_id: model name
if no gemini_api_keys, return HfApiModel()
"""
if cat.lower() in ["hf"]:
logger.info(" usiing HfApiModel, make sure you set HF_TOKEN")
return HfApiModel(provider=provider, model_id=model_id)
# setup proxy for gemini and for golay (local tetsin)
if "golay" in node() and cat.lower() in ["gemini", "llama"]:
os.environ.update(
HTTPS_PROXY="http://localhost:8081",
HTTP_PROXY="http://localhost:8081",
ALL_PROXY="http://localhost:8081",
NO_PROXY="localhost,127.0.0.1",
)
if cat.lower() in ["gemini"]:
# get gemini_api_keys
# dedup
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)]
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
if not gemini_api_keys:
logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ")
logger.info(" set gemini but return HfApiModel()")
return HfApiModel()
if model_id is None:
model_id = "gemini-2.5-flash-preview-04-17"
# model_id = "gemini-2.5-flash-preview-04-17"
llm_loadbalancer_model_list_gemini = []
for api_key in gemini_api_keys:
llm_loadbalancer_model_list_gemini.append(
{
"model_name": "model-group-1",
"litellm_params": {
"model": f"gemini/{model_id}",
"api_key": api_key,
},
},
)
model_id = "deepseek-ai/DeepSeek-V3"
llm_loadbalancer_model_list_siliconflow = [
{
"model_name": "model-group-2",
"litellm_params": {
"model": f"openai/{model_id}",
"api_key": os.getenv("SILICONFLOW_API_KEY"),
"api_base": "https://api.siliconflow.cn/v1",
},
},
]
# gemma-3-27b-it
llm_loadbalancer_model_list_gemma = [
{
"model_name": "model-group-3",
"litellm_params": {
"model": "gemini/gemma-3-27b-it",
"api_key": os.getenv("GEMINI_API_KEY") },
},
]
fallbacks = []
model_list = llm_loadbalancer_model_list_gemini
if os.getenv("SILICONFLOW_API_KEY"):
fallbacks = [{"model-group-1": "model-group-2"}]
model_list += llm_loadbalancer_model_list_siliconflow
model_list += llm_loadbalancer_model_list_gemma
fallbacks13 = [{"model-group-1": "model-group-3"}]
fallbacks31 = [{"model-group-3": "model-group-1"}]
model = LiteLLMRouterModel(
model_id="model-group-1",
model_list=model_list,
client_kwargs={
"routing_strategy": "simple-shuffle",
"num_retries": 3,
"retry_after": 180, # waits min s before retrying request
"fallbacks": fallbacks13, # falllacks dont seem to work
},
)
if os.getenv("SILICONFLOW_API_KEY"):
logger.info(" set gemini, return LiteLLMRouterModel + fallbacks")
else:
logger.info(" set gemini, return LiteLLMRouterModel")
return model
if cat.lower() in ["llama"]:
api_key = os.getenv("LLAMA_API_KEY")
if api_key is None:
logger.warning(" LLAMA_API_EY not set, using HfApiModel(), make sure you set HF_TOKEN")
return HfApiModel()
# default model_id
if model_id is None:
model_id = "Llama-4-Maverick-17B-128E-Instruct-FP8"
model_id = "Llama-4-Scout-17B-16E-Instruct-FP8"
model_llama = OpenAIServerModel(
model_id,
api_base="https://api.llama.com/compat/v1",
api_key=api_key,
# temperature=0.,
)
return model_llama
logger.info(" default return default HfApiModel(provider=None, model_id=None)")
# if cat.lower() in ["hf"]: default
return HfApiModel(provider=provider, model_id=model_id)