File size: 1,883 Bytes
1286203
 
 
 
 
ded6a94
 
 
 
ddd62b8
 
 
ded6a94
 
1286203
 
 
 
 
 
 
 
 
 
 
 
ded6a94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1286203
 
 
 
ded6a94
b8a7810
ded6a94
b8a7810
 
ae798e4
 
 
ddd62b8
ae798e4
 
 
 
 
 
 
065b0d2
ded6a94
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import re
import urllib
import json
from fastapi import FastAPI, HTTPException, Query
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModel
import torch
from torch import Tensor
import torch.nn.functional as F
from Allam_Backend_HF import (
    allam_llm
    ) 
import os
os.environ['HF_HOME'] = '/'

app = FastAPI()

# Enable CORS
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

model_name = "intfloat/multilingual-e5-large"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)

def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]

def embed_single_text(text: str) -> Tensor:
    tokenizer = AutoTokenizer.from_pretrained('intfloat/multilingual-e5-large')
    model = AutoModel.from_pretrained('intfloat/multilingual-e5-large').cpu()

    batch_dict = tokenizer(text, max_length=512, padding=True, truncation=True, return_tensors='pt')

    with torch.no_grad():
        outputs = model(**batch_dict)

    embedding = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])

    embedding = F.normalize(embedding, p=2, dim=1)

    return embedding


@app.get("/e5_embeddings")
def e5_embeddings(query: str = Query(...)):

    result = embed_single_text([query])

    if result is not None:
        return result.tolist()
    else:
        raise HTTPException(status_code=500)


@app.get("/allam_response")
def allam_response(query: str = Query(...)):

    result = allam_llm(query)

    if result is not None:
        return result
    else:
        raise HTTPException(status_code=500)