File size: 16,076 Bytes
e8f9d10
 
 
 
ddd02b3
e8f9d10
 
 
aaf7e4c
 
e8f9d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65c747d
 
 
e8f9d10
65c747d
e8f9d10
de24ee4
e8f9d10
 
65c747d
 
 
 
 
 
 
 
e8f9d10
 
b6efbf5
65c747d
 
e8f9d10
 
 
65c747d
e8f9d10
 
 
 
 
65c747d
e8f9d10
 
 
65c747d
 
 
b6efbf5
e8f9d10
 
 
 
b6efbf5
e8f9d10
65c747d
e8f9d10
65c747d
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
24dd113
 
65c747d
 
24dd113
 
e8f9d10
 
de24ee4
 
 
 
 
 
e8f9d10
 
65c747d
 
 
 
 
b6efbf5
65c747d
 
 
 
 
 
 
e8f9d10
 
ddd02b3
 
 
 
 
 
 
b6efbf5
 
ddd02b3
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
b6efbf5
 
e8f9d10
 
65c747d
b6efbf5
e8f9d10
 
 
b6efbf5
e8f9d10
65c747d
 
e8f9d10
b6efbf5
65c747d
e8f9d10
65c747d
e8f9d10
65c747d
e8f9d10
 
b6efbf5
65c747d
 
 
 
 
 
 
 
 
b6efbf5
65c747d
b6efbf5
65c747d
 
f9ce04f
65c747d
 
 
073aa83
 
 
65c747d
e8f9d10
b6efbf5
65c747d
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
65c747d
 
 
e8f9d10
 
b6efbf5
e8f9d10
b6efbf5
 
e8f9d10
 
65c747d
 
e8f9d10
65c747d
e8f9d10
 
 
65c747d
 
 
e8f9d10
65c747d
e8f9d10
 
 
b6efbf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
b6efbf5
e8f9d10
b6efbf5
 
e8f9d10
 
b6efbf5
65c747d
 
 
b6efbf5
 
65c747d
b6efbf5
 
 
e8f9d10
65c747d
e8f9d10
ddd02b3
 
 
 
 
e8f9d10
b6efbf5
 
e8f9d10
 
aaf7e4c
b6efbf5
 
aaf7e4c
 
b6efbf5
 
 
 
 
 
 
ddd02b3
b6efbf5
aaf7e4c
e8f9d10
65c747d
b6efbf5
65c747d
e8f9d10
 
65c747d
ddd02b3
b6efbf5
e8f9d10
 
b6efbf5
 
e8f9d10
 
ddd02b3
b6efbf5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
65c747d
b6efbf5
65c747d
e8f9d10
 
 
ddd02b3
 
e8f9d10
 
b6efbf5
e8f9d10
ddd02b3
 
b6efbf5
 
 
 
 
 
 
 
 
e8f9d10
 
 
ddd02b3
e8f9d10
b6efbf5
073aa83
e8f9d10
b6efbf5
 
 
 
 
e8f9d10
ddd02b3
b6efbf5
 
 
 
 
 
073aa83
65c747d
b6efbf5
 
 
 
 
65c747d
073aa83
65c747d
073aa83
b6efbf5
65c747d
 
e8f9d10
b6efbf5
 
 
 
 
 
 
 
073aa83
 
 
 
 
e8f9d10
65c747d
 
073aa83
e8f9d10
 
 
073aa83
b6efbf5
 
073aa83
b6efbf5
073aa83
 
b6efbf5
073aa83
e8f9d10
 
 
 
b6efbf5
e8f9d10
b6efbf5
65c747d
 
e8f9d10
 
65c747d
e8f9d10
b6efbf5
65c747d
 
b6efbf5
e8f9d10
65c747d
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
from __future__ import annotations

import logging
from enum import Enum
from typing import List, Union, Dict, Optional, NamedTuple, Any
from dataclasses import dataclass
from pathlib import Path
from io import BytesIO
from hashlib import md5
from cachetools import LRUCache

import requests
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModel

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TextModelType(str, Enum):
    """
    Enumeration of supported text models.
    """

    MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
    MULTILINGUAL_E5_BASE = "multilingual-e5-base"
    MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
    SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0"
    PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
    PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2"
    BGE_M3 = "bge-m3"
    GTE_MULTILINGUAL_BASE = "gte-multilingual-base"


class ImageModelType(str, Enum):
    """
    Enumeration of supported image models.
    """

    SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"


class ModelInfo(NamedTuple):
    """
    This container maps an enum to:
      - model_id: Hugging Face model ID (or local path)
      - onnx_file: Path to ONNX file (if available)
    """

    model_id: str
    onnx_file: Optional[str] = None


@dataclass
class ModelConfig:
    """
    Configuration for text and image models.
    """

    text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
    image_model_type: ImageModelType = (
        ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
    )
    logit_scale: float = 4.60517  # Example scale used in cross-modal similarity

    @property
    def text_model_info(self) -> ModelInfo:
        """
        Returns ModelInfo for the configured text_model_type.
        """
        text_configs = {
            TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
                model_id="Xenova/multilingual-e5-small",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.MULTILINGUAL_E5_BASE: ModelInfo(
                model_id="Xenova/multilingual-e5-base",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo(
                model_id="Xenova/multilingual-e5-large",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo(
                model_id="Snowflake/snowflake-arctic-embed-l-v2.0",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
                model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo(
                model_id="Xenova/paraphrase-multilingual-mpnet-base-v2",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.BGE_M3: ModelInfo(
                model_id="Xenova/bge-m3",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo(
                model_id="onnx-community/gte-multilingual-base",
                onnx_file="onnx/model_quantized.onnx",
            ),
        }
        return text_configs[self.text_model_type]

    @property
    def image_model_info(self) -> ModelInfo:
        """
        Returns ModelInfo for the configured image_model_type.
        """
        image_configs = {
            ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
                model_id="google/siglip-base-patch16-256-multilingual"
            ),
        }
        return image_configs[self.image_model_type]


class ModelKind(str, Enum):
    TEXT = "text"
    IMAGE = "image"


def detect_model_kind(model_id: str) -> ModelKind:
    """
    Detect whether model_id belongs to a text or an image model.
    Raises ValueError if the model is not recognized.
    """
    if model_id in [m.value for m in TextModelType]:
        return ModelKind.TEXT
    elif model_id in [m.value for m in ImageModelType]:
        return ModelKind.IMAGE
    else:
        raise ValueError(
            f"Unrecognized model ID: {model_id}.\n"
            f"Valid text: {[m.value for m in TextModelType]}\n"
            f"Valid image: {[m.value for m in ImageModelType]}"
        )


class EmbeddingsService:
    """
    Service for generating text/image embeddings and performing similarity ranking.
    Batch size has been removed. Single or multiple inputs are handled uniformly.
    """

    def __init__(self, config: Optional[ModelConfig] = None):
        self.lru_cache = LRUCache(maxsize=10_000)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.config = config or ModelConfig()

        # Dictionaries to hold preloaded models
        self.text_models: Dict[TextModelType, SentenceTransformer] = {}
        self.image_models: Dict[ImageModelType, AutoModel] = {}
        self.image_processors: Dict[ImageModelType, AutoProcessor] = {}

        # Load all relevant models on init
        self._load_all_models()

    def _load_all_models(self) -> None:
        """
        Pre-load all known text and image models for quick switching.
        """
        try:
            # Preload text models
            for t_model_type in TextModelType:
                info = ModelConfig(text_model_type=t_model_type).text_model_info
                logger.info("Loading text model: %s", info.model_id)

                if info.onnx_file:
                    logger.info("Using ONNX file: %s", info.onnx_file)
                    self.text_models[t_model_type] = SentenceTransformer(
                        info.model_id,
                        device=self.device,
                        backend="onnx",
                        model_kwargs={
                            "provider": "CPUExecutionProvider",
                            "file_name": info.onnx_file,
                        },
                        trust_remote_code=True,
                    )
                else:
                    self.text_models[t_model_type] = SentenceTransformer(
                        info.model_id,
                        device=self.device,
                        trust_remote_code=True,
                    )

            # Preload image models
            for i_model_type in ImageModelType:
                model_id = ModelConfig(
                    image_model_type=i_model_type
                ).image_model_info.model_id
                logger.info("Loading image model: %s", model_id)

                model = AutoModel.from_pretrained(model_id).to(self.device)
                processor = AutoProcessor.from_pretrained(model_id)

                self.image_models[i_model_type] = model
                self.image_processors[i_model_type] = processor

            logger.info("All models loaded successfully.")
        except Exception as e:
            msg = f"Error loading models: {str(e)}"
            logger.error(msg)
            raise RuntimeError(msg) from e

    @staticmethod
    def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
        """
        Convert text input into a non-empty list of strings.
        Raises ValueError if the input is invalid.
        """
        if isinstance(input_text, str):
            if not input_text.strip():
                raise ValueError("Text input cannot be empty.")
            return [input_text]

        if not isinstance(input_text, list) or not all(
            isinstance(x, str) for x in input_text
        ):
            raise ValueError("Text input must be a string or a list of strings.")

        if len(input_text) == 0:
            raise ValueError("Text input list cannot be empty.")

        return input_text

    @staticmethod
    def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
        """
        Convert image input into a non-empty list of image paths/URLs.
        Raises ValueError if the input is invalid.
        """
        if isinstance(input_images, str):
            if not input_images.strip():
                raise ValueError("Image input cannot be empty.")
            return [input_images]

        if not isinstance(input_images, list) or not all(
            isinstance(x, str) for x in input_images
        ):
            raise ValueError("Image input must be a string or a list of strings.")

        if len(input_images) == 0:
            raise ValueError("Image input list cannot be empty.")

        return input_images

    def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
        """
        Loads and processes a single image from local path or URL.
        Returns a dictionary of tensors ready for the model.
        """
        try:
            if path_or_url.startswith("http"):
                resp = requests.get(path_or_url, timeout=10)
                resp.raise_for_status()
                img = Image.open(BytesIO(resp.content)).convert("RGB")
            else:
                img = Image.open(Path(path_or_url)).convert("RGB")

            processor = self.image_processors[self.config.image_model_type]
            processed_data = processor(images=img, return_tensors="pt").to(self.device)
            return processed_data
        except Exception as e:
            raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e

    def _generate_text_embeddings(
        self,
        model_id: TextModelType,
        texts: List[str],
    ) -> np.ndarray:
        """
        Generates text embeddings using the SentenceTransformer-based model.
        Utilizes an LRU cache for single-input scenarios.
        """
        try:
            if len(texts) == 1:
                single_text = texts[0]
                key = md5(single_text.encode("utf-8")).hexdigest()
                if key in self.lru_cache:
                    return self.lru_cache[key]

                model = self.text_models[model_id]
                emb = model.encode([single_text])
                self.lru_cache[key] = emb
                return emb

            # For multiple texts, no LRU cache is used
            model = self.text_models[model_id]
            return model.encode(texts)

        except Exception as e:
            raise RuntimeError(
                f"Error generating text embeddings with model '{model_id}': {e}"
            ) from e

    def _generate_image_embeddings(
        self,
        model_id: ImageModelType,
        images: List[str],
    ) -> np.ndarray:
        """
        Generates image embeddings using the CLIP-like transformer model.
        Handles single or multiple images uniformly (no batch size parameter).
        """
        try:
            model = self.image_models[model_id]
            # Collect processed inputs in a single batch
            processed_tensors = []
            for img_path in images:
                processed_tensors.append(self._process_image(img_path))

            # Keys should be the same for all processed outputs
            keys = processed_tensors[0].keys()
            # Concatenate along the batch dimension
            combined = {
                k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
            }

            with torch.no_grad():
                embeddings = model.get_image_features(**combined)
            return embeddings.cpu().numpy()

        except Exception as e:
            raise RuntimeError(
                f"Error generating image embeddings with model '{model_id}': {e}"
            ) from e

    async def generate_embeddings(
        self,
        model: str,
        inputs: Union[str, List[str]],
    ) -> np.ndarray:
        """
        Asynchronously generates embeddings for either text or image based on the model type.
        """
        modality = detect_model_kind(model)

        if modality == ModelKind.TEXT:
            text_model_id = TextModelType(model)
            text_list = self._validate_text_list(inputs)
            return self._generate_text_embeddings(text_model_id, text_list)

        elif modality == ModelKind.IMAGE:
            image_model_id = ImageModelType(model)
            image_list = self._validate_image_list(inputs)
            return self._generate_image_embeddings(image_model_id, image_list)

    async def rank(
        self,
        model: str,
        queries: Union[str, List[str]],
        candidates: Union[str, List[str]],
    ) -> Dict[str, Any]:
        """
        Ranks text `candidates` given `queries`, which can be text or images.
        Always returns a dictionary of { probabilities, cosine_similarities, usage }.

        Note: This implementation uses the same model for both queries and candidates.
              For true cross-modal ranking, you might need separate models or a shared model.
        """
        modality = detect_model_kind(model)

        # Convert the string model to the appropriate enum
        if modality == ModelKind.TEXT:
            model_enum = TextModelType(model)
        else:
            model_enum = ImageModelType(model)

        # 1) Generate embeddings for queries
        query_embeds = await self.generate_embeddings(model_enum.value, queries)

        # 2) Generate embeddings for candidates (assumed text if queries are text;
        #    or if queries are images, also use the image model for candidates).
        candidate_embeds = await self.generate_embeddings(model_enum.value, candidates)

        # 3) Compute cosine similarity
        sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)

        # 4) Apply logit scale + softmax to obtain probabilities
        scaled = np.exp(self.config.logit_scale) * sim_matrix
        probs = self.softmax(scaled)

        # 5) Estimate token usage if we're dealing with text
        if modality == ModelKind.TEXT:
            query_tokens = self.estimate_tokens(queries)
            candidate_tokens = self.estimate_tokens(candidates)
            total_tokens = query_tokens + candidate_tokens
        else:
            total_tokens = 0

        usage = {
            "prompt_tokens": total_tokens,
            "total_tokens": total_tokens,
        }

        return {
            "probabilities": probs.tolist(),
            "cosine_similarities": sim_matrix.tolist(),
            "usage": usage,
        }

    def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
        """
        Estimates token count using the SentenceTransformer tokenizer.
        Only applicable if the current configured model is a text model.
        """
        texts = self._validate_text_list(input_data)
        model = self.text_models[self.config.text_model_type]
        tokenized = model.tokenize(texts)
        # Summing over the lengths of input_ids for each example
        return sum(len(ids) for ids in tokenized["input_ids"])

    @staticmethod
    def softmax(scores: np.ndarray) -> np.ndarray:
        """
        Applies the standard softmax function along the last dimension.
        """
        # Stabilize scores by subtracting max
        exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
        return exps / np.sum(exps, axis=-1, keepdims=True)

    @staticmethod
    def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        """
        Computes the pairwise cosine similarity between all rows of a and b.
        a: (N, D)
        b: (M, D)
        Return: (N, M) matrix of cosine similarities
        """
        a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
        b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
        return np.dot(a_norm, b_norm.T)