import numpy as np import torch from . import models def get_name_and_params(base): name = getattr(base, 'name') params = getattr(base, 'params') or {} return name, params def get_transform(base, transform, mode=None): if not base: return None transform = getattr(base, transform) if not transform: return None name, params = get_name_and_params(transform) if mode: params.update({'mode': mode}) return getattr(data.transforms, name)(**params) def build_transforms(cfg, mode): # 1-Resize resizer = get_transform(cfg.transform, 'resize') # 2-(Optional) Data augmentation augmenter = None if mode == "train": augmenter = get_transform(cfg.transform, 'augment') # 3-(Optional) Crop cropper = get_transform(cfg.transform, 'crop', mode=mode) # 4-Preprocess preprocessor = get_transform(cfg.transform, 'preprocess') return { 'resize': resizer, 'augment': augmenter, 'crop': cropper, 'preprocess': preprocessor } def build_dataset(cfg, data_info, mode): dataset_class = getattr(data.datasets, cfg.data.dataset.name) dataset_params = cfg.data.dataset.params dataset_params.test_mode = mode != 'train' dataset_params = dict(dataset_params) if "FeatureDataset" not in cfg.data.dataset.name: transforms = build_transforms(cfg, mode) dataset_params.update(transforms) dataset_params.update(data_info) return dataset_class(**dataset_params) def build_dataloader(cfg, dataset, mode): def worker_init_fn(worker_id): np.random.seed(np.random.get_state()[1][0] + worker_id) dataloader_params = {} dataloader_params['num_workers'] = cfg.data.num_workers dataloader_params['drop_last'] = mode == 'train' dataloader_params['shuffle'] = mode == 'train' dataloader_params["pin_memory"] = cfg.data.get("pin_memory", True) if mode in ('train', 'valid'): if mode == "train": dataloader_params['batch_size'] = cfg.train.batch_size elif mode == "valid": dataloader_params["batch_size"] = cfg.evaluate.get("batch_size") or cfg.train.batch_size sampler = None if cfg.data.get("sampler") and mode == 'train': name, params = get_name_and_params(cfg.data.sampler) sampler = getattr(data.samplers, name)(dataset, **params) if sampler: dataloader_params['shuffle'] = False if cfg.strategy == 'ddp': sampler = data.samplers.DistributedSamplerWrapper(sampler) dataloader_params['sampler'] = sampler print(f'Using sampler {sampler} for training ...') elif cfg.strategy == 'ddp': dataloader_params["shuffle"] = False dataloader_params['sampler'] = DistributedSampler(dataset, shuffle=mode=="train") else: assert cfg.strategy != "ddp", "DDP currently not supported for inference" dataloader_params['batch_size'] = cfg.evaluate.get("batch_size") or cfg.train.batch_size loader = DataLoader(dataset, **dataloader_params, worker_init_fn=worker_init_fn) return loader def build_model(cfg): name, params = get_name_and_params(cfg.model) if cfg.model.params.get("cnn_params", None): cnn_params = cfg.model.params.cnn_params if cnn_params.get("load_pretrained_backbone", None): if "foldx" in cnn_params.load_pretrained_backbone: cfg.model.params.cnn_params.load_pretrained_backbone = cnn_params.load_pretrained_backbone.\ replace("foldx", f"fold{cfg.data.outer_fold}") print(f'Creating model <{name}> ...') model = getattr(models.engine, name)(**params) if 'backbone' in cfg.model.params: print(f' Using backbone <{cfg.model.params.backbone}> ...') if 'pretrained' in cfg.model.params: print(f' Pretrained : {cfg.model.params.pretrained}') if "load_pretrained" in cfg.model: import re if "foldx" in cfg.model.load_pretrained: cfg.model.load_pretrained = cfg.model.load_pretrained.replace("foldx", f"fold{cfg.data.outer_fold}") print(f" Loading pretrained checkpoint from {cfg.model.load_pretrained}") weights = torch.load(cfg.model.load_pretrained, map_location=lambda storage, loc: storage)['state_dict'] weights = {re.sub(r'^model.', '', k) : v for k,v in weights.items() if "loss_fn" not in k} model.load_state_dict(weights) return model def build_loss(cfg): name, params = get_name_and_params(cfg.loss) print(f'Using loss function <{name}> ...') params = dict(params) if "pos_weight" in params: params["pos_weight"] = torch.tensor(params["pos_weight"]) criterion = getattr(losses, name)(**params) return criterion def build_scheduler(cfg, optimizer): # Some schedulers will require manipulation of config params # My specifications were to make it more intuitive for me name, params = get_name_and_params(cfg.scheduler) print(f'Using learning rate schedule <{name}> ...') if name == 'CosineAnnealingLR': # eta_min <-> final_lr # Set T_max as 100000 ... this is changed in on_train_start() method # of the LightningModule task params = { 'T_max': 100000, 'eta_min': max(params.final_lr, 1.0e-8) } if name in ('OneCycleLR', 'CustomOneCycleLR'): # Use learning rate from optimizer parameters as initial learning rate lr_0 = cfg.optimizer.params.lr lr_1 = params.max_lr lr_2 = params.final_lr # lr_0 -> lr_1 -> lr_2 pct_start = params.pct_start params = {} params['steps_per_epoch'] = 100000 # see above- will fix in task params['epochs'] = cfg.train.num_epochs params['max_lr'] = lr_1 params['pct_start'] = pct_start params['div_factor'] = lr_1 / lr_0 # max/init params['final_div_factor'] = lr_0 / max(lr_2, 1.0e-8) # init/final scheduler = getattr(optim, name)(optimizer=optimizer, **params) # Some schedulers might need more manipulation after instantiation if name in ('OneCycleLR', 'CustomOneCycleLR'): scheduler.pct_start = params['pct_start'] # Set update frequency if name in ('OneCycleLR', 'CustomOneCycleLR', 'CosineAnnealingLR'): scheduler.update_frequency = 'on_batch' elif name in ('ReduceLROnPlateau'): scheduler.update_frequency = 'on_valid' else: scheduler.update_frequency = 'on_epoch' return scheduler def build_optimizer(cfg, parameters): name, params = get_name_and_params(cfg.optimizer) print(f'Using optimizer <{name}> ...') optimizer = getattr(optim, name)(parameters, **params) return optimizer def build_task(cfg, model): name, params = get_name_and_params(cfg.task) print(f'Building task <{name}> ...') return getattr(tasks, name)(cfg, model, **params)