C-Stuti commited on
Commit
901aaed
·
verified ·
1 Parent(s): 4a6b688

Upload 3 files

Browse files
Files changed (4) hide show
  1. .gitattributes +1 -0
  2. handler.py +121 -0
  3. model_optimized.onnx +3 -0
  4. onnx-mxbai.mar +3 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ onnx-mxbai.mar filter=lfs diff=lfs merge=lfs -text
handler.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Pipeline
2
+ import torch.nn.functional as F
3
+ import torch
4
+ from ts.torch_handler.base_handler import BaseHandler
5
+ import logging
6
+ import os
7
+ import transformers
8
+ from transformers import AutoTokenizer
9
+ logger = logging.getLogger(__name__)
10
+ logger.info("Transformers version %s", transformers.__version__)
11
+ from optimum.onnxruntime import ORTModelForFeatureExtraction
12
+
13
+ def mean_pooling(model_output, attention_mask):
14
+ token_embeddings = model_output[0]
15
+ input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
16
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
17
+
18
+ class SentenceEmbeddingHandler(BaseHandler):
19
+ def __init__(self):
20
+ super(SentenceEmbeddingHandler, self).__init__()
21
+ self._context = None
22
+ self.initialized = False
23
+ class SentenceEmbeddingPipeline(Pipeline):
24
+ def initialize(self, context):
25
+ """
26
+ Initialize function loads the model and the tokenizer
27
+
28
+ Args:
29
+ context (context): It is a JSON Object containing information
30
+ pertaining to the model artifacts parameters.
31
+
32
+ Raises:
33
+ RuntimeError: Raises the Runtime error when the model or
34
+ tokenizer is missing
35
+ """
36
+
37
+ properties = context.system_properties
38
+ self.manifest = context.manifest
39
+ model_dir = properties.get("model_dir")
40
+
41
+ # use GPU if available
42
+ self.device = torch.device(
43
+ "cuda:" + str(properties.get("gpu_id"))
44
+ if torch.cuda.is_available() and properties.get("gpu_id") is not None
45
+ else "cpu"
46
+ )
47
+ logger.info(f'Using device {self.device}')
48
+
49
+ # load the model
50
+ model_file = self.manifest['model']['modelFile']
51
+ model_path = os.path.join(model_dir, model_file)
52
+
53
+ if os.path.isfile(model_path):
54
+ # self.model = AutoModel.from_pretrained(model_dir)
55
+ self.model = ORTModelForFeatureExtraction.from_pretrained(model_dir, file_name="model_optimized.onnx")
56
+ self.model.to(self.device)
57
+
58
+ logger.info(f'Successfully loaded model from {model_file}')
59
+ else:
60
+ raise RuntimeError('Missing the model file')
61
+
62
+ # load tokenizer
63
+ self.tokenizer = AutoTokenizer.from_pretrained(model_dir)
64
+ if self.tokenizer is not None:
65
+ logger.info('Successfully loaded tokenizer')
66
+ else:
67
+ raise RuntimeError('Missing tokenizer')
68
+
69
+ self.initialized = True
70
+ def _sanitize_parameters(self, **kwargs):
71
+ # we don't have any hyperameters to sanitize
72
+ preprocess_kwargs = {}
73
+ return preprocess_kwargs, {}, {}
74
+
75
+ def preprocess_text(self, inputs):
76
+ encoded_inputs = self.tokenizer(inputs, padding=True, truncation=True, return_tensors='pt')
77
+ return encoded_inputs
78
+
79
+ def preprocess(self, requests):
80
+ """
81
+ Tokenize the input text using the suitable tokenizer and convert
82
+ it to tensor
83
+
84
+ If token_ids is provided, the json must be of the form
85
+ {'input_ids': [[101, 102]], 'token_type_ids': [[0, 0]], 'attention_mask': [[1, 1]]}
86
+
87
+ Args:
88
+ requests: A list containing a dictionary, might be in the form
89
+ of [{'body': json_file}] or [{'data': json_file}] or [{'token_ids': json_file}]
90
+ Returns:
91
+ the tensor containing the batch of token vectors.
92
+ """
93
+
94
+ # unpack the data
95
+ data = requests[0].get('body')
96
+ if data is None:
97
+ data = requests[0].get('data')
98
+
99
+ texts = data.get('input')
100
+ if texts is not None:
101
+ logger.info('Text provided')
102
+ return self.preprocess_text(texts)
103
+
104
+ encodings = data.get('encodings')
105
+ if encodings is not None:
106
+ logger.info('Encodings provided')
107
+ return transformers.BatchEncoding(data={k: torch.tensor(v) for k, v in encodings.items()})
108
+
109
+ raise Exception("unsupported payload")
110
+ def inference(self, model_inputs):
111
+ outputs = self.model(**model_inputs)
112
+ sentence_embeddings = mean_pooling(outputs, model_inputs['attention_mask'])
113
+ sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1)
114
+ return sentence_embeddings
115
+
116
+ def postprocess(self, outputs):
117
+ formatted_outputs = []
118
+ data=[outputs.tolist()]
119
+ for dat in data:
120
+ formatted_outputs.append({"status":"success","data":dat})
121
+ return formatted_outputs
model_optimized.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4c8222c98632a250933d2e1685aa2ba6bd8003cbcf13bf20d91f32b6965974f6
3
+ size 1336607159
onnx-mxbai.mar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:329b7ba82830bf75553d3e3024d9442b2b0a8d8cb81042c2d214e1d139b43099
3
+ size 592590703