ianpan's picture
Initial commit
231edce
raw
history blame
7.06 kB
import numpy as np
import torch
from . import models
def get_name_and_params(base):
name = getattr(base, 'name')
params = getattr(base, 'params') or {}
return name, params
def get_transform(base, transform, mode=None):
if not base: return None
transform = getattr(base, transform)
if not transform: return None
name, params = get_name_and_params(transform)
if mode:
params.update({'mode': mode})
return getattr(data.transforms, name)(**params)
def build_transforms(cfg, mode):
# 1-Resize
resizer = get_transform(cfg.transform, 'resize')
# 2-(Optional) Data augmentation
augmenter = None
if mode == "train":
augmenter = get_transform(cfg.transform, 'augment')
# 3-(Optional) Crop
cropper = get_transform(cfg.transform, 'crop', mode=mode)
# 4-Preprocess
preprocessor = get_transform(cfg.transform, 'preprocess')
return {
'resize': resizer,
'augment': augmenter,
'crop': cropper,
'preprocess': preprocessor
}
def build_dataset(cfg, data_info, mode):
dataset_class = getattr(data.datasets, cfg.data.dataset.name)
dataset_params = cfg.data.dataset.params
dataset_params.test_mode = mode != 'train'
dataset_params = dict(dataset_params)
if "FeatureDataset" not in cfg.data.dataset.name:
transforms = build_transforms(cfg, mode)
dataset_params.update(transforms)
dataset_params.update(data_info)
return dataset_class(**dataset_params)
def build_dataloader(cfg, dataset, mode):
def worker_init_fn(worker_id):
np.random.seed(np.random.get_state()[1][0] + worker_id)
dataloader_params = {}
dataloader_params['num_workers'] = cfg.data.num_workers
dataloader_params['drop_last'] = mode == 'train'
dataloader_params['shuffle'] = mode == 'train'
dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True)
if mode in ('train', 'valid'):
if mode == "train":
dataloader_params['batch_size'] = cfg.train.batch_size
elif mode == "valid":
dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
sampler = None
if cfg.data.get("sampler") and mode == 'train':
name, params = get_name_and_params(cfg.data.sampler)
sampler = getattr(data.samplers, name)(dataset, **params)
if sampler:
dataloader_params['shuffle'] = False
if cfg.strategy == 'ddp':
sampler = data.samplers.DistributedSamplerWrapper(sampler)
dataloader_params['sampler'] = sampler
print(f'Using sampler {sampler} for training ...')
elif cfg.strategy == 'ddp':
dataloader_params["shuffle"] = False
dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train")
else:
assert cfg.strategy != "ddp", "DDP currently not supported for inference"
dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size
loader = DataLoader(dataset,
**dataloader_params,
worker_init_fn=worker_init_fn)
return loader
def build_model(cfg):
name, params = get_name_and_params(cfg.model)
if cfg.model.params.get("cnn_params", None):
cnn_params = cfg.model.params.cnn_params
if cnn_params.get("load_pretrained_backbone", None):
if "foldx" in cnn_params.load_pretrained_backbone:
cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\
replace("foldx", f"fold{cfg.data.outer_fold}")
print(f'Creating model <{name}> ...')
model = getattr(models.engine, name)(**params)
if 'backbone' in cfg.model.params:
print(f' Using backbone <{cfg.model.params.backbone}> ...')
if 'pretrained' in cfg.model.params:
print(f' Pretrained : {cfg.model.params.pretrained}')
if "load_pretrained" in cfg.model:
import re
if "foldx" in cfg.model.load_pretrained:
cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}")
print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}")
weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict']
weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k}
model.load_state_dict(weights)
return model
def build_loss(cfg):
name, params = get_name_and_params(cfg.loss)
print(f'Using loss function <{name}> ...')
params = dict(params)
if "pos_weight" in params:
params["pos_weight"] = torch.tensor(params["pos_weight"])
criterion = getattr(losses, name)(**params)
return criterion
def build_scheduler(cfg, optimizer):
# Some schedulers will require manipulation of config params
# My specifications were to make it more intuitive for me
name, params = get_name_and_params(cfg.scheduler)
print(f'Using learning rate schedule <{name}> ...')
if name == 'CosineAnnealingLR':
# eta_min <-> final_lr
# Set T_max as 100000 ... this is changed in on_train_start() method
# of the LightningModule task
params = {
'T_max': 100000,
'eta_min': max(params.final_lr, 1.0e-8)
}
if name in ('OneCycleLR', 'CustomOneCycleLR'):
# Use learning rate from optimizer parameters as initial learning rate
lr_0 = cfg.optimizer.params.lr
lr_1 = params.max_lr
lr_2 = params.final_lr
# lr_0 -> lr_1 -> lr_2
pct_start = params.pct_start
params = {}
params['steps_per_epoch'] = 100000 # see above- will fix in task
params['epochs'] = cfg.train.num_epochs
params['max_lr'] = lr_1
params['pct_start'] = pct_start
params['div_factor'] = lr_1 / lr_0 # max/init
params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final
scheduler = getattr(optim, name)(optimizer=optimizer, **params)
# Some schedulers might need more manipulation after instantiation
if name in ('OneCycleLR', 'CustomOneCycleLR'):
scheduler.pct_start = params['pct_start']
# Set update frequency
if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'):
scheduler.update_frequency = 'on_batch'
elif name in ('ReduceLROnPlateau'):
scheduler.update_frequency = 'on_valid'
else:
scheduler.update_frequency = 'on_epoch'
return scheduler
def build_optimizer(cfg, parameters):
name, params = get_name_and_params(cfg.optimizer)
print(f'Using optimizer <{name}> ...')
optimizer = getattr(optim, name)(parameters, **params)
return optimizer
def build_task(cfg, model):
name, params = get_name_and_params(cfg.task)
print(f'Building task <{name}> ...')
return getattr(tasks, name)(cfg, model, **params)