import random import numpy as np import torch.distributed as dist from torch.utils.data import Sampler class MultiScaleSampler(Sampler): def __init__( self, data_source, scales, first_bs=128, fix_bs=True, divided_factor=[8, 16], is_training=True, ratio_wh=0.8, max_w=480.0, seed=None, ): """ multi scale samper Args: data_source(dataset) scales(list): several scales for image resolution first_bs(int): batch size for the first scale in scales divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor. is_training(boolean): mode """ # min. and max. spatial dimensions self.data_source = data_source self.data_idx_order_list = np.array(data_source.data_idx_order_list) self.ds_width = data_source.ds_width self.seed = data_source.seed if self.ds_width: self.wh_ratio = data_source.wh_ratio self.wh_ratio_sort = data_source.wh_ratio_sort self.n_data_samples = len(self.data_source) self.ratio_wh = ratio_wh self.max_w = max_w if isinstance(scales[0], list): width_dims = [i[0] for i in scales] height_dims = [i[1] for i in scales] elif isinstance(scales[0], int): width_dims = scales height_dims = scales base_im_w = width_dims[0] base_im_h = height_dims[0] base_batch_size = first_bs # Get the GPU and node related information if dist.is_initialized(): num_replicas = dist.get_world_size() rank = dist.get_rank() else: num_replicas = 1 rank = 0 # adjust the total samples to avoid batch dropping num_samples_per_replica = int(self.n_data_samples * 1.0 / num_replicas) img_indices = [idx for idx in range(self.n_data_samples)] self.shuffle = False if is_training: # compute the spatial dimensions and corresponding batch size # ImageNet models down-sample images by a factor of 32. # Ensure that width and height dimensions are multiples are multiple of 32. width_dims = [ int((w // divided_factor[0]) * divided_factor[0]) for w in width_dims ] height_dims = [ int((h // divided_factor[1]) * divided_factor[1]) for h in height_dims ] img_batch_pairs = list() base_elements = base_im_w * base_im_h * base_batch_size for h, w in zip(height_dims, width_dims): if fix_bs: batch_size = base_batch_size else: batch_size = int(max(1, (base_elements / (h * w)))) img_batch_pairs.append((w, h, batch_size)) self.img_batch_pairs = img_batch_pairs self.shuffle = True else: self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)] self.img_indices = img_indices self.n_samples_per_replica = num_samples_per_replica self.epoch = 0 self.rank = rank self.num_replicas = num_replicas self.batch_list = [] self.current = 0 last_index = num_samples_per_replica * num_replicas indices_rank_i = self.img_indices[self.rank:last_index:self. num_replicas] while self.current < self.n_samples_per_replica: for curr_w, curr_h, curr_bsz in self.img_batch_pairs: end_index = min(self.current + curr_bsz, self.n_samples_per_replica) batch_ids = indices_rank_i[self.current:end_index] n_batch_samples = len(batch_ids) if n_batch_samples != curr_bsz: batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] self.current += curr_bsz if len(batch_ids) > 0: batch = [curr_w, curr_h, len(batch_ids)] self.batch_list.append(batch) random.shuffle(self.batch_list) self.length = len(self.batch_list) self.batchs_in_one_epoch = self.iter() self.batchs_in_one_epoch_id = [ i for i in range(len(self.batchs_in_one_epoch)) ] def __iter__(self): if self.seed is None: random.seed(self.epoch) self.epoch += 1 else: random.seed(self.seed) random.shuffle(self.batchs_in_one_epoch_id) for batch_tuple_id in self.batchs_in_one_epoch_id: yield self.batchs_in_one_epoch[batch_tuple_id] def iter(self): if self.shuffle: if self.seed is not None: random.seed(self.seed) else: random.seed(self.epoch) if not self.ds_width: random.shuffle(self.img_indices) random.shuffle(self.img_batch_pairs) indices_rank_i = self.img_indices[ self.rank:len(self.img_indices):self.num_replicas] else: indices_rank_i = self.img_indices[ self.rank:len(self.img_indices):self.num_replicas] start_index = 0 batchs_in_one_epoch = [] for batch_tuple in self.batch_list: curr_w, curr_h, curr_bsz = batch_tuple end_index = min(start_index + curr_bsz, self.n_samples_per_replica) batch_ids = indices_rank_i[start_index:end_index] n_batch_samples = len(batch_ids) if n_batch_samples != curr_bsz: batch_ids += indices_rank_i[:(curr_bsz - n_batch_samples)] start_index += curr_bsz if len(batch_ids) > 0: if self.ds_width: wh_ratio_current = self.wh_ratio[ self.wh_ratio_sort[batch_ids]] ratio_current = wh_ratio_current.mean() ratio_current = ratio_current if ratio_current * curr_h < self.max_w else self.max_w / curr_h else: ratio_current = None batch = [(curr_w, curr_h, b_id, ratio_current) for b_id in batch_ids] # yield batch batchs_in_one_epoch.append(batch) return batchs_in_one_epoch def set_epoch(self, epoch: int): self.epoch = epoch def __len__(self): return self.length