Spaces:
Running
Running
File size: 5,087 Bytes
54a110c 4d96293 54a110c 2027c04 4d96293 2027c04 4d96293 54a110c 4d96293 54a110c 4d96293 54a110c 48ec86e 54a110c 93b3b82 54a110c 48ec86e 54a110c 4d96293 48ec86e 4d96293 48ec86e 4d96293 54a110c 4d96293 48ec86e 4d96293 54a110c 48ec86e 54a110c 93b3b82 54a110c 4d96293 93b3b82 54a110c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
"""Create and return a model."""
# ruff: noqa: F841
import os
import re
from platform import node
from loguru import logger
from smolagents import InferenceClientModel as HfApiModel
from smolagents import LiteLLMRouterModel, OpenAIServerModel
# FutureWarning: HfApiModel was renamed to InferenceClientModel in version 1.14.0 and will be removed in 1.17.0.
from get_gemini_keys import get_gemini_keys
def get_model(cat: str = "hf", provider=None, model_id=None):
"""
Create and return a model.
Args:
cat: category, hf, gemin, llama (default and fallback: hf)
provider: for HfApiModel (cat='hf')
model_id: model name
if no gemini_api_keys, return HfApiModel()
"""
if cat.lower() in ["hf"]:
logger.info(" usiing HfApiModel, make sure you set HF_TOKEN")
return HfApiModel(provider=provider, model_id=model_id)
# setup proxy for gemini and for golay (local tetsin)
if "golay" in node() and cat.lower() in ["gemini", "llama"]:
os.environ.update(
HTTPS_PROXY="http://localhost:8081",
HTTP_PROXY="http://localhost:8081",
ALL_PROXY="http://localhost:8081",
NO_PROXY="localhost,127.0.0.1",
)
if cat.lower() in ["gemini"]:
# get gemini_api_keys
# dedup
_ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)]
# assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
if not gemini_api_keys:
logger.warning("cat='gemini' but no GEMINI_API_KEYS found, returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini with free space gemini-api-keys if you want to try 'gemini' ")
logger.info(" set gemini but return HfApiModel()")
return HfApiModel()
if model_id is None:
model_id = "gemini-2.5-flash-preview-04-17"
# model_id = "gemini-2.5-flash-preview-04-17"
llm_loadbalancer_model_list_gemini = []
for api_key in gemini_api_keys:
llm_loadbalancer_model_list_gemini.append(
{
"model_name": "model-group-1",
"litellm_params": {
"model": f"gemini/{model_id}",
"api_key": api_key,
},
},
)
model_id = "deepseek-ai/DeepSeek-V3"
llm_loadbalancer_model_list_siliconflow = [
{
"model_name": "model-group-2",
"litellm_params": {
"model": f"openai/{model_id}",
"api_key": os.getenv("SILICONFLOW_API_KEY"),
"api_base": "https://api.siliconflow.cn/v1",
},
},
]
# gemma-3-27b-it
llm_loadbalancer_model_list_gemma = [
{
"model_name": "model-group-3",
"litellm_params": {
"model": "gemini/gemma-3-27b-it",
"api_key": os.getenv("GEMINI_API_KEY") },
},
]
fallbacks = []
model_list = llm_loadbalancer_model_list_gemini
if os.getenv("SILICONFLOW_API_KEY"):
fallbacks = [{"model-group-1": "model-group-2"}]
model_list += llm_loadbalancer_model_list_siliconflow
model_list += llm_loadbalancer_model_list_gemma
fallbacks13 = [{"model-group-1": "model-group-3"}]
fallbacks31 = [{"model-group-3": "model-group-1"}]
model = LiteLLMRouterModel(
model_id="model-group-1",
model_list=model_list,
client_kwargs={
"routing_strategy": "simple-shuffle",
"num_retries": 3,
"retry_after": 180, # waits min s before retrying request
"fallbacks": fallbacks13, # falllacks dont seem to work
},
)
if os.getenv("SILICONFLOW_API_KEY"):
logger.info(" set gemini, return LiteLLMRouterModel + fallbacks")
else:
logger.info(" set gemini, return LiteLLMRouterModel")
return model
if cat.lower() in ["llama"]:
api_key = os.getenv("LLAMA_API_KEY")
if api_key is None:
logger.warning(" LLAMA_API_EY not set, using HfApiModel(), make sure you set HF_TOKEN")
return HfApiModel()
# default model_id
if model_id is None:
model_id = "Llama-4-Maverick-17B-128E-Instruct-FP8"
model_id = "Llama-4-Scout-17B-16E-Instruct-FP8"
model_llama = OpenAIServerModel(
model_id,
api_base="https://api.llama.com/compat/v1",
api_key=api_key,
# temperature=0.,
)
return model_llama
logger.info(" default return default HfApiModel(provider=None, model_id=None)")
# if cat.lower() in ["hf"]: default
return HfApiModel(provider=provider, model_id=model_id)
|