Spaces:
Sleeping
Sleeping
Chidam Gopal
commited on
Commit
•
84af19a
1
Parent(s):
2cba4b1
updates for onnx
Browse files- infer_intent.py +19 -14
infer_intent.py
CHANGED
@@ -41,20 +41,25 @@ class IntentClassifier:
|
|
41 |
truncation=True, # Truncate if the text is too long
|
42 |
max_length=64)
|
43 |
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
|
|
|
|
|
|
58 |
|
59 |
|
60 |
def main():
|
|
|
41 |
truncation=True, # Truncate if the text is too long
|
42 |
max_length=64)
|
43 |
|
44 |
+
# Convert inputs to NumPy arrays
|
45 |
+
onnx_inputs = {k: v for k, v in inputs.items()}
|
46 |
+
|
47 |
+
# Run the ONNX model
|
48 |
+
logits = self.ort_session.run(None, onnx_inputs)[0]
|
49 |
+
|
50 |
+
# Get the prediction
|
51 |
+
prediction = np.argmax(logits, axis=1)[0]
|
52 |
+
probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=1, keepdims=True)
|
53 |
+
rounded_probabilities = np.round(probabilities, decimals=3)
|
54 |
+
|
55 |
+
pred_result = self.id2label[prediction]
|
56 |
+
proba_result = dict(zip(self.label2id.keys(), rounded_probabilities[0].tolist()))
|
57 |
+
|
58 |
+
if verbose:
|
59 |
+
print(sequence + " -> " + pred_result)
|
60 |
+
print(proba_result, "\n")
|
61 |
+
|
62 |
+
return pred_result, proba_result
|
63 |
|
64 |
|
65 |
def main():
|