colorphoto / basicsr /data /lab_dataset.py
My-AI-Projects's picture
Add application file
aaa2047
import cv2
import random
import time
import numpy as np
import torch
from torch.utils import data as data
from basicsr.data.transforms import rgb2lab
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from basicsr.data.fmix import sample_mask
@DATASET_REGISTRY.register()
class LabDataset(data.Dataset):
"""
Dataset used for Lab colorizaion
"""
def __init__(self, opt):
super(LabDataset, self).__init__()
self.opt = opt
# file client (io backend)
self.file_client = None
self.io_backend_opt = opt['io_backend']
self.gt_folder = opt['dataroot_gt']
meta_info_file = self.opt['meta_info_file']
assert meta_info_file is not None
if not isinstance(meta_info_file, list):
meta_info_file = [meta_info_file]
self.paths = []
for meta_info in meta_info_file:
with open(meta_info, 'r') as fin:
self.paths.extend([line.strip() for line in fin])
self.min_ab, self.max_ab = -128, 128
self.interval_ab = 4
self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)]
# print(self.ab_palette)
self.do_fmix = opt['do_fmix']
self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False}
self.fmix_p = opt['fmix_p']
self.do_cutmix = opt['do_cutmix']
self.cutmix_params = {'alpha':1.}
self.cutmix_p = opt['cutmix_p']
def __getitem__(self, index):
if self.file_client is None:
self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
# -------------------------------- Load gt images -------------------------------- #
# Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
gt_path = self.paths[index]
gt_size = self.opt['gt_size']
# avoid errors caused by high latency in reading files
retry = 3
while retry > 0:
try:
img_bytes = self.file_client.get(gt_path, 'gt')
except Exception as e:
logger = get_root_logger()
logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
# change another file to read
index = random.randint(0, self.__len__())
gt_path = self.paths[index]
time.sleep(1) # sleep 1s for occasional server congestion
else:
break
finally:
retry -= 1
img_gt = imfrombytes(img_bytes, float32=True)
img_gt = cv2.resize(img_gt, (gt_size, gt_size)) # TODO: 直接resize是否是最佳方案?
# -------------------------------- (Optional) CutMix & FMix -------------------------------- #
if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p:
with torch.no_grad():
lam, mask = sample_mask(**self.fmix_params)
fmix_index = random.randint(0, self.__len__())
fmix_img_path = self.paths[fmix_index]
fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt')
fmix_img = imfrombytes(fmix_img_bytes, float32=True)
fmix_img = cv2.resize(fmix_img, (gt_size, gt_size))
mask = mask.transpose(1, 2, 0) # (1, 256, 256) -> # (256, 256, 1)
img_gt = mask * img_gt + (1. - mask) * fmix_img
img_gt = img_gt.astype(np.float32)
if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p:
with torch.no_grad():
cmix_index = random.randint(0, self.__len__())
cmix_img_path = self.paths[cmix_index]
cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt')
cmix_img = imfrombytes(cmix_img_bytes, float32=True)
cmix_img = cv2.resize(cmix_img, (gt_size, gt_size))
lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4)
bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)
img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]
# ----------------------------- Get gray lq, to tentor ----------------------------- #
# convert to gray
img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
img_l, img_ab = rgb2lab(img_gt)
target_a, target_b = self.ab2int(img_ab)
# numpy to tensor
img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True)
target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b)
return_d = {
'lq': img_l,
'gt': img_ab,
'target_a': target_a,
'target_b': target_b,
'lq_path': gt_path,
'gt_path': gt_path
}
return return_d
def ab2int(self, img_ab):
img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1]
int_a = (img_a - self.min_ab) / self.interval_ab
int_b = (img_b - self.min_ab) / self.interval_ab
return np.round(int_a), np.round(int_b)
def __len__(self):
return len(self.paths)
def rand_bbox(size, lam):
'''cutmix 的 bbox 截取函数
Args:
size : tuple 图片尺寸 e.g (256,256)
lam : float 截取比例
Returns:
bbox 的左上角和右下角坐标
int,int,int,int
'''
W = size[0] # 截取图片的宽度
H = size[1] # 截取图片的高度
cut_rat = np.sqrt(1. - lam) # 需要截取的 bbox 比例
cut_w = np.int(W * cut_rat) # 需要截取的 bbox 宽度
cut_h = np.int(H * cut_rat) # 需要截取的 bbox 高度
cx = np.random.randint(W) # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标
cy = np.random.randint(H) # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标
bbx1 = np.clip(cx - cut_w // 2, 0, W) # 左上角 x 坐标
bby1 = np.clip(cy - cut_h // 2, 0, H) # 左上角 y 坐标
bbx2 = np.clip(cx + cut_w // 2, 0, W) # 右下角 x 坐标
bby2 = np.clip(cy + cut_h // 2, 0, H) # 右下角 y 坐标
return bbx1, bby1, bbx2, bby2