Spaces:
Running
Running
import random | |
import numpy as np | |
import torch.distributed as dist | |
from torch.utils.data import Sampler | |
class MultiScaleSampler(Sampler): | |
def __init__( | |
self, | |
data_source, | |
scales, | |
first_bs=128, | |
fix_bs=True, | |
divided_factor=[8, 16], | |
is_training=True, | |
ratio_wh=0.8, | |
max_w=480.0, | |
seed=None, | |
): | |
""" | |
multi scale samper | |
Args: | |
data_source(dataset) | |
scales(list): several scales for image resolution | |
first_bs(int): batch size for the first scale in scales | |
divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor. | |
is_training(boolean): mode | |
""" | |
# min. and max. spatial dimensions | |
self.data_source = data_source | |
self.data_idx_order_list = np.array(data_source.data_idx_order_list) | |
self.ds_width = data_source.ds_width | |
self.seed = data_source.seed | |
if self.ds_width: | |
self.wh_ratio = data_source.wh_ratio | |
self.wh_ratio_sort = data_source.wh_ratio_sort | |
self.n_data_samples = len(self.data_source) | |
self.ratio_wh = ratio_wh | |
self.max_w = max_w | |
if isinstance(scales[0], list): | |
width_dims = [i[0] for i in scales] | |
height_dims = [i[1] for i in scales] | |
elif isinstance(scales[0], int): | |
width_dims = scales | |
height_dims = scales | |
base_im_w = width_dims[0] | |
base_im_h = height_dims[0] | |
base_batch_size = first_bs | |
# Get the GPU and node related information | |
if dist.is_initialized(): | |
num_replicas = dist.get_world_size() | |
rank = dist.get_rank() | |
else: | |
num_replicas = 1 | |
rank = 0 | |
# adjust the total samples to avoid batch dropping | |
num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas) | |
img_indices = [idx for idx in range(self.n_data_samples)] | |
self.shuffle = False | |
if is_training: | |
# compute the spatial dimensions and corresponding batch size | |
# ImageNet models down-sample images by a factor of 32. | |
# Ensure that width and height dimensions are multiples are multiple of 32. | |
width_dims = [ | |
int((w // divided_factor[0]) * divided_factor[0]) | |
for w in width_dims | |
] | |
height_dims = [ | |
int((h // divided_factor[1]) * divided_factor[1]) | |
for h in height_dims | |
] | |
img_batch_pairs = list() | |
base_elements = base_im_w * base_im_h * base_batch_size | |
for h, w in zip(height_dims, width_dims): | |
if fix_bs: | |
batch_size = base_batch_size | |
else: | |
batch_size = int(max(1, (base_elements / (h * w)))) | |
img_batch_pairs.append((w, h, batch_size)) | |
self.img_batch_pairs = img_batch_pairs | |
self.shuffle = True | |
else: | |
self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)] | |
self.img_indices = img_indices | |
self.n_samples_per_replica = num_samples_per_replica | |
self.epoch = 0 | |
self.rank = rank | |
self.num_replicas = num_replicas | |
self.batch_list = [] | |
self.current = 0 | |
last_index = num_samples_per_replica * num_replicas | |
indices_rank_i = self.img_indices[self.rank:last_index:self. | |
num_replicas] | |
while self.current < self.n_samples_per_replica: | |
for curr_w, curr_h, curr_bsz in self.img_batch_pairs: | |
end_index = min(self.current + curr_bsz, | |
self.n_samples_per_replica) | |
batch_ids = indices_rank_i[self.current:end_index] | |
n_batch_samples = len(batch_ids) | |
if n_batch_samples != curr_bsz: | |
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] | |
self.current += curr_bsz | |
if len(batch_ids) > 0: | |
batch = [curr_w, curr_h, len(batch_ids)] | |
self.batch_list.append(batch) | |
random.shuffle(self.batch_list) | |
self.length = len(self.batch_list) | |
self.batchs_in_one_epoch = self.iter() | |
self.batchs_in_one_epoch_id = [ | |
i for i in range(len(self.batchs_in_one_epoch)) | |
] | |
def __iter__(self): | |
if self.seed is None: | |
random.seed(self.epoch) | |
self.epoch += 1 | |
else: | |
random.seed(self.seed) | |
random.shuffle(self.batchs_in_one_epoch_id) | |
for batch_tuple_id in self.batchs_in_one_epoch_id: | |
yield self.batchs_in_one_epoch[batch_tuple_id] | |
def iter(self): | |
if self.shuffle: | |
if self.seed is not None: | |
random.seed(self.seed) | |
else: | |
random.seed(self.epoch) | |
if not self.ds_width: | |
random.shuffle(self.img_indices) | |
random.shuffle(self.img_batch_pairs) | |
indices_rank_i = self.img_indices[ | |
self.rank:len(self.img_indices):self.num_replicas] | |
else: | |
indices_rank_i = self.img_indices[ | |
self.rank:len(self.img_indices):self.num_replicas] | |
start_index = 0 | |
batchs_in_one_epoch = [] | |
for batch_tuple in self.batch_list: | |
curr_w, curr_h, curr_bsz = batch_tuple | |
end_index = min(start_index + curr_bsz, self.n_samples_per_replica) | |
batch_ids = indices_rank_i[start_index:end_index] | |
n_batch_samples = len(batch_ids) | |
if n_batch_samples != curr_bsz: | |
batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] | |
start_index += curr_bsz | |
if len(batch_ids) > 0: | |
if self.ds_width: | |
wh_ratio_current = self.wh_ratio[ | |
self.wh_ratio_sort[batch_ids]] | |
ratio_current = wh_ratio_current.mean() | |
ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h | |
else: | |
ratio_current = None | |
batch = [(curr_w, curr_h, b_id, ratio_current) | |
for b_id in batch_ids] | |
# yield batch | |
batchs_in_one_epoch.append(batch) | |
return batchs_in_one_epoch | |
def set_epoch(self, epoch: int): | |
self.epoch = epoch | |
def __len__(self): | |
return self.length | |