NLP / zeroshot_clf.py
ashish rai
updated zeroshot
13315ca
raw
history blame
1.63 kB
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
def zero_shot_classification(premise: str, labels: str, model= model, tokenizer= tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
fig=px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df,
title='Zero Shot Normalized Probabilities')
return fig
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
# labels='science, sports, museum')