Spaces:
Running
Running
# filename: router.py | |
""" | |
FastAPI Router for Embeddings Service | |
This file exposes the EmbeddingsService functionality via a RESTful API | |
to generate embeddings and rank candidates. | |
Supported Text Model IDs: | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Supported Image Model ID: | |
- "google/siglip-base-patch16-256-multilingual" | |
""" | |
from __future__ import annotations | |
import logging | |
from typing import List, Union | |
from enum import Enum | |
from fastapi import APIRouter, HTTPException | |
from pydantic import BaseModel, Field | |
from .service import ModelConfig, TextModelType, EmbeddingsService | |
logger = logging.getLogger(__name__) | |
# Initialize FastAPI router | |
router = APIRouter( | |
tags=["v1"], | |
responses={404: {"description": "Not found"}}, | |
) | |
class ModelType(str, Enum): | |
""" | |
High-level distinction for text vs. image models. | |
""" | |
TEXT = "text" | |
IMAGE = "image" | |
def detect_model_type(model_id: str) -> ModelType: | |
""" | |
Detect whether the provided model ID is for text or image. | |
Supported text model IDs: | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Supported image model ID: | |
- "google/siglip-base-patch16-256-multilingual" | |
(or any model containing "siglip" in its identifier). | |
Args: | |
model_id: String identifier of the model. | |
Returns: | |
ModelType.TEXT if it matches one of the recognized text model IDs, | |
ModelType.IMAGE if it matches (or contains "siglip"). | |
Raises: | |
ValueError: If the model_id is not recognized as either text or image. | |
""" | |
# Gather all known text model IDs (from TextModelType enum) | |
text_model_ids = {m.value for m in TextModelType} | |
# Simple check: if it's in text_model_ids, it's text; | |
# if 'siglip' is in the model ID, it's recognized as an image model. | |
if model_id in text_model_ids: | |
return ModelType.TEXT | |
elif "siglip" in model_id.lower(): | |
return ModelType.IMAGE | |
error_msg = ( | |
f"Unsupported model ID: '{model_id}'.\n" | |
"Valid text model IDs are: " | |
"'multilingual-e5-small', 'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'.\n" | |
"Valid image model ID contains 'siglip', for example: 'google/siglip-base-patch16-256-multilingual'." | |
) | |
raise ValueError(error_msg) | |
# Pydantic Models for request/response | |
class EmbeddingRequest(BaseModel): | |
""" | |
Request body for embedding creation. | |
Model IDs (text): | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Model ID (image): | |
- "google/siglip-base-patch16-256-multilingual" | |
""" | |
model: str = Field( | |
default=TextModelType.MULTILINGUAL_E5_SMALL.value, | |
description=( | |
"Model ID to use. Possible text models include: 'multilingual-e5-small', " | |
"'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'. " | |
"For images, you can use: 'google/siglip-base-patch16-256-multilingual' " | |
"or any ID containing 'siglip'." | |
), | |
) | |
input: Union[str, List[str]] = Field( | |
..., | |
description=( | |
"Input text(s) or image path(s)/URL(s). " | |
"Accepts a single string or a list of strings." | |
), | |
) | |
class RankRequest(BaseModel): | |
""" | |
Request body for ranking candidates against queries. | |
Model IDs (text): | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Model ID (image): | |
- "google/siglip-base-patch16-256-multilingual" | |
""" | |
model: str = Field( | |
default=TextModelType.MULTILINGUAL_E5_SMALL.value, | |
description=( | |
"Model ID to use for the queries. Supported text models: " | |
"'multilingual-e5-small', 'paraphrase-multilingual-MiniLM-L12-v2', 'bge-m3'. " | |
"For image queries, use an ID containing 'siglip' such as 'google/siglip-base-patch16-256-multilingual'." | |
), | |
) | |
queries: Union[str, List[str]] = Field( | |
..., | |
description=( | |
"Query input(s): can be text(s) or image path(s)/URL(s). " | |
"If using an image model, ensure your inputs reference valid image paths or URLs." | |
), | |
) | |
candidates: List[str] = Field( | |
..., | |
description=( | |
"List of candidate texts to rank against the given queries. " | |
"Currently, all candidates must be text." | |
), | |
) | |
class EmbeddingResponse(BaseModel): | |
""" | |
Response structure for embedding creation. | |
""" | |
object: str = "list" | |
data: List[dict] | |
model: str | |
usage: dict | |
class RankResponse(BaseModel): | |
""" | |
Response structure for ranking results. | |
""" | |
probabilities: List[List[float]] | |
cosine_similarities: List[List[float]] | |
# Initialize the service with default configuration | |
service_config = ModelConfig() | |
embeddings_service = EmbeddingsService(config=service_config) | |
async def create_embeddings(request: EmbeddingRequest): | |
""" | |
Generate embeddings for the provided input text(s) or image(s). | |
Supported Model IDs for text: | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Supported Model ID for image: | |
- "google/siglip-base-patch16-256-multilingual" | |
Steps: | |
1. Detects model type (text or image) based on the model ID. | |
2. Adjusts the service configuration accordingly. | |
3. Produces embeddings via the EmbeddingsService. | |
4. Returns embedding vectors along with usage information. | |
Raises: | |
HTTPException: For any errors during model detection or embedding generation. | |
""" | |
try: | |
modality = detect_model_type(request.model) | |
# Adjust global config based on the detected modality | |
if modality == ModelType.TEXT: | |
service_config.text_model_type = TextModelType(request.model) | |
else: | |
service_config.image_model_id = request.model | |
# Generate embeddings asynchronously | |
embeddings = await embeddings_service.generate_embeddings( | |
input_data=request.input, modality=modality.value | |
) | |
# Estimate tokens only if it's text | |
total_tokens = 0 | |
if modality == ModelType.TEXT: | |
total_tokens = embeddings_service.estimate_tokens(request.input) | |
return { | |
"object": "list", | |
"data": [ | |
{ | |
"object": "embedding", | |
"index": idx, | |
"embedding": emb.tolist(), | |
} | |
for idx, emb in enumerate(embeddings) | |
], | |
"model": request.model, | |
"usage": { | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens, | |
}, | |
} | |
except Exception as e: | |
error_msg = ( | |
"Failed to generate embeddings. Please verify your model ID, input data, and server logs.\n" | |
f"Error Details: {str(e)}" | |
) | |
logger.error(error_msg) | |
raise HTTPException(status_code=500, detail=error_msg) | |
async def rank_candidates(request: RankRequest): | |
""" | |
Rank the given candidate texts against the provided queries. | |
Supported Model IDs for text queries: | |
- "multilingual-e5-small" | |
- "paraphrase-multilingual-MiniLM-L12-v2" | |
- "bge-m3" | |
Supported Model ID for image queries: | |
- "google/siglip-base-patch16-256-multilingual" | |
Steps: | |
1. Detects model type (text or image) based on the query model ID. | |
2. Adjusts the service configuration accordingly. | |
3. Generates embeddings for the queries (text or image). | |
4. Generates embeddings for the candidates (always text). | |
5. Computes cosine similarities and returns softmax-normalized probabilities. | |
Raises: | |
HTTPException: For any errors during model detection or ranking. | |
""" | |
try: | |
modality = detect_model_type(request.model) | |
# Adjust global config based on the detected modality | |
if modality == ModelType.TEXT: | |
service_config.text_model_type = TextModelType(request.model) | |
else: | |
service_config.image_model_id = request.model | |
# Perform the ranking | |
results = await embeddings_service.rank( | |
queries=request.queries, | |
candidates=request.candidates, | |
modality=modality.value, | |
) | |
return results | |
except Exception as e: | |
error_msg = ( | |
"Failed to rank candidates. Please verify your model ID, input data, and server logs.\n" | |
f"Error Details: {str(e)}" | |
) | |
logger.error(error_msg) | |
raise HTTPException(status_code=500, detail=error_msg) | |