import pytorch_lightning as pl import torch import numpy as np import datasets from transformers import MaxLengthCriteria, StoppingCriteriaList from transformers.optimization import AdamW import itertools from averitec.models.utils import count_stats, f1_metric, pairwise_meteor from torchmetrics.text.rouge import ROUGEScore import torch.nn.functional as F import torchmetrics from torchmetrics.classification import F1Score def freeze_params(model): for layer in model.parameters(): layer.requires_grade = False class JustificationGenerationModule(pl.LightningModule): def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True): super().__init__() self.tokenizer = tokenizer self.model = model self.learning_rate = learning_rate self.gen_num_beams = gen_num_beams self.gen_max_length = gen_max_length self.should_pad_gen = should_pad_gen #self.metrics = datasets.load_metric('meteor') freeze_params(self.model.get_encoder()) self.freeze_embeds() def freeze_embeds(self): ''' freeze the positional embedding parameters of the model; adapted from finetune.py ''' freeze_params(self.model.model.shared) for d in [self.model.model.encoder, self.model.model.decoder]: freeze_params(d.embed_positions) freeze_params(d.embed_tokens) # Do a forward pass through the model def forward(self, input_ids, **kwargs): return self.model(input_ids, **kwargs) def configure_optimizers(self): optimizer = AdamW(self.parameters(), lr = self.learning_rate) return optimizer def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def run_model(self, batch): src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2] decoder_input_ids = self.shift_tokens_right( tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models. ) outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False) return outputs def compute_loss(self, batch): tgt_ids = batch[2] logits = self.run_model(batch)[0] cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id) loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1)) return loss def training_step(self, batch, batch_idx): loss = self.compute_loss(batch) self.log("train_loss", loss, on_epoch=True) return {'loss':loss} def validation_step(self, batch, batch_idx): preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch) if self.should_pad_gen: preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id) self.log('val_loss', loss, prog_bar=True, sync_dist=True) return {'loss': loss, 'pred': preds, 'target': tgts} def test_step(self, batch, batch_idx): test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch) if self.should_pad_gen: test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id) self.log('test_loss', test_loss, prog_bar=True, sync_dist=True) return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts} def test_epoch_end(self, outputs): self.handle_end_of_epoch_scoring(outputs, "test") def validation_epoch_end(self, outputs): self.handle_end_of_epoch_scoring(outputs, "val") def handle_end_of_epoch_scoring(self, outputs, prefix): gen = {} tgt = {} rouge = ROUGEScore() rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"] for out in outputs: preds = out['pred'] tgts = out['target'] preds = self.do_batch_detokenize(preds) tgts = self.do_batch_detokenize(tgts) for pred, t in zip(preds, tgts): rouge_d = rouge_metric(pred, t) self.log(prefix+"_rouge", rouge_d) meteor_d = pairwise_meteor(pred, t) self.log(prefix+"_meteor", meteor_d) def generate_and_compute_loss_and_tgts(self, batch): src_ids = batch[0] loss = self.compute_loss(batch) pred_ids, _ = self.generate_for_batch(src_ids) tgt_ids = batch[2] return pred_ids, loss, tgt_ids def do_batch_detokenize(self, batch): tokens = self.tokenizer.batch_decode( batch, skip_special_tokens=True, clean_up_tokenization_spaces=True ) # Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety: tokens = [p.replace("", "") for p in tokens] tokens = [p.replace("", "") for p in tokens] tokens = [p.replace("", "") for p in tokens] return [t for t in tokens if t != ""] def generate_for_batch(self, batch): generated_ids = self.model.generate( batch, decoder_start_token_id = self.tokenizer.pad_token_id, num_beams = self.gen_num_beams, max_length = self.gen_max_length ) generated_tokens = self.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True ) return generated_ids, generated_tokens def generate(self, text, max_input_length=512, device=None): encoded_dict = self.tokenizer( [text], max_length=max_input_length, padding="longest", truncation=True, return_tensors="pt", add_prefix_space = True ) input_ids = encoded_dict['input_ids'] if device is not None: input_ids = input_ids.to(device) with torch.no_grad(): _, generated_tokens = self.generate_for_batch(input_ids) return generated_tokens[0]