ashishraics commited on
Commit
a57c1e5
·
1 Parent(s): ce75bf1

update model

Browse files
Files changed (1) hide show
  1. zeroshot_clf_helper.py +23 -6
zeroshot_clf_helper.py CHANGED
@@ -5,6 +5,19 @@ import subprocess
5
  import numpy as np
6
  import pandas as pd
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
10
  try:
@@ -40,23 +53,27 @@ def zero_shot_classification(premise: str, labels: str, model, tokenizer):
40
  # labels='science, sports, museum')
41
 
42
 
43
- def create_onnx_model_zs(art_path='zeroshot_onnx_dir'):
44
 
45
  # create onnx model using
46
- if not os.path.exists(art_path):
47
  try:
48
  subprocess.run(['python3', '-m', 'transformers.onnx',
49
  '--model=valhalla/distilbart-mnli-12-1',
50
  '--feature=sequence-classification',
51
- art_path])
52
- except:
53
- pass
54
 
55
  #create quanitzed model from vanila onnx
56
- quantize_dynamic(f"{art_path}/model.onnx",f"{art_path}/model_quant.onnx",weight_type=QuantType.QUInt8)
 
 
57
  else:
58
  pass
59
 
 
 
60
  def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
61
  try:
62
  labels=labels.split(',')
 
5
  import numpy as np
6
  import pandas as pd
7
 
8
+ import yaml
9
+ def read_yaml(file_path):
10
+ with open(file_path, "r") as f:
11
+ return yaml.safe_load(f)
12
+
13
+ config = read_yaml('config.yaml')
14
+
15
+ zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
16
+ zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
17
+ zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
18
+ zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
19
+ zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
20
+
21
 
22
  def zero_shot_classification(premise: str, labels: str, model, tokenizer):
23
  try:
 
53
  # labels='science, sports, museum')
54
 
55
 
56
+ def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
57
 
58
  # create onnx model using
59
+ if not os.path.exists(zs_onnx_mdl_dir):
60
  try:
61
  subprocess.run(['python3', '-m', 'transformers.onnx',
62
  '--model=valhalla/distilbart-mnli-12-1',
63
  '--feature=sequence-classification',
64
+ zs_onnx_mdl_dir])
65
+ except Exception as e:
66
+ print(e)
67
 
68
  #create quanitzed model from vanila onnx
69
+ quantize_dynamic(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",
70
+ f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}",
71
+ weight_type=QuantType.QUInt8)
72
  else:
73
  pass
74
 
75
+ create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir)
76
+
77
  def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
78
  try:
79
  labels=labels.split(',')