osanseviero's picture
Duplicate from shi-labs/Versatile-Diffusion
67a8158
from tokenize import group
import torch
import numpy as np
import numpy.random as npr
import torch.distributed as dist
import math
from ...log_service import print_log
from ... import sync
def singleton(class_):
instances = {}
def getinstance(*args, **kwargs):
if class_ not in instances:
instances[class_] = class_(*args, **kwargs)
return instances[class_]
return getinstance
@singleton
class get_sampler(object):
def __init__(self):
self.sampler = {}
def register(self, sampler):
self.sampler[sampler.__name__] = sampler
def __call__(self, dataset, cfg):
if cfg == 'default_train':
return GlobalDistributedSampler(dataset, shuffle=True, extend=False)
elif cfg == 'default_eval':
return GlobalDistributedSampler(dataset, shuffle=False, extend=True)
else:
t = cfg.type
return self.sampler[t](dataset=dataset, **cfg.args)
def register():
def wrapper(class_):
get_sampler().register(class_)
return class_
return wrapper
######################
# DistributedSampler #
######################
@register()
class GlobalDistributedSampler(torch.utils.data.Sampler):
"""
This is a distributed sampler that sync accross gpus and nodes.
"""
def __init__(self,
dataset,
shuffle=True,
extend=False,):
"""
Arguments:
dataset: Dataset used for sampling.
shuffle: If true, sampler will shuffle the indices
extend: If true, sampler will extend the indices that can be even distributed by ranks
otherwise sampler will truncate the indices to make it even.
"""
self.ddp = sync.is_ddp()
self.rank = sync.get_rank('global')
self.world_size = sync.get_world_size('global')
self.dataset = dataset
self.shuffle = shuffle
self.extend = extend
num_samples = len(dataset) // self.world_size
if extend and (len(dataset)%self.world_size != 0):
num_samples+=1
self.num_samples = num_samples
self.total_size = num_samples * self.world_size
def __iter__(self):
indices = self.get_sync_order()
if self.extend:
# extend using the front indices
indices = indices+indices[0:self.total_size-len(indices)]
else:
# truncate
indices = indices[0:self.total_size]
# subsample
indices = indices[self.rank : len(indices) : self.world_size]
return iter(indices)
def __len__(self):
return self.num_samples
def get_sync_order(self):
if self.shuffle:
indices = torch.randperm(len(self.dataset)).to(self.rank)
if self.ddp:
dist.broadcast(indices, src=0)
indices = indices.to('cpu').tolist()
else:
indices = list(range(len(self.dataset)))
print_log('Sampler : {}'.format(str(indices[0:5])) )
return indices
@register()
class LocalDistributedSampler(GlobalDistributedSampler):
"""
This is a distributed sampler that sync across gpus within the nodes.
But not sync across nodes.
"""
def __init__(self,
dataset,
shuffle=True,
extend=False,):
super().__init__(dataset, shuffle, extend)
self.rank = sync.get_rank('local')
self.world_size = sync.get_world_size('local')
def get_sync_order(self):
if self.shuffle:
if self.rank == 0:
indices = list(npr.permutation(len(self.dataset)))
sync.nodewise_sync().broadcast_r0(indices)
else:
indices = sync.nodewise_sync().broadcast_r0(None)
else:
indices = list(range(len(self.dataset)))
print_log('Sampler : {}'.format(str(indices[0:5])) )
return indices
############################
# random sample with group #
############################
# Deprecated
@register()
class GroupSampler(torch.utils.data.Sampler):
"""
This is a new DistributedSampler that sample all index according to group.
i.e.
if group_size=3, num_replicas=2, train mode:
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10]
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10])
process1: [0, 1, 2]
==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10)
process1: [0, 1, 2]
==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10)
process1: [0, 1, 2], [8, 9]
it will avoid_batchsize=1:
0, 1, 2, 3, 4, 5, 6, 7, 8,
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8]
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8])
process1: [0, 1, 2]
==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8])
process1: [0, 1, 2]
==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1)
process1: [0, 1, 2]
if group_size=3, num_replicas=2, eval mode:
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10
==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10]
==> (distribute) process0: [0, 1, 2], [6, 7, 8],
process1: [3, 4, 5], [9, 10, 10]
"""
def __init__(self,
dataset,
group_size,
num_replicas=None,
rank=None,
mode='train',):
if num_replicas is None:
if not dist.is_available():
raise ValueError
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise ValueError
rank = dist.get_rank()
self.dataset = dataset
self.len_dataset = len(dataset)
self.group_size = group_size
self.num_replicas = num_replicas
self.rank = rank
self.mode = mode
len_dataset = self.len_dataset
if (len_dataset % num_replicas != 0) and (mode == 'train'):
# drop the non_aligned
aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)]
aligned_len_dataset = aligned_indices.shape[0]
elif (len_dataset % num_replicas != 0) and (mode == 'eval'):
extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)])
aligned_indices = np.concatenate([range(len_dataset), extend])
aligned_len_dataset = aligned_indices.shape[0]
else:
aligned_indices = np.arange(len_dataset)
aligned_len_dataset = len_dataset
num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas)
num_even = num_even_distributed_groups * group_size * num_replicas
self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size)
self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1)
if self.leftover_groups.size == 0:
self.leftover_groups = None
elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'):
# avoid bs=1
self.leftover_groups = None
# a urly way to modify dataset.load_info according to the grouping
for groupi in self.regular_groups:
for idx in groupi:
idx_lowerbd = groupi[0]
idx_upperbd = groupi[-1]
idx_reference = (idx_lowerbd+idx_upperbd)//2
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
if self.leftover_groups is not None:
for groupi in self.leftover_groups:
for idx in groupi:
idx_lowerbd = groupi[0]
idx_upperbd = groupi[-1]
idx_reference = (idx_lowerbd+idx_upperbd)//2
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size']
def concat(self, nparrays, axis=0):
# a helper for save concaternation
nparrays = [i for i in nparrays if i.size > 0]
return np.concatenate(nparrays, axis=axis)
def __iter__(self):
indices = self.get_sync_order()
return iter(indices)
def __len__(self):
return self.num_samples
def get_sync_order(self):
# g = torch.Generator()
# g.manual_seed(self.epoch)
mode = self.mode
rank = self.rank
num_replicas = self.num_replicas
group_size = self.group_size
num_groups = len(self.regular_groups)
if mode == 'train':
g_indices = torch.randperm(num_groups).to(rank)
dist.broadcast(g_indices, src=0)
g_indices = g_indices.to('cpu').tolist()
num_groups_per_rank = num_groups // num_replicas
groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)]
indices = groups.flatten()
if self.leftover_groups is not None:
leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank)
dist.broadcast(leftg_indices, src=0)
leftg_indices = leftg_indices.to('cpu').tolist()
last = self.leftover_groups[leftg_indices][rank]
indices = np.concatenate([indices, last], axis=0)
elif mode == 'eval':
groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :]
indices = groups.flatten()
if self.leftover_groups is not None:
last = self.leftover_groups[rank]
indices = np.concatenate([indices, last], axis=0)
else:
raise ValueError
print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1])))
return indices