E-slam's picture
Update main.py
78d5601 verified
raw
history blame
1.61 kB
import re
import urllib
import json
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import torch
from torch import Tensor
import torch.nn.functional as F
os.environ['HF_HOME'] = '/'
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def embed_single_text(text: str) -> Tensor:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**batch_dict)
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
@app.get("/e5_embeddings")
def e5_embeddings(query: str = Query(...)):
result = embed_single_text([query])
if result is not None:
return result.tolist()
else:
raise HTTPException(status_code=500)