File size: 5,087 Bytes
54a110c
4d96293
54a110c
 
 
 
 
 
2027c04
 
4d96293
2027c04
4d96293
54a110c
 
 
 
 
 
 
4d96293
54a110c
 
 
 
 
 
4d96293
 
 
 
 
 
 
 
 
 
 
 
 
54a110c
 
 
 
48ec86e
54a110c
 
 
93b3b82
 
54a110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48ec86e
54a110c
4d96293
48ec86e
 
 
 
 
4d96293
48ec86e
 
 
4d96293
54a110c
 
 
 
 
4d96293
48ec86e
 
 
4d96293
54a110c
 
 
 
 
 
48ec86e
 
54a110c
 
93b3b82
 
 
 
 
 
54a110c
 
4d96293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93b3b82
54a110c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Create and return a model."""
# ruff: noqa: F841

import os
import re
from platform import node

from loguru import logger
from smolagents import InferenceClientModel as HfApiModel
from smolagents import LiteLLMRouterModel, OpenAIServerModel

# FutureWarning: HfApiModel was renamed to InferenceClientModel in version 1.14.0 and will be removed in 1.17.0.
from get_gemini_keys import get_gemini_keys


def get_model(cat: str = "hf", provider=None, model_id=None):
    """
    Create and return a model.

    Args:
        cat: category, hf, gemin, llama (default and fallback: hf)
        provider: for HfApiModel (cat='hf')
        model_id: model name

    if no gemini_api_keys, return HfApiModel()

    """
    if cat.lower() in ["hf"]:
        logger.info(" usiing HfApiModel, make sure you set HF_TOKEN")
        return HfApiModel(provider=provider, model_id=model_id)

    # setup proxy for gemini and for golay (local tetsin)
    if "golay" in node() and cat.lower() in ["gemini", "llama"]:
        os.environ.update(
            HTTPS_PROXY="http://localhost:8081",
            HTTP_PROXY="http://localhost:8081",
            ALL_PROXY="http://localhost:8081",
            NO_PROXY="localhost,127.0.0.1",
        )

    if cat.lower() in ["gemini"]:
        # get gemini_api_keys
        # dedup
        _ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
        gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)]

        # assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
        if not gemini_api_keys:
            logger.warning("cat='gemini' but no GEMINI_API_KEYS found,  returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini  with free space gemini-api-keys if you want to try 'gemini' ")
            logger.info(" set gemini but return HfApiModel()")
            return HfApiModel()

        if model_id is None:
            model_id = "gemini-2.5-flash-preview-04-17"

        # model_id = "gemini-2.5-flash-preview-04-17"
        llm_loadbalancer_model_list_gemini = []
        for api_key in gemini_api_keys:
            llm_loadbalancer_model_list_gemini.append(
                {
                    "model_name": "model-group-1",
                    "litellm_params": {
                        "model": f"gemini/{model_id}",
                        "api_key": api_key,
                    },
                },
            )

        model_id = "deepseek-ai/DeepSeek-V3"
        llm_loadbalancer_model_list_siliconflow = [
            {
                "model_name": "model-group-2",
                "litellm_params": {
                    "model": f"openai/{model_id}",
                    "api_key": os.getenv("SILICONFLOW_API_KEY"),
                    "api_base": "https://api.siliconflow.cn/v1",
                },
            },
        ]

        # gemma-3-27b-it
        llm_loadbalancer_model_list_gemma = [
            {
                "model_name": "model-group-3",
                "litellm_params": {
                    "model": "gemini/gemma-3-27b-it",
                    "api_key": os.getenv("GEMINI_API_KEY")                },
            },
        ]

        fallbacks = []
        model_list = llm_loadbalancer_model_list_gemini
        if os.getenv("SILICONFLOW_API_KEY"):
            fallbacks = [{"model-group-1": "model-group-2"}]
            model_list += llm_loadbalancer_model_list_siliconflow

        model_list += llm_loadbalancer_model_list_gemma
        fallbacks13 = [{"model-group-1": "model-group-3"}]
        fallbacks31 = [{"model-group-3": "model-group-1"}]

        model = LiteLLMRouterModel(
            model_id="model-group-1",
            model_list=model_list,
            client_kwargs={
                "routing_strategy": "simple-shuffle",
                "num_retries": 3,
                "retry_after": 180,   # waits min  s before retrying request
                "fallbacks": fallbacks13,  # falllacks dont seem to work
            },
        )

        if os.getenv("SILICONFLOW_API_KEY"):
            logger.info(" set gemini, return LiteLLMRouterModel + fallbacks")
        else:
            logger.info(" set gemini, return LiteLLMRouterModel")

        return model

    if cat.lower() in ["llama"]:
        api_key = os.getenv("LLAMA_API_KEY")
        if api_key is None:
            logger.warning(" LLAMA_API_EY not set, using HfApiModel(), make sure you set HF_TOKEN")
            return HfApiModel()

        # default model_id
        if model_id is None:
            model_id = "Llama-4-Maverick-17B-128E-Instruct-FP8"
            model_id = "Llama-4-Scout-17B-16E-Instruct-FP8"
        model_llama = OpenAIServerModel(
            model_id,
            api_base="https://api.llama.com/compat/v1",
            api_key=api_key,
            # temperature=0.,
        )
        return model_llama

    logger.info(" default return default HfApiModel(provider=None, model_id=None)")
    # if cat.lower() in ["hf"]:  default
    return HfApiModel(provider=provider, model_id=model_id)