Spaces:
Running
Running
File size: 16,076 Bytes
e8f9d10 ddd02b3 e8f9d10 aaf7e4c e8f9d10 65c747d e8f9d10 65c747d e8f9d10 de24ee4 e8f9d10 65c747d e8f9d10 b6efbf5 65c747d e8f9d10 65c747d e8f9d10 65c747d e8f9d10 65c747d b6efbf5 e8f9d10 b6efbf5 e8f9d10 65c747d e8f9d10 65c747d e8f9d10 24dd113 65c747d 24dd113 e8f9d10 de24ee4 e8f9d10 65c747d b6efbf5 65c747d e8f9d10 ddd02b3 b6efbf5 ddd02b3 e8f9d10 b6efbf5 e8f9d10 65c747d b6efbf5 e8f9d10 b6efbf5 e8f9d10 65c747d e8f9d10 b6efbf5 65c747d e8f9d10 65c747d e8f9d10 65c747d e8f9d10 b6efbf5 65c747d b6efbf5 65c747d b6efbf5 65c747d f9ce04f 65c747d 073aa83 65c747d e8f9d10 b6efbf5 65c747d e8f9d10 65c747d e8f9d10 b6efbf5 e8f9d10 b6efbf5 e8f9d10 65c747d e8f9d10 65c747d e8f9d10 65c747d e8f9d10 65c747d e8f9d10 b6efbf5 e8f9d10 b6efbf5 e8f9d10 b6efbf5 e8f9d10 b6efbf5 65c747d b6efbf5 65c747d b6efbf5 e8f9d10 65c747d e8f9d10 ddd02b3 e8f9d10 b6efbf5 e8f9d10 aaf7e4c b6efbf5 aaf7e4c b6efbf5 ddd02b3 b6efbf5 aaf7e4c e8f9d10 65c747d b6efbf5 65c747d e8f9d10 65c747d ddd02b3 b6efbf5 e8f9d10 b6efbf5 e8f9d10 ddd02b3 b6efbf5 e8f9d10 65c747d b6efbf5 65c747d e8f9d10 ddd02b3 e8f9d10 b6efbf5 e8f9d10 ddd02b3 b6efbf5 e8f9d10 ddd02b3 e8f9d10 b6efbf5 073aa83 e8f9d10 b6efbf5 e8f9d10 ddd02b3 b6efbf5 073aa83 65c747d b6efbf5 65c747d 073aa83 65c747d 073aa83 b6efbf5 65c747d e8f9d10 b6efbf5 073aa83 e8f9d10 65c747d 073aa83 e8f9d10 073aa83 b6efbf5 073aa83 b6efbf5 073aa83 b6efbf5 073aa83 e8f9d10 b6efbf5 e8f9d10 b6efbf5 65c747d e8f9d10 65c747d e8f9d10 b6efbf5 65c747d b6efbf5 e8f9d10 65c747d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 |
from __future__ import annotations
import logging
from enum import Enum
from typing import List, Union, Dict, Optional, NamedTuple, Any
from dataclasses import dataclass
from pathlib import Path
from io import BytesIO
from hashlib import md5
from cachetools import LRUCache
import requests
import numpy as np
import torch
from PIL import Image
from sentence_transformers import SentenceTransformer
from transformers import AutoProcessor, AutoModel
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
class TextModelType(str, Enum):
"""
Enumeration of supported text models.
"""
MULTILINGUAL_E5_SMALL = "multilingual-e5-small"
MULTILINGUAL_E5_BASE = "multilingual-e5-base"
MULTILINGUAL_E5_LARGE = "multilingual-e5-large"
SNOWFLAKE_ARCTIC_EMBED_L_V2 = "snowflake-arctic-embed-l-v2.0"
PARAPHRASE_MULTILINGUAL_MINILM_L12_V2 = "paraphrase-multilingual-MiniLM-L12-v2"
PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2 = "paraphrase-multilingual-mpnet-base-v2"
BGE_M3 = "bge-m3"
GTE_MULTILINGUAL_BASE = "gte-multilingual-base"
class ImageModelType(str, Enum):
"""
Enumeration of supported image models.
"""
SIGLIP_BASE_PATCH16_256_MULTILINGUAL = "siglip-base-patch16-256-multilingual"
class ModelInfo(NamedTuple):
"""
This container maps an enum to:
- model_id: Hugging Face model ID (or local path)
- onnx_file: Path to ONNX file (if available)
"""
model_id: str
onnx_file: Optional[str] = None
@dataclass
class ModelConfig:
"""
Configuration for text and image models.
"""
text_model_type: TextModelType = TextModelType.MULTILINGUAL_E5_SMALL
image_model_type: ImageModelType = (
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL
)
logit_scale: float = 4.60517 # Example scale used in cross-modal similarity
@property
def text_model_info(self) -> ModelInfo:
"""
Returns ModelInfo for the configured text_model_type.
"""
text_configs = {
TextModelType.MULTILINGUAL_E5_SMALL: ModelInfo(
model_id="Xenova/multilingual-e5-small",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.MULTILINGUAL_E5_BASE: ModelInfo(
model_id="Xenova/multilingual-e5-base",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.MULTILINGUAL_E5_LARGE: ModelInfo(
model_id="Xenova/multilingual-e5-large",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.SNOWFLAKE_ARCTIC_EMBED_L_V2: ModelInfo(
model_id="Snowflake/snowflake-arctic-embed-l-v2.0",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.PARAPHRASE_MULTILINGUAL_MINILM_L12_V2: ModelInfo(
model_id="Xenova/paraphrase-multilingual-MiniLM-L12-v2",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.PARAPHRASE_MULTILINGUAL_MPNET_BASE_V2: ModelInfo(
model_id="Xenova/paraphrase-multilingual-mpnet-base-v2",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.BGE_M3: ModelInfo(
model_id="Xenova/bge-m3",
onnx_file="onnx/model_quantized.onnx",
),
TextModelType.GTE_MULTILINGUAL_BASE: ModelInfo(
model_id="onnx-community/gte-multilingual-base",
onnx_file="onnx/model_quantized.onnx",
),
}
return text_configs[self.text_model_type]
@property
def image_model_info(self) -> ModelInfo:
"""
Returns ModelInfo for the configured image_model_type.
"""
image_configs = {
ImageModelType.SIGLIP_BASE_PATCH16_256_MULTILINGUAL: ModelInfo(
model_id="google/siglip-base-patch16-256-multilingual"
),
}
return image_configs[self.image_model_type]
class ModelKind(str, Enum):
TEXT = "text"
IMAGE = "image"
def detect_model_kind(model_id: str) -> ModelKind:
"""
Detect whether model_id belongs to a text or an image model.
Raises ValueError if the model is not recognized.
"""
if model_id in [m.value for m in TextModelType]:
return ModelKind.TEXT
elif model_id in [m.value for m in ImageModelType]:
return ModelKind.IMAGE
else:
raise ValueError(
f"Unrecognized model ID: {model_id}.\n"
f"Valid text: {[m.value for m in TextModelType]}\n"
f"Valid image: {[m.value for m in ImageModelType]}"
)
class EmbeddingsService:
"""
Service for generating text/image embeddings and performing similarity ranking.
Batch size has been removed. Single or multiple inputs are handled uniformly.
"""
def __init__(self, config: Optional[ModelConfig] = None):
self.lru_cache = LRUCache(maxsize=10_000)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.config = config or ModelConfig()
# Dictionaries to hold preloaded models
self.text_models: Dict[TextModelType, SentenceTransformer] = {}
self.image_models: Dict[ImageModelType, AutoModel] = {}
self.image_processors: Dict[ImageModelType, AutoProcessor] = {}
# Load all relevant models on init
self._load_all_models()
def _load_all_models(self) -> None:
"""
Pre-load all known text and image models for quick switching.
"""
try:
# Preload text models
for t_model_type in TextModelType:
info = ModelConfig(text_model_type=t_model_type).text_model_info
logger.info("Loading text model: %s", info.model_id)
if info.onnx_file:
logger.info("Using ONNX file: %s", info.onnx_file)
self.text_models[t_model_type] = SentenceTransformer(
info.model_id,
device=self.device,
backend="onnx",
model_kwargs={
"provider": "CPUExecutionProvider",
"file_name": info.onnx_file,
},
trust_remote_code=True,
)
else:
self.text_models[t_model_type] = SentenceTransformer(
info.model_id,
device=self.device,
trust_remote_code=True,
)
# Preload image models
for i_model_type in ImageModelType:
model_id = ModelConfig(
image_model_type=i_model_type
).image_model_info.model_id
logger.info("Loading image model: %s", model_id)
model = AutoModel.from_pretrained(model_id).to(self.device)
processor = AutoProcessor.from_pretrained(model_id)
self.image_models[i_model_type] = model
self.image_processors[i_model_type] = processor
logger.info("All models loaded successfully.")
except Exception as e:
msg = f"Error loading models: {str(e)}"
logger.error(msg)
raise RuntimeError(msg) from e
@staticmethod
def _validate_text_list(input_text: Union[str, List[str]]) -> List[str]:
"""
Convert text input into a non-empty list of strings.
Raises ValueError if the input is invalid.
"""
if isinstance(input_text, str):
if not input_text.strip():
raise ValueError("Text input cannot be empty.")
return [input_text]
if not isinstance(input_text, list) or not all(
isinstance(x, str) for x in input_text
):
raise ValueError("Text input must be a string or a list of strings.")
if len(input_text) == 0:
raise ValueError("Text input list cannot be empty.")
return input_text
@staticmethod
def _validate_image_list(input_images: Union[str, List[str]]) -> List[str]:
"""
Convert image input into a non-empty list of image paths/URLs.
Raises ValueError if the input is invalid.
"""
if isinstance(input_images, str):
if not input_images.strip():
raise ValueError("Image input cannot be empty.")
return [input_images]
if not isinstance(input_images, list) or not all(
isinstance(x, str) for x in input_images
):
raise ValueError("Image input must be a string or a list of strings.")
if len(input_images) == 0:
raise ValueError("Image input list cannot be empty.")
return input_images
def _process_image(self, path_or_url: str) -> Dict[str, torch.Tensor]:
"""
Loads and processes a single image from local path or URL.
Returns a dictionary of tensors ready for the model.
"""
try:
if path_or_url.startswith("http"):
resp = requests.get(path_or_url, timeout=10)
resp.raise_for_status()
img = Image.open(BytesIO(resp.content)).convert("RGB")
else:
img = Image.open(Path(path_or_url)).convert("RGB")
processor = self.image_processors[self.config.image_model_type]
processed_data = processor(images=img, return_tensors="pt").to(self.device)
return processed_data
except Exception as e:
raise ValueError(f"Error processing image '{path_or_url}': {str(e)}") from e
def _generate_text_embeddings(
self,
model_id: TextModelType,
texts: List[str],
) -> np.ndarray:
"""
Generates text embeddings using the SentenceTransformer-based model.
Utilizes an LRU cache for single-input scenarios.
"""
try:
if len(texts) == 1:
single_text = texts[0]
key = md5(single_text.encode("utf-8")).hexdigest()
if key in self.lru_cache:
return self.lru_cache[key]
model = self.text_models[model_id]
emb = model.encode([single_text])
self.lru_cache[key] = emb
return emb
# For multiple texts, no LRU cache is used
model = self.text_models[model_id]
return model.encode(texts)
except Exception as e:
raise RuntimeError(
f"Error generating text embeddings with model '{model_id}': {e}"
) from e
def _generate_image_embeddings(
self,
model_id: ImageModelType,
images: List[str],
) -> np.ndarray:
"""
Generates image embeddings using the CLIP-like transformer model.
Handles single or multiple images uniformly (no batch size parameter).
"""
try:
model = self.image_models[model_id]
# Collect processed inputs in a single batch
processed_tensors = []
for img_path in images:
processed_tensors.append(self._process_image(img_path))
# Keys should be the same for all processed outputs
keys = processed_tensors[0].keys()
# Concatenate along the batch dimension
combined = {
k: torch.cat([pt[k] for pt in processed_tensors], dim=0) for k in keys
}
with torch.no_grad():
embeddings = model.get_image_features(**combined)
return embeddings.cpu().numpy()
except Exception as e:
raise RuntimeError(
f"Error generating image embeddings with model '{model_id}': {e}"
) from e
async def generate_embeddings(
self,
model: str,
inputs: Union[str, List[str]],
) -> np.ndarray:
"""
Asynchronously generates embeddings for either text or image based on the model type.
"""
modality = detect_model_kind(model)
if modality == ModelKind.TEXT:
text_model_id = TextModelType(model)
text_list = self._validate_text_list(inputs)
return self._generate_text_embeddings(text_model_id, text_list)
elif modality == ModelKind.IMAGE:
image_model_id = ImageModelType(model)
image_list = self._validate_image_list(inputs)
return self._generate_image_embeddings(image_model_id, image_list)
async def rank(
self,
model: str,
queries: Union[str, List[str]],
candidates: Union[str, List[str]],
) -> Dict[str, Any]:
"""
Ranks text `candidates` given `queries`, which can be text or images.
Always returns a dictionary of { probabilities, cosine_similarities, usage }.
Note: This implementation uses the same model for both queries and candidates.
For true cross-modal ranking, you might need separate models or a shared model.
"""
modality = detect_model_kind(model)
# Convert the string model to the appropriate enum
if modality == ModelKind.TEXT:
model_enum = TextModelType(model)
else:
model_enum = ImageModelType(model)
# 1) Generate embeddings for queries
query_embeds = await self.generate_embeddings(model_enum.value, queries)
# 2) Generate embeddings for candidates (assumed text if queries are text;
# or if queries are images, also use the image model for candidates).
candidate_embeds = await self.generate_embeddings(model_enum.value, candidates)
# 3) Compute cosine similarity
sim_matrix = self.cosine_similarity(query_embeds, candidate_embeds)
# 4) Apply logit scale + softmax to obtain probabilities
scaled = np.exp(self.config.logit_scale) * sim_matrix
probs = self.softmax(scaled)
# 5) Estimate token usage if we're dealing with text
if modality == ModelKind.TEXT:
query_tokens = self.estimate_tokens(queries)
candidate_tokens = self.estimate_tokens(candidates)
total_tokens = query_tokens + candidate_tokens
else:
total_tokens = 0
usage = {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
}
return {
"probabilities": probs.tolist(),
"cosine_similarities": sim_matrix.tolist(),
"usage": usage,
}
def estimate_tokens(self, input_data: Union[str, List[str]]) -> int:
"""
Estimates token count using the SentenceTransformer tokenizer.
Only applicable if the current configured model is a text model.
"""
texts = self._validate_text_list(input_data)
model = self.text_models[self.config.text_model_type]
tokenized = model.tokenize(texts)
# Summing over the lengths of input_ids for each example
return sum(len(ids) for ids in tokenized["input_ids"])
@staticmethod
def softmax(scores: np.ndarray) -> np.ndarray:
"""
Applies the standard softmax function along the last dimension.
"""
# Stabilize scores by subtracting max
exps = np.exp(scores - np.max(scores, axis=-1, keepdims=True))
return exps / np.sum(exps, axis=-1, keepdims=True)
@staticmethod
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> np.ndarray:
"""
Computes the pairwise cosine similarity between all rows of a and b.
a: (N, D)
b: (M, D)
Return: (N, M) matrix of cosine similarities
"""
a_norm = a / (np.linalg.norm(a, axis=1, keepdims=True) + 1e-9)
b_norm = b / (np.linalg.norm(b, axis=1, keepdims=True) + 1e-9)
return np.dot(a_norm, b_norm.T)
|