lamhieu's picture
chore: support `bge-m3` and `gte-multilingual-base` models
de24ee4
raw
history blame
5.73 kB
"""
FastAPI Router for Embeddings Service (Revised & Simplified)
Exposes the EmbeddingsService methods via a RESTful API.
Supported Text Model IDs:
- "multilingual-e5-small"
- "multilingual-e5-base"
- "multilingual-e5-large"
- "snowflake-arctic-embed-l-v2.0"
- "paraphrase-multilingual-MiniLM-L12-v2"
- "paraphrase-multilingual-mpnet-base-v2"
- "bge-m3"
- "gte-multilingual-base"
Supported Image Model IDs:
- "siglip-base-patch16-256-multilingual"
"""
from __future__ import annotations
import logging
from typing import List, Union
from enum import Enum
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from .service import (
ModelConfig,
TextModelType,
ImageModelType,
EmbeddingsService,
)
logger = logging.getLogger(__name__)
router = APIRouter(
tags=["v1"],
responses={404: {"description": "Not found"}},
)
class ModelKind(str, Enum):
TEXT = "text"
IMAGE = "image"
def detect_model_kind(model_id: str) -> ModelKind:
"""
Detect whether model_id is for a text or an image model.
Raises ValueError if unrecognized.
"""
if model_id in [m.value for m in TextModelType]:
return ModelKind.TEXT
elif model_id in [m.value for m in ImageModelType]:
return ModelKind.IMAGE
else:
raise ValueError(
f"Unrecognized model ID: {model_id}.\n"
f"Valid text: {[m.value for m in TextModelType]}\n"
f"Valid image: {[m.value for m in ImageModelType]}"
)
class EmbeddingRequest(BaseModel):
"""
Input to /v1/embeddings
"""
model: str = Field(
default=TextModelType.MULTILINGUAL_E5_SMALL.value,
description=(
"Which model ID to use? "
"Text: ['multilingual-e5-small', 'multilingual-e5-base', 'multilingual-e5-large', 'snowflake-arctic-embed-l-v2.0', 'paraphrase-multilingual-MiniLM-L12-v2', 'paraphrase-multilingual-mpnet-base-v2', 'bge-m3']. "
"Image: ['siglip-base-patch16-256-multilingual']."
),
)
input: Union[str, List[str]] = Field(
..., description="Text(s) or Image URL(s)/path(s)."
)
class RankRequest(BaseModel):
"""
Input to /v1/rank
"""
model: str = Field(
default=TextModelType.MULTILINGUAL_E5_SMALL.value,
description=(
"Model ID for the queries. "
"Text or Image model, e.g. 'siglip-base-patch16-256-multilingual' for images."
),
)
queries: Union[str, List[str]] = Field(
..., description="Query text or image(s) depending on the model type."
)
candidates: List[str] = Field(
..., description="Candidate texts to rank. Must be text."
)
class EmbeddingResponse(BaseModel):
"""
Response of /v1/embeddings
"""
object: str
data: List[dict]
model: str
usage: dict
class RankResponse(BaseModel):
"""
Response of /v1/rank
"""
probabilities: List[List[float]]
cosine_similarities: List[List[float]]
service_config = ModelConfig()
embeddings_service = EmbeddingsService(config=service_config)
@router.post("/embeddings", response_model=EmbeddingResponse, tags=["embeddings"])
async def create_embeddings(request: EmbeddingRequest):
"""
Generates embeddings for the given input (text or image).
"""
try:
# 1) Determine if it's text or image
mkind = detect_model_kind(request.model)
# 2) Update global service config so it uses the correct model
if mkind == ModelKind.TEXT:
service_config.text_model_type = TextModelType(request.model)
else:
service_config.image_model_type = ImageModelType(request.model)
# 3) Generate
embeddings = await embeddings_service.generate_embeddings(
input_data=request.input, modality=mkind.value
)
# 4) Estimate tokens for text only
total_tokens = 0
if mkind == ModelKind.TEXT:
total_tokens = embeddings_service.estimate_tokens(request.input)
resp = {
"object": "list",
"data": [],
"model": request.model,
"usage": {
"prompt_tokens": total_tokens,
"total_tokens": total_tokens,
},
}
for idx, emb in enumerate(embeddings):
resp["data"].append(
{
"object": "embedding",
"index": idx,
"embedding": emb.tolist(),
}
)
return resp
except Exception as e:
msg = (
"Failed to generate embeddings. Check model ID, inputs, etc.\n"
f"Details: {str(e)}"
)
logger.error(msg)
raise HTTPException(status_code=500, detail=msg)
@router.post("/rank", response_model=RankResponse, tags=["rank"])
async def rank_candidates(request: RankRequest):
"""
Ranks candidate texts against the given queries (which can be text or image).
"""
try:
mkind = detect_model_kind(request.model)
if mkind == ModelKind.TEXT:
service_config.text_model_type = TextModelType(request.model)
else:
service_config.image_model_type = ImageModelType(request.model)
results = await embeddings_service.rank(
queries=request.queries,
candidates=request.candidates,
modality=mkind.value,
)
return results
except Exception as e:
msg = (
"Failed to rank candidates. Check model ID, inputs, etc.\n"
f"Details: {str(e)}"
)
logger.error(msg)
raise HTTPException(status_code=500, detail=msg)