|
|
|
import lightning as L |
|
from torch import optim, nn |
|
|
|
import torch |
|
from transformers import BertModel |
|
from argparse import ArgumentParser |
|
from data import Squad_v2 |
|
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping |
|
from lightning.pytorch import Trainer |
|
from lightning.pytorch.loggers import TensorBoardLogger |
|
|
|
|
|
|
|
class BertSquard_v2(L.LightningModule): |
|
|
|
def __init__(self, model_name_or_path="google-bert/bert-base-uncased"): |
|
super().__init__() |
|
self.save_hyperparameters() |
|
|
|
self.bert_model = BertModel.from_pretrained(model_name_or_path) |
|
self.bert_model.train() |
|
|
|
|
|
hidden_size = self.bert_model.config.hidden_size |
|
|
|
|
|
self.qa_start = nn.Linear(hidden_size, 1) |
|
self.qa_end = nn.Linear(hidden_size, 1) |
|
|
|
def configure_optimizers(self): |
|
optimizer = optim.Adam(self.parameters(), lr=1e-5) |
|
scheduler = optim.lr_scheduler.ReduceLROnPlateau( |
|
optimizer, mode="min", factor=0.1, patience=2 |
|
) |
|
return { |
|
"optimizer": optimizer, |
|
"lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"}, |
|
} |
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
outputs = self.bert_model(input_ids, attention_mask, output_hidden_states=True) |
|
|
|
hidden_states = outputs.hidden_states[-1] |
|
|
|
|
|
qa_start_score = self.qa_start(hidden_states).squeeze(-1) * attention_mask |
|
|
|
|
|
qa_end_score = self.qa_end(hidden_states).squeeze(-1) * attention_mask |
|
|
|
|
|
return qa_start_score, qa_end_score |
|
|
|
def training_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
start_positions = batch["start_positions"] |
|
end_positions = batch["end_positions"] |
|
|
|
qa_start_score, qa_end_score = self.forward(input_ids, attention_mask) |
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
loss_start = loss_fn(qa_start_score, start_positions) |
|
loss_end = loss_fn(qa_end_score, end_positions) |
|
|
|
loss = loss_start + loss_end |
|
self.log("train_loss", loss, prog_bar=True) |
|
return loss |
|
|
|
def validation_step(self, batch, batch_idx): |
|
self.bert_model.eval() |
|
with torch.no_grad(): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
start_positions = batch["start_positions"] |
|
end_positions = batch["end_positions"] |
|
|
|
qa_start_score, qa_end_score = self.forward(input_ids, attention_mask) |
|
|
|
loss_fn = nn.CrossEntropyLoss() |
|
loss_start = loss_fn(qa_start_score, start_positions) |
|
loss_end = loss_fn(qa_end_score, end_positions) |
|
|
|
loss = loss_start + loss_end |
|
self.log("val_loss", loss, prog_bar=True) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = ArgumentParser() |
|
|
|
parser.add_argument("--data_model_name_or_path", type=str, default="rajpurkar/squad_v2") |
|
parser.add_argument("--model_name_or_path", type=str, default="google-bert/bert-base-uncased") |
|
parser.add_argument("--data_from_hf", type=str, default=None) |
|
|
|
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/last.ckpt") |
|
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints") |
|
parser.add_argument("--max_epochs", type=int, default=3) |
|
|
|
args = parser.parse_args() |
|
|
|
data = Squad_v2( |
|
data_from_hf=args.data_from_hf, |
|
tokenizer_model_name_or_path=args.model_name_or_path, |
|
data_model_name_or_path=args.data_model_name_or_path, |
|
) |
|
data.setup("fit") |
|
bert = BertSquard_v2(model_name_or_path=args.model_name_or_path) |
|
|
|
checkpoint_callback = ModelCheckpoint( |
|
dirpath=args.checkpoint_dir, |
|
save_top_k=2, |
|
monitor="val_loss", |
|
filename="squad-v2-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}", |
|
save_last=True, |
|
every_n_train_steps=1000, |
|
) |
|
|
|
early_stop_callback = EarlyStopping( |
|
monitor="train_loss", |
|
patience=3, |
|
check_finite=True, |
|
verbose=True, |
|
mode="min", |
|
) |
|
|
|
torch.set_float32_matmul_precision("medium") |
|
|
|
trainer = Trainer( |
|
callbacks=[checkpoint_callback, early_stop_callback], |
|
max_epochs=args.max_epochs, |
|
gradient_clip_val=1.0, |
|
accelerator="auto", |
|
devices=1, |
|
logger=TensorBoardLogger("lightning_logs", name="squad-v2-bert"), |
|
enable_model_summary=True, |
|
val_check_interval=0.25, |
|
) |
|
if args.ckpt_path: |
|
trainer.fit(bert, data, ckpt_path=args.ckpt_path) |
|
else: |
|
trainer.fit(bert, data) |
|
|