Spaces:
Sleeping
Sleeping
File size: 1,757 Bytes
d26cb51 d118bab d26cb51 d118bab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List
from transformers import AutoTokenizer, AutoModel
import torch
class EmbeddingModel:
def __init__(self, model_name="intfloat/multilingual-e5-large"):
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def get_embedding(self, text):
inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
with torch.no_grad():
outputs = self.model(**inputs)
return outputs.last_hidden_state.mean(dim=1).squeeze().numpy()
app = FastAPI()
embedding_model = EmbeddingModel()
class EmbeddingRequest(BaseModel):
input: List[str]
model: str = "intfloat/multilingual-e5-large"
class EmbeddingResponse(BaseModel):
object: str = "embedding"
data: List[dict]
model: str
usage: dict
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest):
if not request.input:
raise HTTPException(status_code=400, detail="Input text cannot be empty")
embeddings = []
for idx, text in enumerate(request.input):
embedding_vector = embedding_model.get_embedding(text).tolist()
embeddings.append({
"object": "embedding",
"embedding": embedding_vector,
"index": idx
})
response = EmbeddingResponse(
data=embeddings,
model=request.model,
usage={
"prompt_tokens": sum(len(text.split()) for text in request.input),
"total_tokens": sum(len(text.split()) for text in request.input)
}
)
return response
|