File size: 1,511 Bytes
4e5c5cb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a3a3881
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from sentence_transformers import SentenceTransformer
import litserve as ls
from huggingface_hub import login
import os

login(token=os.getenv("HF_TOKEN"))

DATA_PATH = os.getenv("DATA_PATH")
RETRIEVAL_MODEL_NAME = os.getenv("RETRIEVAL_MODEL_NAME")
SIMILARITY_MODEL_NAME = os.getenv("SIMILARITY_MODEL_NAME")

class EmbeddingModelAPI(ls.LitAPI):
    def setup(self, device):
        self.retrieval_model = SentenceTransformer(
            os.path.join(DATA_PATH, RETRIEVAL_MODEL_NAME),
            backend="onnx",
            model_kwargs={"file_name": "onnx/model.onnx"},
            trust_remote_code=True,
        )
        self.similarity_model = SentenceTransformer(
            os.path.join(DATA_PATH, SIMILARITY_MODEL_NAME),
            backend="onnx",
            model_kwargs={"file_name": "onnx/model.onnx"},
            trust_remote_code=True,
        )

    def decode_request(self, request, **kwargs):
        sentences = request["sentences"]
        type = request["type"]
        return sentences, type
    
    def predict(self, x, **kwargs):
        chunks, type = x
        if type == "default":
            return self.retrieval_model.encode(chunks).tolist()
        elif type == "similarity":
            return self.similarity_model.encode(chunks).tolist()
    
    def encode_response(self, output, **kwargs):
        return {"data": output}
    

if __name__ == "__main__":
    api = EmbeddingModelAPI()
    server = ls.LitServer(api)
    server.run(generate_client_file=False, port=7860)