import cv2 import random import time import numpy as np import torch from torch.utils import data as data from basicsr.data.transforms import rgb2lab from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor from basicsr.utils.registry import DATASET_REGISTRY from basicsr.data.fmix import sample_mask @DATASET_REGISTRY.register() class LabDataset(data.Dataset): """ Dataset used for Lab colorizaion """ def __init__(self, opt): super(LabDataset, self).__init__() self.opt = opt # file client (io backend) self.file_client = None self.io_backend_opt = opt['io_backend'] self.gt_folder = opt['dataroot_gt'] meta_info_file = self.opt['meta_info_file'] assert meta_info_file is not None if not isinstance(meta_info_file, list): meta_info_file = [meta_info_file] self.paths = [] for meta_info in meta_info_file: with open(meta_info, 'r') as fin: self.paths.extend([line.strip() for line in fin]) self.min_ab, self.max_ab = -128, 128 self.interval_ab = 4 self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)] # print(self.ab_palette) self.do_fmix = opt['do_fmix'] self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False} self.fmix_p = opt['fmix_p'] self.do_cutmix = opt['do_cutmix'] self.cutmix_params = {'alpha':1.} self.cutmix_p = opt['cutmix_p'] def __getitem__(self, index): if self.file_client is None: self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) # -------------------------------- Load gt images -------------------------------- # # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32. gt_path = self.paths[index] gt_size = self.opt['gt_size'] # avoid errors caused by high latency in reading files retry = 3 while retry > 0: try: img_bytes = self.file_client.get(gt_path, 'gt') except Exception as e: logger = get_root_logger() logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}') # change another file to read index = random.randint(0, self.__len__()) gt_path = self.paths[index] time.sleep(1) # sleep 1s for occasional server congestion else: break finally: retry -= 1 img_gt = imfrombytes(img_bytes, float32=True) img_gt = cv2.resize(img_gt, (gt_size, gt_size)) # TODO: 直接resize是否是最佳方案? # -------------------------------- (Optional) CutMix & FMix -------------------------------- # if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p: with torch.no_grad(): lam, mask = sample_mask(**self.fmix_params) fmix_index = random.randint(0, self.__len__()) fmix_img_path = self.paths[fmix_index] fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt') fmix_img = imfrombytes(fmix_img_bytes, float32=True) fmix_img = cv2.resize(fmix_img, (gt_size, gt_size)) mask = mask.transpose(1, 2, 0) # (1, 256, 256) -> # (256, 256, 1) img_gt = mask * img_gt + (1. - mask) * fmix_img img_gt = img_gt.astype(np.float32) if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p: with torch.no_grad(): cmix_index = random.randint(0, self.__len__()) cmix_img_path = self.paths[cmix_index] cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt') cmix_img = imfrombytes(cmix_img_bytes, float32=True) cmix_img = cv2.resize(cmix_img, (gt_size, gt_size)) lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4) bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam) img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2] # ----------------------------- Get gray lq, to tentor ----------------------------- # # convert to gray img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB) img_l, img_ab = rgb2lab(img_gt) target_a, target_b = self.ab2int(img_ab) # numpy to tensor img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True) target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b) return_d = { 'lq': img_l, 'gt': img_ab, 'target_a': target_a, 'target_b': target_b, 'lq_path': gt_path, 'gt_path': gt_path } return return_d def ab2int(self, img_ab): img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1] int_a = (img_a - self.min_ab) / self.interval_ab int_b = (img_b - self.min_ab) / self.interval_ab return np.round(int_a), np.round(int_b) def __len__(self): return len(self.paths) def rand_bbox(size, lam): '''cutmix 的 bbox 截取函数 Args: size : tuple 图片尺寸 e.g (256,256) lam : float 截取比例 Returns: bbox 的左上角和右下角坐标 int,int,int,int ''' W = size[0] # 截取图片的宽度 H = size[1] # 截取图片的高度 cut_rat = np.sqrt(1. - lam) # 需要截取的 bbox 比例 cut_w = np.int(W * cut_rat) # 需要截取的 bbox 宽度 cut_h = np.int(H * cut_rat) # 需要截取的 bbox 高度 cx = np.random.randint(W) # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标 cy = np.random.randint(H) # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标 bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上角 x 坐标 bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上角 y 坐标 bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下角 x 坐标 bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下角 y 坐标 return bbx1, bby1, bbx2, bby2