ashish rai commited on
Commit
0d4914a
·
1 Parent(s): 19da0ee

added script for onnx sent clf

Browse files
Files changed (1) hide show
  1. sentiment_onnx_classify.py +67 -0
sentiment_onnx_classify.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as ort
2
+ import torch
3
+ from transformers import AutoTokenizer
4
+ import numpy as np
5
+
6
+ tokenizer=AutoTokenizer.from_pretrained("sentiment_classifier/")
7
+
8
+ #create onnx & onnx_int_8 sessions
9
+ session=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx.onnx")
10
+ session_int8=ort.InferenceSession("sent_clf_onnx/sentiment_classifier_onnx_int8.onnx")
11
+
12
+
13
+
14
+ def classify_sentiment_onnx(texts,_model=session,_tokenizer=tokenizer):
15
+ """
16
+ user will pass texts separated by comma
17
+ """
18
+ try:
19
+ texts=texts.split(',')
20
+ except:
21
+ pass
22
+
23
+ _inputs = _tokenizer(texts, padding=True, truncation=True,
24
+ return_tensors="np")
25
+
26
+ input_feed={
27
+ "input_ids":np.array(_inputs['input_ids']),
28
+ "attention_mask":np.array((_inputs['attention_mask']))
29
+ }
30
+
31
+ output = _model.run(input_feed=input_feed, output_names=['output_0'])[0]
32
+
33
+ output=np.argmax(output,axis=1)
34
+ output = ['Positive' if i == 1 else 'Negative' for i in output]
35
+ return output
36
+
37
+ def classify_sentiment_onnx_quant(texts, _model=session_int8, _tokenizer=tokenizer):
38
+ """
39
+ user will pass texts separated by comma
40
+ """
41
+ try:
42
+ texts=texts.split(',')
43
+ except:
44
+ pass
45
+
46
+ _inputs = _tokenizer(texts, padding=True, truncation=True,
47
+ return_tensors="np")
48
+
49
+
50
+ input_feed={
51
+ "input_ids":np.array(_inputs['input_ids']),
52
+ "attention_mask":np.array((_inputs['attention_mask']))
53
+ }
54
+
55
+ output = _model.run(input_feed=input_feed, output_names=['output_0'])[0]
56
+
57
+ output=np.argmax(output,axis=1)
58
+ output = ['Positive' if i == 1 else 'Negative' for i in output]
59
+
60
+ return output
61
+
62
+
63
+
64
+
65
+
66
+
67
+