File size: 3,624 Bytes
a48f2db
 
 
 
 
 
 
a57c1e5
 
 
 
 
 
 
 
 
 
 
 
 
a48f2db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a57c1e5
a48f2db
 
a57c1e5
a48f2db
 
9111b95
a48f2db
a57c1e5
 
 
a48f2db
9811800
 
 
 
a48f2db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from onnxruntime.quantization import quantize_dynamic,QuantType
import os
import subprocess
import numpy as np
import pandas as pd

import yaml
def read_yaml(file_path):
    with open(file_path, "r") as f:
        return yaml.safe_load(f)

config = read_yaml('config.yaml')

zs_chkpt=config['ZEROSHOT_CLF']['zs_chkpt']
zs_mdl_dir=config['ZEROSHOT_CLF']['zs_mdl_dir']
zs_onnx_mdl_dir=config['ZEROSHOT_CLF']['zs_onnx_mdl_dir']
zs_onnx_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_mdl_name']
zs_onnx_quant_mdl_name=config['ZEROSHOT_CLF']['zs_onnx_quant_mdl_name']


def zero_shot_classification(premise: str, labels: str, model, tokenizer):
    try:
        labels=labels.split(',')
        labels=[l.lower() for l in labels]
    except:
        raise Exception("please pass atleast 2 labels to classify")

    premise=premise.lower()

    labels_prob=[]

    for l in labels:

        hypothesis= f'this is an example of {l}'

        input = tokenizer.encode(premise,hypothesis,
                             return_tensors='pt',
                                 truncation_strategy='only_first')
        output = model(input)
        entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
        labels_prob.append(entail_contra_prob)

    labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]

    df=pd.DataFrame({'labels':labels,
                     'Probability':labels_prob_norm})

    return df

##example
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
#                          labels='science, sports, museum')


def create_onnx_model_zs(zs_onnx_mdl_dir=zs_onnx_mdl_dir):

    # create onnx model using
    if not os.path.exists(zs_onnx_mdl_dir):
        try:
            subprocess.run(['python3', '-m', 'transformers.onnx',
                            '--model=valhalla/distilbart-mnli-12-1',
                            '--feature=sequence-classification',
                            zs_onnx_mdl_dir])
        except Exception as e:
            print(e)

        # #create quanitzed model from vanila onnx
        # quantize_dynamic(f"{zs_onnx_mdl_dir}/{zs_onnx_mdl_name}",
        #                  f"{zs_onnx_mdl_dir}/{zs_onnx_quant_mdl_name}",
        #                  weight_type=QuantType.QUInt8)
    else:
        pass

def zero_shot_classification_onnx(premise,labels,_session,_tokenizer):
    try:
        labels=labels.split(',')
        labels=[l.lower() for l in labels]
    except:
        raise Exception("please pass atleast 2 labels to classify")

    premise=premise.lower()

    labels_prob=[]

    for l in labels:

        hypothesis= f'this is an example of {l}'

        inputs = _tokenizer(premise,hypothesis,
                             return_tensors='pt',
                                 truncation_strategy='only_first')

        input_feed = {
            "input_ids": np.array(inputs['input_ids']),
            "attention_mask": np.array((inputs['attention_mask']))
        }

        output = _session.run(output_names=["logits"],input_feed=dict(input_feed))[0] #returns logits as array
        output=torch.from_numpy(output)
        entail_contra_prob = output[:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
        labels_prob.append(entail_contra_prob)

    labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]

    df=pd.DataFrame({'labels':labels,
                     'Probability':labels_prob_norm})

    return df