File size: 13,048 Bytes
8e12b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
import inspect
import random
from copy import deepcopy

import cv2
import numpy as np
import torch
import torch.nn as nn
import torchvision.transforms.functional as F
from PIL import Image
from torchvision.transforms import CenterCrop, Normalize, RandomCrop, RandomHorizontalFlip, Resize
from torchvision.transforms.functional import InterpolationMode

from mixofshow.utils.registry import TRANSFORM_REGISTRY


def build_transform(opt):
    """Build performance evaluator from options.
    Args:
        opt (dict): Configuration.
    """
    opt = deepcopy(opt)
    transform_type = opt.pop('type')
    transform = TRANSFORM_REGISTRY.get(transform_type)(**opt)
    return transform


TRANSFORM_REGISTRY.register(Normalize)
TRANSFORM_REGISTRY.register(Resize)
TRANSFORM_REGISTRY.register(RandomHorizontalFlip)
TRANSFORM_REGISTRY.register(CenterCrop)
TRANSFORM_REGISTRY.register(RandomCrop)


@TRANSFORM_REGISTRY.register()
class BILINEARResize(Resize):
    def __init__(self, size):
        super(BILINEARResize,
              self).__init__(size, interpolation=InterpolationMode.BILINEAR)


@TRANSFORM_REGISTRY.register()
class PairRandomCrop(nn.Module):
    def __init__(self, size):
        super().__init__()
        if isinstance(size, int):
            self.height, self.width = size, size
        else:
            self.height, self.width = size

    def forward(self, img, **kwargs):
        img_width, img_height = img.size
        mask_width, mask_height = kwargs['mask'].size

        assert img_height >= self.height and img_height == mask_height
        assert img_width >= self.width and img_width == mask_width

        x = random.randint(0, img_width - self.width)
        y = random.randint(0, img_height - self.height)
        img = F.crop(img, y, x, self.height, self.width)
        kwargs['mask'] = F.crop(kwargs['mask'], y, x, self.height, self.width)
        return img, kwargs


@TRANSFORM_REGISTRY.register()
class ToTensor(nn.Module):
    def __init__(self) -> None:
        super().__init__()

    def forward(self, pic):
        return F.to_tensor(pic)

    def __repr__(self) -> str:
        return f'{self.__class__.__name__}()'


@TRANSFORM_REGISTRY.register()
class PairRandomHorizontalFlip(torch.nn.Module):
    def __init__(self, p=0.5):
        super().__init__()
        self.p = p

    def forward(self, img, **kwargs):
        if torch.rand(1) < self.p:
            kwargs['mask'] = F.hflip(kwargs['mask'])
            return F.hflip(img), kwargs
        return img, kwargs


@TRANSFORM_REGISTRY.register()
class PairResize(nn.Module):
    def __init__(self, size):
        super().__init__()
        self.resize = Resize(size=size)

    def forward(self, img, **kwargs):
        kwargs['mask'] = self.resize(kwargs['mask'])
        img = self.resize(img)
        return img, kwargs


class PairCompose(nn.Module):
    def __init__(self, transforms):
        super().__init__()
        self.transforms = transforms

    def __call__(self, img, **kwargs):
        for t in self.transforms:
            if len(inspect.signature(t.forward).parameters
                   ) == 1:  # count how many args, not count self
                img = t(img)
            else:
                img, kwargs = t(img, **kwargs)
        return img, kwargs

    def __repr__(self) -> str:
        format_string = self.__class__.__name__ + '('
        for t in self.transforms:
            format_string += '\n'
            format_string += f'    {t}'
        format_string += '\n)'
        return format_string


@TRANSFORM_REGISTRY.register()
class HumanResizeCropFinalV3(nn.Module):
    def __init__(self, size, crop_p=0.5):
        super().__init__()
        self.size = size
        self.crop_p = crop_p
        self.random_crop = RandomCrop(size=size)
        self.paired_random_crop = PairRandomCrop(size=size)

    def forward(self, img, **kwargs):
        # step 1: short edge resize to 512
        img = F.resize(img, size=self.size)
        if 'mask' in kwargs:
            kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)

        # step 2: random crop
        width, height = img.size
        if random.random() < self.crop_p:
            if height > width:
                crop_pos = random.randint(0, height - width)
                img = F.crop(img, 0, 0, width + crop_pos, width)
                if 'mask' in kwargs:
                    kwargs['mask'] = F.crop(kwargs['mask'], 0, 0, width + crop_pos, width)
            else:
                if 'mask' in kwargs:
                    img, kwargs = self.paired_random_crop(img, **kwargs)
                else:
                    img = self.random_crop(img)
        else:
            img = img

        # step 3: long edge resize
        img = F.resize(img, size=self.size - 1, max_size=self.size)
        if 'mask' in kwargs:
            kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)

        new_width, new_height = img.size

        img = np.array(img)
        if 'mask' in kwargs:
            kwargs['mask'] = np.array(kwargs['mask']) / 255
            new_width = min(new_width, kwargs['mask'].shape[1])
            new_height = min(new_height, kwargs['mask'].shape[0])

        start_y = random.randint(0, 512 - new_height)
        start_x = random.randint(0, 512 - new_width)

        res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
        res_mask = np.zeros((self.size, self.size))
        res_img_mask = np.zeros((self.size, self.size))

        res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img[:new_height, :new_width]
        if 'mask' in kwargs:
            res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'][:new_height, :new_width]
            kwargs['mask'] = res_mask

        res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
        kwargs['img_mask'] = res_img_mask

        img = Image.fromarray(res_img)

        if 'mask' in kwargs:
            kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
            kwargs['mask'] = torch.from_numpy(kwargs['mask'])
        kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
        kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
        return img, kwargs


@TRANSFORM_REGISTRY.register()
class ResizeFillMaskNew(nn.Module):
    def __init__(self, size, crop_p, scale_ratio):
        super().__init__()
        self.size = size
        self.crop_p = crop_p
        self.scale_ratio = scale_ratio
        self.random_crop = RandomCrop(size=size)
        self.paired_random_crop = PairRandomCrop(size=size)

    def forward(self, img, **kwargs):
        # width, height = img.size

        # step 1: short edge resize to 512
        img = F.resize(img, size=self.size)
        if 'mask' in kwargs:
            kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)

        # step 2: random crop
        if random.random() < self.crop_p:
            if 'mask' in kwargs:
                img, kwargs = self.paired_random_crop(img, **kwargs)  # 51
            else:
                img = self.random_crop(img)  # 512
        else:
            # long edge resize
            img = F.resize(img, size=self.size - 1, max_size=self.size)
            if 'mask' in kwargs:
                kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)

        # step 3: random aspect ratio
        width, height = img.size
        ratio = random.uniform(*self.scale_ratio)

        img = F.resize(img, size=(int(height * ratio), int(width * ratio)))
        if 'mask' in kwargs:
            kwargs['mask'] = F.resize(kwargs['mask'], size=(int(height * ratio), int(width * ratio)), interpolation=0)

        # step 4: random place
        new_width, new_height = img.size

        img = np.array(img)
        if 'mask' in kwargs:
            kwargs['mask'] = np.array(kwargs['mask']) / 255

        start_y = random.randint(0, 512 - new_height)
        start_x = random.randint(0, 512 - new_width)

        res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
        res_mask = np.zeros((self.size, self.size))
        res_img_mask = np.zeros((self.size, self.size))

        res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img
        if 'mask' in kwargs:
            res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask']
            kwargs['mask'] = res_mask

        res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
        kwargs['img_mask'] = res_img_mask

        img = Image.fromarray(res_img)

        if 'mask' in kwargs:
            kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
            kwargs['mask'] = torch.from_numpy(kwargs['mask'])
        kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
        kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])

        return img, kwargs


@TRANSFORM_REGISTRY.register()
class ShuffleCaption(nn.Module):
    def __init__(self, keep_token_num):
        super().__init__()
        self.keep_token_num = keep_token_num

    def forward(self, img, **kwargs):
        prompts = kwargs['prompts'].strip()

        fixed_tokens = []
        flex_tokens = [t.strip() for t in prompts.strip().split(',')]
        if self.keep_token_num > 0:
            fixed_tokens = flex_tokens[:self.keep_token_num]
            flex_tokens = flex_tokens[self.keep_token_num:]

        random.shuffle(flex_tokens)
        prompts = ', '.join(fixed_tokens + flex_tokens)
        kwargs['prompts'] = prompts
        return img, kwargs


@TRANSFORM_REGISTRY.register()
class EnhanceText(nn.Module):
    def __init__(self, enhance_type='object'):
        super().__init__()
        STYLE_TEMPLATE = [
            'a painting in the style of {}',
            'a rendering in the style of {}',
            'a cropped painting in the style of {}',
            'the painting in the style of {}',
            'a clean painting in the style of {}',
            'a dirty painting in the style of {}',
            'a dark painting in the style of {}',
            'a picture in the style of {}',
            'a cool painting in the style of {}',
            'a close-up painting in the style of {}',
            'a bright painting in the style of {}',
            'a cropped painting in the style of {}',
            'a good painting in the style of {}',
            'a close-up painting in the style of {}',
            'a rendition in the style of {}',
            'a nice painting in the style of {}',
            'a small painting in the style of {}',
            'a weird painting in the style of {}',
            'a large painting in the style of {}',
        ]

        OBJECT_TEMPLATE = [
            'a photo of a {}',
            'a rendering of a {}',
            'a cropped photo of the {}',
            'the photo of a {}',
            'a photo of a clean {}',
            'a photo of a dirty {}',
            'a dark photo of the {}',
            'a photo of my {}',
            'a photo of the cool {}',
            'a close-up photo of a {}',
            'a bright photo of the {}',
            'a cropped photo of a {}',
            'a photo of the {}',
            'a good photo of the {}',
            'a photo of one {}',
            'a close-up photo of the {}',
            'a rendition of the {}',
            'a photo of the clean {}',
            'a rendition of a {}',
            'a photo of a nice {}',
            'a good photo of a {}',
            'a photo of the nice {}',
            'a photo of the small {}',
            'a photo of the weird {}',
            'a photo of the large {}',
            'a photo of a cool {}',
            'a photo of a small {}',
        ]

        HUMAN_TEMPLATE = [
            'a photo of a {}', 'a photo of one {}', 'a photo of the {}',
            'the photo of a {}', 'a rendering of a {}',
            'a rendition of the {}', 'a rendition of a {}',
            'a cropped photo of the {}', 'a cropped photo of a {}',
            'a bad photo of the {}', 'a bad photo of a {}',
            'a photo of a weird {}', 'a weird photo of a {}',
            'a bright photo of the {}', 'a good photo of the {}',
            'a photo of a nice {}', 'a good photo of a {}',
            'a photo of a cool {}', 'a bright photo of the {}'
        ]

        if enhance_type == 'object':
            self.templates = OBJECT_TEMPLATE
        elif enhance_type == 'style':
            self.templates = STYLE_TEMPLATE
        elif enhance_type == 'human':
            self.templates = HUMAN_TEMPLATE
        else:
            raise NotImplementedError

    def forward(self, img, **kwargs):
        concept_token = kwargs['prompts'].strip()
        kwargs['prompts'] = random.choice(self.templates).format(concept_token)
        return img, kwargs