fact-checking-api / averitec /models /JustificationGenerationModule.py
zhenyundeng
add files
afdeeca
raw
history blame
6.51 kB
import pytorch_lightning as pl
import torch
import numpy as np
import datasets
from transformers import MaxLengthCriteria, StoppingCriteriaList
from transformers.optimization import AdamW
import itertools
from averitec.models.utils import count_stats, f1_metric, pairwise_meteor
from torchmetrics.text.rouge import ROUGEScore
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score
def freeze_params(model):
for layer in model.parameters():
layer.requires_grade = False
class JustificationGenerationModule(pl.LightningModule):
def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.gen_num_beams = gen_num_beams
self.gen_max_length = gen_max_length
self.should_pad_gen = should_pad_gen
#self.metrics = datasets.load_metric('meteor')
freeze_params(self.model.get_encoder())
self.freeze_embeds()
def freeze_embeds(self):
''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
# Do a forward pass through the model
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
return optimizer
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def run_model(self, batch):
src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]
decoder_input_ids = self.shift_tokens_right(
tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models.
)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
return outputs
def compute_loss(self, batch):
tgt_ids = batch[2]
logits = self.run_model(batch)[0]
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1))
return loss
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss, on_epoch=True)
return {'loss':loss}
def validation_step(self, batch, batch_idx):
preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch)
if self.should_pad_gen:
preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
return {'loss': loss, 'pred': preds, 'target': tgts}
def test_step(self, batch, batch_idx):
test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch)
if self.should_pad_gen:
test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id)
self.log('test_loss', test_loss, prog_bar=True, sync_dist=True)
return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts}
def test_epoch_end(self, outputs):
self.handle_end_of_epoch_scoring(outputs, "test")
def validation_epoch_end(self, outputs):
self.handle_end_of_epoch_scoring(outputs, "val")
def handle_end_of_epoch_scoring(self, outputs, prefix):
gen = {}
tgt = {}
rouge = ROUGEScore()
rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"]
for out in outputs:
preds = out['pred']
tgts = out['target']
preds = self.do_batch_detokenize(preds)
tgts = self.do_batch_detokenize(tgts)
for pred, t in zip(preds, tgts):
rouge_d = rouge_metric(pred, t)
self.log(prefix+"_rouge", rouge_d)
meteor_d = pairwise_meteor(pred, t)
self.log(prefix+"_meteor", meteor_d)
def generate_and_compute_loss_and_tgts(self, batch):
src_ids = batch[0]
loss = self.compute_loss(batch)
pred_ids, _ = self.generate_for_batch(src_ids)
tgt_ids = batch[2]
return pred_ids, loss, tgt_ids
def do_batch_detokenize(self, batch):
tokens = self.tokenizer.batch_decode(
batch,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety:
tokens = [p.replace("<pad>", "") for p in tokens]
tokens = [p.replace("<s>", "") for p in tokens]
tokens = [p.replace("</s>", "") for p in tokens]
return [t for t in tokens if t != ""]
def generate_for_batch(self, batch):
generated_ids = self.model.generate(
batch,
decoder_start_token_id = self.tokenizer.pad_token_id,
num_beams = self.gen_num_beams,
max_length = self.gen_max_length
)
generated_tokens = self.tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return generated_ids, generated_tokens
def generate(self, text, max_input_length=512, device=None):
encoded_dict = self.tokenizer(
[text],
max_length=max_input_length,
padding="longest",
truncation=True,
return_tensors="pt",
add_prefix_space = True
)
input_ids = encoded_dict['input_ids']
if device is not None:
input_ids = input_ids.to(device)
with torch.no_grad():
_, generated_tokens = self.generate_for_batch(input_ids)
return generated_tokens[0]