chienweichang's picture
Update app.py
d118bab verified
raw
history blame
1.76 kB
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch
class EmbeddingModel:
def __init__(self, model_name="intfloat/multilingual-e5-large"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def get_embedding(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
app = FastAPI()
embedding_model = EmbeddingModel()
class EmbeddingRequest(BaseModel):
input: List[str]
model: str = "intfloat/multilingual-e5-large"
class EmbeddingResponse(BaseModel):
object: str = "embedding"
data: List[dict]
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
if not request.input:
raise HTTPException(status_code=400, detail="Input text cannot be empty")
embeddings = []
for idx, text in enumerate(request.input):
embedding_vector = embedding_model.get_embedding(text).tolist()
embeddings.append({
"object": "embedding",
"embedding": embedding_vector,
"index": idx
})
response = EmbeddingResponse(
data=embeddings,
model=request.model,
usage={
"prompt_tokens": sum(len(text.split()) for text in request.input),
"total_tokens": sum(len(text.split()) for text in request.input)
}
)
return response