import math import os import random import numpy as np import torch from torch.utils.data import Sampler class RatioSampler(Sampler): def __init__(self, data_source, scales, first_bs=512, fix_bs=True, divided_factor=[8, 16], is_training=True, max_ratio=10, max_bs=1024, seed=None): """ multi scale samper Args: data_source(dataset) scales(list): several scales for image resolution first_bs(int): batch size for the first scale in scales divided_factor(list[w, h]): ImageNet models down-sample images by a factor, ensure that width and height dimensions are multiples are multiple of devided_factor. is_training(boolean): mode """ # min. and max. spatial dimensions self.data_source = data_source # self.data_idx_order_list = np.array(data_source.data_idx_order_list) self.ds_width = data_source.ds_width self.seed = data_source.seed if self.ds_width: self.wh_ratio = data_source.wh_ratio self.wh_ratio_sort = data_source.wh_ratio_sort self.n_data_samples = len(self.data_source) self.max_ratio = max_ratio self.max_bs = max_bs if isinstance(scales[0], list): width_dims = [i[0] for i in scales] height_dims = [i[1] for i in scales] elif isinstance(scales[0], int): width_dims = scales height_dims = scales base_im_w = width_dims[0] base_im_h = height_dims[0] base_batch_size = first_bs base_elements = base_im_w * base_im_h * base_batch_size self.base_elements = base_elements self.base_batch_size = base_batch_size self.base_im_h = base_im_h self.base_im_w = base_im_w # Get the GPU and node related information num_replicas = torch.cuda.device_count() # rank = dist.get_rank() rank = (int(os.environ['LOCAL_RANK']) if 'LOCAL_RANK' in os.environ else 0) # self.rank = rank # adjust the total samples to avoid batch dropping num_samples_per_replica = int( math.ceil(self.n_data_samples * 1.0 / num_replicas)) img_indices = [idx for idx in range(self.n_data_samples)] self.shuffle = False if is_training: # compute the spatial dimensions and corresponding batch size # ImageNet models down-sample images by a factor of 32. # Ensure that width and height dimensions are multiples are multiple of 32. width_dims = [ int((w // divided_factor[0]) * divided_factor[0]) for w in width_dims ] height_dims = [ int((h // divided_factor[1]) * divided_factor[1]) for h in height_dims ] img_batch_pairs = list() for (h, w) in zip(height_dims, width_dims): if fix_bs: batch_size = base_batch_size else: batch_size = int(max(1, (base_elements / (h * w)))) img_batch_pairs.append((w, h, batch_size)) self.img_batch_pairs = img_batch_pairs self.shuffle = True np.random.seed(seed) random.seed(seed) else: self.img_batch_pairs = [(base_im_w, base_im_h, base_batch_size)] self.img_indices = img_indices self.n_samples_per_replica = num_samples_per_replica self.epoch = 0 self.rank = rank self.num_replicas = num_replicas # self.batch_list = [] self.current = 0 self.is_training = is_training if is_training: indices_rank_i = self.img_indices[ self.rank:len(self.img_indices):self.num_replicas] else: indices_rank_i = self.img_indices self.indices_rank_i_ori = np.array(self.wh_ratio_sort[indices_rank_i]) self.indices_rank_i_ratio = self.wh_ratio[self.indices_rank_i_ori] indices_rank_i_ratio_unique = np.unique(self.indices_rank_i_ratio) self.indices_rank_i_ratio_unique = indices_rank_i_ratio_unique.tolist() self.batch_list = self.create_batch() self.length = len(self.batch_list) self.batchs_in_one_epoch_id = [i for i in range(self.length)] def create_batch(self): batch_list = [] for ratio in self.indices_rank_i_ratio_unique: ratio_ids = np.where(self.indices_rank_i_ratio == ratio)[0] ratio_ids = self.indices_rank_i_ori[ratio_ids] if self.shuffle: random.shuffle(ratio_ids) num_ratio = ratio_ids.shape[0] if ratio < 5: batch_size_ratio = self.base_batch_size else: batch_size_ratio = min( self.max_bs, int( max(1, (self.base_elements / (self.base_im_h * ratio * self.base_im_h))))) if num_ratio > batch_size_ratio: batch_num_ratio = num_ratio // batch_size_ratio print(self.rank, num_ratio, ratio * self.base_im_h, batch_num_ratio, batch_size_ratio) ratio_ids_full = ratio_ids[:batch_num_ratio * batch_size_ratio].reshape( batch_num_ratio, batch_size_ratio, 1) w = np.full_like(ratio_ids_full, ratio * self.base_im_h) h = np.full_like(ratio_ids_full, self.base_im_h) ra_wh = np.full_like(ratio_ids_full, ratio) ratio_ids_full = np.concatenate([w, h, ratio_ids_full, ra_wh], axis=-1) batch_ratio = ratio_ids_full.tolist() if batch_num_ratio * batch_size_ratio < num_ratio: drop = ratio_ids[batch_num_ratio * batch_size_ratio:] if self.is_training: drop_full = ratio_ids[:batch_size_ratio - ( num_ratio - batch_num_ratio * batch_size_ratio)] drop = np.append(drop_full, drop) drop = drop.reshape(-1, 1) w = np.full_like(drop, ratio * self.base_im_h) h = np.full_like(drop, self.base_im_h) ra_wh = np.full_like(drop, ratio) drop = np.concatenate([w, h, drop, ra_wh], axis=-1) batch_ratio.append(drop.tolist()) batch_list += batch_ratio else: print(self.rank, num_ratio, ratio * self.base_im_h, batch_size_ratio) ratio_ids = ratio_ids.reshape(-1, 1) w = np.full_like(ratio_ids, ratio * self.base_im_h) h = np.full_like(ratio_ids, self.base_im_h) ra_wh = np.full_like(ratio_ids, ratio) ratio_ids = np.concatenate([w, h, ratio_ids, ra_wh], axis=-1) batch_list.append(ratio_ids.tolist()) return batch_list def __iter__(self): if self.shuffle or self.is_training: random.seed(self.epoch) self.epoch += 1 self.batch_list = self.create_batch() random.shuffle(self.batchs_in_one_epoch_id) for batch_tuple_id in self.batchs_in_one_epoch_id: yield self.batch_list[batch_tuple_id] def set_epoch(self, epoch: int): self.epoch = epoch def __len__(self): return self.length