File size: 20,635 Bytes
422dfa9
 
e8f9d10
 
716eebd
e8f9d10
9604bdd
e8f9d10
ddd02b3
e8f9d10
 
 
aaf7e4c
 
e8f9d10
716eebd
e8f9d10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65c747d
 
 
e8f9d10
65c747d
e8f9d10
de24ee4
e8f9d10
 
65c747d
 
 
 
 
 
 
422dfa9
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
716eebd
e8f9d10
 
65c747d
e8f9d10
 
 
 
 
65c747d
e8f9d10
 
9604bdd
b6efbf5
e8f9d10
 
 
 
716eebd
e8f9d10
65c747d
e8f9d10
65c747d
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
24dd113
 
65c747d
 
24dd113
 
e8f9d10
 
de24ee4
 
 
 
 
 
e8f9d10
 
65c747d
 
 
 
 
716eebd
65c747d
 
 
 
 
 
 
e8f9d10
 
ddd02b3
716eebd
 
 
ddd02b3
 
 
 
 
 
716eebd
 
 
 
ddd02b3
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
 
b6efbf5
716eebd
e8f9d10
 
65c747d
716eebd
 
 
 
b6efbf5
e8f9d10
 
 
716eebd
e8f9d10
65c747d
 
e8f9d10
9604bdd
 
 
 
 
716eebd
 
 
 
65c747d
e8f9d10
65c747d
e8f9d10
716eebd
e8f9d10
 
716eebd
65c747d
 
 
 
 
 
 
 
b6efbf5
65c747d
b6efbf5
65c747d
 
f9ce04f
65c747d
 
 
073aa83
 
 
65c747d
422dfa9
 
 
9604bdd
e8f9d10
716eebd
65c747d
9604bdd
65c747d
 
716eebd
65c747d
 
 
 
 
e8f9d10
65c747d
 
 
e8f9d10
 
b6efbf5
e8f9d10
716eebd
 
 
 
e8f9d10
 
65c747d
 
e8f9d10
65c747d
9604bdd
65c747d
 
 
e8f9d10
65c747d
e8f9d10
 
 
b6efbf5
 
716eebd
 
 
 
b6efbf5
 
 
 
 
 
9604bdd
b6efbf5
 
 
 
 
 
e8f9d10
422dfa9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9604bdd
422dfa9
 
 
 
 
9604bdd
 
 
422dfa9
 
 
 
 
716eebd
e8f9d10
716eebd
 
 
 
 
 
 
 
 
 
e8f9d10
 
b6efbf5
716eebd
 
 
 
 
b6efbf5
716eebd
 
 
e8f9d10
716eebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
9604bdd
e8f9d10
716eebd
 
 
422dfa9
 
 
 
716eebd
 
 
 
 
e8f9d10
 
ddd02b3
9604bdd
 
 
 
 
 
 
 
 
 
 
 
e8f9d10
65c747d
b6efbf5
65c747d
e8f9d10
9604bdd
e8f9d10
716eebd
 
 
 
 
422dfa9
 
 
 
716eebd
 
 
 
 
e8f9d10
 
716eebd
 
 
 
 
b6efbf5
9604bdd
b6efbf5
716eebd
 
9604bdd
716eebd
e8f9d10
716eebd
e8f9d10
65c747d
b6efbf5
65c747d
e8f9d10
9604bdd
e8f9d10
716eebd
 
 
 
 
 
 
 
e8f9d10
ddd02b3
b6efbf5
422dfa9
b6efbf5
422dfa9
9604bdd
 
 
 
716eebd
422dfa9
716eebd
b6efbf5
422dfa9
b6efbf5
9604bdd
e8f9d10
9604bdd
e8f9d10
716eebd
 
b6efbf5
422dfa9
 
 
 
 
716eebd
 
e8f9d10
ddd02b3
b6efbf5
 
 
 
073aa83
716eebd
 
9604bdd
 
65c747d
716eebd
65c747d
 
 
e8f9d10
b6efbf5
 
 
 
 
 
 
073aa83
 
 
 
 
e8f9d10
65c747d
 
073aa83
e8f9d10
 
 
073aa83
716eebd
 
422dfa9
 
 
716eebd
 
073aa83
b6efbf5
073aa83
 
 
e8f9d10
 
 
 
716eebd
 
422dfa9
 
 
716eebd
422dfa9
e8f9d10
65c747d
 
e8f9d10
 
65c747d
e8f9d10
716eebd
 
422dfa9
 
 
 
716eebd
 
e8f9d10
65c747d
 
 
716eebd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
# filename: service.py

from __future__ import annotations

import asyncio
import logging
import threading
from enum import Enum
from typing import List, Union, Dict, Optional, NamedTuple, Any
from dataclasses import dataclass
from pathlib import Path
from io import BytesIO
from hashlib import md5
from cachetools import LRUCache

import httpx
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModel

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)


class TextModelType(str, Enum):
    """
    Enumeration of supported text models.
    """
    MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
    MULTILINGUAL_E5_BASE = "multilingual-e5-base"
    MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
    SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0"
    PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
    PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2"
    BGE_M3 = "bge-m3"
    GTE_MULTILINGUAL_BASE = "gte-multilingual-base"


class ImageModelType(str, Enum):
    """
    Enumeration of supported image models.
    """
    SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"


class MaxModelLength(str, Enum):
    """
    Enumeration of maximum token lengths for supported text models.
    """
    MULTILINGUAL_E5_SMALL = 512
    MULTILINGUAL_E5_BASE = 512
    MULTILINGUAL_E5_LARGE = 512
    SNOWFLAKE_ARCTIC_EMBED_L_V2 = 8192
    PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = 128
    PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = 128
    BGE_M3 = 8192
    GTE_MULTILINGUAL_BASE = 8192


class ModelInfo(NamedTuple):
    """
    Container mapping a model type to its model identifier and optional ONNX file.
    """
    model_id: str
    onnx_file: Optional[str] = None


@dataclass
class ModelConfig:
    """
    Configuration for text and image models.
    """
    text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
    image_model_type: ImageModelType = ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
    logit_scale: float = 4.60517  # Example scale used in cross-modal similarity

    @property
    def text_model_info(self) -> ModelInfo:
        """
        Return model information for the configured text model.
        """
        text_configs = {
            TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
                model_id="Xenova/multilingual-e5-small",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.MULTILINGUAL_E5_BASE: ModelInfo(
                model_id="Xenova/multilingual-e5-base",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo(
                model_id="Xenova/multilingual-e5-large",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo(
                model_id="Snowflake/snowflake-arctic-embed-l-v2.0",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
                model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo(
                model_id="Xenova/paraphrase-multilingual-mpnet-base-v2",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.BGE_M3: ModelInfo(
                model_id="Xenova/bge-m3",
                onnx_file="onnx/model_quantized.onnx",
            ),
            TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo(
                model_id="onnx-community/gte-multilingual-base",
                onnx_file="onnx/model_quantized.onnx",
            ),
        }
        return text_configs[self.text_model_type]

    @property
    def image_model_info(self) -> ModelInfo:
        """
        Return model information for the configured image model.
        """
        image_configs = {
            ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
                model_id="google/siglip-base-patch16-256-multilingual"
            ),
        }
        return image_configs[self.image_model_type]


class ModelKind(str, Enum):
    """
    Indicates the type of model: text or image.
    """
    TEXT = "text"
    IMAGE = "image"


def detect_model_kind(model_id: str) -> ModelKind:
    """
    Detect whether the model identifier corresponds to a text or image model.

    Raises:
        ValueError: If the model identifier is unrecognized.
    """
    if model_id in [m.value for m in TextModelType]:
        return ModelKind.TEXT
    elif model_id in [m.value for m in ImageModelType]:
        return ModelKind.IMAGE
    else:
        raise ValueError(
            f"Unrecognized model ID: {model_id}.\n"
            f"Valid text: {[m.value for m in TextModelType]}\n"
            f"Valid image: {[m.value for m in ImageModelType]}"
        )


class EmbeddingsService:
    """
    Service for generating text/image embeddings and performing similarity ranking.
    Asynchronous methods are used to maximize throughput and avoid blocking the event loop.
    """

    def __init__(self, config: Optional[ModelConfig] = None):
        """
        Initialize the service by setting up model caches, device configuration,
        and asynchronous HTTP client.
        """
        self.lru_cache = LRUCache(maxsize=10_000)
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.config = config or ModelConfig()

        # Dictionaries to hold preloaded models.
        self.text_models: Dict[TextModelType, SentenceTransformer] = {}
        self.image_models: Dict[ImageModelType, AutoModel] = {}
        self.image_processors: Dict[ImageModelType, AutoProcessor] = {}

        # Create reentrant locks for each text model to ensure thread safety.
        self.text_model_locks: Dict[TextModelType, threading.RLock] = {
            t: threading.RLock() for t in TextModelType
        }

        # Create a persistent asynchronous HTTP client.
        self.async_http_client = httpx.AsyncClient(timeout=10)

        # Preload all models.
        self._load_all_models()

    def _load_all_models(self) -> None:
        """
        Pre-load all text and image models to minimize latency at request time.
        """
        try:
            # Preload text models.
            for t_model_type in TextModelType:
                info = ModelConfig(text_model_type=t_model_type).text_model_info
                logger.info("Loading text model: %s", info.model_id)
                if info.onnx_file:
                    logger.info("Using ONNX file: %s", info.onnx_file)
                    self.text_models[t_model_type] = SentenceTransformer(
                        info.model_id,
                        device=self.device,
                        backend="onnx",
                        model_kwargs={
                            "provider": "CPUExecutionProvider",
                            "file_name": info.onnx_file,
                        },
                        trust_remote_code=True,
                    )
                else:
                    self.text_models[t_model_type] = SentenceTransformer(
                        info.model_id,
                        device=self.device,
                        trust_remote_code=True,
                    )
                # Set maximum sequence length based on configuration.
                max_length = int(MaxModelLength[t_model_type.name].value)
                self.text_models[t_model_type].max_seq_length = max_length
                logger.info("Set max_seq_length=%d for text model: %s", max_length, info.model_id)

            # Preload image models.
            for i_model_type in ImageModelType:
                model_id = ModelConfig(image_model_type=i_model_type).image_model_info.model_id
                logger.info("Loading image model: %s", model_id)
                model = AutoModel.from_pretrained(model_id).to(self.device)
                model.eval()  # Set the model to evaluation mode.
                processor = AutoProcessor.from_pretrained(model_id)
                self.image_models[i_model_type] = model
                self.image_processors[i_model_type] = processor

            logger.info("All models loaded successfully.")
        except Exception as e:
            msg = f"Error loading models: {str(e)}"
            logger.error(msg)
            raise RuntimeError(msg) from e

    @staticmethod
    def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
        """
        Validate and convert text input into a non-empty list of strings.

        Raises:
            ValueError: If the input is invalid.
        """
        if isinstance(input_text, str):
            if not input_text.strip():
                raise ValueError("Text input cannot be empty.")
            return [input_text]

        if not isinstance(input_text, list) or not all(isinstance(x, str) for x in input_text):
            raise ValueError("Text input must be a string or a list of strings.")

        if len(input_text) == 0:
            raise ValueError("Text input list cannot be empty.")

        return input_text

    @staticmethod
    def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
        """
        Validate and convert image input into a non-empty list of image paths/URLs.

        Raises:
            ValueError: If the input is invalid.
        """
        if isinstance(input_images, str):
            if not input_images.strip():
                raise ValueError("Image input cannot be empty.")
            return [input_images]

        if not isinstance(input_images, list) or not all(isinstance(x, str) for x in input_images):
            raise ValueError("Image input must be a string or a list of strings.")

        if len(input_images) == 0:
            raise ValueError("Image input list cannot be empty.")

        return input_images

    def _truncate_text(self, text: str, model: SentenceTransformer) -> str:
        """
        Truncate the input text to the maximum allowed tokens for the given model.

        Args:
            text: The input text.
            model: The SentenceTransformer model used for tokenization.

        Returns:
            The truncated text if token length exceeds the maximum allowed length,
            otherwise the original text.
        """
        try:
            # Attempt to get the tokenizer from the first module of the SentenceTransformer.
            module = model._first_module()
            if not hasattr(module, 'tokenizer'):
                return text
            tokenizer = module.tokenizer
            # Tokenize without truncation.
            encoded = tokenizer(text, add_special_tokens=True, truncation=False)
            max_length = model.max_seq_length
            if len(encoded['input_ids']) > max_length:
                truncated_ids = encoded['input_ids'][:max_length]
                truncated_text = tokenizer.decode(truncated_ids, skip_special_tokens=True)
                return truncated_text
        except Exception as e:
            logger.warning("Error during text truncation: %s", str(e))
        return text

    async def _fetch_image(self, path_or_url: str) -> Image.Image:
        """
        Asynchronously fetch an image from a URL or load from a local path.

        Args:
            path_or_url: The URL or file path of the image.

        Returns:
            A PIL Image in RGB mode.

        Raises:
            ValueError: If image fetching or processing fails.
        """
        try:
            if path_or_url.startswith("http"):
                # Asynchronously fetch the image bytes.
                response = await self.async_http_client.get(path_or_url)
                response.raise_for_status()
                # Offload the blocking I/O (PIL image opening) to a thread.
                img = await asyncio.to_thread(Image.open, BytesIO(response.content))
            else:
                # Offload file I/O to a thread.
                img = await asyncio.to_thread(Image.open, Path(path_or_url))
            return img.convert("RGB")
        except Exception as e:
            raise ValueError(f"Error fetching image '{path_or_url}': {str(e)}") from e

    async def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
        """
        Asynchronously load and process a single image.

        Args:
            path_or_url: The image URL or local path.

        Returns:
            A dictionary of processed tensors ready for model input.

        Raises:
            ValueError: If image processing fails.
        """
        img = await self._fetch_image(path_or_url)
        processor = self.image_processors[self.config.image_model_type]
        # Note: Processor may perform CPU-intensive work; if needed, offload to thread.
        processed_data = processor(images=img, return_tensors="pt").to(self.device)
        return processed_data

    def _generate_text_embeddings(self, model_id: TextModelType, texts: List[str]) -> np.ndarray:
        """
        Generate text embeddings using the SentenceTransformer model.
        Single-text requests are cached using an LRU cache.

        Args:
            model_id: The text model type.
            texts: A list of input texts.

        Returns:
            A NumPy array of text embeddings.

        Raises:
            RuntimeError: If text embedding generation fails.
        """
        try:
            model = self.text_models[model_id]
            lock = self.text_model_locks[model_id]
            with lock:
                if len(texts) == 1:
                    single_text = texts[0]
                    key = md5(f"{model_id}:{single_text}".encode("utf-8")).hexdigest()[:8]
                    if key in self.lru_cache:
                        return self.lru_cache[key]
                    emb = model.encode([single_text])
                    self.lru_cache[key] = emb
                    return emb

                return model.encode(texts)
        except Exception as e:
            raise RuntimeError(
                f"Error generating text embeddings with model '{model_id}': {e}"
            ) from e

    async def _async_generate_image_embeddings(self, model_id: ImageModelType, images: List[str]) -> np.ndarray:
        """
        Asynchronously generate image embeddings.

        This method concurrently processes multiple images and offloads
        the blocking model inference to a separate thread.

        Args:
            model_id: The image model type.
            images: A list of image URLs or file paths.

        Returns:
            A NumPy array of image embeddings.

        Raises:
            RuntimeError: If image embedding generation fails.
        """
        try:
            # Concurrently process all images.
            processed_tensors = await asyncio.gather(
                *[self._process_image(img_path) for img_path in images]
            )
            # Assume all processed outputs have the same keys.
            keys = processed_tensors[0].keys()
            combined = {k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys}

            def infer():
                with torch.no_grad():
                    embeddings = self.image_models[model_id].get_image_features(**combined)
                return embeddings.cpu().numpy()

            return await asyncio.to_thread(infer)
        except Exception as e:
            raise RuntimeError(
                f"Error generating image embeddings with model '{model_id}': {e}"
            ) from e

    async def generate_embeddings(self, model: str, inputs: Union[str, List[str]]) -> np.ndarray:
        """
        Asynchronously generate embeddings for text or image inputs based on model type.

        Args:
            model: The model identifier.
            inputs: The text or image input(s).

        Returns:
            A NumPy array of embeddings.
        """
        modality = detect_model_kind(model)
        if modality == ModelKind.TEXT:
            text_model_enum = TextModelType(model)
            text_list = self._validate_text_list(inputs)
            model_instance = self.text_models[text_model_enum]
            lock = self.text_model_locks[text_model_enum]
            with lock:
                # Truncate each text if it exceeds the maximum allowed token length.
                truncated_texts = [self._truncate_text(text, model_instance) for text in text_list]
            return await asyncio.to_thread(
                self._generate_text_embeddings, text_model_enum, truncated_texts
            )
        elif modality == ModelKind.IMAGE:
            image_model_enum = ImageModelType(model)
            image_list = self._validate_image_list(inputs)
            return await self._async_generate_image_embeddings(image_model_enum, image_list)

    async def rank(self, model: str, queries: Union[str, List[str]], candidates: Union[str, List[str]]) -> Dict[str, Any]:
        """
        Asynchronously rank candidate texts/images against the provided queries.
        Embeddings for queries and candidates are generated concurrently.

        Args:
            model: The model identifier.
            queries: The query input(s).
            candidates: The candidate input(s).

        Returns:
            A dictionary containing probabilities, cosine similarities, and usage statistics.
        """
        modality = detect_model_kind(model)
        if modality == ModelKind.TEXT:
            model_enum = TextModelType(model)
        else:
            model_enum = ImageModelType(model)

        # Concurrently generate embeddings.
        query_task = asyncio.create_task(self.generate_embeddings(model, queries))
        candidate_task = asyncio.create_task(self.generate_embeddings(model, candidates))
        query_embeds, candidate_embeds = await asyncio.gather(query_task, candidate_task)

        # Compute cosine similarity.
        sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
        scaled = np.exp(self.config.logit_scale) * sim_matrix
        probs = self.softmax(scaled)

        if modality == ModelKind.TEXT:
            query_tokens = self.estimate_tokens(queries)
            candidate_tokens = self.estimate_tokens(candidates)
            total_tokens = query_tokens + candidate_tokens
        else:
            total_tokens = 0

        usage = {
            "prompt_tokens": total_tokens,
            "total_tokens": total_tokens,
        }

        return {
            "probabilities": probs.tolist(),
            "cosine_similarities": sim_matrix.tolist(),
            "usage": usage,
        }

    def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
        """
        Estimate the token count for the given text input using the SentenceTransformer tokenizer.

        Args:
            input_data: The text input(s).

        Returns:
            The total number of tokens.
        """
        texts = self._validate_text_list(input_data)
        model = self.text_models[self.config.text_model_type]
        tokenized = model.tokenize(texts)
        return sum(len(ids) for ids in tokenized["input_ids"])

    @staticmethod
    def softmax(scores: np.ndarray) -> np.ndarray:
        """
        Compute the softmax over the last dimension of the input array.

        Args:
            scores: A NumPy array of scores.

        Returns:
            A NumPy array of softmax probabilities.
        """
        exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
        return exps / np.sum(exps, axis=-1, keepdims=True)

    @staticmethod
    def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
        """
        Compute the pairwise cosine similarity between all rows of arrays a and b.

        Args:
            a: A NumPy array.
            b: A NumPy array.

        Returns:
            A (N x M) matrix of cosine similarities.
        """
        a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
        b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
        return np.dot(a_norm, b_norm.T)

    async def close(self) -> None:
        """
        Close the asynchronous HTTP client.
        """
        await self.async_http_client.aclose()