Spaces:
Running
Running
File size: 1,883 Bytes
1286203 ded6a94 ddd62b8 ded6a94 1286203 ded6a94 1286203 ded6a94 b8a7810 ded6a94 b8a7810 ae798e4 ddd62b8 ae798e4 065b0d2 ded6a94 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import re
import urllib
import json
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import torch
from torch import Tensor
import torch.nn.functional as F
from Allam_Backend_HF import (
allam_llm
)
import os
os.environ['HF_HOME'] = '/'
app = FastAPI()
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model_name = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)
def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
def embed_single_text(text: str) -> Tensor:
tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()
batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')
with torch.no_grad():
outputs = model(**batch_dict)
embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
embedding = F.normalize(embedding, p=2, dim=1)
return embedding
@app.get("/e5_embeddings")
def e5_embeddings(query: str = Query(...)):
result = embed_single_text([query])
if result is not None:
return result.tolist()
else:
raise HTTPException(status_code=500)
@app.get("/allam_response")
def allam_response(query: str = Query(...)):
result = allam_llm(query)
if result is not None:
return result
else:
raise HTTPException(status_code=500) |