File size: 2,509 Bytes
699342a
 
 
 
 
 
 
 
 
8afb176
699342a
 
 
f036ad4
699342a
 
 
 
f036ad4
 
699342a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import datetime
from argparse import ArgumentParser

import torch

from lightning import Trainer
from lightning.pytorch.loggers import TensorBoardLogger
from lightning.pytorch.callbacks import ModelSummary

from src.trainer import ViTLightningModule


def main():
    """ Neural network trainer entry point. """

    parser = ArgumentParser(description='KAUST-SDAIA Diabetic Retinopathy')
    parser.add_argument('--tag', action='store', type=str,
                        help='Extra suffix to put on the artefact dir name')
    parser.add_argument('--debug', action='store_true',
                        help="Dummy training cycle for testing purposes")
    parser.add_argument('--convert-checkpoint', action='store', type=str,
                        help='Convert a checkpoint from training to pickle-independent '
                             'predictor-compatible directory')

    args = parser.parse_args()

    torch.set_float32_matmul_precision('high') # for V100/A100

    if args.convert_checkpoint is not None:

        print("Converting checkpoint", args.convert_checkpoint)

        checkpoint = torch.load(args.convert_checkpoint, map_location="cpu")
        print(list(checkpoint.keys()))

        model = ViTLightningModule.load_from_checkpoint(
            args.convert_checkpoint,
            map_location="cpu",
            hparams_file="tmp_ckpt_deleteme.yaml")

        model.save_checkpoint_dk("tmp_checkp_path_deleteme")

        print("Saved checkpoint. Done.")

    else:

        print("Start training")

        fast_dev_run = True if args.debug == True else False

        model = ViTLightningModule(fast_dev_run)

        datetime_str = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
        art_dir_name = (f"{datetime_str}" +
                        (f"_{args.tag}" if args.tag is not None else ""))
        logger = TensorBoardLogger(save_dir=".", name="lightning_logs", version=art_dir_name)

        trainer = Trainer(
            logger=logger,
            benchmark=True,
            devices="auto",
            accelerator="auto",
            max_epochs=-1,
            callbacks=[
                ModelSummary(max_depth=-1),
                ],
            fast_dev_run=fast_dev_run,
            log_every_n_steps=10,
            )

        trainer.fit(
            model,
            train_dataloaders=model._train_dataloader,
            val_dataloaders=model._val_dataloader,
            )

        print("Training done")


if __name__ == "__main__":
    main()