chienweichang's picture
Update app.py
d118bab verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch
class EmbeddingModel:
def __init__(self, model_name="intfloat/multilingual-e5-large"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def get_embedding(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
app = FastAPI()
embedding_model = EmbeddingModel()
class EmbeddingRequest(BaseModel):
input: List[str]
model: str = "intfloat/multilingual-e5-large"
class EmbeddingResponse(BaseModel):
object: str = "embedding"
data: List[dict]
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
if not request.input:
raise HTTPException(status_code=400, detail="Input text cannot be empty")
embeddings = []
for idx, text in enumerate(request.input):
embedding_vector = embedding_model.get_embedding(text).tolist()
embeddings.append({
"object": "embedding",
"embedding": embedding_vector,
"index": idx
})
response = EmbeddingResponse(
data=embeddings,
model=request.model,
usage={
"prompt_tokens": sum(len(text.split()) for text in request.input),
"total_tokens": sum(len(text.split()) for text in request.input)
}
)
return response