Spaces:
Runtime error
Runtime error
File size: 1,575 Bytes
e867b58 0ca767f e867b58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
import pandas as pd
import streamlit
import torch
from transformers import AutoModelForSequenceClassification,AutoTokenizer
import numpy as np
import plotly.express as px
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
try:
labels=labels.split(',')
labels=[l.lower() for l in labels]
except:
raise Exception("please pass atleast 2 labels to classify")
premise=premise.lower()
labels_prob=[]
for l in labels:
hypothesis= f'this is an example of {l}'
input = tokenizer.encode(premise,hypothesis,
return_tensors='pt',
truncation_strategy='only_first')
output = model(input)
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
labels_prob.append(entail_contra_prob)
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
df=pd.DataFrame({'labels':labels,
'Probability':labels_prob_norm})
fig=px.bar(x='Probability',
y='labels',
text='Probability',
data_frame=df,
title='Zero Shot Normalized Probabilities')
return fig
# zero_shot_classification(premise='Tiny worms and breath analyzers could screen for \disease while it’s early and treatable',
# labels='science, sports, museum')
|