File size: 5,445 Bytes
d93c847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b4e45ad
d93c847
 
 
b4e45ad
d93c847
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0491592
d93c847
 
 
 
 
 
 
 
 
 
 
 
 
 
3021ad5
d93c847
 
 
 
 
 
 
 
3021ad5
d93c847
 
 
 
 
 
 
 
 
 
 
b4e45ad
 
 
 
d93c847
 
 
 
 
 
b4e45ad
 
 
 
 
d93c847
b4e45ad
d93c847
 
 
 
0ca782e
3021ad5
d93c847
0ca782e
d93c847
 
 
0ca782e
d93c847
 
 
 
 
 
3021ad5
 
d93c847
 
 
 
 
 
b4e45ad
3021ad5
0ca782e
d93c847
b4e45ad
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
# %%
import lightning as L
from torch import optim, nn

import torch
from transformers import BertModel
from argparse import ArgumentParser
from data import Squad_v2
from lightning.pytorch.callbacks import ModelCheckpoint, EarlyStopping
from lightning.pytorch import Trainer
from lightning.pytorch.loggers import TensorBoardLogger


# %%
class BertSquard_v2(L.LightningModule):

    def __init__(self, model_name_or_path="google-bert/bert-base-uncased"):
        super().__init__()
        self.save_hyperparameters()
        # BERT backbone
        self.bert_model = BertModel.from_pretrained(model_name_or_path)
        self.bert_model.train()
        # self.tokenizer = BertTokenizerFast.from_pretrained("google-bert/bert-base-uncased")

        hidden_size = self.bert_model.config.hidden_size

        # QA Heads
        self.qa_start = nn.Linear(hidden_size, 1)
        self.qa_end = nn.Linear(hidden_size, 1)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-5)  # 降低学习率
        scheduler = optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, mode="min", factor=0.1, patience=2
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {"scheduler": scheduler, "monitor": "train_loss"},
        }

    def forward(self, input_ids, attention_mask):
        # input_ids shape : (batch_size, seq_len)
        # attention_mask shape : (batch_size, seq_len)
        outputs = self.bert_model(input_ids, attention_mask, output_hidden_states=True)

        hidden_states = outputs.hidden_states[-1]
        # hidden_states shape : (batch_size, seq_len, hidden_size)

        qa_start_score = self.qa_start(hidden_states).squeeze(-1) * attention_mask
        # qa_start_score shape : (batch_size, seq_len)

        qa_end_score = self.qa_end(hidden_states).squeeze(-1) * attention_mask
        # qa_end_score shape : (batch_size, seq_len)

        return qa_start_score, qa_end_score

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        start_positions = batch["start_positions"]
        end_positions = batch["end_positions"]

        qa_start_score, qa_end_score = self.forward(input_ids, attention_mask)

        loss_fn = nn.CrossEntropyLoss()
        loss_start = loss_fn(qa_start_score, start_positions)
        loss_end = loss_fn(qa_end_score, end_positions)

        loss = loss_start + loss_end
        self.log("train_loss", loss, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        self.bert_model.eval()
        with torch.no_grad():
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            start_positions = batch["start_positions"]
            end_positions = batch["end_positions"]

            qa_start_score, qa_end_score = self.forward(input_ids, attention_mask)

            loss_fn = nn.CrossEntropyLoss()
            loss_start = loss_fn(qa_start_score, start_positions)
            loss_end = loss_fn(qa_end_score, end_positions)

            loss = loss_start + loss_end
            self.log("val_loss", loss, prog_bar=True)
            return loss


# # %%

# data = Squad_v2(loaddata_from_hf="eming/squad_v2_processed")
# data.setup("fit")
# # %%
# # val the model
# sample = next(iter(data.train_dataloader()))
# input_ids = sample["input_ids"]
# attention_mask = sample["attention_mask"]
# qa_start_socre, qa_end_socre = BertSquard_v2().forward(input_ids, attention_mask)
# print(qa_start_socre.shape, qa_end_socre.shape)

# %%
if __name__ == "__main__":

    parser = ArgumentParser()

    parser.add_argument("--data_model_name_or_path", type=str, default="rajpurkar/squad_v2")
    parser.add_argument("--model_name_or_path", type=str, default="google-bert/bert-base-uncased")
    parser.add_argument("--data_from_hf", type=str, default=None)

    parser.add_argument("--ckpt_path", type=str, default="./checkpoints/last.ckpt")
    parser.add_argument("--checkpoint_dir", type=str, default="./checkpoints")
    parser.add_argument("--max_epochs", type=int, default=3)

    args = parser.parse_args()

    data = Squad_v2(
        data_from_hf=args.data_from_hf,
        tokenizer_model_name_or_path=args.model_name_or_path,
        data_model_name_or_path=args.data_model_name_or_path,
    )
    data.setup("fit")
    bert = BertSquard_v2(model_name_or_path=args.model_name_or_path)

    checkpoint_callback = ModelCheckpoint(
        dirpath=args.checkpoint_dir,
        save_top_k=2,
        monitor="val_loss",
        filename="squad-v2-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}",
        save_last=True,
        every_n_train_steps=1000,
    )

    early_stop_callback = EarlyStopping(
        monitor="train_loss",
        patience=3,
        check_finite=True,
        verbose=True,
        mode="min",
    )

    torch.set_float32_matmul_precision("medium")

    trainer = Trainer(
        callbacks=[checkpoint_callback, early_stop_callback],
        max_epochs=args.max_epochs,
        gradient_clip_val=1.0,
        accelerator="auto",
        devices=1,
        logger=TensorBoardLogger("lightning_logs", name="squad-v2-bert"),
        enable_model_summary=True,
        val_check_interval=0.25,
    )
    if args.ckpt_path:
        trainer.fit(bert, data, ckpt_path=args.ckpt_path)
    else:
        trainer.fit(bert, data)