NLP / zeroshot_clf.py
ashishraics's picture
deleted model bin
ec4681c
raw
history blame
1.66 kB
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px
chkpt='valhalla/distilbart-mnli-12-1'
model=AutoModelForSequenceClassification.from_pretrained(chkpt)
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
def zero_shot_classification(premise: str, labels: str, model= model, tokenizer= tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
fig=px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df,
title='Zero Shot Normalized Probabilities')
return fig
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
# labels='science, sports, museum')