unit_test / uniperceiver /datasets /batch_sampler.py
herrius's picture
Upload 259 files
32b542e
import numpy as np
import os
import math
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler
import random
from uniperceiver.utils import comm
import itertools
from .sampler import TrainingSampler, NaiveSampler, NodeDistributedSampler
from uniperceiver.datasets.unified_dataset import UnifiedDataset
try:
import deepspeed.utils.groups as groups
DEEPSPEED_INSTALLED = True
except:
DEEPSPEED_INSTALLED = False
class WeightedBatchSampler(torch.utils.data.sampler.BatchSampler):
def __init__(self,
dataset: UnifiedDataset,
cfg,
task_cfg,
stage='train',
shuffle=True,
drop_last=True):
self.dataset = dataset
self.cfg = cfg
self.task_cfg = task_cfg
self._tasks = list(self.task_cfg.keys())
# dataset_names = self.dataset.dataset_name
# dataset_list = self.dataset.dataset_list
unit_sampler = dict()
for name, new_cfg in self.task_cfg.items():
if new_cfg.DATASETS.DATASET_NAME in [
"MSCOCO", "FLICKR", "ImageNet22k", "ImageNet1k", "VG", "VideoDataSet", "K700", 'K400', 'MiT', 'MSVDDataset', 'MSRVTTDataset',
"RTE", "CoLA", "SST-2", "MRPC", "QQP", "QNLI", "MNLI", "MNLI_Match", "VQA"
]:
sampler = TrainingSampler(self.dataset.datasets[name])
elif new_cfg.DATASETS.DATASET_NAME in ["BooksWiki"]:
# block cache
sampler = NaiveSampler(self.dataset.datasets[name])
elif new_cfg.DATASETS.DATASET_NAME in [
# "ImageTextPairDataset", 'SBUDataset', 'TQAPretrain'
'YFCC', 'CC12M', 'CC3M', 'SBU', 'TQAPretrain'
]:
sampler = NodeDistributedSampler(
self.dataset.datasets[name],
shuffle=True,
num_replicas=comm.get_world_size(),
rank=comm.get_rank(),
local_rank=comm.get_local_rank(),
local_size=comm.get_local_size())
else:
raise NotImplementedError(
f'please check the sampler used for this dataset {new_cfg.DATASETS.DATASET_NAME}'
)
unit_sampler[name] = sampler
self.unit_sampler = unit_sampler
self.unit_sampler_iter = {
k: iter(v)
for k, v in self.unit_sampler.items()
}
self.sampling_weights = {
k: v.DATALOADER.SAMPLING_WEIGHT
for k, v in self.task_cfg.items()
}
self._weights = [self.sampling_weights[k] for k in self._tasks]
self.stage = stage
if self.stage == 'train':
self.task_batch_size = {
k: v.DATALOADER.TRAIN_BATCH_SIZE
for k, v in self.task_cfg.items()
}
else:
raise NotImplementedError('only train dataset supportted now!')
self.len = [ len_ds//bs for len_ds, bs in zip([len(ds) for ds in self.dataset.dataset_list], self.task_batch_size.values())]
self.special_strategy = cfg.DATALOADER.STRATEGY
self.count = 0
self.task_index_offset = {
k: v
for k, v in zip(self.task_cfg.keys(),self.dataset.dataset_scale.tolist())
}
def __len__(self):
return sum(self.len)
def __iter__(self):
batch = []
while True:
if self.special_strategy == 'uniform':
task = self._tasks[comm.get_local_rank() % len(self._tasks)]
elif self.special_strategy == 'uniformv2':
task = self._tasks[(self.count + comm.get_rank()) %
len(self._tasks)]
self.count = (self.count + 1) % len(self._tasks)
elif self.special_strategy == 'turn':
task = self._tasks[self.count % len(self._tasks)]
self.count = (self.count + 1) % len(self._tasks)
else:
task = random.choices(self._tasks,
weights=self._weights)[0]
if self.cfg.MOE.MOE and DEEPSPEED_INSTALLED and groups.expert_parallel_is_initialized(
) and groups.get_expert_data_parallel_world_size() > 1:
task = comm.broadcast_object(
task,
src=comm.get_rank() -
comm.get_rank() % groups.get_expert_parallel_world_size(),
group=groups.get_expert_parallel_group())
"""
all sampler are infinite stream
"""
sample_index_offset = self.task_index_offset[task]
for i in range(self.task_batch_size[task]):
try:
batch.append(
next(self.unit_sampler_iter[task]) + sample_index_offset)
except:
self.unit_sampler_iter[task] = iter(self.unit_sampler[task])
batch.append(
next(self.unit_sampler_iter[task]) + sample_index_offset)
assert len(batch) == self.task_batch_size[task]
yield batch
batch = []