Chidam Gopal commited on
Commit
2cba4b1
1 Parent(s): 7538db6

directly use the onnx quantized file

Browse files
Files changed (2) hide show
  1. infer_intent.py +19 -5
  2. requirements.txt +3 -1
infer_intent.py CHANGED
@@ -1,5 +1,9 @@
1
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
  import torch
 
 
 
 
3
 
4
 
5
  class IntentClassifier:
@@ -15,10 +19,20 @@ class IntentClassifier:
15
  self.label2id = {label:id for id,label in self.id2label.items()}
16
 
17
  self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
18
- self.intent_model = AutoModelForSequenceClassification.from_pretrained('Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier',
19
- num_labels=8,
20
- id2label=self.id2label,
21
- label2id=self.label2id)
 
 
 
 
 
 
 
 
 
 
22
 
23
  def find_intent(self, sequence, verbose=False):
24
  inputs = self.tokenizer(sequence,
 
1
+ from transformers import AutoTokenizer
2
  import torch
3
+ import onnxruntime as ort
4
+ import numpy as np
5
+ import requests
6
+ import os
7
 
8
 
9
  class IntentClassifier:
 
19
  self.label2id = {label:id for id,label in self.id2label.items()}
20
 
21
  self.tokenizer = AutoTokenizer.from_pretrained("Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier")
22
+ model_url = "https://huggingface.co/Mozilla/mobilebert-uncased-finetuned-LoRA-intent-classifier/resolve/main/onnx/model_quantized.onnx"
23
+ model_dir_path = "models"
24
+ model_path = f"{model_dir_path}/mobilebert-uncased-finetuned-LoRA-intent-classifier_model_quantized.onnx"
25
+ if not os.path.exists(model_dir_path):
26
+ os.makedirs(model_dir_path)
27
+ if not os.path.exists(model_path):
28
+ print("Downloading ONNX model...")
29
+ response = requests.get(model_url)
30
+ with open(model_path, "wb") as f:
31
+ f.write(response.content)
32
+ print("ONNX model downloaded.")
33
+
34
+ # Load the ONNX model
35
+ self.ort_session = ort.InferenceSession(model_path)
36
 
37
  def find_intent(self, sequence, verbose=False):
38
  inputs = self.tokenizer(sequence,
requirements.txt CHANGED
@@ -1,4 +1,6 @@
1
  transformers==4.45.1
2
  torch==2.4.1
3
  streamlit==1.38.0
4
- matplotlib==3.9.2
 
 
 
1
  transformers==4.45.1
2
  torch==2.4.1
3
  streamlit==1.38.0
4
+ matplotlib==3.9.2
5
+ ## onnx
6
+ onnxruntime==1.19.2