File size: 4,720 Bytes
a930e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
from collections import abc
from loguru import logger

import pytorch_lightning as pl
from torch import distributed as dist
from torch.utils.data import (
    DataLoader,
    ConcatDataset,
    DistributedSampler
)

from src.datasets.pretrain_dataset import PretrainDataset


class PretrainDataModule(pl.LightningDataModule):
    """ 

    For distributed training, each training process is assgined

    only a part of the training scenes to reduce memory overhead.

    """
    def __init__(self, args, config):
        super().__init__()

        # 1. data config
        # Train and Val should from the same data source
        self.train_data_source = config.DATASET.TRAIN_DATA_SOURCE
        self.val_data_source = config.DATASET.VAL_DATA_SOURCE
        # training and validating
        self.train_data_root = config.DATASET.TRAIN_DATA_ROOT
        self.val_data_root = config.DATASET.VAL_DATA_ROOT

        # 2. dataset config']

        # dataset options
        self.pretrain_img_resize = config.DATASET.PRETRAIN_IMG_RESIZE  # 840
        self.pretrain_img_pad = config.DATASET.PRETRAIN_IMG_PAD   # True
        self.pretrain_df = config.DATASET.PRETRAIN_DF  # 8
        self.coarse_scale = 1 / config.XOFTR.RESOLUTION[0]  # 0.125. for training xoftr.
        self.frame_gap = config.DATASET.PRETRAIN_FRAME_GAP

        # 3.loader parameters
        self.train_loader_params = {
            'batch_size': args.batch_size,
            'num_workers': args.num_workers,
            'pin_memory': getattr(args, 'pin_memory', True)
        }
        self.val_loader_params = {
            'batch_size': 1,
            'shuffle': False,
            'num_workers': args.num_workers,
            'pin_memory': getattr(args, 'pin_memory', True)
        }

    def setup(self, stage=None):
        """

        Setup train / val / test dataset. This method will be called by PL automatically.

        Args:

            stage (str): 'fit' in training phase, and 'test' in testing phase.

        """

        assert stage in ['fit', 'test'], "stage must be either fit or test"

        try:
            self.world_size = dist.get_world_size()
            self.rank = dist.get_rank()
            logger.info(f"[rank:{self.rank}] world_size: {self.world_size}")
        except AssertionError as ae:
            self.world_size = 1
            self.rank = 0
            logger.warning(str(ae) + " (set wolrd_size=1 and rank=0)")

        if stage == 'fit':
            self.train_dataset = self._setup_dataset(
                self.train_data_root,
                mode='train')
            # setup multiple (optional) validation subsets
            self.val_dataset = []
            self.val_dataset.append(self._setup_dataset(
                self.val_data_root,
                mode='val'))
            logger.info(f'[rank:{self.rank}] Train & Val Dataset loaded!')
        else:  # stage == 'test
            raise ValueError(f"only 'fit' implemented")

    def _setup_dataset(self,

                       data_root,

                       mode='train'):
        """ Setup train / val / test set"""
        
        dataset_builder = self._build_concat_dataset
        return dataset_builder(data_root, mode=mode)

    def _build_concat_dataset(

        self,

        data_root,

        mode

    ):
        datasets = []

        datasets.append(
            PretrainDataset(data_root,
                                mode=mode,
                                img_resize=self.pretrain_img_resize,
                                df=self.pretrain_df,
                                img_padding=self.pretrain_img_pad,
                                coarse_scale=self.coarse_scale,
                                frame_gap=self.frame_gap))

        return ConcatDataset(datasets)

    def train_dataloader(self):
        """ Build training dataloader for KAIST dataset. """
        sampler = DistributedSampler(self.train_dataset, shuffle=True)
        dataloader = DataLoader(self.train_dataset, sampler=sampler, **self.train_loader_params)
        return dataloader
    
    def val_dataloader(self):
        """ Build validation dataloader KAIST dataset. """
        if not isinstance(self.val_dataset, abc.Sequence):
            return DataLoader(self.val_dataset, sampler=sampler, **self.val_loader_params)
        else:
            dataloaders = []
            for dataset in self.val_dataset:
                sampler = DistributedSampler(dataset, shuffle=False)
                dataloaders.append(DataLoader(dataset, sampler=sampler, **self.val_loader_params))
            return dataloaders