Spaces:
Build error
Build error
File size: 4,384 Bytes
6655655 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import pytorch_lightning as pl
import torch
import numpy as np
import datasets
from transformers import MaxLengthCriteria, StoppingCriteriaList
from transformers.optimization import AdamW
import itertools
from utils import count_stats, f1_metric, pairwise_meteor
from torchmetrics.text.rouge import ROUGEScore
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score
class NaiveSeqClassModule(pl.LightningModule):
# Instantiate the model
def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.train_acc = torchmetrics.Accuracy()
self.val_acc = torchmetrics.Accuracy()
self.test_acc = torchmetrics.Accuracy()
self.train_f1 = F1Score(num_classes=4, average="macro")
self.val_f1 = F1Score(num_classes=4, average=None)
self.test_f1 = F1Score(num_classes=4, average=None)
self.use_question_stance_approach = use_question_stance_approach
# Do a forward pass through the model
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
return optimizer
def training_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask, labels=y)
logits = outputs.logits
loss = outputs.loss
#cross_entropy = torch.nn.CrossEntropyLoss()
#loss = cross_entropy(logits, y)
preds = torch.argmax(logits, axis=1)
self.train_acc(preds.cpu(), y.cpu())
self.train_f1(preds.cpu(), y.cpu())
self.log("train_loss", loss)
return {'loss': loss}
def training_epoch_end(self, outs):
self.log('train_acc_epoch', self.train_acc)
self.log('train_f1_epoch', self.train_f1)
def validation_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask, labels=y)
logits = outputs.logits
loss = outputs.loss
preds = torch.argmax(logits, axis=1)
if not self.use_question_stance_approach:
self.val_acc(preds, y)
self.log('val_acc_step', self.val_acc)
self.val_f1(preds, y)
self.log("val_loss", loss)
return {'val_loss':loss, "src": x, "pred": preds, "target": y}
def validation_epoch_end(self, outs):
if self.use_question_stance_approach:
self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
self.log('val_acc_epoch', self.val_acc)
f1 = self.val_f1.compute()
self.val_f1.reset()
self.log('val_f1_epoch', torch.mean(f1))
class_names = ["supported", "refuted", "nei", "conflicting"]
for i, c_name in enumerate(class_names):
self.log("val_f1_" + c_name, f1[i])
def test_step(self, batch, batch_idx):
x, x_mask, y = batch
outputs = self(x, attention_mask=x_mask)
logits = outputs.logits
preds = torch.argmax(logits, axis=1)
if not self.use_question_stance_approach:
self.test_acc(preds, y)
self.log('test_acc_step', self.test_acc)
self.test_f1(preds, y)
return {"src": x, "pred": preds, "target": y}
def test_epoch_end(self, outs):
if self.use_question_stance_approach:
self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
self.log('test_acc_epoch', self.test_acc)
f1 = self.test_f1.compute()
self.test_f1.reset()
self.log('test_f1_epoch', torch.mean(f1))
class_names = ["supported", "refuted", "nei", "conflicting"]
for i, c_name in enumerate(class_names):
self.log("test_f1_" + c_name, f1[i])
def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
gold_labels = {}
question_support = {}
for out in outputs:
srcs = out['src']
preds = out['pred']
tgts = out['target']
tokens = self.tokenizer.batch_decode(
srcs,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
for src, pred, tgt in zip(tokens, preds, tgts):
acc_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
f1_scorer(torch.as_tensor([pred]).to("cuda:0"), torch.as_tensor([tgt]).to("cuda:0"))
|