# %% import lightning as L from torch import optim, nn import torch from transformers import BertModel from argparse import ArgumentParser from data import Squad_v2 from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping from lightning.pytorch import Trainer from lightning.pytorch.loggers import TensorBoardLogger # %% class BertSquard_v2(L.LightningModule): def __init__(self, model_name_or_path="google-bert/bert-base-uncased"): super().__init__() self.save_hyperparameters() # BERT backbone self.bert_model = BertModel.from_pretrained(model_name_or_path) self.bert_model.train() # self.tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased") hidden_size = self.bert_model.config.hidden_size # QA Heads self.qa_start = nn.Linear(hidden_size, 1) self.qa_end = nn.Linear(hidden_size, 1) def configure_optimizers(self): optimizer = optim.Adam(self.parameters(), lr=1e-5) # 降低学习率 scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.1, patience=2 ) return { "optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"}, } def forward(self, input_ids, attention_mask): # input_ids shape : (batch_size, seq_len) # attention_mask shape : (batch_size, seq_len) outputs = self.bert_model(input_ids, attention_mask, output_hidden_states=True) hidden_states = outputs.hidden_states[-1] # hidden_states shape : (batch_size, seq_len, hidden_size) qa_start_score = self.qa_start(hidden_states).squeeze(-1) * attention_mask # qa_start_score shape : (batch_size, seq_len) qa_end_score = self.qa_end(hidden_states).squeeze(-1) * attention_mask # qa_end_score shape : (batch_size, seq_len) return qa_start_score, qa_end_score def training_step(self, batch, batch_idx): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] start_positions = batch["start_positions"] end_positions = batch["end_positions"] qa_start_score, qa_end_score = self.forward(input_ids, attention_mask) loss_fn = nn.CrossEntropyLoss() loss_start = loss_fn(qa_start_score, start_positions) loss_end = loss_fn(qa_end_score, end_positions) loss = loss_start + loss_end self.log("train_loss", loss, prog_bar=True) return loss def validation_step(self, batch, batch_idx): self.bert_model.eval() with torch.no_grad(): input_ids = batch["input_ids"] attention_mask = batch["attention_mask"] start_positions = batch["start_positions"] end_positions = batch["end_positions"] qa_start_score, qa_end_score = self.forward(input_ids, attention_mask) loss_fn = nn.CrossEntropyLoss() loss_start = loss_fn(qa_start_score, start_positions) loss_end = loss_fn(qa_end_score, end_positions) loss = loss_start + loss_end self.log("val_loss", loss, prog_bar=True) return loss # # %% # data = Squad_v2(loaddata_from_hf="eming/squad_v2_processed") # data.setup("fit") # # %% # # val the model # sample = next(iter(data.train_dataloader())) # input_ids = sample["input_ids"] # attention_mask = sample["attention_mask"] # qa_start_socre, qa_end_socre = BertSquard_v2().forward(input_ids, attention_mask) # print(qa_start_socre.shape, qa_end_socre.shape) # %% if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--data_model_name_or_path", type=str, default="rajpurkar/squad_v2") parser.add_argument("--model_name_or_path", type=str, default="google-bert/bert-base-uncased") parser.add_argument("--data_from_hf", type=str, default=None) parser.add_argument("--ckpt_path", type=str, default="./checkpoints/last.ckpt") parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints") parser.add_argument("--max_epochs", type=int, default=3) args = parser.parse_args() data = Squad_v2( data_from_hf=args.data_from_hf, tokenizer_model_name_or_path=args.model_name_or_path, data_model_name_or_path=args.data_model_name_or_path, ) data.setup("fit") bert = BertSquard_v2(model_name_or_path=args.model_name_or_path) checkpoint_callback = ModelCheckpoint( dirpath=args.checkpoint_dir, save_top_k=2, monitor="val_loss", filename="squad-v2-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}", save_last=True, every_n_train_steps=1000, ) early_stop_callback = EarlyStopping( monitor="train_loss", patience=3, check_finite=True, verbose=True, mode="min", ) torch.set_float32_matmul_precision("medium") trainer = Trainer( callbacks=[checkpoint_callback, early_stop_callback], max_epochs=args.max_epochs, gradient_clip_val=1.0, accelerator="auto", devices=1, logger=TensorBoardLogger("lightning_logs", name="squad-v2-bert"), enable_model_summary=True, val_check_interval=0.25, ) if args.ckpt_path: trainer.fit(bert, data, ckpt_path=args.ckpt_path) else: trainer.fit(bert, data)