zhenyundeng commited on
Commit
cc6c0eb
1 Parent(s): 8e3188e
Files changed (1) hide show
  1. app.py +9 -9
app.py CHANGED
@@ -75,19 +75,19 @@ LABEL = [
75
  ]
76
 
77
  # Veracity
78
- # device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
80
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
81
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
82
- veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
83
- # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
84
 
85
  # Justification
86
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
87
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
88
  best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
89
- justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
90
- # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
91
  # ---------------------------------------------------------------------------
92
 
93
  # ----------------------------------------------------------------------------
@@ -281,8 +281,8 @@ def veracity_prediction(claim, evidence):
281
  return pred_label
282
 
283
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
284
- example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
285
- # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
286
 
287
  has_unanswerable = False
288
  has_true = False
@@ -344,8 +344,8 @@ def justification_generation(claim, evidence, verdict_label):
344
  #
345
  claim_str = extract_claim_str(claim, evidence, verdict_label)
346
  claim_str.strip()
347
- pred_justification = justification_model.generate(claim_str, device='cuda')
348
- # pred_justification = justification_model.generate(claim_str, device=device)
349
 
350
  return pred_justification.strip()
351
 
 
75
  ]
76
 
77
  # Veracity
78
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
79
  veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
80
  bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
81
  veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
82
+ # veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
83
+ veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
84
 
85
  # Justification
86
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
87
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
88
  best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
89
+ # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
90
+ justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
91
  # ---------------------------------------------------------------------------
92
 
93
  # ----------------------------------------------------------------------------
 
281
  return pred_label
282
 
283
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
284
+ # example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
285
+ example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
286
 
287
  has_unanswerable = False
288
  has_true = False
 
344
  #
345
  claim_str = extract_claim_str(claim, evidence, verdict_label)
346
  claim_str.strip()
347
+ # pred_justification = justification_model.generate(claim_str, device='cuda')
348
+ pred_justification = justification_model.generate(claim_str, device=device)
349
 
350
  return pred_justification.strip()
351