Spaces:
Sleeping
Sleeping
File size: 6,508 Bytes
afdeeca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 |
import pytorch_lightning as pl
import torch
import numpy as np
import datasets
from transformers import MaxLengthCriteria, StoppingCriteriaList
from transformers.optimization import AdamW
import itertools
from averitec.models.utils import count_stats, f1_metric, pairwise_meteor
from torchmetrics.text.rouge import ROUGEScore
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score
def freeze_params(model):
for layer in model.parameters():
layer.requires_grade = False
class JustificationGenerationModule(pl.LightningModule):
def __init__(self, tokenizer, model, learning_rate=1e-3, gen_num_beams=2, gen_max_length=100, should_pad_gen=True):
super().__init__()
self.tokenizer = tokenizer
self.model = model
self.learning_rate = learning_rate
self.gen_num_beams = gen_num_beams
self.gen_max_length = gen_max_length
self.should_pad_gen = should_pad_gen
#self.metrics = datasets.load_metric('meteor')
freeze_params(self.model.get_encoder())
self.freeze_embeds()
def freeze_embeds(self):
''' freeze the positional embedding parameters of the model; adapted from finetune.py '''
freeze_params(self.model.model.shared)
for d in [self.model.model.encoder, self.model.model.decoder]:
freeze_params(d.embed_positions)
freeze_params(d.embed_tokens)
# Do a forward pass through the model
def forward(self, input_ids, **kwargs):
return self.model(input_ids, **kwargs)
def configure_optimizers(self):
optimizer = AdamW(self.parameters(), lr = self.learning_rate)
return optimizer
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
https://github.com/huggingface/transformers/blob/main/src/transformers/models/bart/modeling_bart.py.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def run_model(self, batch):
src_ids, src_mask, tgt_ids = batch[0], batch[1], batch[2]
decoder_input_ids = self.shift_tokens_right(
tgt_ids, self.tokenizer.pad_token_id, self.tokenizer.pad_token_id # BART uses the EOS token to start generation as well. Might have to change for other models.
)
outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False)
return outputs
def compute_loss(self, batch):
tgt_ids = batch[2]
logits = self.run_model(batch)[0]
cross_entropy = torch.nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_token_id)
loss = cross_entropy(logits.view(-1, logits.shape[-1]), tgt_ids.view(-1))
return loss
def training_step(self, batch, batch_idx):
loss = self.compute_loss(batch)
self.log("train_loss", loss, on_epoch=True)
return {'loss':loss}
def validation_step(self, batch, batch_idx):
preds, loss, tgts = self.generate_and_compute_loss_and_tgts(batch)
if self.should_pad_gen:
preds = F.pad(preds, pad=(0, self.gen_max_length - preds.shape[1]), value=self.tokenizer.pad_token_id)
self.log('val_loss', loss, prog_bar=True, sync_dist=True)
return {'loss': loss, 'pred': preds, 'target': tgts}
def test_step(self, batch, batch_idx):
test_preds, test_loss, test_tgts = self.generate_and_compute_loss_and_tgts(batch)
if self.should_pad_gen:
test_preds = F.pad(test_preds, pad=(0, self.gen_max_length - test_preds.shape[1]), value=self.tokenizer.pad_token_id)
self.log('test_loss', test_loss, prog_bar=True, sync_dist=True)
return {'loss': test_loss, 'pred': test_preds, 'target': test_tgts}
def test_epoch_end(self, outputs):
self.handle_end_of_epoch_scoring(outputs, "test")
def validation_epoch_end(self, outputs):
self.handle_end_of_epoch_scoring(outputs, "val")
def handle_end_of_epoch_scoring(self, outputs, prefix):
gen = {}
tgt = {}
rouge = ROUGEScore()
rouge_metric = lambda x, y: rouge(x,y)["rougeL_precision"]
for out in outputs:
preds = out['pred']
tgts = out['target']
preds = self.do_batch_detokenize(preds)
tgts = self.do_batch_detokenize(tgts)
for pred, t in zip(preds, tgts):
rouge_d = rouge_metric(pred, t)
self.log(prefix+"_rouge", rouge_d)
meteor_d = pairwise_meteor(pred, t)
self.log(prefix+"_meteor", meteor_d)
def generate_and_compute_loss_and_tgts(self, batch):
src_ids = batch[0]
loss = self.compute_loss(batch)
pred_ids, _ = self.generate_for_batch(src_ids)
tgt_ids = batch[2]
return pred_ids, loss, tgt_ids
def do_batch_detokenize(self, batch):
tokens = self.tokenizer.batch_decode(
batch,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
# Huggingface skipping of special tokens doesn't work for all models, so we do it manually as well for safety:
tokens = [p.replace("<pad>", "") for p in tokens]
tokens = [p.replace("<s>", "") for p in tokens]
tokens = [p.replace("</s>", "") for p in tokens]
return [t for t in tokens if t != ""]
def generate_for_batch(self, batch):
generated_ids = self.model.generate(
batch,
decoder_start_token_id = self.tokenizer.pad_token_id,
num_beams = self.gen_num_beams,
max_length = self.gen_max_length
)
generated_tokens = self.tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
return generated_ids, generated_tokens
def generate(self, text, max_input_length=512, device=None):
encoded_dict = self.tokenizer(
[text],
max_length=max_input_length,
padding="longest",
truncation=True,
return_tensors="pt",
add_prefix_space = True
)
input_ids = encoded_dict['input_ids']
if device is not None:
input_ids = input_ids.to(device)
with torch.no_grad():
_, generated_tokens = self.generate_for_batch(input_ids)
return generated_tokens[0] |