File size: 1,575 Bytes
e867b58
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ca767f
 
 
 
 
 
 
 
e867b58
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px

model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')

def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
    try:
        labels=labels.split(',')
        labels=[l.lower() for l in labels]
    except:
        raise Exception("please pass atleast 2 labels to classify")

    premise=premise.lower()

    labels_prob=[]

    for l in labels:

        hypothesis= f'this is an example of {l}'

        input = tokenizer.encode(premise,hypothesis,
                             return_tensors='pt',
                                 truncation_strategy='only_first')
        output = model(input)
        entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
        labels_prob.append(entail_contra_prob)

    labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]

    df=pd.DataFrame({'labels':labels,
                     'Probability':labels_prob_norm})
    fig=px.bar(x='Probability',
               y='labels',
               text='Probability',
               data_frame=df,
               title='Zero Shot Normalized Probabilities')
    return fig

# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
#                          labels='science, sports, museum')