File size: 4,074 Bytes
54a110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48ec86e
54a110c
 
 
93b3b82
 
54a110c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48ec86e
54a110c
48ec86e
 
 
 
 
 
 
 
 
 
 
54a110c
 
 
 
 
48ec86e
 
 
 
 
54a110c
 
 
 
 
 
48ec86e
 
54a110c
 
93b3b82
 
 
 
 
 
54a110c
 
93b3b82
54a110c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""Create and return a model."""

import os
import re
from platform import node

from get_gemini_keys import get_gemini_keys
from loguru import logger
from smolagents import HfApiModel, LiteLLMRouterModel


def get_model(cat: str = "hf", provider=None, model_id=None):
    """
    Create and return a model.

    Args:
        cat: category
        provider: for HfApiModel (cat='hf')
        model_id: model name

    if no gemini_api_keys, return HfApiModel()

    """
    if cat.lower() in ["gemini"]:
        # get gemini_api_keys
        # dedup
        _ = re.findall(r"AIzaSy[A-Z][\w-]{32}", os.getenv("GEMINI_API_KEYS", ""))
        gemini_api_keys = [*dict.fromkeys(get_gemini_keys() + _)]

        # assert gemini_api_keys, "No GEMINI_API_KEYS, set env var GEMINI_API_KEYS or put them in .env-gemini and try again."
        if not gemini_api_keys:
            logger.warning("cat='gemini' but no GEMINI_API_KEYS found,  returning HfApiModel()... Set env var GEMINI_API_KEYS and/or .env-gemini  with free space gemini-api-keys if you want to try 'gemini' ")
            logger.info(" set gemini but return HfApiModel()")
            return HfApiModel()

        # setup proxy for gemini and for golay (local)
        if "golay" in node():
            os.environ.update(
                HTTPS_PROXY="http://localhost:8081",
                HTTP_PROXY="http://localhost:8081",
                ALL_PROXY="http://localhost:8081",
                NO_PROXY="localhost,127.0.0.1,oracle",
            )

        if model_id is None:
            model_id = "gemini-2.5-flash-preview-04-17"

        # model_id = "gemini-2.5-flash-preview-04-17"
        llm_loadbalancer_model_list_gemini = []
        for api_key in gemini_api_keys:
            llm_loadbalancer_model_list_gemini.append(
                {
                    "model_name": "model-group-1",
                    "litellm_params": {
                        "model": f"gemini/{model_id}",
                        "api_key": api_key,
                    },
                },
            )

        model_id = "deepseek-ai/DeepSeek-V3"
        llm_loadbalancer_model_list_siliconflow = [
            {
                "model_name": "model-group-2",
                "litellm_params": {
                    "model": f"openai/{model_id}",
                    "api_key": os.getenv("SILICONFLOW_API_KEY"),
                    "api_base": "https://api.siliconflow.cn/v1",
                },
            },
        ]
        
        # gemma-3-27b-it
        llm_loadbalancer_model_list_gemma = [
            {
                "model_name": "model-group-3",
                "litellm_params": {
                    "model": f"gemini/gemma-3-27b-it",
                    "api_key": os.getenv("GEMINI_API_KEY")                },
            },
        ]
            
        fallbacks = []
        model_list = llm_loadbalancer_model_list_gemini
        if os.getenv("SILICONFLOW_API_KEY"):
            fallbacks = [{"model-group-1": "model-group-2"}]
            model_list += llm_loadbalancer_model_list_siliconflow
        
        model_list += llm_loadbalancer_model_list_gemma
        fallbacks13 = [{"model-group-1": "model-group-3"}]
        fallbacks31 = [{"model-group-3": "model-group-1"}]
        
        model = LiteLLMRouterModel(
            model_id="model-group-1",
            model_list=model_list,
            client_kwargs={
                "routing_strategy": "simple-shuffle",
                "num_retries": 3,
                "retry_after": 180,   # waits min  s before retrying request
                "fallbacks": fallbacks13,  # falllacks dont seem to work
            },
        )

        if os.getenv("SILICONFLOW_API_KEY"):
            logger.info(" set gemini, return LiteLLMRouterModel + fallbacks")
        else:
            logger.info(" set gemini, return LiteLLMRouterModel")

        return model

    logger.info(" default return default HfApiModel(provider=None, model_id=None)")
    # if cat.lower() in ["hf"]:  default
    return HfApiModel(provider=provider, model_id=model_id)