zhenyundeng commited on
Commit
f58b8d2
·
1 Parent(s): c052247

update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -9
app.py CHANGED
@@ -85,13 +85,9 @@ veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_chec
85
  # Justification
86
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
87
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
88
- best_checkpoint = os.getcwd()+ '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
89
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
90
  # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
91
-
92
-
93
- print("veracity_model_device_0:{}".format(veracity_model.device))
94
- print("justification_model_device_0:{}".format(justification_model.device))
95
  # ---------------------------------------------------------------------------
96
 
97
  # ----------------------------------------------------------------------------
@@ -285,9 +281,8 @@ def veracity_prediction(claim, evidence):
285
  return pred_label
286
 
287
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
288
- example_support = torch.argmax(veracity_model(tokenized_strings, attention_mask=attention_mask).logits, axis=1)
289
  # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
290
- print("veracity_model_device_1:{}".format(veracity_model.device))
291
 
292
  has_unanswerable = False
293
  has_true = False
@@ -349,9 +344,8 @@ def justification_generation(claim, evidence, verdict_label):
349
  #
350
  claim_str = extract_claim_str(claim, evidence, verdict_label)
351
  claim_str.strip()
352
- pred_justification = justification_model.generate(claim_str)
353
  # pred_justification = justification_model.generate(claim_str, device=device)
354
- print("justification_model_device_1:{}".format(justification_model.device))
355
 
356
  return pred_justification.strip()
357
 
 
85
  # Justification
86
  justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
87
  bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
88
+ best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
89
  justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
90
  # justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
 
 
 
 
91
  # ---------------------------------------------------------------------------
92
 
93
  # ----------------------------------------------------------------------------
 
281
  return pred_label
282
 
283
  tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
284
+ example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
285
  # example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
 
286
 
287
  has_unanswerable = False
288
  has_true = False
 
344
  #
345
  claim_str = extract_claim_str(claim, evidence, verdict_label)
346
  claim_str.strip()
347
+ pred_justification = justification_model.generate(claim_str, device='cuda')
348
  # pred_justification = justification_model.generate(claim_str, device=device)
 
349
 
350
  return pred_justification.strip()
351