lamhieu commited on
Commit
b6efbf5
·
1 Parent(s): 7234381

chore: update something

Browse files
lightweight_embeddings/__init__.py CHANGED
@@ -1,24 +1,3 @@
1
- # filename: __init__.py
2
-
3
- """
4
- LightweightEmbeddings - FastAPI Application Entry Point
5
-
6
- This application provides text and image embeddings using multiple text models and one image model.
7
-
8
- Supported text model IDs:
9
- - "multilingual-e5-small"
10
- - "multilingual-e5-base"
11
- - "multilingual-e5-large"
12
- - "snowflake-arctic-embed-l-v2.0"
13
- - "paraphrase-multilingual-MiniLM-L12-v2"
14
- - "paraphrase-multilingual-mpnet-base-v2"
15
- - "bge-m3"
16
- - "gte-multilingual-base"
17
-
18
- Supported image model ID:
19
- - "siglip-base-patch16-256-multilingual"
20
- """
21
-
22
  import gradio as gr
23
  import requests
24
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import requests
3
  import json
lightweight_embeddings/router.py CHANGED
@@ -1,22 +1,3 @@
1
- """
2
- FastAPI Router for Embeddings Service (Revised & Simplified)
3
-
4
- Exposes the EmbeddingsService methods via a RESTful API.
5
-
6
- Supported Text Model IDs:
7
- - "multilingual-e5-small"
8
- - "multilingual-e5-base"
9
- - "multilingual-e5-large"
10
- - "snowflake-arctic-embed-l-v2.0"
11
- - "paraphrase-multilingual-MiniLM-L12-v2"
12
- - "paraphrase-multilingual-mpnet-base-v2"
13
- - "bge-m3"
14
- - "gte-multilingual-base"
15
-
16
- Supported Image Model IDs:
17
- - "siglip-base-patch16-256-multilingual"
18
- """
19
-
20
  from __future__ import annotations
21
 
22
  import logging
@@ -158,10 +139,6 @@ async def create_embeddings(
158
  },
159
  }
160
 
161
- background_tasks.add_task(
162
- analytics.access, request.model, resp["usage"]["total_tokens"]
163
- )
164
-
165
  for idx, emb in enumerate(embeddings):
166
  resp["data"].append(
167
  {
@@ -171,6 +148,10 @@ async def create_embeddings(
171
  }
172
  )
173
 
 
 
 
 
174
  return resp
175
 
176
  except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
 
139
  },
140
  }
141
 
 
 
 
 
142
  for idx, emb in enumerate(embeddings):
143
  resp["data"].append(
144
  {
 
148
  }
149
  )
150
 
151
+ background_tasks.add_task(
152
+ analytics.access, request.model, resp["usage"]["total_tokens"]
153
+ )
154
+
155
  return resp
156
 
157
  except Exception as e:
lightweight_embeddings/service.py CHANGED
@@ -1,29 +1,3 @@
1
- """
2
- Lightweight Embeddings Service Module (Revised & Simplified)
3
-
4
- This module provides a service for generating and comparing embeddings from text and images
5
- using state-of-the-art transformer models. It supports both CPU and GPU inference.
6
-
7
- Features:
8
- - Text and image embedding generation
9
- - Cross-modal similarity ranking
10
- - Batch processing support
11
- - Asynchronous API support
12
-
13
- Supported Text Model IDs:
14
- - "multilingual-e5-small"
15
- - "multilingual-e5-base"
16
- - "multilingual-e5-large"
17
- - "snowflake-arctic-embed-l-v2.0"
18
- - "paraphrase-multilingual-MiniLM-L12-v2"
19
- - "paraphrase-multilingual-mpnet-base-v2"
20
- - "bge-m3"
21
- - "gte-multilingual-base"
22
-
23
- Supported Image Model IDs:
24
- - "google/siglip-base-patch16-256-multilingual" (default, but extensible)
25
- """
26
-
27
  from __future__ import annotations
28
 
29
  import logging
@@ -49,7 +23,6 @@ logging.basicConfig(level=logging.INFO)
49
  class TextModelType(str, Enum):
50
  """
51
  Enumeration of supported text models.
52
- Adjust as needed for your environment.
53
  """
54
 
55
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
@@ -72,7 +45,7 @@ class ImageModelType(str, Enum):
72
 
73
  class ModelInfo(NamedTuple):
74
  """
75
- Simple container that maps an enum to:
76
  - model_id: Hugging Face model ID (or local path)
77
  - onnx_file: Path to ONNX file (if available)
78
  """
@@ -91,14 +64,12 @@ class ModelConfig:
91
  image_model_type: ImageModelType = (
92
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
93
  )
94
-
95
- # If you need extra parameters like `logit_scale`, etc., keep them here
96
- logit_scale: float = 4.60517
97
 
98
  @property
99
  def text_model_info(self) -> ModelInfo:
100
  """
101
- Return ModelInfo for the configured text_model_type.
102
  """
103
  text_configs = {
104
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
@@ -139,7 +110,7 @@ class ModelConfig:
139
  @property
140
  def image_model_info(self) -> ModelInfo:
141
  """
142
- Return ModelInfo for the configured image_model_type.
143
  """
144
  image_configs = {
145
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
@@ -156,8 +127,8 @@ class ModelKind(str, Enum):
156
 
157
  def detect_model_kind(model_id: str) -> ModelKind:
158
  """
159
- Detect whether model_id is for a text or an image model.
160
- Raises ValueError if unrecognized.
161
  """
162
  if model_id in [m.value for m in TextModelType]:
163
  return ModelKind.TEXT
@@ -173,21 +144,21 @@ def detect_model_kind(model_id: str) -> ModelKind:
173
 
174
  class EmbeddingsService:
175
  """
176
- Service for generating text/image embeddings and performing ranking.
 
177
  """
178
 
179
  def __init__(self, config: Optional[ModelConfig] = None):
180
- self.lru_cache = LRUCache(maxsize=10_000) # Approximate for ~100MB usage
181
-
182
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
183
  self.config = config or ModelConfig()
184
 
185
- # Preloaded text & image models
186
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
187
  self.image_models: Dict[ImageModelType, AutoModel] = {}
188
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
189
 
190
- # Load all models
191
  self._load_all_models()
192
 
193
  def _load_all_models(self) -> None:
@@ -195,40 +166,37 @@ class EmbeddingsService:
195
  Pre-load all known text and image models for quick switching.
196
  """
197
  try:
 
198
  for t_model_type in TextModelType:
199
  info = ModelConfig(text_model_type=t_model_type).text_model_info
200
  logger.info("Loading text model: %s", info.model_id)
201
 
202
- # If you have an ONNX file AND your SentenceTransformer supports ONNX
203
  if info.onnx_file:
204
  logger.info("Using ONNX file: %s", info.onnx_file)
205
- # The following 'backend' & 'model_kwargs' parameters
206
- # are recognized only in special/certain distributions of SentenceTransformer
207
  self.text_models[t_model_type] = SentenceTransformer(
208
  info.model_id,
209
  device=self.device,
210
- backend="onnx", # or "ort" in some custom forks
211
  model_kwargs={
212
- "provider": "CPUExecutionProvider", # or "CUDAExecutionProvider"
213
  "file_name": info.onnx_file,
214
  },
215
  trust_remote_code=True,
216
  )
217
  else:
218
- # Fallback: standard HF loading
219
  self.text_models[t_model_type] = SentenceTransformer(
220
  info.model_id,
221
  device=self.device,
222
  trust_remote_code=True,
223
  )
224
 
 
225
  for i_model_type in ImageModelType:
226
  model_id = ModelConfig(
227
  image_model_type=i_model_type
228
  ).image_model_info.model_id
229
  logger.info("Loading image model: %s", model_id)
230
 
231
- # Typically, for CLIP-like models:
232
  model = AutoModel.from_pretrained(model_id).to(self.device)
233
  processor = AutoProcessor.from_pretrained(model_id)
234
 
@@ -242,9 +210,10 @@ class EmbeddingsService:
242
  raise RuntimeError(msg) from e
243
 
244
  @staticmethod
245
- def _validate_text_input(input_text: Union[str, List[str]]) -> List[str]:
246
  """
247
- Ensure input_text is a non-empty string or list of strings.
 
248
  """
249
  if isinstance(input_text, str):
250
  if not input_text.strip():
@@ -262,27 +231,42 @@ class EmbeddingsService:
262
  return input_text
263
 
264
  @staticmethod
265
- def _validate_modality(modality: str) -> None:
266
- if modality not in ("text", "image"):
267
- raise ValueError("Unsupported modality. Must be 'text' or 'image'.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
 
269
- def _process_image(self, path_or_url: Union[str, Path]) -> torch.Tensor:
270
  """
271
- Download/Load image from path/URL and apply transformations.
 
272
  """
273
  try:
274
- if isinstance(path_or_url, Path) or not path_or_url.startswith("http"):
275
- # Local file path
276
- img = Image.open(path_or_url).convert("RGB")
277
- else:
278
- # URL
279
  resp = requests.get(path_or_url, timeout=10)
280
  resp.raise_for_status()
281
  img = Image.open(BytesIO(resp.content)).convert("RGB")
 
 
282
 
283
- proc = self.image_processors[self.config.image_model_type]
284
- data = proc(images=img, return_tensors="pt").to(self.device)
285
- return data
286
  except Exception as e:
287
  raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
288
 
@@ -292,145 +276,125 @@ class EmbeddingsService:
292
  texts: List[str],
293
  ) -> np.ndarray:
294
  """
295
- Generate text embeddings using the currently configured text model
296
- with an LRU cache for single-text requests.
297
  """
298
  try:
299
  if len(texts) == 1:
300
- key = md5(texts[0].encode("utf-8")).hexdigest()
 
301
  if key in self.lru_cache:
302
  return self.lru_cache[key]
 
 
 
 
 
 
 
303
  model = self.text_models[model_id]
304
- embeddings = model.encode(texts)
305
 
306
- if len(texts) == 1:
307
- self.lru_cache[key] = embeddings
308
- return embeddings
309
  except Exception as e:
310
  raise RuntimeError(
311
- f"Error generating text embeddings for model '{self.config.text_model_type}': {e}"
312
  ) from e
313
 
314
  def _generate_image_embeddings(
315
  self,
316
  model_id: ImageModelType,
317
- images: Union[str, List[str]],
318
- batch_size: Optional[int] = None,
319
  ) -> np.ndarray:
320
  """
321
- Generate image embeddings using the currently configured image model.
322
- If `batch_size` is None, all images are processed at once.
323
  """
324
  try:
325
  model = self.image_models[model_id]
326
-
327
- # Single image
328
- if isinstance(images, str):
329
- processed = self._process_image(images)
330
- with torch.no_grad():
331
- emb = model.get_image_features(**processed)
332
- return emb.cpu().numpy()
333
-
334
- # Multiple images
335
- if batch_size is None:
336
- # Process them all in one batch
337
- tensors = []
338
- for img_path in images:
339
- tensors.append(self._process_image(img_path))
340
- # Concatenate
341
- keys = tensors[0].keys()
342
- combined = {k: torch.cat([t[k] for t in tensors], dim=0) for k in keys}
343
- with torch.no_grad():
344
- emb = model.get_image_features(**combined)
345
- return emb.cpu().numpy()
346
-
347
- # Process in smaller batches
348
- all_embeddings = []
349
- for i in range(0, len(images), batch_size):
350
- batch_images = images[i : i + batch_size]
351
- # Process each sub-batch
352
- tensors = []
353
- for img_path in batch_images:
354
- tensors.append(self._process_image(img_path))
355
- keys = tensors[0].keys()
356
- combined = {k: torch.cat([t[k] for t in tensors], dim=0) for k in keys}
357
-
358
- with torch.no_grad():
359
- emb = model.get_image_features(**combined)
360
- all_embeddings.append(emb.cpu().numpy())
361
-
362
- return np.vstack(all_embeddings)
363
 
364
  except Exception as e:
365
  raise RuntimeError(
366
- f"Error generating image embeddings for model '{self.config.image_model_type}': {e}"
367
  ) from e
368
 
369
  async def generate_embeddings(
370
  self,
371
  model: str,
372
  inputs: Union[str, List[str]],
373
- batch_size: Optional[int] = None,
374
  ) -> np.ndarray:
375
  """
376
- Asynchronously generate embeddings for text or image.
377
  """
378
- # Determine if it's text or image
379
  modality = detect_model_kind(model)
380
- model_id = (
381
- TextModelType(model)
382
- if modality == ModelKind.TEXT
383
- else ImageModelType(model)
384
- )
385
 
386
- self._validate_modality(modality)
387
- if modality == "text" and isinstance(model_id, TextModelType):
388
- text_list = self._validate_text_input(inputs)
389
- return self._generate_text_embeddings(model_id=model_id, texts=text_list)
390
- elif modality == "image" and isinstance(model_id, ImageModelType):
391
- return self._generate_image_embeddings(
392
- model_id=model_id, images=inputs, batch_size=batch_size
393
- )
 
394
 
395
  async def rank(
396
  self,
397
  model: str,
398
  queries: Union[str, List[str]],
399
- candidates: List[str],
400
- batch_size: Optional[int] = None,
401
  ) -> Dict[str, Any]:
402
  """
403
- Rank candidates (always text) against the queries, which may be text or image.
404
- Returns dict of { probabilities, cosine_similarities, usage }.
 
 
 
405
  """
406
- # Determine if it's text or image
407
  modality = detect_model_kind(model)
408
- model_id = (
409
- TextModelType(model)
410
- if modality == ModelKind.TEXT
411
- else ImageModelType(model)
412
- )
 
413
 
414
  # 1) Generate embeddings for queries
415
- query_embeds = await self.generate_embeddings(
416
- model=model_id, inputs=queries, batch_size=batch_size
417
- )
418
- # 2) Generate embeddings for text candidates
419
- candidate_embeds = await self.generate_embeddings(
420
- model=model_id, inputs=candidates, batch_size=batch_size
421
- )
422
 
423
  # 3) Compute cosine similarity
424
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
425
 
426
- # 4) Apply logit scale + softmax
427
  scaled = np.exp(self.config.logit_scale) * sim_matrix
428
  probs = self.softmax(scaled)
429
 
430
- # 5) Compute usage (similar to embeddings)
431
- query_tokens = self.estimate_tokens(queries) if modality == "text" else 0
432
- candidate_tokens = self.estimate_tokens(candidates) if modality == "text" else 0
433
- total_tokens = query_tokens + candidate_tokens
 
 
 
 
434
  usage = {
435
  "prompt_tokens": total_tokens,
436
  "total_tokens": total_tokens,
@@ -444,27 +408,31 @@ class EmbeddingsService:
444
 
445
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
446
  """
447
- Estimate token count using the model's tokenizer.
 
448
  """
449
- texts = self._validate_text_input(input_data)
450
  model = self.text_models[self.config.text_model_type]
451
  tokenized = model.tokenize(texts)
 
452
  return sum(len(ids) for ids in tokenized["input_ids"])
453
 
454
  @staticmethod
455
  def softmax(scores: np.ndarray) -> np.ndarray:
456
  """
457
- Standard softmax along the last dimension.
458
  """
 
459
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
460
  return exps / np.sum(exps, axis=-1, keepdims=True)
461
 
462
  @staticmethod
463
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
464
  """
 
465
  a: (N, D)
466
  b: (M, D)
467
- Return: (N, M) of cos sim
468
  """
469
  a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
470
  b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from __future__ import annotations
2
 
3
  import logging
 
23
  class TextModelType(str, Enum):
24
  """
25
  Enumeration of supported text models.
 
26
  """
27
 
28
  MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
 
45
 
46
  class ModelInfo(NamedTuple):
47
  """
48
+ This container maps an enum to:
49
  - model_id: Hugging Face model ID (or local path)
50
  - onnx_file: Path to ONNX file (if available)
51
  """
 
64
  image_model_type: ImageModelType = (
65
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
66
  )
67
+ logit_scale: float = 4.60517 # Example scale used in cross-modal similarity
 
 
68
 
69
  @property
70
  def text_model_info(self) -> ModelInfo:
71
  """
72
+ Returns ModelInfo for the configured text_model_type.
73
  """
74
  text_configs = {
75
  TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
 
110
  @property
111
  def image_model_info(self) -> ModelInfo:
112
  """
113
+ Returns ModelInfo for the configured image_model_type.
114
  """
115
  image_configs = {
116
  ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
 
127
 
128
  def detect_model_kind(model_id: str) -> ModelKind:
129
  """
130
+ Detect whether model_id belongs to a text or an image model.
131
+ Raises ValueError if the model is not recognized.
132
  """
133
  if model_id in [m.value for m in TextModelType]:
134
  return ModelKind.TEXT
 
144
 
145
  class EmbeddingsService:
146
  """
147
+ Service for generating text/image embeddings and performing similarity ranking.
148
+ Batch size has been removed. Single or multiple inputs are handled uniformly.
149
  """
150
 
151
  def __init__(self, config: Optional[ModelConfig] = None):
152
+ self.lru_cache = LRUCache(maxsize=10_000)
 
153
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
154
  self.config = config or ModelConfig()
155
 
156
+ # Dictionaries to hold preloaded models
157
  self.text_models: Dict[TextModelType, SentenceTransformer] = {}
158
  self.image_models: Dict[ImageModelType, AutoModel] = {}
159
  self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
160
 
161
+ # Load all relevant models on init
162
  self._load_all_models()
163
 
164
  def _load_all_models(self) -> None:
 
166
  Pre-load all known text and image models for quick switching.
167
  """
168
  try:
169
+ # Preload text models
170
  for t_model_type in TextModelType:
171
  info = ModelConfig(text_model_type=t_model_type).text_model_info
172
  logger.info("Loading text model: %s", info.model_id)
173
 
 
174
  if info.onnx_file:
175
  logger.info("Using ONNX file: %s", info.onnx_file)
 
 
176
  self.text_models[t_model_type] = SentenceTransformer(
177
  info.model_id,
178
  device=self.device,
179
+ backend="onnx",
180
  model_kwargs={
181
+ "provider": "CPUExecutionProvider",
182
  "file_name": info.onnx_file,
183
  },
184
  trust_remote_code=True,
185
  )
186
  else:
 
187
  self.text_models[t_model_type] = SentenceTransformer(
188
  info.model_id,
189
  device=self.device,
190
  trust_remote_code=True,
191
  )
192
 
193
+ # Preload image models
194
  for i_model_type in ImageModelType:
195
  model_id = ModelConfig(
196
  image_model_type=i_model_type
197
  ).image_model_info.model_id
198
  logger.info("Loading image model: %s", model_id)
199
 
 
200
  model = AutoModel.from_pretrained(model_id).to(self.device)
201
  processor = AutoProcessor.from_pretrained(model_id)
202
 
 
210
  raise RuntimeError(msg) from e
211
 
212
  @staticmethod
213
+ def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
214
  """
215
+ Convert text input into a non-empty list of strings.
216
+ Raises ValueError if the input is invalid.
217
  """
218
  if isinstance(input_text, str):
219
  if not input_text.strip():
 
231
  return input_text
232
 
233
  @staticmethod
234
+ def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
235
+ """
236
+ Convert image input into a non-empty list of image paths/URLs.
237
+ Raises ValueError if the input is invalid.
238
+ """
239
+ if isinstance(input_images, str):
240
+ if not input_images.strip():
241
+ raise ValueError("Image input cannot be empty.")
242
+ return [input_images]
243
+
244
+ if not isinstance(input_images, list) or not all(
245
+ isinstance(x, str) for x in input_images
246
+ ):
247
+ raise ValueError("Image input must be a string or a list of strings.")
248
+
249
+ if len(input_images) == 0:
250
+ raise ValueError("Image input list cannot be empty.")
251
+
252
+ return input_images
253
 
254
+ def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
255
  """
256
+ Loads and processes a single image from local path or URL.
257
+ Returns a dictionary of tensors ready for the model.
258
  """
259
  try:
260
+ if path_or_url.startswith("http"):
 
 
 
 
261
  resp = requests.get(path_or_url, timeout=10)
262
  resp.raise_for_status()
263
  img = Image.open(BytesIO(resp.content)).convert("RGB")
264
+ else:
265
+ img = Image.open(Path(path_or_url)).convert("RGB")
266
 
267
+ processor = self.image_processors[self.config.image_model_type]
268
+ processed_data = processor(images=img, return_tensors="pt").to(self.device)
269
+ return processed_data
270
  except Exception as e:
271
  raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
272
 
 
276
  texts: List[str],
277
  ) -> np.ndarray:
278
  """
279
+ Generates text embeddings using the SentenceTransformer-based model.
280
+ Utilizes an LRU cache for single-input scenarios.
281
  """
282
  try:
283
  if len(texts) == 1:
284
+ single_text = texts[0]
285
+ key = md5(single_text.encode("utf-8")).hexdigest()
286
  if key in self.lru_cache:
287
  return self.lru_cache[key]
288
+
289
+ model = self.text_models[model_id]
290
+ emb = model.encode([single_text])
291
+ self.lru_cache[key] = emb
292
+ return emb
293
+
294
+ # For multiple texts, no LRU cache is used
295
  model = self.text_models[model_id]
296
+ return model.encode(texts)
297
 
 
 
 
298
  except Exception as e:
299
  raise RuntimeError(
300
+ f"Error generating text embeddings with model '{model_id}': {e}"
301
  ) from e
302
 
303
  def _generate_image_embeddings(
304
  self,
305
  model_id: ImageModelType,
306
+ images: List[str],
 
307
  ) -> np.ndarray:
308
  """
309
+ Generates image embeddings using the CLIP-like transformer model.
310
+ Handles single or multiple images uniformly (no batch size parameter).
311
  """
312
  try:
313
  model = self.image_models[model_id]
314
+ # Collect processed inputs in a single batch
315
+ processed_tensors = []
316
+ for img_path in images:
317
+ processed_tensors.append(self._process_image(img_path))
318
+
319
+ # Keys should be the same for all processed outputs
320
+ keys = processed_tensors[0].keys()
321
+ # Concatenate along the batch dimension
322
+ combined = {
323
+ k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
324
+ }
325
+
326
+ with torch.no_grad():
327
+ embeddings = model.get_image_features(**combined)
328
+ return embeddings.cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
 
330
  except Exception as e:
331
  raise RuntimeError(
332
+ f"Error generating image embeddings with model '{model_id}': {e}"
333
  ) from e
334
 
335
  async def generate_embeddings(
336
  self,
337
  model: str,
338
  inputs: Union[str, List[str]],
 
339
  ) -> np.ndarray:
340
  """
341
+ Asynchronously generates embeddings for either text or image based on the model type.
342
  """
 
343
  modality = detect_model_kind(model)
 
 
 
 
 
344
 
345
+ if modality == ModelKind.TEXT:
346
+ text_model_id = TextModelType(model)
347
+ text_list = self._validate_text_list(inputs)
348
+ return self._generate_text_embeddings(text_model_id, text_list)
349
+
350
+ elif modality == ModelKind.IMAGE:
351
+ image_model_id = ImageModelType(model)
352
+ image_list = self._validate_image_list(inputs)
353
+ return self._generate_image_embeddings(image_model_id, image_list)
354
 
355
  async def rank(
356
  self,
357
  model: str,
358
  queries: Union[str, List[str]],
359
+ candidates: Union[str, List[str]],
 
360
  ) -> Dict[str, Any]:
361
  """
362
+ Ranks text `candidates` given `queries`, which can be text or images.
363
+ Always returns a dictionary of { probabilities, cosine_similarities, usage }.
364
+
365
+ Note: This implementation uses the same model for both queries and candidates.
366
+ For true cross-modal ranking, you might need separate models or a shared model.
367
  """
 
368
  modality = detect_model_kind(model)
369
+
370
+ # Convert the string model to the appropriate enum
371
+ if modality == ModelKind.TEXT:
372
+ model_enum = TextModelType(model)
373
+ else:
374
+ model_enum = ImageModelType(model)
375
 
376
  # 1) Generate embeddings for queries
377
+ query_embeds = await self.generate_embeddings(model_enum.value, queries)
378
+
379
+ # 2) Generate embeddings for candidates (assumed text if queries are text;
380
+ # or if queries are images, also use the image model for candidates).
381
+ candidate_embeds = await self.generate_embeddings(model_enum.value, candidates)
 
 
382
 
383
  # 3) Compute cosine similarity
384
  sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
385
 
386
+ # 4) Apply logit scale + softmax to obtain probabilities
387
  scaled = np.exp(self.config.logit_scale) * sim_matrix
388
  probs = self.softmax(scaled)
389
 
390
+ # 5) Estimate token usage if we're dealing with text
391
+ if modality == ModelKind.TEXT:
392
+ query_tokens = self.estimate_tokens(queries)
393
+ candidate_tokens = self.estimate_tokens(candidates)
394
+ total_tokens = query_tokens + candidate_tokens
395
+ else:
396
+ total_tokens = 0
397
+
398
  usage = {
399
  "prompt_tokens": total_tokens,
400
  "total_tokens": total_tokens,
 
408
 
409
  def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
410
  """
411
+ Estimates token count using the SentenceTransformer tokenizer.
412
+ Only applicable if the current configured model is a text model.
413
  """
414
+ texts = self._validate_text_list(input_data)
415
  model = self.text_models[self.config.text_model_type]
416
  tokenized = model.tokenize(texts)
417
+ # Summing over the lengths of input_ids for each example
418
  return sum(len(ids) for ids in tokenized["input_ids"])
419
 
420
  @staticmethod
421
  def softmax(scores: np.ndarray) -> np.ndarray:
422
  """
423
+ Applies the standard softmax function along the last dimension.
424
  """
425
+ # Stabilize scores by subtracting max
426
  exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
427
  return exps / np.sum(exps, axis=-1, keepdims=True)
428
 
429
  @staticmethod
430
  def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
431
  """
432
+ Computes the pairwise cosine similarity between all rows of a and b.
433
  a: (N, D)
434
  b: (M, D)
435
+ Return: (N, M) matrix of cosine similarities
436
  """
437
  a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
438
  b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)