fact-checking-api / averitec /models /SequenceClassificationModule.py
zhenyundeng
add files
afdeeca
raw
history blame
5.75 kB
import pytorch_lightning as pl
import torch
import numpy as np
import datasets
from transformers import MaxLengthCriteria, StoppingCriteriaList
from transformers.optimization import AdamW
import itertools
# from utils import count_stats, f1_metric, pairwise_meteor
from torchmetrics.text.rouge import ROUGEScore
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score
class SequenceClassificationModule(pl.LightningModule):
# Instantiate the model
def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
self.train_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average="macro")
self.val_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
self.test_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
# self.train_acc = torchmetrics.Accuracy()
# self.val_acc = torchmetrics.Accuracy()
# self.test_acc = torchmetrics.Accuracy()
# self.train_f1 = F1Score(num_classes=4, average="macro")
# self.val_f1 = F1Score(num_classes=4, average=None)
# self.test_f1 = F1Score(num_classes=4, average=None)
self.use_question_stance_approach = use_question_stance_approach
# Do a forward pass through the model
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask, labels=y)
logits = outputs.logits
loss = outputs.loss
#cross_entropy = torch.nn.CrossEntropyLoss()
#loss = cross_entropy(logits, y)
preds = torch.argmax(logits, axis=1)
self.log("train_loss", loss)
return {'loss': loss}
def validation_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask, labels=y)
logits = outputs.logits
loss = outputs.loss
preds = torch.argmax(logits, axis=1)
if not self.use_question_stance_approach:
self.val_acc(preds, y)
self.log('val_acc_step', self.val_acc)
self.val_f1(preds, y)
self.log("val_loss", loss)
return {'val_loss':loss, "src": x, "pred": preds, "target": y}
def validation_epoch_end(self, outs):
if self.use_question_stance_approach:
self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
self.log('val_acc_epoch', self.val_acc)
f1 = self.val_f1.compute()
self.val_f1.reset()
self.log('val_f1_epoch', torch.mean(f1))
class_names = ["supported", "refuted", "nei", "conflicting"]
for i, c_name in enumerate(class_names):
self.log("val_f1_" + c_name, f1[i])
def test_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask)
logits = outputs.logits
preds = torch.argmax(logits, axis=1)
if not self.use_question_stance_approach:
self.test_acc(preds, y)
self.log('test_acc_step', self.test_acc)
self.test_f1(preds, y)
return {"src": x, "pred": preds, "target": y}
def test_epoch_end(self, outs):
if self.use_question_stance_approach:
self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
self.log('test_acc_epoch', self.test_acc)
f1 = self.test_f1.compute()
self.test_f1.reset()
self.log('test_f1_epoch', torch.mean(f1))
class_names = ["supported", "refuted", "nei", "conflicting"]
for i, c_name in enumerate(class_names):
self.log("test_f1_" + c_name, f1[i])
def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
gold_labels = {}
question_support = {}
for out in outputs:
srcs = out['src']
preds = out['pred']
tgts = out['target']
tokens = self.tokenizer.batch_decode(
srcs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
for src, pred, tgt in zip(tokens, preds, tgts):
claim_id = src.split("[ question ]")[0]
if claim_id not in gold_labels:
gold_labels[claim_id] = tgt
question_support[claim_id] = []
question_support[claim_id].append(pred)
for k,gold_label in gold_labels.items():
support = question_support[k]
has_unansw = False
has_true = False
has_false = False
for v in support:
if v == 0:
has_true = True
if v == 1:
has_false = True
if v == 2 or v == 3: # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
has_unansw = True
if has_unansw:
answer = 2
elif has_true and not has_false:
answer = 0
elif has_false and not has_true:
answer = 1
elif has_true and has_false:
answer = 3
# TODO this is very hacky and wont work if the device is literally anything other than cuda:0
acc_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))
f1_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))