Chidam Gopal commited on
Commit
84af19a
1 Parent(s): 2cba4b1

updates for onnx

Browse files
Files changed (1) hide show
  1. infer_intent.py +19 -14
infer_intent.py CHANGED
@@ -41,20 +41,25 @@ class IntentClassifier:
41
  truncation=True, # Truncate if the text is too long
42
  max_length=64)
43
 
44
- self.intent_model.eval()
45
- with torch.no_grad():
46
- outputs = self.intent_model(**inputs)
47
- logits = outputs.logits
48
- prediction = torch.argmax(logits, dim=1).item()
49
- probabilities = torch.softmax(logits, dim=1)
50
- rounded_probabilities = torch.round(probabilities, decimals=3)
51
-
52
- pred_result = self.id2label[prediction]
53
- proba_result = dict(zip(self.label2id.keys(), rounded_probabilities.tolist()[0]))
54
- if verbose:
55
- print(sequence + " -> " + pred_result)
56
- print(proba_result, "\n")
57
- return pred_result, proba_result
 
 
 
 
 
58
 
59
 
60
  def main():
 
41
  truncation=True, # Truncate if the text is too long
42
  max_length=64)
43
 
44
+ # Convert inputs to NumPy arrays
45
+ onnx_inputs = {k: v for k, v in inputs.items()}
46
+
47
+ # Run the ONNX model
48
+ logits = self.ort_session.run(None, onnx_inputs)[0]
49
+
50
+ # Get the prediction
51
+ prediction = np.argmax(logits, axis=1)[0]
52
+ probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
53
+ rounded_probabilities = np.round(probabilities, decimals=3)
54
+
55
+ pred_result = self.id2label[prediction]
56
+ proba_result = dict(zip(self.label2id.keys(), rounded_probabilities[0].tolist()))
57
+
58
+ if verbose:
59
+ print(sequence + " -> " + pred_result)
60
+ print(proba_result, "\n")
61
+
62
+ return pred_result, proba_result
63
 
64
 
65
  def main():