File size: 2,196 Bytes
21eb680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2e43905
21eb680
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import asyncio
import logging
from contextlib import asynccontextmanager
from google import genai
import threading

from config import settings

logger = logging.getLogger(__name__)


class ApiKeyPool:
    """Manage Google API keys with round-robin selection."""

    def __init__(self) -> None:
        self._keys: list[str] | None = None
        self._index = 0
        self._lock = asyncio.Lock()
        self._sync_lock = threading.Lock()

    def _load_keys(self) -> None:
        keys_raw = settings.gemini_api_keys
        keys_str = keys_raw.get_secret_value()
        keys = [k.strip() for k in keys_str.split(',') if k.strip()] if keys_str else []
        if not keys:
            msg = "Google API keys are not configured or invalid"
            logger.error(msg)
            raise ValueError(msg)
        self._keys = keys

    async def get_key(self) -> str:
        async with self._lock:
            if self._keys is None:
                self._load_keys()
            key = self._keys[self._index]
            self._index = (self._index + 1) % len(self._keys)
            logger.debug("Using Google API key index %s", self._index)
            return key

    def get_key_sync(self) -> str:
        """Synchronous helper for environments without an event loop."""
        with self._sync_lock:
            if self._keys is None:
                self._load_keys()
            key = self._keys[self._index]
            self._index = (self._index + 1) % len(self._keys)
            logger.debug("Using Google API key index %s", self._index)
            return key


class GoogleClientFactory:
    """Factory for thread-safe creation of Google GenAI clients."""

    _pool = ApiKeyPool()

    @classmethod
    @asynccontextmanager
    async def image(cls):
        key = await cls._pool.get_key()
        client = genai.Client(api_key=key)
        try:
            yield client.aio
        finally:
            pass

    @classmethod
    @asynccontextmanager
    async def audio(cls):
        key = await cls._pool.get_key()
        client = genai.Client(api_key=key, http_options={"api_version": "v1alpha"})
        try:
            yield client.aio
        finally:
            pass