File size: 557 Bytes
9123ba9
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from dataset import DEFAULTDataset
from torch.utils.data import WeightedRandomSampler


def get_dataset(cfg):
    if cfg.dataset.name == 'DEFAULT':
        train_dataset = DEFAULTDataset(
            root_dir=cfg.dataset.root_dir, internal_resolution=cfg.model.internal_resolution)
        val_dataset = DEFAULTDataset(
            root_dir=cfg.dataset.root_dir, internal_resolution=cfg.model.internal_resolution)
        sampler = None
        return train_dataset, val_dataset, sampler
    raise ValueError(f'{cfg.dataset.name} Dataset is not available')