BertSquard_v2 / model.py
MingLi
adjust parameter
0ca782e
# %%
import lightning as L
from torch import optim, nn
import torch
from transformers import BertModel
from argparse import ArgumentParser
from data import Squad_v2
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
# %%
class BertSquard_v2(L.LightningModule):
def __init__(self, model_name_or_path="google-bert/bert-base-uncased"):
super().__init__()
self.save_hyperparameters()
# BERT backbone
self.bert_model = BertModel.from_pretrained(model_name_or_path)
self.bert_model.train()
# self.tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")
hidden_size = self.bert_model.config.hidden_size
# QA Heads
self.qa_start = nn.Linear(hidden_size, 1)
self.qa_end = nn.Linear(hidden_size, 1)
def configure_optimizers(self):
optimizer = optim.Adam(self.parameters(), lr=1e-5) # 降低学习率
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", factor=0.1, patience=2
)
return {
"optimizer": optimizer,
"lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"},
}
def forward(self, input_ids, attention_mask):
# input_ids shape : (batch_size, seq_len)
# attention_mask shape : (batch_size, seq_len)
outputs = self.bert_model(input_ids, attention_mask, output_hidden_states=True)
hidden_states = outputs.hidden_states[-1]
# hidden_states shape : (batch_size, seq_len, hidden_size)
qa_start_score = self.qa_start(hidden_states).squeeze(-1) * attention_mask
# qa_start_score shape : (batch_size, seq_len)
qa_end_score = self.qa_end(hidden_states).squeeze(-1) * attention_mask
# qa_end_score shape : (batch_size, seq_len)
return qa_start_score, qa_end_score
def training_step(self, batch, batch_idx):
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
start_positions = batch["start_positions"]
end_positions = batch["end_positions"]
qa_start_score, qa_end_score = self.forward(input_ids, attention_mask)
loss_fn = nn.CrossEntropyLoss()
loss_start = loss_fn(qa_start_score, start_positions)
loss_end = loss_fn(qa_end_score, end_positions)
loss = loss_start + loss_end
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
self.bert_model.eval()
with torch.no_grad():
input_ids = batch["input_ids"]
attention_mask = batch["attention_mask"]
start_positions = batch["start_positions"]
end_positions = batch["end_positions"]
qa_start_score, qa_end_score = self.forward(input_ids, attention_mask)
loss_fn = nn.CrossEntropyLoss()
loss_start = loss_fn(qa_start_score, start_positions)
loss_end = loss_fn(qa_end_score, end_positions)
loss = loss_start + loss_end
self.log("val_loss", loss, prog_bar=True)
return loss
# # %%
# data = Squad_v2(loaddata_from_hf="eming/squad_v2_processed")
# data.setup("fit")
# # %%
# # val the model
# sample = next(iter(data.train_dataloader()))
# input_ids = sample["input_ids"]
# attention_mask = sample["attention_mask"]
# qa_start_socre, qa_end_socre = BertSquard_v2().forward(input_ids, attention_mask)
# print(qa_start_socre.shape, qa_end_socre.shape)
# %%
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--data_model_name_or_path", type=str, default="rajpurkar/squad_v2")
parser.add_argument("--model_name_or_path", type=str, default="google-bert/bert-base-uncased")
parser.add_argument("--data_from_hf", type=str, default=None)
parser.add_argument("--ckpt_path", type=str, default="./checkpoints/last.ckpt")
parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints")
parser.add_argument("--max_epochs", type=int, default=3)
args = parser.parse_args()
data = Squad_v2(
data_from_hf=args.data_from_hf,
tokenizer_model_name_or_path=args.model_name_or_path,
data_model_name_or_path=args.data_model_name_or_path,
)
data.setup("fit")
bert = BertSquard_v2(model_name_or_path=args.model_name_or_path)
checkpoint_callback = ModelCheckpoint(
dirpath=args.checkpoint_dir,
save_top_k=2,
monitor="val_loss",
filename="squad-v2-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}",
save_last=True,
every_n_train_steps=1000,
)
early_stop_callback = EarlyStopping(
monitor="train_loss",
patience=3,
check_finite=True,
verbose=True,
mode="min",
)
torch.set_float32_matmul_precision("medium")
trainer = Trainer(
callbacks=[checkpoint_callback, early_stop_callback],
max_epochs=args.max_epochs,
gradient_clip_val=1.0,
accelerator="auto",
devices=1,
logger=TensorBoardLogger("lightning_logs", name="squad-v2-bert"),
enable_model_summary=True,
val_check_interval=0.25,
)
if args.ckpt_path:
trainer.fit(bert, data, ckpt_path=args.ckpt_path)
else:
trainer.fit(bert, data)