cguynup commited on
Commit
ed8a7fb
·
1 Parent(s): 29013a7

Upload 2 files

Browse files
Files changed (2) hide show
  1. handler.py +26 -0
  2. requirements.txt +63 -0
handler.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from optimum.onnxruntime import ORTModelForSequenceClassification
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+
5
+
6
+ class EndpointHandler():
7
+ def __init__(self, path=""):
8
+ # load the optimized model
9
+ self.model = ORTModelForSequenceClassification.from_pretrained(path)
10
+ self.tokenizer = AutoTokenizer.from_pretrained(path)
11
+
12
+
13
+ def __call__(self, data):
14
+
15
+ answers = data.pop("answers")
16
+ paraphrases = data.pop("paraphrases")
17
+
18
+ inputs = self.tokenizer(answers, paraphrases, max_length=253, padding=True, truncation=True, return_tensors='pt')
19
+
20
+ with torch.no_grad():
21
+ outputs = self.model(**inputs)
22
+
23
+ logits = outputs.logits
24
+ predictions = torch.argmax(logits, dim=-1).numpy()
25
+
26
+ return list(predictions)
requirements.txt ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiohttp==3.9.0
2
+ aiosignal==1.3.1
3
+ attrs==23.1.0
4
+ certifi==2023.11.17
5
+ charset-normalizer==3.3.2
6
+ coloredlogs==15.0.1
7
+ datasets==2.15.0
8
+ dill==0.3.7
9
+ evaluate==0.4.1
10
+ filelock==3.13.1
11
+ flatbuffers==23.5.26
12
+ frozenlist==1.4.0
13
+ fsspec==2023.10.0
14
+ huggingface-hub==0.19.4
15
+ humanfriendly==10.0
16
+ idna==3.4
17
+ Jinja2==3.1.2
18
+ MarkupSafe==2.1.3
19
+ mpmath==1.3.0
20
+ multidict==6.0.4
21
+ multiprocess==0.70.15
22
+ networkx==3.2.1
23
+ numpy==1.26.2
24
+ nvidia-cublas-cu12==12.1.3.1
25
+ nvidia-cuda-cupti-cu12==12.1.105
26
+ nvidia-cuda-nvrtc-cu12==12.1.105
27
+ nvidia-cuda-runtime-cu12==12.1.105
28
+ nvidia-cudnn-cu12==8.9.2.26
29
+ nvidia-cufft-cu12==11.0.2.54
30
+ nvidia-curand-cu12==10.3.2.106
31
+ nvidia-cusolver-cu12==11.4.5.107
32
+ nvidia-cusparse-cu12==12.1.0.106
33
+ nvidia-nccl-cu12==2.18.1
34
+ nvidia-nvjitlink-cu12==12.3.101
35
+ nvidia-nvtx-cu12==12.1.105
36
+ onnx==1.15.0
37
+ onnxruntime==1.16.3
38
+ optimum==1.14.1
39
+ packaging==23.2
40
+ pandas==2.1.3
41
+ protobuf==4.25.1
42
+ pyarrow==14.0.1
43
+ pyarrow-hotfix==0.6
44
+ python-dateutil==2.8.2
45
+ pytz==2023.3.post1
46
+ PyYAML==6.0.1
47
+ regex==2023.10.3
48
+ requests==2.31.0
49
+ responses==0.18.0
50
+ safetensors==0.4.0
51
+ sentencepiece==0.1.99
52
+ six==1.16.0
53
+ sympy==1.12
54
+ tokenizers==0.15.0
55
+ torch==2.1.1
56
+ tqdm==4.66.1
57
+ transformers==4.35.2
58
+ triton==2.1.0
59
+ typing_extensions==4.8.0
60
+ tzdata==2023.3
61
+ urllib3==2.1.0
62
+ xxhash==3.4.1
63
+ yarl==1.9.3