Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import List | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import os | |
class EmbeddingModel: | |
def __init__(self, model_name="intfloat/multilingual-e5-large"): | |
cache_dir = os.getenv("MODEL_CACHE_DIR", "./model_cache") | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir) | |
self.model = AutoModel.from_pretrained(model_name, cache_dir=cache_dir) | |
def get_embedding(self, text): | |
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy() | |
app = FastAPI() | |
embedding_model = EmbeddingModel() | |
class EmbeddingRequest(BaseModel): | |
input: List[str] | |
model: str = "intfloat/multilingual-e5-large" | |
class EmbeddingResponse(BaseModel): | |
object: str = "embedding" | |
data: List[dict] | |
model: str | |
usage: dict | |
async def create_embeddings(request: EmbeddingRequest): | |
if not request.input: | |
raise HTTPException(status_code=400, detail="Input text cannot be empty") | |
embeddings = [] | |
for idx, text in enumerate(request.input): | |
embedding_vector = embedding_model.get_embedding(text).tolist() | |
embeddings.append({ | |
"object": "embedding", | |
"embedding": embedding_vector, | |
"index": idx | |
}) | |
response = EmbeddingResponse( | |
data=embeddings, | |
model=request.model, | |
usage={ | |
"prompt_tokens": sum(len(text.split()) for text in request.input), | |
"total_tokens": sum(len(text.split()) for text in request.input) | |
} | |
) | |
return response |