Spaces:
Build error
Build error
from tokenize import group | |
import torch | |
import numpy as np | |
import numpy.random as npr | |
import torch.distributed as dist | |
import math | |
from ...log_service import print_log | |
from ... import sync | |
def singleton(class_): | |
instances = {} | |
def getinstance(*args, **kwargs): | |
if class_ not in instances: | |
instances[class_] = class_(*args, **kwargs) | |
return instances[class_] | |
return getinstance | |
class get_sampler(object): | |
def __init__(self): | |
self.sampler = {} | |
def register(self, sampler): | |
self.sampler[sampler.__name__] = sampler | |
def __call__(self, dataset, cfg): | |
if cfg == 'default_train': | |
return GlobalDistributedSampler(dataset, shuffle=True, extend=False) | |
elif cfg == 'default_eval': | |
return GlobalDistributedSampler(dataset, shuffle=False, extend=True) | |
else: | |
t = cfg.type | |
return self.sampler[t](dataset=dataset, **cfg.args) | |
def register(): | |
def wrapper(class_): | |
get_sampler().register(class_) | |
return class_ | |
return wrapper | |
###################### | |
# DistributedSampler # | |
###################### | |
class GlobalDistributedSampler(torch.utils.data.Sampler): | |
""" | |
This is a distributed sampler that sync accross gpus and nodes. | |
""" | |
def __init__(self, | |
dataset, | |
shuffle=True, | |
extend=False,): | |
""" | |
Arguments: | |
dataset: Dataset used for sampling. | |
shuffle: If true, sampler will shuffle the indices | |
extend: If true, sampler will extend the indices that can be even distributed by ranks | |
otherwise sampler will truncate the indices to make it even. | |
""" | |
self.ddp = sync.is_ddp() | |
self.rank = sync.get_rank('global') | |
self.world_size = sync.get_world_size('global') | |
self.dataset = dataset | |
self.shuffle = shuffle | |
self.extend = extend | |
num_samples = len(dataset) // self.world_size | |
if extend and (len(dataset)%self.world_size != 0): | |
num_samples+=1 | |
self.num_samples = num_samples | |
self.total_size = num_samples * self.world_size | |
def __iter__(self): | |
indices = self.get_sync_order() | |
if self.extend: | |
# extend using the front indices | |
indices = indices+indices[0:self.total_size-len(indices)] | |
else: | |
# truncate | |
indices = indices[0:self.total_size] | |
# subsample | |
indices = indices[self.rank : len(indices) : self.world_size] | |
return iter(indices) | |
def __len__(self): | |
return self.num_samples | |
def get_sync_order(self): | |
if self.shuffle: | |
indices = torch.randperm(len(self.dataset)).to(self.rank) | |
if self.ddp: | |
dist.broadcast(indices, src=0) | |
indices = indices.to('cpu').tolist() | |
else: | |
indices = list(range(len(self.dataset))) | |
print_log('Sampler : {}'.format(str(indices[0:5])) ) | |
return indices | |
class LocalDistributedSampler(GlobalDistributedSampler): | |
""" | |
This is a distributed sampler that sync across gpus within the nodes. | |
But not sync across nodes. | |
""" | |
def __init__(self, | |
dataset, | |
shuffle=True, | |
extend=False,): | |
super().__init__(dataset, shuffle, extend) | |
self.rank = sync.get_rank('local') | |
self.world_size = sync.get_world_size('local') | |
def get_sync_order(self): | |
if self.shuffle: | |
if self.rank == 0: | |
indices = list(npr.permutation(len(self.dataset))) | |
sync.nodewise_sync().broadcast_r0(indices) | |
else: | |
indices = sync.nodewise_sync().broadcast_r0(None) | |
else: | |
indices = list(range(len(self.dataset))) | |
print_log('Sampler : {}'.format(str(indices[0:5])) ) | |
return indices | |
############################ | |
# random sample with group # | |
############################ | |
# Deprecated | |
class GroupSampler(torch.utils.data.Sampler): | |
""" | |
This is a new DistributedSampler that sample all index according to group. | |
i.e. | |
if group_size=3, num_replicas=2, train mode: | |
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 | |
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10] | |
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8, 9, 10]) | |
process1: [0, 1, 2] | |
==> (group leftover) process0: [3, 4, 5], (leftover [6, 7], [8, 9], 10) | |
process1: [0, 1, 2] | |
==> (distribute) process0: [3, 4, 5], [6, 7] (remove 10) | |
process1: [0, 1, 2], [8, 9] | |
it will avoid_batchsize=1: | |
0, 1, 2, 3, 4, 5, 6, 7, 8, | |
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8] | |
==> (distribute) process0: [3, 4, 5], (leftover [6, 7, 8]) | |
process1: [0, 1, 2] | |
==> (group leftover) process0: [3, 4, 5], (leftover [6], [7], [8]) | |
process1: [0, 1, 2] | |
==> (distribute) process0: [3, 4, 5], (remove 6, 7, 8) (because distribute make batchsize 1) | |
process1: [0, 1, 2] | |
if group_size=3, num_replicas=2, eval mode: | |
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 | |
==> (extend) 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 10 | |
==> (group) [0, 1, 2], [3, 4, 5], [6, 7, 8], [9, 10, 10] | |
==> (distribute) process0: [0, 1, 2], [6, 7, 8], | |
process1: [3, 4, 5], [9, 10, 10] | |
""" | |
def __init__(self, | |
dataset, | |
group_size, | |
num_replicas=None, | |
rank=None, | |
mode='train',): | |
if num_replicas is None: | |
if not dist.is_available(): | |
raise ValueError | |
num_replicas = dist.get_world_size() | |
if rank is None: | |
if not dist.is_available(): | |
raise ValueError | |
rank = dist.get_rank() | |
self.dataset = dataset | |
self.len_dataset = len(dataset) | |
self.group_size = group_size | |
self.num_replicas = num_replicas | |
self.rank = rank | |
self.mode = mode | |
len_dataset = self.len_dataset | |
if (len_dataset % num_replicas != 0) and (mode == 'train'): | |
# drop the non_aligned | |
aligned_indices = np.arange(len_dataset)[:-(len_dataset % num_replicas)] | |
aligned_len_dataset = aligned_indices.shape[0] | |
elif (len_dataset % num_replicas != 0) and (mode == 'eval'): | |
extend = np.array([len_dataset-1 for _ in range(num_replicas - len_dataset % num_replicas)]) | |
aligned_indices = np.concatenate([range(len_dataset), extend]) | |
aligned_len_dataset = aligned_indices.shape[0] | |
else: | |
aligned_indices = np.arange(len_dataset) | |
aligned_len_dataset = len_dataset | |
num_even_distributed_groups = aligned_len_dataset // (group_size * num_replicas) | |
num_even = num_even_distributed_groups * group_size * num_replicas | |
self.regular_groups = aligned_indices[0:num_even].reshape(-1, group_size) | |
self.leftover_groups = aligned_indices[num_even:].reshape(num_replicas, -1) | |
if self.leftover_groups.size == 0: | |
self.leftover_groups = None | |
elif (self.leftover_groups.shape[-1]==1) and (mode == 'train'): | |
# avoid bs=1 | |
self.leftover_groups = None | |
# a urly way to modify dataset.load_info according to the grouping | |
for groupi in self.regular_groups: | |
for idx in groupi: | |
idx_lowerbd = groupi[0] | |
idx_upperbd = groupi[-1] | |
idx_reference = (idx_lowerbd+idx_upperbd)//2 | |
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] | |
if self.leftover_groups is not None: | |
for groupi in self.leftover_groups: | |
for idx in groupi: | |
idx_lowerbd = groupi[0] | |
idx_upperbd = groupi[-1] | |
idx_reference = (idx_lowerbd+idx_upperbd)//2 | |
dataset.load_info[idx]['ref_size'] = dataset.load_info[idx_reference]['image_size'] | |
def concat(self, nparrays, axis=0): | |
# a helper for save concaternation | |
nparrays = [i for i in nparrays if i.size > 0] | |
return np.concatenate(nparrays, axis=axis) | |
def __iter__(self): | |
indices = self.get_sync_order() | |
return iter(indices) | |
def __len__(self): | |
return self.num_samples | |
def get_sync_order(self): | |
# g = torch.Generator() | |
# g.manual_seed(self.epoch) | |
mode = self.mode | |
rank = self.rank | |
num_replicas = self.num_replicas | |
group_size = self.group_size | |
num_groups = len(self.regular_groups) | |
if mode == 'train': | |
g_indices = torch.randperm(num_groups).to(rank) | |
dist.broadcast(g_indices, src=0) | |
g_indices = g_indices.to('cpu').tolist() | |
num_groups_per_rank = num_groups // num_replicas | |
groups = self.regular_groups[g_indices][num_groups_per_rank*rank : num_groups_per_rank*(rank+1)] | |
indices = groups.flatten() | |
if self.leftover_groups is not None: | |
leftg_indices = torch.randperm(len(self.leftover_groups)).to(rank) | |
dist.broadcast(leftg_indices, src=0) | |
leftg_indices = leftg_indices.to('cpu').tolist() | |
last = self.leftover_groups[leftg_indices][rank] | |
indices = np.concatenate([indices, last], axis=0) | |
elif mode == 'eval': | |
groups = self.regular_groups.reshape(-1, num_replicas, group_size)[:, rank, :] | |
indices = groups.flatten() | |
if self.leftover_groups is not None: | |
last = self.leftover_groups[rank] | |
indices = np.concatenate([indices, last], axis=0) | |
else: | |
raise ValueError | |
print_log('Sampler RANK {} : {}'.format(rank, str(indices[0:group_size+1]))) | |
return indices | |