File size: 3,939 Bytes
2f044c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from pathlib import Path
from pprint import pprint
import hydra
import lightning
from hydra.utils import to_absolute_path
from lightning import Trainer
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from omegaconf import DictConfig, OmegaConf, open_dict
import torch
from torch.utils.data import DataLoader

from relik.reader.data.relik_reader_data import RelikDataset
from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
from relik.reader.pytorch_modules.optim import LayerWiseLRDecayOptimizer
from relik.reader.utils.special_symbols import get_special_symbols
from relik.reader.utils.strong_matching_eval import ELStrongMatchingCallback
from relik.reader.utils.shuffle_train_callback import ShuffleTrainCallback

@hydra.main(config_path="../conf", config_name="config")
def train(cfg: DictConfig) -> None:
    
    lightning.seed_everything(cfg.training.seed)
    # check if deterministic algorithms are available
    # torch.use_deterministic_algorithms(True, warn_only=True)

    # log the configuration
    pprint(OmegaConf.to_container(cfg, resolve=True))

    special_symbols = get_special_symbols(cfg.model.entities_per_forward)
    
    # model declaration
    model = RelikReaderPLModule(
        cfg=OmegaConf.to_container(cfg),
        transformer_model=cfg.model.model.transformer_model,
        additional_special_symbols=len(special_symbols),
        training=True,
    )

    # optimizer declaration
    opt_conf = cfg.model.optimizer
    electra_optimizer_factory = LayerWiseLRDecayOptimizer(
        lr=opt_conf.lr,
        warmup_steps=opt_conf.warmup_steps,
        total_steps=opt_conf.total_steps,
        total_reset=opt_conf.total_reset,
        no_decay_params=opt_conf.no_decay_params,
        weight_decay=opt_conf.weight_decay,
        lr_decay=opt_conf.lr_decay,
    )

    model.set_optimizer_factory(electra_optimizer_factory)

    # datasets declaration
    train_dataset: RelikDataset = hydra.utils.instantiate(
        cfg.data.train_dataset,
        dataset_path=to_absolute_path(cfg.data.train_dataset_path),
        special_symbols=special_symbols,
    )

    # update of validation dataset config with special_symbols since they
    #  are required even from the EvaluationCallback dataset_config
    with open_dict(cfg):
        cfg.data.val_dataset.special_symbols = special_symbols

    val_dataset: RelikDataset = hydra.utils.instantiate(
        cfg.data.val_dataset,
        dataset_path=to_absolute_path(cfg.data.val_dataset_path),
    )

    # callbacks declaration
    callbacks = [
        ELStrongMatchingCallback(
            to_absolute_path(cfg.data.val_dataset_path), cfg.data.val_dataset
        ),
        ModelCheckpoint(
            "model",
            filename="{epoch}-{val_core_f1:.2f}",
            monitor="val_core_f1",
            mode="max",
        ),
        LearningRateMonitor(),
        ShuffleTrainCallback(),
    ]

    wandb_logger = WandbLogger(
        cfg.model_name, project=cfg.project_name, offline=cfg.offline
    )

    # trainer declaration
    trainer: Trainer = hydra.utils.instantiate(
        cfg.training.trainer,
        callbacks=callbacks,
        logger=wandb_logger,
    )

    # Trainer fit
    trainer.fit(
        model=model,
        train_dataloaders=DataLoader(train_dataset, batch_size=None, num_workers=1),
        val_dataloaders=DataLoader(val_dataset, batch_size=None, num_workers=0),
    )

    # if cfg.training.save_model_path:
    experiment_path = Path(wandb_logger.experiment.dir)
    model = RelikReaderPLModule.load_from_checkpoint(
        trainer.checkpoint_callback.best_model_path
    )
    model.relik_reader_core_model._tokenizer = train_dataset.tokenizer
    model.relik_reader_core_model.save_pretrained(experiment_path / "hf_model")


def main():
    train()


if __name__ == "__main__":
    main()