from collections import abc from loguru import logger import pytorch_lightning as pl from torch import distributed as dist from torch.utils.data import ( DataLoader, ConcatDataset, DistributedSampler ) from src.datasets.pretrain_dataset import PretrainDataset class PretrainDataModule(pl.LightningDataModule): """ For distributed training, each training process is assgined only a part of the training scenes to reduce memory overhead. """ def __init__(self, args, config): super().__init__() # 1. data config # Train and Val should from the same data source self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE self.val_data_source = config.DATASET.VAL_DATA_SOURCE # training and validating self.train_data_root = config.DATASET.TRAIN_DATA_ROOT self.val_data_root = config.DATASET.VAL_DATA_ROOT # 2. dataset config'] # dataset options self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE # 840 self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD # True self.pretrain_df = config.DATASET.PRETRAIN_DF # 8 self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0] # 0.125. for training xoftr. self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP # 3.loader parameters self.train_loader_params = { 'batch_size': args.batch_size, 'num_workers': args.num_workers, 'pin_memory': getattr(args, 'pin_memory', True) } self.val_loader_params = { 'batch_size': 1, 'shuffle': False, 'num_workers': args.num_workers, 'pin_memory': getattr(args, 'pin_memory', True) } def setup(self, stage=None): """ Setup train / val / test dataset. This method will be called by PL automatically. Args: stage (str): 'fit' in training phase, and 'test' in testing phase. """ assert stage in ['fit', 'test'], "stage must be either fit or test" try: self.world_size = dist.get_world_size() self.rank = dist.get_rank() logger.info(f"[rank:{self.rank}] world_size: {self.world_size}") except AssertionError as ae: self.world_size = 1 self.rank = 0 logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)") if stage == 'fit': self.train_dataset = self._setup_dataset( self.train_data_root, mode='train') # setup multiple (optional) validation subsets self.val_dataset = [] self.val_dataset.append(self._setup_dataset( self.val_data_root, mode='val')) logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!') else: # stage == 'test raise ValueError(f"only 'fit' implemented") def _setup_dataset(self, data_root, mode='train'): """ Setup train / val / test set""" dataset_builder = self._build_concat_dataset return dataset_builder(data_root, mode=mode) def _build_concat_dataset( self, data_root, mode ): datasets = [] datasets.append( PretrainDataset(data_root, mode=mode, img_resize=self.pretrain_img_resize, df=self.pretrain_df, img_padding=self.pretrain_img_pad, coarse_scale=self.coarse_scale, frame_gap=self.frame_gap)) return ConcatDataset(datasets) def train_dataloader(self): """ Build training dataloader for KAIST dataset. """ sampler = DistributedSampler(self.train_dataset, shuffle=True) dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params) return dataloader def val_dataloader(self): """ Build validation dataloader KAIST dataset. """ if not isinstance(self.val_dataset, abc.Sequence): return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params) else: dataloaders = [] for dataset in self.val_dataset: sampler = DistributedSampler(dataset, shuffle=False) dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params)) return dataloaders