File size: 6,418 Bytes
aaa2047
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import cv2
import random
import time
import numpy as np
import torch
from torch.utils import data as data

from basicsr.data.transforms import rgb2lab
from basicsr.utils import FileClient, get_root_logger, imfrombytes, img2tensor
from basicsr.utils.registry import DATASET_REGISTRY
from basicsr.data.fmix import sample_mask


@DATASET_REGISTRY.register()
class LabDataset(data.Dataset):
    """
    Dataset used for Lab colorizaion
    """

    def __init__(self, opt):
        super(LabDataset, self).__init__()
        self.opt = opt
        # file client (io backend)
        self.file_client = None
        self.io_backend_opt = opt['io_backend']
        self.gt_folder = opt['dataroot_gt']

        meta_info_file = self.opt['meta_info_file']
        assert meta_info_file is not None
        if not isinstance(meta_info_file, list):
            meta_info_file = [meta_info_file]
        self.paths = []
        for meta_info in meta_info_file:
            with open(meta_info, 'r') as fin:
                self.paths.extend([line.strip() for line in fin])

        self.min_ab, self.max_ab = -128, 128
        self.interval_ab = 4
        self.ab_palette = [i for i in range(self.min_ab, self.max_ab + self.interval_ab, self.interval_ab)]
        # print(self.ab_palette)

        self.do_fmix = opt['do_fmix']
        self.fmix_params = {'alpha':1.,'decay_power':3.,'shape':(256,256),'max_soft':0.0,'reformulate':False}
        self.fmix_p = opt['fmix_p']
        self.do_cutmix = opt['do_cutmix']
        self.cutmix_params = {'alpha':1.}
        self.cutmix_p = opt['cutmix_p']


    def __getitem__(self, index):
        if self.file_client is None:
            self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)

        # -------------------------------- Load gt images -------------------------------- #
        # Shape: (h, w, c); channel order: BGR; image range: [0, 1], float32.
        gt_path = self.paths[index]
        gt_size = self.opt['gt_size']
        # avoid errors caused by high latency in reading files
        retry = 3
        while retry > 0:
            try:
                img_bytes = self.file_client.get(gt_path, 'gt')
            except Exception as e:
                logger = get_root_logger()
                logger.warn(f'File client error: {e}, remaining retry times: {retry - 1}')
                # change another file to read
                index = random.randint(0, self.__len__())
                gt_path = self.paths[index]
                time.sleep(1)  # sleep 1s for occasional server congestion
            else:
                break
            finally:
                retry -= 1
        img_gt = imfrombytes(img_bytes, float32=True)
        img_gt = cv2.resize(img_gt, (gt_size, gt_size))  # TODO: 直接resize是否是最佳方案?
        
        # -------------------------------- (Optional) CutMix & FMix -------------------------------- #
        if self.do_fmix and np.random.uniform(0., 1., size=1)[0] > self.fmix_p:
            with torch.no_grad():
                lam, mask = sample_mask(**self.fmix_params)
                
                fmix_index = random.randint(0, self.__len__())
                fmix_img_path = self.paths[fmix_index]
                fmix_img_bytes = self.file_client.get(fmix_img_path, 'gt')
                fmix_img = imfrombytes(fmix_img_bytes, float32=True)
                fmix_img = cv2.resize(fmix_img, (gt_size, gt_size))
                
                mask = mask.transpose(1, 2, 0)  # (1, 256, 256) ->  # (256, 256, 1)
                img_gt = mask * img_gt + (1. - mask) * fmix_img
                img_gt = img_gt.astype(np.float32)

        if self.do_cutmix and np.random.uniform(0., 1., size=1)[0] > self.cutmix_p:
            with torch.no_grad():
                cmix_index = random.randint(0, self.__len__())
                cmix_img_path = self.paths[cmix_index]
                cmix_img_bytes = self.file_client.get(cmix_img_path, 'gt')
                cmix_img = imfrombytes(cmix_img_bytes, float32=True)
                cmix_img = cv2.resize(cmix_img, (gt_size, gt_size))

                lam = np.clip(np.random.beta(self.cutmix_params['alpha'], self.cutmix_params['alpha']), 0.3, 0.4)
                bbx1, bby1, bbx2, bby2 = rand_bbox(cmix_img.shape[:2], lam)

                img_gt[:, bbx1:bbx2, bby1:bby2] = cmix_img[:, bbx1:bbx2, bby1:bby2]


        # ----------------------------- Get gray lq, to tentor ----------------------------- #
        # convert to gray
        img_gt = cv2.cvtColor(img_gt, cv2.COLOR_BGR2RGB)
        img_l, img_ab = rgb2lab(img_gt)

        target_a, target_b = self.ab2int(img_ab)

        # numpy to tensor
        img_l, img_ab = img2tensor([img_l, img_ab], bgr2rgb=False, float32=True)
        target_a, target_b = torch.LongTensor(target_a), torch.LongTensor(target_b)
        return_d = {
            'lq': img_l,
            'gt': img_ab,
            'target_a': target_a,
            'target_b': target_b,
            'lq_path': gt_path,
            'gt_path': gt_path
        }
        return return_d

    def ab2int(self, img_ab):
        img_a, img_b = img_ab[:, :, 0], img_ab[:, :, 1]
        int_a = (img_a - self.min_ab) / self.interval_ab
        int_b = (img_b - self.min_ab) / self.interval_ab

        return np.round(int_a), np.round(int_b)

    def __len__(self):
        return len(self.paths)


def rand_bbox(size, lam):
    '''cutmix 的 bbox 截取函数
    Args:
        size : tuple 图片尺寸 e.g (256,256)
        lam  : float 截取比例
    Returns:
        bbox 的左上角和右下角坐标
        int,int,int,int
    '''
    W = size[0]  # 截取图片的宽度
    H = size[1]  # 截取图片的高度
    cut_rat = np.sqrt(1. - lam)  # 需要截取的 bbox 比例
    cut_w = np.int(W * cut_rat)  # 需要截取的 bbox 宽度
    cut_h = np.int(H * cut_rat)  # 需要截取的 bbox 高度

    cx = np.random.randint(W)  # 均匀分布采样,随机选择截取的 bbox 的中心点 x 坐标
    cy = np.random.randint(H)  # 均匀分布采样,随机选择截取的 bbox 的中心点 y 坐标

    bbx1 = np.clip(cx - cut_w // 2, 0, W)  # 左上角 x 坐标
    bby1 = np.clip(cy - cut_h // 2, 0, H)  # 左上角 y 坐标
    bbx2 = np.clip(cx + cut_w // 2, 0, W)  # 右下角 x 坐标
    bby2 = np.clip(cy + cut_h // 2, 0, H)  # 右下角 y 坐标
    return bbx1, bby1, bbx2, bby2