zhenyundeng commited on
Commit
7156297
·
1 Parent(s): 287a9db

update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -6
app.py CHANGED
@@ -75,16 +75,20 @@ LABEL = [
75
  ]
76
 
77
  # Veracity
78
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
80
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
81
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
82
- veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
 
 
83
  # Justification
84
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
85
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
86
  best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
87
- justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
 
 
88
 
89
  print("veracity_model_device_0:{}".format(veracity_model.device))
90
  print("justification_model_device_0:{}".format(justification_model.device))
@@ -281,8 +285,8 @@ def veracity_prediction(claim, evidence):
281
  return pred_label
282
 
283
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
284
- example_support = torch.argmax(
285
- veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
286
  print("veracity_model_device_1:{}".format(veracity_model.device))
287
 
288
  has_unanswerable = False
@@ -345,7 +349,8 @@ def justification_generation(claim, evidence, verdict_label):
345
  #
346
  claim_str = extract_claim_str(claim, evidence, verdict_label)
347
  claim_str.strip()
348
- pred_justification = justification_model.generate(claim_str, device=device)
 
349
  print("justification_model_device_1:{}".format(justification_model.device))
350
 
351
  return pred_justification.strip()
 
75
  ]
76
 
77
  # Veracity
78
+ # device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
80
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
81
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
82
+ veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model)
83
+ # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
84
+
85
  # Justification
86
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
87
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
88
  best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
89
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model)
90
+ # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
91
+
92
 
93
  print("veracity_model_device_0:{}".format(veracity_model.device))
94
  print("justification_model_device_0:{}".format(justification_model.device))
 
285
  return pred_label
286
 
287
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
288
+ example_support = torch.argmax(veracity_model(tokenized_strings, attention_mask=attention_mask).logits, axis=1)
289
+ # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
290
  print("veracity_model_device_1:{}".format(veracity_model.device))
291
 
292
  has_unanswerable = False
 
349
  #
350
  claim_str = extract_claim_str(claim, evidence, verdict_label)
351
  claim_str.strip()
352
+ pred_justification = justification_model.generate(claim_str)
353
+ # pred_justification = justification_model.generate(claim_str, device=device)
354
  print("justification_model_device_1:{}".format(justification_model.device))
355
 
356
  return pred_justification.strip()