Spaces:
Runtime error
Runtime error
Commit
·
a57c1e5
1
Parent(s):
ce75bf1
update model
Browse files- zeroshot_clf_helper.py +23 -6
zeroshot_clf_helper.py
CHANGED
@@ -5,6 +5,19 @@ import subprocess
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
def zero_shot_classification(premise: str, labels: str, model, tokenizer):
|
10 |
try:
|
@@ -40,23 +53,27 @@ def zero_shot_classification(premise: str, labels: str, model, tokenizer):
|
|
40 |
# labels='science, sports, museum')
|
41 |
|
42 |
|
43 |
-
def create_onnx_model_zs(
|
44 |
|
45 |
# create onnx model using
|
46 |
-
if not os.path.exists(
|
47 |
try:
|
48 |
subprocess.run(['python3', '-m', 'transformers.onnx',
|
49 |
'--model=valhalla/distilbart-mnli-12-1',
|
50 |
'--feature=sequence-classification',
|
51 |
-
|
52 |
-
except:
|
53 |
-
|
54 |
|
55 |
#create quanitzed model from vanila onnx
|
56 |
-
quantize_dynamic(f"{
|
|
|
|
|
57 |
else:
|
58 |
pass
|
59 |
|
|
|
|
|
60 |
def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
|
61 |
try:
|
62 |
labels=labels.split(',')
|
|
|
5 |
import numpy as np
|
6 |
import pandas as pd
|
7 |
|
8 |
+
import yaml
|
9 |
+
def read_yaml(file_path):
|
10 |
+
with open(file_path, "r") as f:
|
11 |
+
return yaml.safe_load(f)
|
12 |
+
|
13 |
+
config = read_yaml('config.yaml')
|
14 |
+
|
15 |
+
zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
|
16 |
+
zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
|
17 |
+
zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
|
18 |
+
zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
|
19 |
+
zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']
|
20 |
+
|
21 |
|
22 |
def zero_shot_classification(premise: str, labels: str, model, tokenizer):
|
23 |
try:
|
|
|
53 |
# labels='science, sports, museum')
|
54 |
|
55 |
|
56 |
+
def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):
|
57 |
|
58 |
# create onnx model using
|
59 |
+
if not os.path.exists(zs_onnx_mdl_dir):
|
60 |
try:
|
61 |
subprocess.run(['python3', '-m', 'transformers.onnx',
|
62 |
'--model=valhalla/distilbart-mnli-12-1',
|
63 |
'--feature=sequence-classification',
|
64 |
+
zs_onnx_mdl_dir])
|
65 |
+
except Exception as e:
|
66 |
+
print(e)
|
67 |
|
68 |
#create quanitzed model from vanila onnx
|
69 |
+
quantize_dynamic(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",
|
70 |
+
f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}",
|
71 |
+
weight_type=QuantType.QUInt8)
|
72 |
else:
|
73 |
pass
|
74 |
|
75 |
+
create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir)
|
76 |
+
|
77 |
def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
|
78 |
try:
|
79 |
labels=labels.split(',')
|