Spaces:
Sleeping
Sleeping
zhenyundeng
commited on
Commit
•
cc6c0eb
1
Parent(s):
8e3188e
update
Browse files
app.py
CHANGED
@@ -75,19 +75,19 @@ LABEL = [
|
|
75 |
]
|
76 |
|
77 |
# Veracity
|
78 |
-
|
79 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
80 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
81 |
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
|
82 |
-
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
|
83 |
-
|
84 |
|
85 |
# Justification
|
86 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
87 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
88 |
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
89 |
-
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
|
90 |
-
|
91 |
# ---------------------------------------------------------------------------
|
92 |
|
93 |
# ----------------------------------------------------------------------------
|
@@ -281,8 +281,8 @@ def veracity_prediction(claim, evidence):
|
|
281 |
return pred_label
|
282 |
|
283 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
284 |
-
example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
|
285 |
-
|
286 |
|
287 |
has_unanswerable = False
|
288 |
has_true = False
|
@@ -344,8 +344,8 @@ def justification_generation(claim, evidence, verdict_label):
|
|
344 |
#
|
345 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
346 |
claim_str.strip()
|
347 |
-
pred_justification = justification_model.generate(claim_str, device='cuda')
|
348 |
-
|
349 |
|
350 |
return pred_justification.strip()
|
351 |
|
|
|
75 |
]
|
76 |
|
77 |
# Veracity
|
78 |
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
79 |
veracity_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
|
80 |
bert_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4, problem_type="single_label_classification")
|
81 |
veracity_checkpoint_path = os.getcwd() + "/averitec/pretrained_models/bert_veracity.ckpt"
|
82 |
+
# veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to('cuda')
|
83 |
+
veracity_model = SequenceClassificationModule.load_from_checkpoint(veracity_checkpoint_path,tokenizer=veracity_tokenizer, model=bert_model).to(device)
|
84 |
|
85 |
# Justification
|
86 |
justification_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large', add_prefix_space=True)
|
87 |
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-large")
|
88 |
best_checkpoint = os.getcwd() + '/averitec/pretrained_models/bart_justifications_verdict-epoch=13-val_loss=2.03-val_meteor=0.28.ckpt'
|
89 |
+
# justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to('cuda')
|
90 |
+
justification_model = JustificationGenerationModule.load_from_checkpoint(best_checkpoint, tokenizer=justification_tokenizer, model=bart_model).to(device)
|
91 |
# ---------------------------------------------------------------------------
|
92 |
|
93 |
# ----------------------------------------------------------------------------
|
|
|
281 |
return pred_label
|
282 |
|
283 |
tokenized_strings, attention_mask = dataLoader.tokenize_strings(evidence_strings)
|
284 |
+
# example_support = torch.argmax(veracity_model(tokenized_strings.to('cuda'), attention_mask=attention_mask.to('cuda')).logits, axis=1)
|
285 |
+
example_support = torch.argmax(veracity_model(tokenized_strings.to(device), attention_mask=attention_mask.to(device)).logits, axis=1)
|
286 |
|
287 |
has_unanswerable = False
|
288 |
has_true = False
|
|
|
344 |
#
|
345 |
claim_str = extract_claim_str(claim, evidence, verdict_label)
|
346 |
claim_str.strip()
|
347 |
+
# pred_justification = justification_model.generate(claim_str, device='cuda')
|
348 |
+
pred_justification = justification_model.generate(claim_str, device=device)
|
349 |
|
350 |
return pred_justification.strip()
|
351 |
|