ashish rai commited on
Commit
13315ca
·
1 Parent(s): 81f2986

updated zeroshot

Browse files
Files changed (1) hide show
  1. zeroshot_clf.py +2 -2
zeroshot_clf.py CHANGED
@@ -8,7 +8,7 @@ import plotly.express as px
8
  model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
9
  tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
10
 
11
- def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokenizer):
12
  try:
13
  labels=labels.split(',')
14
  labels=[l.lower() for l in labels]
@@ -27,7 +27,7 @@ def zero_shot_classification(premise:str,labels:str,model=model,tokenizer=tokeni
27
  return_tensors='pt',
28
  truncation_strategy='only_first')
29
  output = model(input)
30
- entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item()
31
  labels_prob.append(entail_contra_prob)
32
 
33
  labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]
 
8
  model=AutoModelForSequenceClassification.from_pretrained('zero_shot_clf/')
9
  tokenizer=AutoTokenizer.from_pretrained('zero_shot_clf/')
10
 
11
+ def zero_shot_classification(premise: str, labels: str, model= model, tokenizer= tokenizer):
12
  try:
13
  labels=labels.split(',')
14
  labels=[l.lower() for l in labels]
 
27
  return_tensors='pt',
28
  truncation_strategy='only_first')
29
  output = model(input)
30
+ entail_contra_prob = output['logits'][:,[0,2]].softmax(dim=1)[:,1].item() #only normalizing entail & contradict probabilties
31
  labels_prob.append(entail_contra_prob)
32
 
33
  labels_prob_norm=[np.round(100*c/np.sum(labels_prob),1) for c in labels_prob]