Spaces:
Runtime error
Runtime error
ashish rai
commited on
Commit
·
13315ca
1
Parent(s):
81f2986
updated zeroshot
Browse files- zeroshot_clf.py +2 -2
zeroshot_clf.py
CHANGED
@@ -8,7 +8,7 @@ import plotly.express as px
|
|
8 |
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
|
9 |
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
|
10 |
|
11 |
-
def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
|
12 |
try:
|
13 |
labels=labels.split(',')
|
14 |
labels=[l.lower() for l in labels]
|
@@ -27,7 +27,7 @@ def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokeni
|
|
27 |
return_tensors='pt',
|
28 |
truncation_strategy='only_first')
|
29 |
output = model(input)
|
30 |
-
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
|
31 |
labels_prob.append(entail_contra_prob)
|
32 |
|
33 |
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
|
|
|
8 |
model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
|
9 |
tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
|
10 |
|
11 |
+
def zero_shot_classification(premise: str, labels: str, model= model, tokenizer= tokenizer):
|
12 |
try:
|
13 |
labels=labels.split(',')
|
14 |
labels=[l.lower() for l in labels]
|
|
|
27 |
return_tensors='pt',
|
28 |
truncation_strategy='only_first')
|
29 |
output = model(input)
|
30 |
+
entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
|
31 |
labels_prob.append(entail_contra_prob)
|
32 |
|
33 |
labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
|