NLP / zeroshot_clf.py
ashish rai
updated zero shot fig
0ca767f
raw
history blame
1.58 kB
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
fig=px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df,
title='Zero Shot Normalized Probabilities')
return fig
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
# labels='science, sports, museum')