File size: 7,061 Bytes
231edce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
import numpy as np
import torch

from . import models 


def get_name_and_params(base):
    name = getattr(base, 'name')
    params = getattr(base, 'params') or {}
    return name, params


def get_transform(base, transform, mode=None):
    if not base: return None
    transform = getattr(base, transform)
    if not transform: return None
    name, params = get_name_and_params(transform)
    if mode:
        params.update({'mode': mode})
    return getattr(data.transforms, name)(**params)


def build_transforms(cfg, mode):
    # 1-Resize
    resizer = get_transform(cfg.transform, 'resize')
    # 2-(Optional) Data augmentation
    augmenter = None
    if mode == "train":
        augmenter = get_transform(cfg.transform, 'augment')
    # 3-(Optional) Crop
    cropper = get_transform(cfg.transform, 'crop', mode=mode)
    # 4-Preprocess
    preprocessor = get_transform(cfg.transform, 'preprocess')
    return {
        'resize': resizer,
        'augment': augmenter,
        'crop': cropper,
        'preprocess': preprocessor
    }


def build_dataset(cfg, data_info, mode):
    dataset_class = getattr(data.datasets, cfg.data.dataset.name)
    dataset_params = cfg.data.dataset.params
    dataset_params.test_mode = mode != 'train'
    dataset_params = dict(dataset_params)
    if "FeatureDataset" not in cfg.data.dataset.name:
        transforms = build_transforms(cfg, mode)
        dataset_params.update(transforms)
    dataset_params.update(data_info)
    return dataset_class(**dataset_params)


def build_dataloader(cfg, dataset, mode):

    def worker_init_fn(worker_id):                                                          
        np.random.seed(np.random.get_state()[1][0] + worker_id)

    dataloader_params = {}
    dataloader_params['num_workers'] = cfg.data.num_workers
    dataloader_params['drop_last'] = mode == 'train'
    dataloader_params['shuffle'] = mode == 'train'
    dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True)
    if mode in ('train', 'valid'):
        if mode == "train":
            dataloader_params['batch_size'] = cfg.train.batch_size
        elif mode == "valid":
            dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
        sampler = None
        if cfg.data.get("sampler") and mode == 'train':
            name, params = get_name_and_params(cfg.data.sampler)
            sampler = getattr(data.samplers, name)(dataset, **params)
        if sampler:
            dataloader_params['shuffle'] = False
            if cfg.strategy == 'ddp':
                sampler = data.samplers.DistributedSamplerWrapper(sampler)
            dataloader_params['sampler'] = sampler
            print(f'Using sampler {sampler} for training ...')
        elif cfg.strategy == 'ddp':
            dataloader_params["shuffle"] = False
            dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train")
    else:
        assert cfg.strategy != "ddp", "DDP currently not supported for inference"
        dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size

    loader = DataLoader(dataset,
        **dataloader_params,
        worker_init_fn=worker_init_fn)
    return loader


def build_model(cfg):
    name, params = get_name_and_params(cfg.model)
    if cfg.model.params.get("cnn_params", None):
        cnn_params = cfg.model.params.cnn_params
        if cnn_params.get("load_pretrained_backbone", None):
            if "foldx" in cnn_params.load_pretrained_backbone:
                cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\
                    replace("foldx", f"fold{cfg.data.outer_fold}")
    print(f'Creating model <{name}> ...')
    model = getattr(models.engine, name)(**params)
    if 'backbone' in cfg.model.params:
        print(f'  Using backbone <{cfg.model.params.backbone}> ...')
    if 'pretrained' in cfg.model.params:
        print(f'  Pretrained : {cfg.model.params.pretrained}')
    if "load_pretrained" in cfg.model:
        import re
        if "foldx" in cfg.model.load_pretrained:
            cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}")
        print(f"  Loading pretrained checkpoint from {cfg.model.load_pretrained}")
        weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict']
        weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k}
        model.load_state_dict(weights) 
    return model 


def build_loss(cfg):
    name, params = get_name_and_params(cfg.loss)
    print(f'Using loss function <{name}> ...')
    params = dict(params)
    if "pos_weight" in params:
        params["pos_weight"] = torch.tensor(params["pos_weight"])
    criterion = getattr(losses, name)(**params)
    return criterion


def build_scheduler(cfg, optimizer):
    # Some schedulers will require manipulation of config params
    # My specifications were to make it more intuitive for me
    name, params = get_name_and_params(cfg.scheduler)
    print(f'Using learning rate schedule <{name}> ...')

    if name == 'CosineAnnealingLR':
        # eta_min <-> final_lr
        # Set T_max as 100000 ... this is changed in on_train_start() method
        # of the LightningModule task 

        params = {
            'T_max': 100000,
            'eta_min': max(params.final_lr, 1.0e-8)
        }

    if name in ('OneCycleLR', 'CustomOneCycleLR'):
        # Use learning rate from optimizer parameters as initial learning rate
        lr_0 = cfg.optimizer.params.lr
        lr_1 = params.max_lr
        lr_2 = params.final_lr
        # lr_0 -> lr_1 -> lr_2 
        pct_start = params.pct_start
        params = {}
        params['steps_per_epoch'] = 100000 # see above- will fix in task
        params['epochs'] = cfg.train.num_epochs
        params['max_lr'] = lr_1
        params['pct_start'] = pct_start
        params['div_factor'] = lr_1 / lr_0 # max/init
        params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final

    scheduler = getattr(optim, name)(optimizer=optimizer, **params)
    
    # Some schedulers might need more manipulation after instantiation
    if name in ('OneCycleLR', 'CustomOneCycleLR'):
        scheduler.pct_start = params['pct_start']

    # Set update frequency
    if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'):
        scheduler.update_frequency = 'on_batch'
    elif name in ('ReduceLROnPlateau'):
        scheduler.update_frequency = 'on_valid'
    else:
        scheduler.update_frequency = 'on_epoch'

    return scheduler


def build_optimizer(cfg, parameters):
    name, params = get_name_and_params(cfg.optimizer)
    print(f'Using optimizer <{name}> ...')
    optimizer = getattr(optim, name)(parameters, **params)
    return optimizer


def build_task(cfg, model):
    name, params = get_name_and_params(cfg.task)
    print(f'Building task <{name}> ...')
    return getattr(tasks, name)(cfg, model, **params)