WindVChen commited on
Commit
033bd8b
Β·
1 Parent(s): 6710c89
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -0
  2. assets/demo.gif +3 -0
  3. assets/metrics.png +0 -0
  4. assets/network.png +0 -0
  5. assets/title_any_image.gif +0 -0
  6. assets/title_harmon.gif +0 -0
  7. assets/title_you_want.gif +0 -0
  8. assets/visualizations.png +0 -0
  9. assets/visualizations2.png +3 -0
  10. datasets/__init__.py +0 -0
  11. datasets/__pycache__/__init__.cpython-38.pyc +0 -0
  12. datasets/__pycache__/build_INR_dataset.cpython-38.pyc +0 -0
  13. datasets/__pycache__/build_dataset.cpython-38.pyc +0 -0
  14. datasets/build_INR_dataset.py +36 -0
  15. datasets/build_dataset.py +371 -0
  16. demo/demo_2k_composite.jpg +0 -0
  17. demo/demo_2k_mask.jpg +0 -0
  18. demo/demo_2k_real.jpg +0 -0
  19. demo/demo_6k_composite.jpg +3 -0
  20. demo/demo_6k_mask.jpg +0 -0
  21. demo/demo_6k_real.jpg +3 -0
  22. model/__init__.py +0 -0
  23. model/__pycache__/__init__.cpython-38.pyc +0 -0
  24. model/__pycache__/backbone.cpython-38.pyc +0 -0
  25. model/__pycache__/build_model.cpython-38.pyc +0 -0
  26. model/__pycache__/lut_transformation_net.cpython-38.pyc +0 -0
  27. model/backbone.py +79 -0
  28. model/base/__init__.py +0 -0
  29. model/base/__pycache__/__init__.cpython-38.pyc +0 -0
  30. model/base/__pycache__/basic_blocks.cpython-38.pyc +0 -0
  31. model/base/__pycache__/conv_autoencoder.cpython-38.pyc +0 -0
  32. model/base/__pycache__/ih_model.cpython-38.pyc +0 -0
  33. model/base/__pycache__/ops.cpython-38.pyc +0 -0
  34. model/base/basic_blocks.py +366 -0
  35. model/base/conv_autoencoder.py +519 -0
  36. model/base/ih_model.py +88 -0
  37. model/base/ops.py +397 -0
  38. model/build_model.py +24 -0
  39. model/hrnetv2/__init__.py +0 -0
  40. model/hrnetv2/__pycache__/__init__.cpython-38.pyc +0 -0
  41. model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc +0 -0
  42. model/hrnetv2/__pycache__/modifiers.cpython-38.pyc +0 -0
  43. model/hrnetv2/__pycache__/ocr.cpython-38.pyc +0 -0
  44. model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc +0 -0
  45. model/hrnetv2/hrnet_ocr.py +400 -0
  46. model/hrnetv2/modifiers.py +11 -0
  47. model/hrnetv2/ocr.py +140 -0
  48. model/hrnetv2/resnetv1b.py +276 -0
  49. model/lut_transformation_net.py +65 -0
  50. pretrained_models/Resolution_1024_HAdobe5K.pth +3 -0
.gitattributes CHANGED
@@ -33,3 +33,7 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/demo.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/visualizations2.png filter=lfs diff=lfs merge=lfs -text
38
+ demo/demo_6k_composite.jpg filter=lfs diff=lfs merge=lfs -text
39
+ demo/demo_6k_real.jpg filter=lfs diff=lfs merge=lfs -text
assets/demo.gif ADDED

Git LFS Details

  • SHA256: c5f136d5335252050ca723e0360a767ebc5d94fd87d6d372221575769d6528a7
  • Pointer size: 132 Bytes
  • Size of remote file: 1.73 MB
assets/metrics.png ADDED
assets/network.png ADDED
assets/title_any_image.gif ADDED
assets/title_harmon.gif ADDED
assets/title_you_want.gif ADDED
assets/visualizations.png ADDED
assets/visualizations2.png ADDED

Git LFS Details

  • SHA256: 0fa5f4c202818ab94d6faf57055a323285e169a33ccfd59200bc93a8d597a4a4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
datasets/__init__.py ADDED
File without changes
datasets/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (176 Bytes). View file
 
datasets/__pycache__/build_INR_dataset.cpython-38.pyc ADDED
Binary file (1.31 kB). View file
 
datasets/__pycache__/build_dataset.cpython-38.pyc ADDED
Binary file (6.96 kB). View file
 
datasets/build_INR_dataset.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils import misc
2
+ from albumentations import Resize
3
+
4
+
5
+ class Implicit2DGenerator(object):
6
+ def __init__(self, opt, mode):
7
+ if mode == 'Train':
8
+ sidelength = opt.INR_input_size
9
+ elif mode == 'Val':
10
+ sidelength = opt.input_size
11
+ else:
12
+ raise NotImplementedError
13
+
14
+ self.mode = mode
15
+
16
+ self.size = sidelength
17
+
18
+ if isinstance(sidelength, int):
19
+ sidelength = (sidelength, sidelength)
20
+
21
+ self.mgrid = misc.get_mgrid(sidelength)
22
+
23
+ self.transform = Resize(self.size, self.size)
24
+
25
+ def generator(self, torch_transforms, composite_image, real_image, mask):
26
+ composite_image = torch_transforms(self.transform(image=composite_image)['image'])
27
+ real_image = torch_transforms(self.transform(image=real_image)['image'])
28
+
29
+ fg_INR_RGB = composite_image.permute(1, 2, 0).contiguous().view(-1, 3)
30
+ fg_transfer_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
31
+ bg_INR_RGB = real_image.permute(1, 2, 0).contiguous().view(-1, 3)
32
+
33
+ fg_INR_coordinates = self.mgrid
34
+ bg_INR_coordinates = self.mgrid
35
+
36
+ return fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB
datasets/build_dataset.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import numpy as np
4
+ import torchvision
5
+ import os
6
+ import random
7
+
8
+ from utils.misc import prepare_cooridinate_input, customRandomCrop
9
+
10
+ from datasets.build_INR_dataset import Implicit2DGenerator
11
+ import albumentations
12
+ from albumentations import Resize, RandomResizedCrop, HorizontalFlip
13
+ from torch.utils.data import DataLoader
14
+
15
+
16
+ class dataset_generator(torch.utils.data.Dataset):
17
+ def __init__(self, dataset_txt, alb_transforms, torch_transforms, opt, area_keep_thresh=0.2, mode='Train'):
18
+ super().__init__()
19
+
20
+ self.opt = opt
21
+ self.root_path = opt.dataset_path
22
+ self.mode = mode
23
+
24
+ self.alb_transforms = alb_transforms
25
+ self.torch_transforms = torch_transforms
26
+ self.kp_t = area_keep_thresh
27
+
28
+ with open(dataset_txt, 'r') as f:
29
+ self.dataset_samples = [os.path.join(self.root_path, x.strip()) for x in f.readlines()]
30
+
31
+ self.INR_dataset = Implicit2DGenerator(opt, self.mode)
32
+
33
+ def __len__(self):
34
+ return len(self.dataset_samples)
35
+
36
+ def __getitem__(self, idx):
37
+ composite_image = self.dataset_samples[idx]
38
+
39
+ if self.opt.hr_train:
40
+ if self.opt.isFullRes:
41
+ "Since in dataset preprocessing, we resize the image in HAdobe5k to a lower resolution for " \
42
+ "quick loading, we need to change the path here to that of the original resolution of HAdobe5k " \
43
+ "if `opt.isFullRes` is set to True."
44
+ composite_image = composite_image.replace("HAdobe5k", "HAdobe5kori")
45
+
46
+ real_image = '_'.join(composite_image.split('_')[:2]).replace("composite_images", "real_images") + '.jpg'
47
+ mask = '_'.join(composite_image.split('_')[:-1]).replace("composite_images", "masks") + '.png'
48
+
49
+ composite_image = cv2.imread(composite_image)
50
+ composite_image = cv2.cvtColor(composite_image, cv2.COLOR_BGR2RGB)
51
+
52
+ real_image = cv2.imread(real_image)
53
+ real_image = cv2.cvtColor(real_image, cv2.COLOR_BGR2RGB)
54
+
55
+ mask = cv2.imread(mask)
56
+ mask = mask[:, :, 0].astype(np.float32) / 255.
57
+
58
+ """
59
+ If set `opt.hr_train` to True:
60
+
61
+ Apply multi resolution crop for HR image train. Specifically, for 1024/2048 `input_size` (not fullres),
62
+ the training phase is first to RandomResizeCrop 1024/2048 `input_size`, then to random crop a `base_size`
63
+ patch to feed in multiINR process. For inference, just resize it.
64
+
65
+ While for fullres, the RandomResizeCrop is removed and just do a random crop. For inference, just keep the size.
66
+
67
+ BTW, we implement LR and HR mixing train. I.e., the following `random.random() < 0.5`
68
+ """
69
+ if self.opt.hr_train:
70
+ if self.mode == 'Train' and self.opt.isFullRes:
71
+ if random.random() < 0.5: # LR mix training
72
+ mixTransform = albumentations.Compose(
73
+ [
74
+ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
75
+ HorizontalFlip()],
76
+ additional_targets={'real_image': 'image', 'object_mask': 'image'}
77
+ )
78
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
79
+ origin_bg_ratio = 1 - origin_fg_ratio
80
+
81
+ "Ensure fg and bg not disappear after transformation"
82
+ valid_augmentation = False
83
+ transform_out = None
84
+ time = 0
85
+ while not valid_augmentation:
86
+ time += 1
87
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
88
+ if time == 20:
89
+ tmp_transform = albumentations.Compose(
90
+ [Resize(self.opt.base_size, self.opt.base_size)],
91
+ additional_targets={'real_image': 'image',
92
+ 'object_mask': 'image'})
93
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
94
+ object_mask=mask)
95
+ valid_augmentation = True
96
+ else:
97
+ transform_out = mixTransform(image=composite_image, real_image=real_image,
98
+ object_mask=mask)
99
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
100
+ origin_fg_ratio,
101
+ origin_bg_ratio,
102
+ self.kp_t)
103
+ composite_image = transform_out['image']
104
+ real_image = transform_out['real_image']
105
+ mask = transform_out['object_mask']
106
+ else: # Padding to ensure that the original resolution can be divided by 4. This is for pixel-aligned crop.
107
+ if real_image.shape[0] < 256:
108
+ bottom_pad = 256 - real_image.shape[0]
109
+ else:
110
+ bottom_pad = (4 - real_image.shape[0] % 4) % 4
111
+ if real_image.shape[1] < 256:
112
+ right_pad = 256 - real_image.shape[1]
113
+ else:
114
+ right_pad = (4 - real_image.shape[1] % 4) % 4
115
+ composite_image = cv2.copyMakeBorder(composite_image, 0, bottom_pad, 0, right_pad,
116
+ cv2.BORDER_REPLICATE)
117
+ real_image = cv2.copyMakeBorder(real_image, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
118
+ mask = cv2.copyMakeBorder(mask, 0, bottom_pad, 0, right_pad, cv2.BORDER_REPLICATE)
119
+
120
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
121
+ origin_bg_ratio = 1 - origin_fg_ratio
122
+
123
+ "Ensure fg and bg not disappear after transformation"
124
+ valid_augmentation = False
125
+ transform_out = None
126
+ time = 0
127
+
128
+ if self.opt.hr_train:
129
+ if self.mode == 'Train':
130
+ if not self.opt.isFullRes:
131
+ if random.random() < 0.5: # LR mix training
132
+ mixTransform = albumentations.Compose(
133
+ [
134
+ RandomResizedCrop(self.opt.base_size, self.opt.base_size, scale=(0.5, 1.0)),
135
+ HorizontalFlip()],
136
+ additional_targets={'real_image': 'image', 'object_mask': 'image'}
137
+ )
138
+ while not valid_augmentation:
139
+ time += 1
140
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
141
+ if time == 20:
142
+ tmp_transform = albumentations.Compose(
143
+ [Resize(self.opt.base_size, self.opt.base_size)],
144
+ additional_targets={'real_image': 'image',
145
+ 'object_mask': 'image'})
146
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
147
+ object_mask=mask)
148
+ valid_augmentation = True
149
+ else:
150
+ transform_out = mixTransform(image=composite_image, real_image=real_image,
151
+ object_mask=mask)
152
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
153
+ origin_fg_ratio,
154
+ origin_bg_ratio,
155
+ self.kp_t)
156
+ else:
157
+ while not valid_augmentation:
158
+ time += 1
159
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
160
+ if time == 20:
161
+ tmp_transform = albumentations.Compose(
162
+ [Resize(self.opt.input_size, self.opt.input_size)],
163
+ additional_targets={'real_image': 'image',
164
+ 'object_mask': 'image'})
165
+ transform_out = tmp_transform(image=composite_image, real_image=real_image,
166
+ object_mask=mask)
167
+ valid_augmentation = True
168
+ else:
169
+ transform_out = self.alb_transforms(image=composite_image, real_image=real_image,
170
+ object_mask=mask)
171
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'],
172
+ origin_fg_ratio,
173
+ origin_bg_ratio,
174
+ self.kp_t)
175
+ composite_image = transform_out['image']
176
+ real_image = transform_out['real_image']
177
+ mask = transform_out['object_mask']
178
+
179
+ origin_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
180
+
181
+ full_coord = prepare_cooridinate_input(mask).transpose(1, 2, 0)
182
+
183
+ tmp_transform = albumentations.Compose([Resize(self.opt.base_size, self.opt.base_size)],
184
+ additional_targets={'real_image': 'image',
185
+ 'object_mask': 'image'})
186
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
187
+ compos_list = [self.torch_transforms(transform_out['image'])]
188
+ real_list = [self.torch_transforms(transform_out['real_image'])]
189
+ mask_list = [
190
+ torchvision.transforms.ToTensor()(transform_out['object_mask'][..., np.newaxis].astype(np.float32))]
191
+ coord_map_list = []
192
+
193
+ valid_augmentation = False
194
+ while not valid_augmentation:
195
+ # RSC strategy. To crop different resolutions.
196
+ transform_out, c_h, c_w = customRandomCrop([composite_image, real_image, mask, full_coord],
197
+ self.opt.base_size, self.opt.base_size)
198
+ valid_augmentation = check_hr_crop_sample(transform_out[2], origin_fg_ratio)
199
+
200
+ compos_list.append(self.torch_transforms(transform_out[0]))
201
+ real_list.append(self.torch_transforms(transform_out[1]))
202
+ mask_list.append(
203
+ torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
204
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
205
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
206
+ for n in range(2):
207
+ tmp_comp = cv2.resize(composite_image, (
208
+ composite_image.shape[1] // 2 ** (n + 1), composite_image.shape[0] // 2 ** (n + 1)))
209
+ tmp_real = cv2.resize(real_image,
210
+ (real_image.shape[1] // 2 ** (n + 1), real_image.shape[0] // 2 ** (n + 1)))
211
+ tmp_mask = cv2.resize(mask, (mask.shape[1] // 2 ** (n + 1), mask.shape[0] // 2 ** (n + 1)))
212
+ tmp_coord = prepare_cooridinate_input(tmp_mask).transpose(1, 2, 0)
213
+
214
+ transform_out, c_h, c_w = customRandomCrop([tmp_comp, tmp_real, tmp_mask, tmp_coord],
215
+ self.opt.base_size // 2 ** (n + 1),
216
+ self.opt.base_size // 2 ** (n + 1), c_h, c_w)
217
+ compos_list.append(self.torch_transforms(transform_out[0]))
218
+ real_list.append(self.torch_transforms(transform_out[1]))
219
+ mask_list.append(
220
+ torchvision.transforms.ToTensor()(transform_out[2][..., np.newaxis].astype(np.float32)))
221
+ coord_map_list.append(torchvision.transforms.ToTensor()(transform_out[3]))
222
+ out_comp = compos_list
223
+ out_real = real_list
224
+ out_mask = mask_list
225
+ out_coord = coord_map_list
226
+
227
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
228
+ self.torch_transforms, transform_out[0], transform_out[1], mask)
229
+
230
+ return {
231
+ 'file_path': self.dataset_samples[idx],
232
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
233
+ 'composite_image': out_comp,
234
+ 'real_image': out_real,
235
+ 'mask': out_mask,
236
+ 'coordinate_map': out_coord,
237
+ 'composite_image0': out_comp[0],
238
+ 'real_image0': out_real[0],
239
+ 'mask0': out_mask[0],
240
+ 'coordinate_map0': out_coord[0],
241
+ 'composite_image1': out_comp[1],
242
+ 'real_image1': out_real[1],
243
+ 'mask1': out_mask[1],
244
+ 'coordinate_map1': out_coord[1],
245
+ 'composite_image2': out_comp[2],
246
+ 'real_image2': out_real[2],
247
+ 'mask2': out_mask[2],
248
+ 'coordinate_map2': out_coord[2],
249
+ 'composite_image3': out_comp[3],
250
+ 'real_image3': out_real[3],
251
+ 'mask3': out_mask[3],
252
+ 'coordinate_map3': out_coord[3],
253
+ 'fg_INR_coordinates': fg_INR_coordinates,
254
+ 'bg_INR_coordinates': bg_INR_coordinates,
255
+ 'fg_INR_RGB': fg_INR_RGB,
256
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
257
+ 'bg_INR_RGB': bg_INR_RGB
258
+ }
259
+ else:
260
+ if not self.opt.isFullRes:
261
+ tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
262
+ additional_targets={'real_image': 'image',
263
+ 'object_mask': 'image'})
264
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
265
+
266
+ coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
267
+
268
+ "Generate INR dataset."
269
+ mask = (torchvision.transforms.ToTensor()(
270
+ transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
271
+ mask = np.bool_(mask.numpy())
272
+
273
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
274
+ self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
275
+
276
+ return {
277
+ 'file_path': self.dataset_samples[idx],
278
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
279
+ 'composite_image': self.torch_transforms(transform_out['image']),
280
+ 'real_image': self.torch_transforms(transform_out['real_image']),
281
+ 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
282
+ # Can automatically transfer to Tensor.
283
+ 'coordinate_map': coordinate_map,
284
+ 'fg_INR_coordinates': fg_INR_coordinates,
285
+ 'bg_INR_coordinates': bg_INR_coordinates,
286
+ 'fg_INR_RGB': fg_INR_RGB,
287
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
288
+ 'bg_INR_RGB': bg_INR_RGB
289
+ }
290
+ else:
291
+ coordinate_map = prepare_cooridinate_input(mask)
292
+
293
+ "Generate INR dataset."
294
+ mask_tmp = (torchvision.transforms.ToTensor()(mask).squeeze() > 100 / 255.).view(-1)
295
+ mask_tmp = np.bool_(mask_tmp.numpy())
296
+
297
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
298
+ self.torch_transforms, composite_image, real_image, mask_tmp)
299
+
300
+ return {
301
+ 'file_path': self.dataset_samples[idx],
302
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
303
+ 'composite_image': self.torch_transforms(composite_image),
304
+ 'real_image': self.torch_transforms(real_image),
305
+ 'mask': mask[np.newaxis, ...].astype(np.float32),
306
+ # Can automatically transfer to Tensor.
307
+ 'coordinate_map': coordinate_map,
308
+ 'fg_INR_coordinates': fg_INR_coordinates,
309
+ 'bg_INR_coordinates': bg_INR_coordinates,
310
+ 'fg_INR_RGB': fg_INR_RGB,
311
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
312
+ 'bg_INR_RGB': bg_INR_RGB
313
+ }
314
+
315
+ while not valid_augmentation:
316
+ time += 1
317
+ # There are some extreme ratio pics, this code is to avoid being hindered by them.
318
+ if time == 20:
319
+ tmp_transform = albumentations.Compose([Resize(self.opt.input_size, self.opt.input_size)],
320
+ additional_targets={'real_image': 'image',
321
+ 'object_mask': 'image'})
322
+ transform_out = tmp_transform(image=composite_image, real_image=real_image, object_mask=mask)
323
+ valid_augmentation = True
324
+ else:
325
+ transform_out = self.alb_transforms(image=composite_image, real_image=real_image, object_mask=mask)
326
+ valid_augmentation = check_augmented_sample(transform_out['object_mask'], origin_fg_ratio,
327
+ origin_bg_ratio,
328
+ self.kp_t)
329
+
330
+ coordinate_map = prepare_cooridinate_input(transform_out['object_mask'])
331
+
332
+ "Generate INR dataset."
333
+ mask = (torchvision.transforms.ToTensor()(transform_out['object_mask']).squeeze() > 100 / 255.).view(-1)
334
+ mask = np.bool_(mask.numpy())
335
+
336
+ fg_INR_coordinates, bg_INR_coordinates, fg_INR_RGB, fg_transfer_INR_RGB, bg_INR_RGB = self.INR_dataset.generator(
337
+ self.torch_transforms, transform_out['image'], transform_out['real_image'], mask)
338
+
339
+ return {
340
+ 'file_path': self.dataset_samples[idx],
341
+ 'category': self.dataset_samples[idx].split("\\")[-1].split("/")[0],
342
+ 'composite_image': self.torch_transforms(transform_out['image']),
343
+ 'real_image': self.torch_transforms(transform_out['real_image']),
344
+ 'mask': transform_out['object_mask'][np.newaxis, ...].astype(np.float32),
345
+ # Can automatically transfer to Tensor.
346
+ 'coordinate_map': coordinate_map,
347
+ 'fg_INR_coordinates': fg_INR_coordinates,
348
+ 'bg_INR_coordinates': bg_INR_coordinates,
349
+ 'fg_INR_RGB': fg_INR_RGB,
350
+ 'fg_transfer_INR_RGB': fg_transfer_INR_RGB,
351
+ 'bg_INR_RGB': bg_INR_RGB
352
+ }
353
+
354
+
355
+ def check_augmented_sample(mask, origin_fg_ratio, origin_bg_ratio, area_keep_thresh):
356
+ current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
357
+ current_bg_ratio = 1 - current_fg_ratio
358
+
359
+ if current_fg_ratio < origin_fg_ratio * area_keep_thresh or current_bg_ratio < origin_bg_ratio * area_keep_thresh:
360
+ return False
361
+
362
+ return True
363
+
364
+
365
+ def check_hr_crop_sample(mask, origin_fg_ratio):
366
+ current_fg_ratio = mask.sum() / (mask.shape[0] * mask.shape[1])
367
+
368
+ if current_fg_ratio < 0.8 * origin_fg_ratio:
369
+ return False
370
+
371
+ return True
demo/demo_2k_composite.jpg ADDED
demo/demo_2k_mask.jpg ADDED
demo/demo_2k_real.jpg ADDED
demo/demo_6k_composite.jpg ADDED

Git LFS Details

  • SHA256: 910f8a9787c7b2dd739c89a56f2cd64fa67be9a257ea17963b656f63e1ad2250
  • Pointer size: 132 Bytes
  • Size of remote file: 5.88 MB
demo/demo_6k_mask.jpg ADDED
demo/demo_6k_real.jpg ADDED

Git LFS Details

  • SHA256: 5dd69a2a79388378e43079a0ca0dbb7e3e9c86822526083d54f315a3f1a48647
  • Pointer size: 132 Bytes
  • Size of remote file: 6.1 MB
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (173 Bytes). View file
 
model/__pycache__/backbone.cpython-38.pyc ADDED
Binary file (2.96 kB). View file
 
model/__pycache__/build_model.cpython-38.pyc ADDED
Binary file (1.03 kB). View file
 
model/__pycache__/lut_transformation_net.cpython-38.pyc ADDED
Binary file (2.43 kB). View file
 
model/backbone.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from .hrnetv2.hrnet_ocr import HighResolutionNet
4
+ from .hrnetv2.modifiers import LRMult
5
+ from .base.basic_blocks import MaxPoolDownSize
6
+ from .base.ih_model import IHModelWithBackbone, DeepImageHarmonization
7
+
8
+
9
+ def build_backbone(name, opt):
10
+ return eval(name)(opt)
11
+
12
+
13
+ class baseline(IHModelWithBackbone):
14
+ def __init__(self, opt, ocr=64):
15
+ base_config = {'model': DeepImageHarmonization,
16
+ 'params': {'depth': 7, 'batchnorm_from': 2, 'image_fusion': True, 'opt': opt}}
17
+
18
+ params = base_config['params']
19
+
20
+ backbone = HRNetV2(opt, ocr=ocr)
21
+
22
+ params.update(dict(
23
+ backbone_from=2,
24
+ backbone_channels=backbone.output_channels,
25
+ backbone_mode='cat',
26
+ opt=opt
27
+ ))
28
+ base_model = base_config['model'](**params)
29
+
30
+ super(baseline, self).__init__(base_model, backbone, False, 'sum', opt=opt)
31
+
32
+
33
+ class HRNetV2(nn.Module):
34
+ def __init__(
35
+ self, opt,
36
+ cat_outputs=True,
37
+ pyramid_channels=-1, pyramid_depth=4,
38
+ width=18, ocr=128, small=False,
39
+ lr_mult=0.1, pretained=True
40
+ ):
41
+ super(HRNetV2, self).__init__()
42
+ self.opt = opt
43
+ self.cat_outputs = cat_outputs
44
+ self.ocr_on = ocr > 0 and cat_outputs
45
+ self.pyramid_on = pyramid_channels > 0 and cat_outputs
46
+
47
+ self.hrnet = HighResolutionNet(width, 2, ocr_width=ocr, small=small, opt=opt)
48
+ self.hrnet.apply(LRMult(lr_mult))
49
+ if self.ocr_on:
50
+ self.hrnet.ocr_distri_head.apply(LRMult(1.0))
51
+ self.hrnet.ocr_gather_head.apply(LRMult(1.0))
52
+ self.hrnet.conv3x3_ocr.apply(LRMult(1.0))
53
+
54
+ hrnet_cat_channels = [width * 2 ** i for i in range(4)]
55
+ if self.pyramid_on:
56
+ self.output_channels = [pyramid_channels] * 4
57
+ elif self.ocr_on:
58
+ self.output_channels = [ocr * 2]
59
+ elif self.cat_outputs:
60
+ self.output_channels = [sum(hrnet_cat_channels)]
61
+ else:
62
+ self.output_channels = hrnet_cat_channels
63
+
64
+ if self.pyramid_on:
65
+ downsize_in_channels = ocr * 2 if self.ocr_on else sum(hrnet_cat_channels)
66
+ self.downsize = MaxPoolDownSize(downsize_in_channels, pyramid_channels, pyramid_channels, pyramid_depth)
67
+
68
+ if pretained:
69
+ self.load_pretrained_weights(
70
+ r".\pretrained_models/hrnetv2_w18_imagenet_pretrained.pth")
71
+
72
+ self.output_resolution = (opt.input_size // 8) ** 2
73
+
74
+ def forward(self, image, mask, mask_features=None):
75
+ outputs = list(self.hrnet(image, mask, mask_features))
76
+ return outputs
77
+
78
+ def load_pretrained_weights(self, pretrained_path):
79
+ self.hrnet.load_pretrained_weights(pretrained_path)
model/base/__init__.py ADDED
File without changes
model/base/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (178 Bytes). View file
 
model/base/__pycache__/basic_blocks.cpython-38.pyc ADDED
Binary file (10.1 kB). View file
 
model/base/__pycache__/conv_autoencoder.cpython-38.pyc ADDED
Binary file (13.8 kB). View file
 
model/base/__pycache__/ih_model.cpython-38.pyc ADDED
Binary file (3.22 kB). View file
 
model/base/__pycache__/ops.cpython-38.pyc ADDED
Binary file (14 kB). View file
 
model/base/basic_blocks.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ import numpy as np
4
+
5
+
6
+ def hyper_weight_init(m, in_features_main_net, activation):
7
+ if hasattr(m, 'weight'):
8
+ nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
9
+ m.weight.data = m.weight.data / 1.e2
10
+
11
+ if hasattr(m, 'bias'):
12
+ with torch.no_grad():
13
+ if activation == 'sine':
14
+ m.bias.uniform_(-np.sqrt(6 / in_features_main_net) / 30, np.sqrt(6 / in_features_main_net) / 30)
15
+ elif activation == 'leakyrelu_pe':
16
+ m.bias.uniform_(-np.sqrt(6 / in_features_main_net), np.sqrt(6 / in_features_main_net))
17
+ else:
18
+ raise NotImplementedError
19
+
20
+
21
+ class ConvBlock(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_channels, out_channels,
25
+ kernel_size=4, stride=2, padding=1,
26
+ norm_layer=nn.BatchNorm2d, activation=nn.ELU,
27
+ bias=True,
28
+ ):
29
+ super(ConvBlock, self).__init__()
30
+ self.block = nn.Sequential(
31
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias),
32
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
33
+ activation(),
34
+ )
35
+
36
+ def forward(self, x):
37
+ return self.block(x)
38
+
39
+
40
+ class MaxPoolDownSize(nn.Module):
41
+ def __init__(self, in_channels, mid_channels, out_channels, depth):
42
+ super(MaxPoolDownSize, self).__init__()
43
+ self.depth = depth
44
+ self.reduce_conv = ConvBlock(in_channels, mid_channels, kernel_size=1, stride=1, padding=0)
45
+ self.convs = nn.ModuleList([
46
+ ConvBlock(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
47
+ for conv_i in range(depth)
48
+ ])
49
+ self.pool2d = nn.MaxPool2d(kernel_size=2)
50
+
51
+ def forward(self, x):
52
+ outputs = []
53
+
54
+ output = self.reduce_conv(x)
55
+
56
+ for conv_i, conv in enumerate(self.convs):
57
+ output = output if conv_i == 0 else self.pool2d(output)
58
+ outputs.append(conv(output))
59
+
60
+ return outputs
61
+
62
+
63
+ class convParams(nn.Module):
64
+ def __init__(self, input_dim, INR_in_out, opt, hidden_mlp_num, hidden_dim=512, toRGB=False):
65
+ super(convParams, self).__init__()
66
+ self.INR_in_out = INR_in_out
67
+ self.cont_split_weight = []
68
+ self.cont_split_bias = []
69
+ self.hidden_mlp_num = hidden_mlp_num
70
+ self.param_factorize_dim = opt.param_factorize_dim
71
+ output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num, toRGB)
72
+ self.output_dim = output_dim
73
+ self.toRGB = toRGB
74
+ self.cont_extraction_net = nn.Sequential(
75
+ nn.Conv2d(input_dim, hidden_dim, kernel_size=3, stride=2, padding=1, bias=False),
76
+ # nn.BatchNorm2d(hidden_dim),
77
+ nn.ReLU(inplace=True),
78
+ nn.Conv2d(hidden_dim, hidden_dim, kernel_size=3, stride=1, padding=1, bias=False),
79
+ # nn.BatchNorm2d(hidden_dim),
80
+ nn.ReLU(inplace=True),
81
+ nn.Conv2d(hidden_dim, output_dim, kernel_size=1, stride=1, padding=0, bias=True),
82
+ )
83
+
84
+ self.cont_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
85
+
86
+ self.basic_params = nn.ParameterList()
87
+ if opt.param_factorize_dim > 0:
88
+ for id in range(self.hidden_mlp_num + 1):
89
+ if id == 0:
90
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
91
+ else:
92
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
93
+ self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, inp, outp)))
94
+
95
+ if toRGB:
96
+ self.basic_params.append(nn.Parameter(torch.randn(1, 1, 1, self.INR_in_out[1], 3)))
97
+
98
+ def forward(self, feat, outMore=False):
99
+ cont_params = self.cont_extraction_net(feat)
100
+ out_mlp = self.to_mlp(cont_params)
101
+ if outMore:
102
+ return out_mlp, cont_params
103
+ return out_mlp
104
+
105
+ def cal_params_num(self, INR_in_out, hidden_mlp_num, toRGB=False):
106
+ cont_params = 0
107
+ start = 0
108
+ if self.param_factorize_dim == -1:
109
+ cont_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
110
+ self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
111
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
112
+ start = cont_params
113
+
114
+ for id in range(hidden_mlp_num):
115
+ cont_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
116
+ self.cont_split_weight.append([start, cont_params - INR_in_out[1]])
117
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
118
+ start = cont_params
119
+
120
+ if toRGB:
121
+ cont_params += INR_in_out[1] * 3 + 3
122
+ self.cont_split_weight.append([start, cont_params - 3])
123
+ self.cont_split_bias.append([cont_params - 3, cont_params])
124
+
125
+ elif self.param_factorize_dim > 0:
126
+ cont_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
127
+ INR_in_out[1]
128
+ self.cont_split_weight.append(
129
+ [start, start + INR_in_out[0] * self.param_factorize_dim, cont_params - INR_in_out[1]])
130
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
131
+ start = cont_params
132
+
133
+ for id in range(hidden_mlp_num):
134
+ cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
135
+ INR_in_out[1]
136
+ self.cont_split_weight.append(
137
+ [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - INR_in_out[1]])
138
+ self.cont_split_bias.append([cont_params - INR_in_out[1], cont_params])
139
+ start = cont_params
140
+
141
+ if toRGB:
142
+ cont_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
143
+ self.cont_split_weight.append(
144
+ [start, start + INR_in_out[1] * self.param_factorize_dim, cont_params - 3])
145
+ self.cont_split_bias.append([cont_params - 3, cont_params])
146
+
147
+ return cont_params
148
+
149
+ def to_mlp(self, params):
150
+ all_weight_bias = []
151
+ if self.param_factorize_dim == -1:
152
+ for id in range(self.hidden_mlp_num + 1):
153
+ if id == 0:
154
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
155
+ else:
156
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
157
+ weight = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
158
+ weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
159
+ inp, outp)
160
+
161
+ bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
162
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
163
+ all_weight_bias.append([weight, bias])
164
+
165
+ if self.toRGB:
166
+ inp, outp = self.INR_in_out[1], 3
167
+ weight = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
168
+ weight = weight.permute(0, 2, 3, 1).contiguous().view(weight.shape[0], *weight.shape[2:],
169
+ inp, outp)
170
+
171
+ bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
172
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
173
+ all_weight_bias.append([weight, bias])
174
+
175
+ return all_weight_bias
176
+
177
+ else:
178
+ for id in range(self.hidden_mlp_num + 1):
179
+ if id == 0:
180
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
181
+ else:
182
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
183
+ weight1 = params[:, self.cont_split_weight[id][0]:self.cont_split_weight[id][1], :, :]
184
+ weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
185
+ inp, self.param_factorize_dim)
186
+
187
+ weight2 = params[:, self.cont_split_weight[id][1]:self.cont_split_weight[id][2], :, :]
188
+ weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
189
+ self.param_factorize_dim, outp)
190
+
191
+ bias = params[:, self.cont_split_bias[id][0]:self.cont_split_bias[id][1], :, :]
192
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
193
+
194
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
195
+
196
+ if self.toRGB:
197
+ inp, outp = self.INR_in_out[1], 3
198
+ weight1 = params[:, self.cont_split_weight[-1][0]:self.cont_split_weight[-1][1], :, :]
199
+ weight1 = weight1.permute(0, 2, 3, 1).contiguous().view(weight1.shape[0], *weight1.shape[2:],
200
+ inp, self.param_factorize_dim)
201
+
202
+ weight2 = params[:, self.cont_split_weight[-1][1]:self.cont_split_weight[-1][2], :, :]
203
+ weight2 = weight2.permute(0, 2, 3, 1).contiguous().view(weight2.shape[0], *weight2.shape[2:],
204
+ self.param_factorize_dim, outp)
205
+
206
+ bias = params[:, self.cont_split_bias[-1][0]:self.cont_split_bias[-1][1], :, :]
207
+ bias = bias.permute(0, 2, 3, 1).contiguous().view(bias.shape[0], *bias.shape[2:], 1, outp)
208
+
209
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[-1], bias])
210
+
211
+ return all_weight_bias
212
+
213
+
214
+ class lineParams(nn.Module):
215
+ def __init__(self, input_dim, INR_in_out, input_resolution, opt, hidden_mlp_num, toRGB=False,
216
+ hidden_dim=512):
217
+ super(lineParams, self).__init__()
218
+ self.INR_in_out = INR_in_out
219
+ self.app_split_weight = []
220
+ self.app_split_bias = []
221
+ self.toRGB = toRGB
222
+ self.hidden_mlp_num = hidden_mlp_num
223
+ self.param_factorize_dim = opt.param_factorize_dim
224
+ output_dim = self.cal_params_num(INR_in_out, hidden_mlp_num)
225
+ self.output_dim = output_dim
226
+
227
+ self.compress_layer = nn.Sequential(
228
+ nn.Linear(input_resolution, 64, bias=False),
229
+ nn.BatchNorm1d(input_dim),
230
+ nn.ReLU(inplace=True),
231
+ nn.Linear(64, 1, bias=True)
232
+ )
233
+
234
+ self.app_extraction_net = nn.Sequential(
235
+ nn.Linear(input_dim, hidden_dim, bias=False),
236
+ # nn.BatchNorm1d(hidden_dim),
237
+ nn.ReLU(inplace=True),
238
+ nn.Linear(hidden_dim, hidden_dim, bias=False),
239
+ # nn.BatchNorm1d(hidden_dim),
240
+ nn.ReLU(inplace=True),
241
+ nn.Linear(hidden_dim, output_dim, bias=True)
242
+ )
243
+
244
+ self.app_extraction_net[-1].apply(lambda m: hyper_weight_init(m, INR_in_out[0], opt.activation))
245
+
246
+ self.basic_params = nn.ParameterList()
247
+ if opt.param_factorize_dim > 0:
248
+ for id in range(self.hidden_mlp_num + 1):
249
+ if id == 0:
250
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
251
+ else:
252
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
253
+ self.basic_params.append(nn.Parameter(torch.randn(1, inp, outp)))
254
+ if toRGB:
255
+ self.basic_params.append(nn.Parameter(torch.randn(1, self.INR_in_out[1], 3)))
256
+
257
+ def forward(self, feat):
258
+ app_params = self.app_extraction_net(self.compress_layer(torch.flatten(feat, 2)).squeeze(-1))
259
+ out_mlp = self.to_mlp(app_params)
260
+ return out_mlp, app_params
261
+
262
+ def cal_params_num(self, INR_in_out, hidden_mlp_num):
263
+ app_params = 0
264
+ start = 0
265
+ if self.param_factorize_dim == -1:
266
+ app_params += INR_in_out[0] * INR_in_out[1] + INR_in_out[1]
267
+ self.app_split_weight.append([start, app_params - INR_in_out[1]])
268
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
269
+ start = app_params
270
+
271
+ for id in range(hidden_mlp_num):
272
+ app_params += INR_in_out[1] * INR_in_out[1] + INR_in_out[1]
273
+ self.app_split_weight.append([start, app_params - INR_in_out[1]])
274
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
275
+ start = app_params
276
+
277
+ if self.toRGB:
278
+ app_params += INR_in_out[1] * 3 + 3
279
+ self.app_split_weight.append([start, app_params - 3])
280
+ self.app_split_bias.append([app_params - 3, app_params])
281
+
282
+ elif self.param_factorize_dim > 0:
283
+ app_params += INR_in_out[0] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
284
+ INR_in_out[1]
285
+ self.app_split_weight.append([start, start + INR_in_out[0] * self.param_factorize_dim,
286
+ app_params - INR_in_out[1]])
287
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
288
+ start = app_params
289
+
290
+ for id in range(hidden_mlp_num):
291
+ app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * INR_in_out[1] + \
292
+ INR_in_out[1]
293
+ self.app_split_weight.append(
294
+ [start, start + INR_in_out[1] * self.param_factorize_dim, app_params - INR_in_out[1]])
295
+ self.app_split_bias.append([app_params - INR_in_out[1], app_params])
296
+ start = app_params
297
+
298
+ if self.toRGB:
299
+ app_params += INR_in_out[1] * self.param_factorize_dim + self.param_factorize_dim * 3 + 3
300
+ self.app_split_weight.append([start, start + INR_in_out[1] * self.param_factorize_dim,
301
+ app_params - 3])
302
+ self.app_split_bias.append([app_params - 3, app_params])
303
+
304
+ return app_params
305
+
306
+ def to_mlp(self, params):
307
+ all_weight_bias = []
308
+ if self.param_factorize_dim == -1:
309
+ for id in range(self.hidden_mlp_num + 1):
310
+ if id == 0:
311
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
312
+ else:
313
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
314
+ weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
315
+ weight = weight.view(weight.shape[0], inp, outp)
316
+
317
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
318
+ bias = bias.view(bias.shape[0], 1, outp)
319
+
320
+ all_weight_bias.append([weight, bias])
321
+
322
+ if self.toRGB:
323
+ id = -1
324
+ inp, outp = self.INR_in_out[1], 3
325
+ weight = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
326
+ weight = weight.view(weight.shape[0], inp, outp)
327
+
328
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
329
+ bias = bias.view(bias.shape[0], 1, outp)
330
+
331
+ all_weight_bias.append([weight, bias])
332
+
333
+ return all_weight_bias
334
+
335
+ else:
336
+ for id in range(self.hidden_mlp_num + 1):
337
+ if id == 0:
338
+ inp, outp = self.INR_in_out[0], self.INR_in_out[1]
339
+ else:
340
+ inp, outp = self.INR_in_out[1], self.INR_in_out[1]
341
+ weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
342
+ weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
343
+
344
+ weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
345
+ weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
346
+
347
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
348
+ bias = bias.view(bias.shape[0], 1, outp)
349
+
350
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
351
+
352
+ if self.toRGB:
353
+ id = -1
354
+ inp, outp = self.INR_in_out[1], 3
355
+ weight1 = params[:, self.app_split_weight[id][0]:self.app_split_weight[id][1]]
356
+ weight1 = weight1.view(weight1.shape[0], inp, self.param_factorize_dim)
357
+
358
+ weight2 = params[:, self.app_split_weight[id][1]:self.app_split_weight[id][2]]
359
+ weight2 = weight2.view(weight2.shape[0], self.param_factorize_dim, outp)
360
+
361
+ bias = params[:, self.app_split_bias[id][0]:self.app_split_bias[id][1]]
362
+ bias = bias.view(bias.shape[0], 1, outp)
363
+
364
+ all_weight_bias.append([torch.tanh(torch.matmul(weight1, weight2)) * self.basic_params[id], bias])
365
+
366
+ return all_weight_bias
model/base/conv_autoencoder.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ from torch import nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ import math
7
+
8
+ from .basic_blocks import ConvBlock, lineParams, convParams
9
+ from .ops import MaskedChannelAttention, FeaturesConnector
10
+ from .ops import PosEncodingNeRF, INRGAN_embed, RandomFourier, CIPS_embed
11
+ from utils import misc
12
+ from utils.misc import lin2img
13
+ from ..lut_transformation_net import build_lut_transform
14
+
15
+
16
+ class Sine(nn.Module):
17
+ def __init__(self):
18
+ super().__init__()
19
+
20
+ def forward(self, input):
21
+ return torch.sin(30 * input)
22
+
23
+
24
+ class Leaky_relu(nn.Module):
25
+ def __init__(self):
26
+ super().__init__()
27
+
28
+ def forward(self, input):
29
+ return torch.nn.functional.leaky_relu(input, 0.01, inplace=True)
30
+
31
+
32
+ def select_activation(type):
33
+ if type == 'sine':
34
+ return Sine()
35
+ elif type == 'leakyrelu_pe':
36
+ return Leaky_relu()
37
+ else:
38
+ raise NotImplementedError
39
+
40
+
41
+ class ConvEncoder(nn.Module):
42
+ def __init__(
43
+ self,
44
+ depth, ch,
45
+ norm_layer, batchnorm_from, max_channels,
46
+ backbone_from, backbone_channels=None, backbone_mode='', INRDecode=False
47
+ ):
48
+ super(ConvEncoder, self).__init__()
49
+ self.depth = depth
50
+ self.INRDecode = INRDecode
51
+ self.backbone_from = backbone_from
52
+ backbone_channels = [] if backbone_channels is None else backbone_channels[::-1]
53
+
54
+ in_channels = 4
55
+ out_channels = ch
56
+
57
+ self.block0 = ConvBlock(in_channels, out_channels, norm_layer=norm_layer if batchnorm_from == 0 else None)
58
+ self.block1 = ConvBlock(out_channels, out_channels, norm_layer=norm_layer if 0 <= batchnorm_from <= 1 else None)
59
+ self.blocks_channels = [out_channels, out_channels]
60
+
61
+ self.blocks_connected = nn.ModuleDict()
62
+ self.connectors = nn.ModuleDict()
63
+ for block_i in range(2, depth):
64
+ if block_i % 2:
65
+ in_channels = out_channels
66
+ else:
67
+ in_channels, out_channels = out_channels, min(2 * out_channels, max_channels)
68
+
69
+ if 0 <= backbone_from <= block_i and len(backbone_channels):
70
+ if INRDecode:
71
+ self.blocks_connected[f'block{block_i}_decode'] = ConvBlock(
72
+ in_channels, out_channels,
73
+ norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
74
+ padding=int(block_i < depth - 1)
75
+ )
76
+ self.blocks_channels += [out_channels]
77
+ stage_channels = backbone_channels.pop()
78
+ connector = FeaturesConnector(backbone_mode, in_channels, stage_channels, in_channels)
79
+ self.connectors[f'connector{block_i}'] = connector
80
+ in_channels = connector.output_channels
81
+
82
+ self.blocks_connected[f'block{block_i}'] = ConvBlock(
83
+ in_channels, out_channels,
84
+ norm_layer=norm_layer if 0 <= batchnorm_from <= block_i else None,
85
+ padding=int(block_i < depth - 1)
86
+ )
87
+ self.blocks_channels += [out_channels]
88
+
89
+ def forward(self, x, backbone_features):
90
+ backbone_features = [] if backbone_features is None else backbone_features[::-1]
91
+
92
+ outputs = [self.block0(x)]
93
+ outputs += [self.block1(outputs[-1])]
94
+
95
+ for block_i in range(2, self.depth):
96
+ output = outputs[-1]
97
+ connector_name = f'connector{block_i}'
98
+ if connector_name in self.connectors:
99
+ if self.INRDecode:
100
+ block = self.blocks_connected[f'block{block_i}_decode']
101
+ outputs += [block(output)]
102
+
103
+ stage_features = backbone_features.pop()
104
+ connector = self.connectors[connector_name]
105
+ output = connector(output, stage_features)
106
+ block = self.blocks_connected[f'block{block_i}']
107
+ outputs += [block(output)]
108
+
109
+ return outputs[::-1]
110
+
111
+
112
+ class DeconvDecoder(nn.Module):
113
+ def __init__(self, depth, encoder_blocks_channels, norm_layer, attend_from=-1, image_fusion=False):
114
+ super(DeconvDecoder, self).__init__()
115
+ self.image_fusion = image_fusion
116
+ self.deconv_blocks = nn.ModuleList()
117
+
118
+ in_channels = encoder_blocks_channels.pop()
119
+ out_channels = in_channels
120
+ for d in range(depth):
121
+ out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
122
+ self.deconv_blocks.append(SEDeconvBlock(
123
+ in_channels, out_channels,
124
+ norm_layer=norm_layer,
125
+ padding=0 if d == 0 else 1,
126
+ with_se=0 <= attend_from <= d
127
+ ))
128
+ in_channels = out_channels
129
+
130
+ if self.image_fusion:
131
+ self.conv_attention = nn.Conv2d(out_channels, 1, kernel_size=1)
132
+ self.to_rgb = nn.Conv2d(out_channels, 3, kernel_size=1)
133
+
134
+ def forward(self, encoder_outputs, image, mask=None):
135
+ output = encoder_outputs[0]
136
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
137
+ output = block(output, mask)
138
+ output = output + skip_output
139
+ output = self.deconv_blocks[-1](output, mask)
140
+
141
+ if self.image_fusion:
142
+ attention_map = torch.sigmoid(3.0 * self.conv_attention(output))
143
+ output = attention_map * image + (1.0 - attention_map) * self.to_rgb(output)
144
+ else:
145
+ output = self.to_rgb(output)
146
+
147
+ return output
148
+
149
+
150
+ class SEDeconvBlock(nn.Module):
151
+ def __init__(
152
+ self,
153
+ in_channels, out_channels,
154
+ kernel_size=4, stride=2, padding=1,
155
+ norm_layer=nn.BatchNorm2d, activation=nn.ELU,
156
+ with_se=False
157
+ ):
158
+ super(SEDeconvBlock, self).__init__()
159
+ self.with_se = with_se
160
+ self.block = nn.Sequential(
161
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding),
162
+ norm_layer(out_channels) if norm_layer is not None else nn.Identity(),
163
+ activation(),
164
+ )
165
+ if self.with_se:
166
+ self.se = MaskedChannelAttention(out_channels)
167
+
168
+ def forward(self, x, mask=None):
169
+ out = self.block(x)
170
+ if self.with_se:
171
+ out = self.se(out, mask)
172
+ return out
173
+
174
+
175
+ class INRDecoder(nn.Module):
176
+ def __init__(self, depth, encoder_blocks_channels, norm_layer, opt, attend_from):
177
+ super(INRDecoder, self).__init__()
178
+ self.INR_encoding = None
179
+ if opt.embedding_type == "PosEncodingNeRF":
180
+ self.INR_encoding = PosEncodingNeRF(in_features=2, sidelength=opt.input_size)
181
+ elif opt.embedding_type == "RandomFourier":
182
+ self.INR_encoding = RandomFourier(std_scale=10, embedding_length=64, device=opt.device)
183
+ elif opt.embedding_type == "CIPS_embed":
184
+ self.INR_encoding = CIPS_embed(size=opt.base_size, embedding_length=32)
185
+ elif opt.embedding_type == "INRGAN_embed":
186
+ self.INR_encoding = INRGAN_embed(resolution=opt.INR_input_size)
187
+ else:
188
+ raise NotImplementedError
189
+ encoder_blocks_channels = encoder_blocks_channels[::-1]
190
+ max_hidden_mlp_num = attend_from + 1
191
+ self.opt = opt
192
+ self.max_hidden_mlp_num = max_hidden_mlp_num
193
+ self.content_mlp_blocks = nn.ModuleDict()
194
+ for n in range(max_hidden_mlp_num):
195
+ if n != max_hidden_mlp_num - 1:
196
+ self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
197
+ [self.INR_encoding.out_dim + opt.INR_MLP_dim + (
198
+ 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
199
+ opt, n + 1)
200
+ else:
201
+ self.content_mlp_blocks[f"block{n}"] = convParams(encoder_blocks_channels.pop(),
202
+ [self.INR_encoding.out_dim + (
203
+ 4 if opt.isMoreINRInput else 0), opt.INR_MLP_dim],
204
+ opt, n + 1)
205
+
206
+ self.deconv_blocks = nn.ModuleList()
207
+
208
+ encoder_blocks_channels = encoder_blocks_channels[::-1]
209
+ in_channels = encoder_blocks_channels.pop()
210
+ out_channels = in_channels
211
+ for d in range(depth - attend_from):
212
+ out_channels = encoder_blocks_channels.pop() if len(encoder_blocks_channels) else in_channels // 2
213
+ self.deconv_blocks.append(SEDeconvBlock(
214
+ in_channels, out_channels,
215
+ norm_layer=norm_layer,
216
+ padding=0 if d == 0 else 1,
217
+ with_se=False
218
+ ))
219
+ in_channels = out_channels
220
+
221
+ self.appearance_mlps = lineParams(out_channels, [opt.INR_MLP_dim, opt.INR_MLP_dim],
222
+ (opt.base_size // (2 ** (max_hidden_mlp_num - 1))) ** 2,
223
+ opt, 2, toRGB=True)
224
+
225
+ self.lut_transform = build_lut_transform(self.appearance_mlps.output_dim, opt.LUT_dim,
226
+ None, opt)
227
+
228
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
229
+
230
+ def forward(self, encoder_outputs, image=None, mask=None, coord_samples=None, start_proportion=None):
231
+ """For full resolution, do split."""
232
+ if self.opt.hr_train and not (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt,
233
+ 'split_resolution')) and self.opt.isFullRes:
234
+ return self.forward_fullResInference(encoder_outputs, image=image, mask=mask, coord_samples=coord_samples)
235
+
236
+ encoder_outputs = encoder_outputs[::-1]
237
+ mlp_output = None
238
+ waitToRGB = []
239
+ for n in range(self.max_hidden_mlp_num):
240
+ if not self.opt.hr_train:
241
+ coord = misc.get_mgrid(self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))) \
242
+ .unsqueeze(0).repeat(encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
243
+ else:
244
+ if self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution'):
245
+ coord = coord_samples[self.max_hidden_mlp_num - n - 1].permute(0, 2, 3, 1).view(
246
+ encoder_outputs[0].shape[0], -1, 2)
247
+ else:
248
+ coord = misc.get_mgrid(
249
+ self.opt.INR_input_size // (2 ** (self.max_hidden_mlp_num - n - 1))).unsqueeze(0).repeat(
250
+ encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
251
+
252
+ """Whether to leverage multiple input to INR decoder. See Section 3.4 in the paper."""
253
+ if self.opt.isMoreINRInput:
254
+ if not self.opt.isFullRes or (
255
+ self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
256
+ res_h = res_w = np.sqrt(coord.shape[1]).astype(int)
257
+ else:
258
+ res_h = image.shape[-2] // (2 ** (self.max_hidden_mlp_num - n - 1))
259
+ res_w = image.shape[-1] // (2 ** (self.max_hidden_mlp_num - n - 1))
260
+
261
+ res_image = torchvision.transforms.Resize([res_h, res_w])(image)
262
+ res_mask = torchvision.transforms.Resize([res_h, res_w])(mask)
263
+ coord = torch.cat([self.INR_encoding(coord), res_image.view(*res_image.shape[:2], -1).permute(0, 2, 1),
264
+ res_mask.view(*res_mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
265
+ else:
266
+ coord = self.INR_encoding(coord)
267
+
268
+ """============ LRIP structure, see Section 3.3 =============="""
269
+
270
+ """Local MLPs."""
271
+ if n == 0:
272
+ mlp_output = self.mlp_process(coord, self.INR_encoding.out_dim + (4 if self.opt.isMoreINRInput else 0),
273
+ self.opt, content_mlp=self.content_mlp_blocks[
274
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
275
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)), start_proportion=start_proportion)
276
+ waitToRGB.append(mlp_output[1])
277
+ else:
278
+ mlp_output = self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
279
+ 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0],
280
+ content_mlp=self.content_mlp_blocks[
281
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
282
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n)),
283
+ start_proportion=start_proportion)
284
+ waitToRGB.append(mlp_output[1])
285
+
286
+ encoder_outputs = encoder_outputs[::-1]
287
+ output = encoder_outputs[0]
288
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
289
+ output = block(output)
290
+ output = output + skip_output
291
+ output = self.deconv_blocks[-1](output)
292
+
293
+ """Global MLPs."""
294
+ app_mlp, app_params = self.appearance_mlps(output)
295
+ harm_out = []
296
+ for id in range(len(waitToRGB)):
297
+ output = self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=waitToRGB[id],
298
+ appearance_mlp=app_mlp)
299
+ harm_out.append(output[0])
300
+
301
+ """Optional 3D LUT prediction."""
302
+ fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
303
+
304
+ return harm_out, fit_lut3d, lut_transform_image
305
+
306
+ def mlp_process(self, coorinates, INR_input_dim, opt, base_feat=None, content_mlp=None, appearance_mlp=None,
307
+ resolution=None, start_proportion=None):
308
+
309
+ activation = select_activation(opt.activation)
310
+
311
+ output = None
312
+
313
+ if content_mlp is not None:
314
+ if base_feat is not None:
315
+ coorinates = torch.cat([coorinates, base_feat], dim=2)
316
+ coorinates = lin2img(coorinates, resolution)
317
+
318
+ if hasattr(opt, 'split_resolution'):
319
+ """
320
+ Here we crop the needed MLPs according to the region of the split input patches.
321
+ Note that this only support inferencing square images.
322
+ """
323
+ for idx in range(len(content_mlp)):
324
+ content_mlp[idx][0] = content_mlp[idx][0][:,
325
+ (content_mlp[idx][0].shape[1] * start_proportion[0]).int():(
326
+ content_mlp[idx][0].shape[1] * start_proportion[2]).int(),
327
+ (content_mlp[idx][0].shape[2] * start_proportion[1]).int():(
328
+ content_mlp[idx][0].shape[2] * start_proportion[3]).int(), :,
329
+ :]
330
+ content_mlp[idx][1] = content_mlp[idx][1][:,
331
+ (content_mlp[idx][1].shape[1] * start_proportion[0]).int():(
332
+ content_mlp[idx][1].shape[1] * start_proportion[2]).int(),
333
+ (content_mlp[idx][1].shape[2] * start_proportion[1]).int():(
334
+ content_mlp[idx][1].shape[2] * start_proportion[3]).int(),
335
+ :,
336
+ :]
337
+ k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
338
+ k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
339
+ bs = coorinates.shape[0]
340
+ h_lr = w_lr = content_mlp[0][0].shape[1]
341
+ nci = INR_input_dim
342
+
343
+ coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
344
+ coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
345
+ bs, h_lr, w_lr, int(k_h * k_w), nci)
346
+
347
+ for id, layer in enumerate(content_mlp):
348
+ if id == 0:
349
+ output = torch.matmul(coorinates, layer[0]) + layer[1]
350
+ output = activation(output)
351
+ else:
352
+ output = torch.matmul(output, layer[0]) + layer[1]
353
+ output = activation(output)
354
+
355
+ output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
356
+ 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
357
+
358
+ output_large = self.up(lin2img(output))
359
+
360
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
361
+
362
+ k_h = coorinates.shape[2] // content_mlp[0][0].shape[1]
363
+ k_w = coorinates.shape[3] // content_mlp[0][0].shape[1]
364
+ bs = coorinates.shape[0]
365
+ h_lr = w_lr = content_mlp[0][0].shape[1]
366
+ nci = INR_input_dim
367
+
368
+ """(evaluation or not HR training) and not fullres evaluation"""
369
+ if (not self.opt.hr_train or not (self.training or hasattr(self.opt, 'split_num'))) and not (
370
+ not (self.training or hasattr(self.opt, 'split_num')) and self.opt.isFullRes and self.opt.hr_train):
371
+ coorinates = coorinates.unfold(2, k_h, k_h).unfold(3, k_w, k_w)
372
+ coorinates = coorinates.permute(0, 2, 3, 4, 5, 1).contiguous().view(
373
+ bs, h_lr, w_lr, int(k_h * k_w), nci)
374
+
375
+ for id, layer in enumerate(content_mlp):
376
+ if id == 0:
377
+ output = torch.matmul(coorinates, layer[0]) + layer[1]
378
+ output = activation(output)
379
+ else:
380
+ output = torch.matmul(output, layer[0]) + layer[1]
381
+ output = activation(output)
382
+
383
+ output = output.view(bs, h_lr, w_lr, k_h, k_w, opt.INR_MLP_dim).permute(
384
+ 0, 1, 3, 2, 4, 5).contiguous().view(bs, -1, opt.INR_MLP_dim)
385
+
386
+ output_large = self.up(lin2img(output))
387
+
388
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
389
+ else:
390
+ coorinates = coorinates.permute(0, 2, 3, 1)
391
+ for id, layer in enumerate(content_mlp):
392
+ weigt_shape = layer[0].shape
393
+ bias_shape = layer[1].shape
394
+ layer[0] = layer[0].view(*layer[0].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
395
+ layer[1] = layer[1].view(*layer[1].shape[:-2], -1).permute(0, 3, 1, 2).contiguous()
396
+ layer[0] = F.grid_sample(layer[0], coorinates[..., :2].flip(-1), mode='nearest' if True
397
+ else 'bilinear', padding_mode='border', align_corners=False)
398
+ layer[1] = F.grid_sample(layer[1], coorinates[..., :2].flip(-1), mode='nearest' if True
399
+ else 'bilinear', padding_mode='border', align_corners=False)
400
+ layer[0] = layer[0].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *weigt_shape[-2:])
401
+ layer[1] = layer[1].permute(0, 2, 3, 1).contiguous().view(*coorinates.shape[:-1], *bias_shape[-2:])
402
+
403
+ if id == 0:
404
+ output = torch.matmul(coorinates.unsqueeze(-2), layer[0]) + layer[1]
405
+ output = activation(output)
406
+ else:
407
+ output = torch.matmul(output, layer[0]) + layer[1]
408
+ output = activation(output)
409
+
410
+ output = output.squeeze(-2).view(bs, -1, opt.INR_MLP_dim)
411
+
412
+ output_large = self.up(lin2img(output, resolution))
413
+
414
+ return output_large.view(bs, -1, opt.INR_MLP_dim), output
415
+
416
+ elif appearance_mlp is not None:
417
+ output = base_feat
418
+ genMask = None
419
+ for id, layer in enumerate(appearance_mlp):
420
+ if id != len(appearance_mlp) - 1:
421
+ output = torch.matmul(output, layer[0]) + layer[1]
422
+ output = activation(output)
423
+ else:
424
+ output = torch.matmul(output, layer[0]) + layer[1] # last layer
425
+ if opt.activation == 'leakyrelu_pe':
426
+ output = torch.tanh(output)
427
+ return lin2img(output, resolution), None
428
+
429
+ def forward_fullResInference(self, encoder_outputs, image=None, mask=None, coord_samples=None):
430
+ encoder_outputs = encoder_outputs[::-1]
431
+ mlp_output = None
432
+ res_w = image.shape[-1]
433
+ res_h = image.shape[-2]
434
+ coord = misc.get_mgrid([image.shape[-2], image.shape[-1]]).unsqueeze(0).repeat(
435
+ encoder_outputs[0].shape[0], 1, 1).to(self.opt.device)
436
+
437
+ if self.opt.isMoreINRInput:
438
+ coord = torch.cat(
439
+ [self.INR_encoding(coord, (res_h, res_w)), image.view(*image.shape[:2], -1).permute(0, 2, 1),
440
+ mask.view(*mask.shape[:2], -1).permute(0, 2, 1)], dim=-1)
441
+ else:
442
+ coord = self.INR_encoding(coord, (res_h, res_w))
443
+
444
+ total = coord.clone()
445
+
446
+ interval = 10
447
+ all_intervals = math.ceil(res_h / interval)
448
+ divisible = True
449
+ if res_h / interval != res_h // interval:
450
+ divisible = False
451
+
452
+ for n in range(self.max_hidden_mlp_num):
453
+ accum_mlp_output = []
454
+ for line in range(all_intervals):
455
+ if not divisible and line == all_intervals - 1:
456
+ coord = total[:, line * interval * res_w:, :]
457
+ else:
458
+ coord = total[:, line * interval * res_w: (line + 1) * interval * res_w, :]
459
+ if n == 0:
460
+ accum_mlp_output.append(self.mlp_process(coord,
461
+ self.INR_encoding.out_dim + (
462
+ 4 if self.opt.isMoreINRInput else 0),
463
+ self.opt, content_mlp=self.content_mlp_blocks[
464
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
465
+ encoder_outputs.pop(self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
466
+ encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
467
+ resolution=(interval,
468
+ res_w) if divisible or line != all_intervals - 1 else (
469
+ res_h - interval * (all_intervals - 1), res_w))[1])
470
+
471
+ else:
472
+ accum_mlp_output.append(self.mlp_process(coord, self.opt.INR_MLP_dim + self.INR_encoding.out_dim + (
473
+ 4 if self.opt.isMoreINRInput else 0), self.opt, base_feat=mlp_output[0][:,
474
+ line * interval * res_w: (
475
+ line + 1) * interval * res_w,
476
+ :]
477
+ if divisible or line != all_intervals - 1 else mlp_output[0][:, line * interval * res_w:, :],
478
+ content_mlp=self.content_mlp_blocks[
479
+ f"block{self.max_hidden_mlp_num - 1 - n}"](
480
+ encoder_outputs.pop(
481
+ self.max_hidden_mlp_num - 1 - n) if line == all_intervals - 1 else
482
+ encoder_outputs[self.max_hidden_mlp_num - 1 - n]),
483
+ resolution=(interval,
484
+ res_w) if divisible or line != all_intervals - 1 else (
485
+ res_h - interval * (all_intervals - 1), res_w))[1])
486
+
487
+ accum_mlp_output = torch.cat(accum_mlp_output, dim=1)
488
+ mlp_output = [accum_mlp_output, accum_mlp_output]
489
+
490
+ encoder_outputs = encoder_outputs[::-1]
491
+ output = encoder_outputs[0]
492
+ for block, skip_output in zip(self.deconv_blocks[:-1], encoder_outputs[1:]):
493
+ output = block(output)
494
+ output = output + skip_output
495
+ output = self.deconv_blocks[-1](output)
496
+
497
+ app_mlp, app_params = self.appearance_mlps(output)
498
+ harm_out = []
499
+
500
+ accum_mlp_output = []
501
+ for line in range(all_intervals):
502
+ if not divisible and line == all_intervals - 1:
503
+ base = mlp_output[1][:, line * interval * res_w:, :]
504
+ else:
505
+ base = mlp_output[1][:, line * interval * res_w: (line + 1) * interval * res_w, :]
506
+
507
+ accum_mlp_output.append(self.mlp_process(None, self.opt.INR_MLP_dim, self.opt, base_feat=base,
508
+ appearance_mlp=app_mlp,
509
+ resolution=(
510
+ interval,
511
+ res_w) if divisible or line != all_intervals - 1 else (
512
+ res_h - interval * (all_intervals - 1), res_w))[0])
513
+
514
+ accum_mlp_output = torch.cat(accum_mlp_output, dim=2)
515
+ harm_out.append(accum_mlp_output)
516
+
517
+ fit_lut3d, lut_transform_image = self.lut_transform(image, app_params, None)
518
+
519
+ return harm_out, fit_lut3d, lut_transform_image
model/base/ih_model.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import torch.nn as nn
4
+
5
+ from .conv_autoencoder import ConvEncoder, DeconvDecoder, INRDecoder
6
+
7
+ from .ops import ScaleLayer
8
+
9
+
10
+ class IHModelWithBackbone(nn.Module):
11
+ def __init__(
12
+ self,
13
+ model, backbone,
14
+ downsize_backbone_input=False,
15
+ mask_fusion='sum',
16
+ backbone_conv1_channels=64, opt=None
17
+ ):
18
+ super(IHModelWithBackbone, self).__init__()
19
+ self.downsize_backbone_input = downsize_backbone_input
20
+ self.mask_fusion = mask_fusion
21
+
22
+ self.backbone = backbone
23
+ self.model = model
24
+ self.opt = opt
25
+
26
+ self.mask_conv = nn.Sequential(
27
+ nn.Conv2d(1, backbone_conv1_channels, kernel_size=3, stride=2, padding=1, bias=True),
28
+ ScaleLayer(init_value=0.1, lr_mult=1)
29
+ )
30
+
31
+ def forward(self, image, mask, coord=None, start_proportion=None):
32
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
33
+ backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0])
34
+ backbone_mask = torch.cat(
35
+ (torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0]),
36
+ 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
37
+ else:
38
+ backbone_image = torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image)
39
+ backbone_mask = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask),
40
+ 1.0 - torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
41
+
42
+ backbone_mask_features = self.mask_conv(backbone_mask[:, :1])
43
+ backbone_features = self.backbone(backbone_image, backbone_mask, backbone_mask_features)
44
+
45
+ output = self.model(image, mask, backbone_features, coord=coord, start_proportion=start_proportion)
46
+ return output
47
+
48
+
49
+ class DeepImageHarmonization(nn.Module):
50
+ def __init__(
51
+ self,
52
+ depth,
53
+ norm_layer=nn.BatchNorm2d, batchnorm_from=0,
54
+ attend_from=-1,
55
+ image_fusion=False,
56
+ ch=64, max_channels=512,
57
+ backbone_from=-1, backbone_channels=None, backbone_mode='', opt=None
58
+ ):
59
+ super(DeepImageHarmonization, self).__init__()
60
+ self.depth = depth
61
+ self.encoder = ConvEncoder(
62
+ depth, ch,
63
+ norm_layer, batchnorm_from, max_channels,
64
+ backbone_from, backbone_channels, backbone_mode, INRDecode=opt.INRDecode
65
+ )
66
+ self.opt = opt
67
+ if opt.INRDecode:
68
+ "See Table 2 in the paper to test with different INR decoders' structures."
69
+ self.decoder = INRDecoder(depth, self.encoder.blocks_channels, norm_layer, opt, backbone_from)
70
+ else:
71
+ "Baseline: https://github.com/SamsungLabs/image_harmonization"
72
+ self.decoder = DeconvDecoder(depth, self.encoder.blocks_channels, norm_layer, attend_from, image_fusion)
73
+
74
+ def forward(self, image, mask, backbone_features=None, coord=None, start_proportion=None):
75
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
76
+ x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image[0]),
77
+ torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask[0])), dim=1)
78
+ else:
79
+ x = torch.cat((torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(image),
80
+ torchvision.transforms.Resize([self.opt.base_size, self.opt.base_size])(mask)), dim=1)
81
+
82
+ intermediates = self.encoder(x, backbone_features)
83
+
84
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
85
+ output = self.decoder(intermediates, image[1], mask[1], coord_samples=coord, start_proportion=start_proportion)
86
+ else:
87
+ output = self.decoder(intermediates, image, mask)
88
+ return output
model/base/ops.py ADDED
@@ -0,0 +1,397 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ import numpy as np
4
+ import math
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class SimpleInputFusion(nn.Module):
9
+ def __init__(self, add_ch=1, rgb_ch=3, ch=8, norm_layer=nn.BatchNorm2d):
10
+ super(SimpleInputFusion, self).__init__()
11
+
12
+ self.fusion_conv = nn.Sequential(
13
+ nn.Conv2d(in_channels=add_ch + rgb_ch, out_channels=ch, kernel_size=1),
14
+ nn.LeakyReLU(negative_slope=0.2),
15
+ norm_layer(ch),
16
+ nn.Conv2d(in_channels=ch, out_channels=rgb_ch, kernel_size=1),
17
+ )
18
+
19
+ def forward(self, image, additional_input):
20
+ return self.fusion_conv(torch.cat((image, additional_input), dim=1))
21
+
22
+
23
+ class MaskedChannelAttention(nn.Module):
24
+ def __init__(self, in_channels, *args, **kwargs):
25
+ super(MaskedChannelAttention, self).__init__()
26
+ self.global_max_pool = MaskedGlobalMaxPool2d()
27
+ self.global_avg_pool = FastGlobalAvgPool2d()
28
+
29
+ intermediate_channels_count = max(in_channels // 16, 8)
30
+ self.attention_transform = nn.Sequential(
31
+ nn.Linear(3 * in_channels, intermediate_channels_count),
32
+ nn.ReLU(inplace=True),
33
+ nn.Linear(intermediate_channels_count, in_channels),
34
+ nn.Sigmoid(),
35
+ )
36
+
37
+ def forward(self, x, mask):
38
+ if mask.shape[2:] != x.shape[:2]:
39
+ mask = nn.functional.interpolate(
40
+ mask, size=x.size()[-2:],
41
+ mode='bilinear', align_corners=True
42
+ )
43
+ pooled_x = torch.cat([
44
+ self.global_max_pool(x, mask),
45
+ self.global_avg_pool(x)
46
+ ], dim=1)
47
+ channel_attention_weights = self.attention_transform(pooled_x)[..., None, None]
48
+
49
+ return channel_attention_weights * x
50
+
51
+
52
+ class MaskedGlobalMaxPool2d(nn.Module):
53
+ def __init__(self):
54
+ super().__init__()
55
+ self.global_max_pool = FastGlobalMaxPool2d()
56
+
57
+ def forward(self, x, mask):
58
+ return torch.cat((
59
+ self.global_max_pool(x * mask),
60
+ self.global_max_pool(x * (1.0 - mask))
61
+ ), dim=1)
62
+
63
+
64
+ class FastGlobalAvgPool2d(nn.Module):
65
+ def __init__(self):
66
+ super(FastGlobalAvgPool2d, self).__init__()
67
+
68
+ def forward(self, x):
69
+ in_size = x.size()
70
+ return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
71
+
72
+
73
+ class FastGlobalMaxPool2d(nn.Module):
74
+ def __init__(self):
75
+ super(FastGlobalMaxPool2d, self).__init__()
76
+
77
+ def forward(self, x):
78
+ in_size = x.size()
79
+ return x.view((in_size[0], in_size[1], -1)).max(dim=2)[0]
80
+
81
+
82
+ class ScaleLayer(nn.Module):
83
+ def __init__(self, init_value=1.0, lr_mult=1):
84
+ super().__init__()
85
+ self.lr_mult = lr_mult
86
+ self.scale = nn.Parameter(
87
+ torch.full((1,), init_value / lr_mult, dtype=torch.float32)
88
+ )
89
+
90
+ def forward(self, x):
91
+ scale = torch.abs(self.scale * self.lr_mult)
92
+ return x * scale
93
+
94
+
95
+ class FeaturesConnector(nn.Module):
96
+ def __init__(self, mode, in_channels, feature_channels, out_channels):
97
+ super(FeaturesConnector, self).__init__()
98
+ self.mode = mode if feature_channels else ''
99
+
100
+ if self.mode == 'catc':
101
+ self.reduce_conv = nn.Conv2d(in_channels + feature_channels, out_channels, kernel_size=1)
102
+ elif self.mode == 'sum':
103
+ self.reduce_conv = nn.Conv2d(feature_channels, out_channels, kernel_size=1)
104
+
105
+ self.output_channels = out_channels if self.mode != 'cat' else in_channels + feature_channels
106
+
107
+ def forward(self, x, features):
108
+ if self.mode == 'cat':
109
+ return torch.cat((x, features), 1)
110
+ if self.mode == 'catc':
111
+ return self.reduce_conv(torch.cat((x, features), 1))
112
+ if self.mode == 'sum':
113
+ return self.reduce_conv(features) + x
114
+ return x
115
+
116
+ def extra_repr(self):
117
+ return self.mode
118
+
119
+
120
+ class PosEncodingNeRF(nn.Module):
121
+ def __init__(self, in_features, sidelength=None, fn_samples=None, use_nyquist=True):
122
+ super().__init__()
123
+
124
+ self.in_features = in_features
125
+
126
+ if self.in_features == 3:
127
+ self.num_frequencies = 10
128
+ elif self.in_features == 2:
129
+ assert sidelength is not None
130
+ if isinstance(sidelength, int):
131
+ sidelength = (sidelength, sidelength)
132
+ self.num_frequencies = 4
133
+ if use_nyquist:
134
+ self.num_frequencies = self.get_num_frequencies_nyquist(min(sidelength[0], sidelength[1]))
135
+ elif self.in_features == 1:
136
+ assert fn_samples is not None
137
+ self.num_frequencies = 4
138
+ if use_nyquist:
139
+ self.num_frequencies = self.get_num_frequencies_nyquist(fn_samples)
140
+
141
+ self.out_dim = in_features + 2 * in_features * self.num_frequencies
142
+
143
+ def get_num_frequencies_nyquist(self, samples):
144
+ nyquist_rate = 1 / (2 * (2 * 1 / samples))
145
+ return int(math.floor(math.log(nyquist_rate, 2)))
146
+
147
+ def forward(self, coords):
148
+ coords = coords.view(coords.shape[0], -1, self.in_features)
149
+
150
+ coords_pos_enc = coords
151
+ for i in range(self.num_frequencies):
152
+ for j in range(self.in_features):
153
+ c = coords[..., j]
154
+
155
+ sin = torch.unsqueeze(torch.sin((2 ** i) * np.pi * c), -1)
156
+ cos = torch.unsqueeze(torch.cos((2 ** i) * np.pi * c), -1)
157
+
158
+ coords_pos_enc = torch.cat((coords_pos_enc, sin, cos), axis=-1)
159
+
160
+ return coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)
161
+
162
+
163
+ class RandomFourier(nn.Module):
164
+ def __init__(self, std_scale, embedding_length, device):
165
+ super().__init__()
166
+
167
+ self.embed = torch.normal(0, 1, (2, embedding_length)) * std_scale
168
+ self.embed = self.embed.to(device)
169
+
170
+ self.out_dim = embedding_length * 2 + 2
171
+
172
+ def forward(self, coords):
173
+ coords_pos_enc = torch.cat([torch.sin(torch.matmul(2 * np.pi * coords, self.embed)),
174
+ torch.cos(torch.matmul(2 * np.pi * coords, self.embed))], dim=-1)
175
+
176
+ return torch.cat([coords, coords_pos_enc.reshape(coords.shape[0], -1, self.out_dim)], dim=-1)
177
+
178
+
179
+ class CIPS_embed(nn.Module):
180
+ def __init__(self, size, embedding_length):
181
+ super().__init__()
182
+ self.fourier_embed = ConstantInput(size, embedding_length)
183
+ self.predict_embed = Predict_embed(embedding_length)
184
+ self.out_dim = embedding_length * 2 + 2
185
+
186
+ def forward(self, coord, res=None):
187
+ x = self.predict_embed(coord)
188
+ y = self.fourier_embed(x, coord, res)
189
+
190
+ return torch.cat([coord, x, y], dim=-1)
191
+
192
+
193
+ class Predict_embed(nn.Module):
194
+ def __init__(self, embedding_length):
195
+ super(Predict_embed, self).__init__()
196
+ self.ffm = nn.Linear(2, embedding_length, bias=True)
197
+ nn.init.uniform_(self.ffm.weight, -np.sqrt(9 / 2), np.sqrt(9 / 2))
198
+
199
+ def forward(self, x):
200
+ x = self.ffm(x)
201
+ x = torch.sin(x)
202
+ return x
203
+
204
+
205
+ class ConstantInput(nn.Module):
206
+ def __init__(self, size, channel):
207
+ super().__init__()
208
+
209
+ self.input = nn.Parameter(torch.randn(1, size ** 2, channel))
210
+
211
+ def forward(self, input, coord, resolution=None):
212
+ batch = input.shape[0]
213
+ out = self.input.repeat(batch, 1, 1)
214
+
215
+ if coord.shape[1] != self.input.shape[1]:
216
+ x = out.permute(0, 2, 1).contiguous().view(batch, self.input.shape[-1],
217
+ int(self.input.shape[1] ** 0.5), int(self.input.shape[1] ** 0.5))
218
+
219
+ if resolution is None:
220
+ grid = coord.view(coord.shape[0], int(coord.shape[1] ** 0.5), int(coord.shape[1] ** 0.5), coord.shape[-1])
221
+ else:
222
+ grid = coord.view(coord.shape[0], *resolution, coord.shape[-1])
223
+
224
+ out = F.grid_sample(x, grid.flip(-1), mode='bilinear', padding_mode='border', align_corners=True)
225
+
226
+ out = out.permute(0, 2, 3, 1).contiguous().view(batch, -1, self.input.shape[-1])
227
+
228
+ return out
229
+
230
+
231
+ class INRGAN_embed(nn.Module):
232
+ def __init__(self, resolution: int, w_dim=None):
233
+ super().__init__()
234
+
235
+ self.resolution = resolution
236
+ self.res_cfg = {"log_emb_size": 32,
237
+ "random_emb_size": 32,
238
+ "const_emb_size": 64,
239
+ "use_cosine": True}
240
+ self.log_emb_size = self.res_cfg.get('log_emb_size', 0)
241
+ self.random_emb_size = self.res_cfg.get('random_emb_size', 0)
242
+ self.shared_emb_size = self.res_cfg.get('shared_emb_size', 0)
243
+ self.predictable_emb_size = self.res_cfg.get('predictable_emb_size', 0)
244
+ self.const_emb_size = self.res_cfg.get('const_emb_size', 0)
245
+ self.fourier_scale = self.res_cfg.get('fourier_scale', np.sqrt(10))
246
+ self.use_cosine = self.res_cfg.get('use_cosine', False)
247
+
248
+ if self.log_emb_size > 0:
249
+ self.register_buffer('log_basis', generate_logarithmic_basis(
250
+ resolution, self.log_emb_size, use_diagonal=self.res_cfg.get('use_diagonal', False)))
251
+
252
+ if self.random_emb_size > 0:
253
+ self.register_buffer('random_basis', self.sample_w_matrix((2, self.random_emb_size), self.fourier_scale))
254
+
255
+ if self.shared_emb_size > 0:
256
+ self.shared_basis = nn.Parameter(self.sample_w_matrix((2, self.shared_emb_size), self.fourier_scale))
257
+
258
+ if self.predictable_emb_size > 0:
259
+ self.W_size = self.predictable_emb_size * self.cfg.coord_dim
260
+ self.b_size = self.predictable_emb_size
261
+ self.affine = nn.Linear(w_dim, self.W_size + self.b_size)
262
+
263
+ if self.const_emb_size > 0:
264
+ self.const_embs = nn.Parameter(torch.randn(1, resolution ** 2, self.const_emb_size))
265
+
266
+ self.out_dim = self.get_total_dim() + 2
267
+
268
+ def sample_w_matrix(self, shape, scale: float):
269
+ return torch.randn(shape) * scale
270
+
271
+ def get_total_dim(self) -> int:
272
+ total_dim = 0
273
+ if self.log_emb_size > 0:
274
+ total_dim += self.log_basis.shape[0] * (2 if self.use_cosine else 1)
275
+ total_dim += self.random_emb_size * (2 if self.use_cosine else 1)
276
+ total_dim += self.shared_emb_size * (2 if self.use_cosine else 1)
277
+ total_dim += self.predictable_emb_size * (2 if self.use_cosine else 1)
278
+ total_dim += self.const_emb_size
279
+
280
+ return total_dim
281
+
282
+ def forward(self, raw_coords, w=None):
283
+ batch_size, img_size, in_channels = raw_coords.shape
284
+
285
+ raw_embs = []
286
+
287
+ if self.log_emb_size > 0:
288
+ log_bases = self.log_basis.unsqueeze(0).repeat(batch_size, 1, 1).permute(0, 2, 1)
289
+ raw_log_embs = torch.matmul(raw_coords, log_bases)
290
+ raw_embs.append(raw_log_embs)
291
+
292
+ if self.random_emb_size > 0:
293
+ random_bases = self.random_basis.unsqueeze(0).repeat(batch_size, 1, 1)
294
+ raw_random_embs = torch.matmul(raw_coords, random_bases)
295
+ raw_embs.append(raw_random_embs)
296
+
297
+ if self.shared_emb_size > 0:
298
+ shared_bases = self.shared_basis.unsqueeze(0).repeat(batch_size, 1, 1)
299
+ raw_shared_embs = torch.matmul(raw_coords, shared_bases)
300
+ raw_embs.append(raw_shared_embs)
301
+
302
+ if self.predictable_emb_size > 0:
303
+ mod = self.affine(w)
304
+ W = self.fourier_scale * mod[:, :self.W_size]
305
+ W = W.view(batch_size, self.cfg.coord_dim, self.predictable_emb_size)
306
+ bias = mod[:, self.W_size:].view(batch_size, 1, self.predictable_emb_size)
307
+ raw_predictable_embs = (torch.matmul(raw_coords, W) + bias)
308
+ raw_embs.append(raw_predictable_embs)
309
+
310
+ if len(raw_embs) > 0:
311
+ raw_embs = torch.cat(raw_embs, dim=-1)
312
+ raw_embs = raw_embs.contiguous()
313
+ out = raw_embs.sin()
314
+
315
+ if self.use_cosine:
316
+ out = torch.cat([out, raw_embs.cos()], dim=-1)
317
+
318
+ if self.const_emb_size > 0:
319
+ const_embs = self.const_embs.repeat([batch_size, 1, 1])
320
+ const_embs = const_embs
321
+ out = torch.cat([out, const_embs], dim=-1)
322
+
323
+ return torch.cat([raw_coords, out], dim=-1)
324
+
325
+
326
+ def generate_logarithmic_basis(
327
+ resolution,
328
+ max_num_feats,
329
+ remove_lowest_freq: bool = False,
330
+ use_diagonal: bool = True):
331
+ """
332
+ Generates a directional logarithmic basis with the following directions:
333
+ - horizontal
334
+ - vertical
335
+ - main diagonal
336
+ - anti-diagonal
337
+ """
338
+ max_num_feats_per_direction = np.ceil(np.log2(resolution)).astype(int)
339
+ bases = [
340
+ generate_horizontal_basis(max_num_feats_per_direction),
341
+ generate_vertical_basis(max_num_feats_per_direction),
342
+ ]
343
+
344
+ if use_diagonal:
345
+ bases.extend([
346
+ generate_diag_main_basis(max_num_feats_per_direction),
347
+ generate_anti_diag_basis(max_num_feats_per_direction),
348
+ ])
349
+
350
+ if remove_lowest_freq:
351
+ bases = [b[1:] for b in bases]
352
+
353
+ # If we do not fit into `max_num_feats`, then trying to remove the features in the order:
354
+ # 1) anti-diagonal 2) main-diagonal
355
+ # while (max_num_feats_per_direction * len(bases) > max_num_feats) and (len(bases) > 2):
356
+ # bases = bases[:-1]
357
+
358
+ basis = torch.cat(bases, dim=0)
359
+
360
+ # If we still do not fit, then let's remove each second feature,
361
+ # then each third, each forth and so on
362
+ # We cannot drop the whole horizontal or vertical direction since otherwise
363
+ # model won't be able to locate the position
364
+ # (unless the previously computed embeddings encode the position)
365
+ # while basis.shape[0] > max_num_feats:
366
+ # num_exceeding_feats = basis.shape[0] - max_num_feats
367
+ # basis = basis[::2]
368
+
369
+ assert basis.shape[0] <= max_num_feats, \
370
+ f"num_coord_feats > max_num_fixed_coord_feats: {basis.shape, max_num_feats}."
371
+
372
+ return basis
373
+
374
+
375
+ def generate_horizontal_basis(num_feats: int):
376
+ return generate_wavefront_basis(num_feats, [0.0, 1.0], 4.0)
377
+
378
+
379
+ def generate_vertical_basis(num_feats: int):
380
+ return generate_wavefront_basis(num_feats, [1.0, 0.0], 4.0)
381
+
382
+
383
+ def generate_diag_main_basis(num_feats: int):
384
+ return generate_wavefront_basis(num_feats, [-1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
385
+
386
+
387
+ def generate_anti_diag_basis(num_feats: int):
388
+ return generate_wavefront_basis(num_feats, [1.0 / np.sqrt(2), 1.0 / np.sqrt(2)], 4.0 * np.sqrt(2))
389
+
390
+
391
+ def generate_wavefront_basis(num_feats: int, basis_block, period_length: float):
392
+ period_coef = 2.0 * np.pi / period_length
393
+ basis = torch.tensor([basis_block]).repeat(num_feats, 1) # [num_feats, 2]
394
+ powers = torch.tensor([2]).repeat(num_feats).pow(torch.arange(num_feats)).unsqueeze(1) # [num_feats, 1]
395
+ result = basis * powers * period_coef # [num_feats, 2]
396
+
397
+ return result.float()
model/build_model.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from .backbone import build_backbone
3
+
4
+
5
+ class build_model(nn.Module):
6
+ def __init__(self, opt):
7
+ super().__init__()
8
+
9
+ self.opt = opt
10
+ self.backbone = build_backbone('baseline', opt)
11
+
12
+ def forward(self, composite_image, mask, fg_INR_coordinates, start_proportion=None):
13
+ if self.opt.INRDecode and self.opt.hr_train and (self.training or hasattr(self.opt, 'split_num') or hasattr(self.opt, 'split_resolution')):
14
+ """
15
+ For HR Training, due to the designed RSC strategy in Section 3.4 in the paper,
16
+ here we need to pass in the coordinates of the cropped regions.
17
+ """
18
+ extracted_features = self.backbone(composite_image, mask, fg_INR_coordinates, start_proportion=start_proportion)
19
+ else:
20
+ extracted_features = self.backbone(composite_image, mask)
21
+
22
+ if self.opt.INRDecode:
23
+ return extracted_features
24
+ return None, None, extracted_features
model/hrnetv2/__init__.py ADDED
File without changes
model/hrnetv2/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (181 Bytes). View file
 
model/hrnetv2/__pycache__/hrnet_ocr.cpython-38.pyc ADDED
Binary file (10.6 kB). View file
 
model/hrnetv2/__pycache__/modifiers.cpython-38.pyc ADDED
Binary file (704 Bytes). View file
 
model/hrnetv2/__pycache__/ocr.cpython-38.pyc ADDED
Binary file (4.54 kB). View file
 
model/hrnetv2/__pycache__/resnetv1b.cpython-38.pyc ADDED
Binary file (7.54 kB). View file
 
model/hrnetv2/hrnet_ocr.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch._utils
7
+
8
+ from .ocr import SpatialOCR_Module, SpatialGather_Module
9
+ from .resnetv1b import BasicBlockV1b, BottleneckV1b
10
+
11
+ relu_inplace = True
12
+
13
+
14
+ class HighResolutionModule(nn.Module):
15
+ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
16
+ num_channels, fuse_method,multi_scale_output=True,
17
+ norm_layer=nn.BatchNorm2d, align_corners=True):
18
+ super(HighResolutionModule, self).__init__()
19
+ self._check_branches(num_branches, num_blocks, num_inchannels, num_channels)
20
+
21
+ self.num_inchannels = num_inchannels
22
+ self.fuse_method = fuse_method
23
+ self.num_branches = num_branches
24
+ self.norm_layer = norm_layer
25
+ self.align_corners = align_corners
26
+
27
+ self.multi_scale_output = multi_scale_output
28
+
29
+ self.branches = self._make_branches(
30
+ num_branches, blocks, num_blocks, num_channels)
31
+ self.fuse_layers = self._make_fuse_layers()
32
+ self.relu = nn.ReLU(inplace=relu_inplace)
33
+
34
+ def _check_branches(self, num_branches, num_blocks, num_inchannels, num_channels):
35
+ if num_branches != len(num_blocks):
36
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format(
37
+ num_branches, len(num_blocks))
38
+ raise ValueError(error_msg)
39
+
40
+ if num_branches != len(num_channels):
41
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format(
42
+ num_branches, len(num_channels))
43
+ raise ValueError(error_msg)
44
+
45
+ if num_branches != len(num_inchannels):
46
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format(
47
+ num_branches, len(num_inchannels))
48
+ raise ValueError(error_msg)
49
+
50
+ def _make_one_branch(self, branch_index, block, num_blocks, num_channels,
51
+ stride=1):
52
+ downsample = None
53
+ if stride != 1 or \
54
+ self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion:
55
+ downsample = nn.Sequential(
56
+ nn.Conv2d(self.num_inchannels[branch_index],
57
+ num_channels[branch_index] * block.expansion,
58
+ kernel_size=1, stride=stride, bias=False),
59
+ self.norm_layer(num_channels[branch_index] * block.expansion),
60
+ )
61
+
62
+ layers = []
63
+ layers.append(block(self.num_inchannels[branch_index],
64
+ num_channels[branch_index], stride,
65
+ downsample=downsample, norm_layer=self.norm_layer))
66
+ self.num_inchannels[branch_index] = \
67
+ num_channels[branch_index] * block.expansion
68
+ for i in range(1, num_blocks[branch_index]):
69
+ layers.append(block(self.num_inchannels[branch_index],
70
+ num_channels[branch_index],
71
+ norm_layer=self.norm_layer))
72
+
73
+ return nn.Sequential(*layers)
74
+
75
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
76
+ branches = []
77
+
78
+ for i in range(num_branches):
79
+ branches.append(
80
+ self._make_one_branch(i, block, num_blocks, num_channels))
81
+
82
+ return nn.ModuleList(branches)
83
+
84
+ def _make_fuse_layers(self):
85
+ if self.num_branches == 1:
86
+ return None
87
+
88
+ num_branches = self.num_branches
89
+ num_inchannels = self.num_inchannels
90
+ fuse_layers = []
91
+ for i in range(num_branches if self.multi_scale_output else 1):
92
+ fuse_layer = []
93
+ for j in range(num_branches):
94
+ if j > i:
95
+ fuse_layer.append(nn.Sequential(
96
+ nn.Conv2d(in_channels=num_inchannels[j],
97
+ out_channels=num_inchannels[i],
98
+ kernel_size=1,
99
+ bias=False),
100
+ self.norm_layer(num_inchannels[i])))
101
+ elif j == i:
102
+ fuse_layer.append(None)
103
+ else:
104
+ conv3x3s = []
105
+ for k in range(i - j):
106
+ if k == i - j - 1:
107
+ num_outchannels_conv3x3 = num_inchannels[i]
108
+ conv3x3s.append(nn.Sequential(
109
+ nn.Conv2d(num_inchannels[j],
110
+ num_outchannels_conv3x3,
111
+ kernel_size=3, stride=2, padding=1, bias=False),
112
+ self.norm_layer(num_outchannels_conv3x3)))
113
+ else:
114
+ num_outchannels_conv3x3 = num_inchannels[j]
115
+ conv3x3s.append(nn.Sequential(
116
+ nn.Conv2d(num_inchannels[j],
117
+ num_outchannels_conv3x3,
118
+ kernel_size=3, stride=2, padding=1, bias=False),
119
+ self.norm_layer(num_outchannels_conv3x3),
120
+ nn.ReLU(inplace=relu_inplace)))
121
+ fuse_layer.append(nn.Sequential(*conv3x3s))
122
+ fuse_layers.append(nn.ModuleList(fuse_layer))
123
+
124
+ return nn.ModuleList(fuse_layers)
125
+
126
+ def get_num_inchannels(self):
127
+ return self.num_inchannels
128
+
129
+ def forward(self, x):
130
+ if self.num_branches == 1:
131
+ return [self.branches[0](x[0])]
132
+
133
+ for i in range(self.num_branches):
134
+ x[i] = self.branches[i](x[i])
135
+
136
+ x_fuse = []
137
+ for i in range(len(self.fuse_layers)):
138
+ y = x[0] if i == 0 else self.fuse_layers[i][0](x[0])
139
+ for j in range(1, self.num_branches):
140
+ if i == j:
141
+ y = y + x[j]
142
+ elif j > i:
143
+ width_output = x[i].shape[-1]
144
+ height_output = x[i].shape[-2]
145
+ y = y + F.interpolate(
146
+ self.fuse_layers[i][j](x[j]),
147
+ size=[height_output, width_output],
148
+ mode='bilinear', align_corners=self.align_corners)
149
+ else:
150
+ y = y + self.fuse_layers[i][j](x[j])
151
+ x_fuse.append(self.relu(y))
152
+
153
+ return x_fuse
154
+
155
+
156
+ class HighResolutionNet(nn.Module):
157
+ def __init__(self, width, num_classes, ocr_width=256, small=False,
158
+ norm_layer=nn.BatchNorm2d, align_corners=True, opt=None):
159
+ super(HighResolutionNet, self).__init__()
160
+ self.opt = opt
161
+ self.norm_layer = norm_layer
162
+ self.width = width
163
+ self.ocr_width = ocr_width
164
+ self.ocr_on = ocr_width > 0
165
+ self.align_corners = align_corners
166
+
167
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
168
+ self.bn1 = norm_layer(64)
169
+ self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
170
+ self.bn2 = norm_layer(64)
171
+ self.relu = nn.ReLU(inplace=relu_inplace)
172
+
173
+ num_blocks = 2 if small else 4
174
+
175
+ stage1_num_channels = 64
176
+ self.layer1 = self._make_layer(BottleneckV1b, 64, stage1_num_channels, blocks=num_blocks)
177
+ stage1_out_channel = BottleneckV1b.expansion * stage1_num_channels
178
+
179
+ self.stage2_num_branches = 2
180
+ num_channels = [width, 2 * width]
181
+ num_inchannels = [
182
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
183
+ self.transition1 = self._make_transition_layer(
184
+ [stage1_out_channel], num_inchannels)
185
+ self.stage2, pre_stage_channels = self._make_stage(
186
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=1, num_branches=self.stage2_num_branches,
187
+ num_blocks=2 * [num_blocks], num_channels=num_channels)
188
+
189
+ self.stage3_num_branches = 3
190
+ num_channels = [width, 2 * width, 4 * width]
191
+ num_inchannels = [
192
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
193
+ self.transition2 = self._make_transition_layer(
194
+ pre_stage_channels, num_inchannels)
195
+ self.stage3, pre_stage_channels = self._make_stage(
196
+ BasicBlockV1b, num_inchannels=num_inchannels,
197
+ num_modules=3 if small else 4, num_branches=self.stage3_num_branches,
198
+ num_blocks=3 * [num_blocks], num_channels=num_channels)
199
+
200
+ self.stage4_num_branches = 4
201
+ num_channels = [width, 2 * width, 4 * width, 8 * width]
202
+ num_inchannels = [
203
+ num_channels[i] * BasicBlockV1b.expansion for i in range(len(num_channels))]
204
+ self.transition3 = self._make_transition_layer(
205
+ pre_stage_channels, num_inchannels)
206
+ self.stage4, pre_stage_channels = self._make_stage(
207
+ BasicBlockV1b, num_inchannels=num_inchannels, num_modules=2 if small else 3,
208
+ num_branches=self.stage4_num_branches,
209
+ num_blocks=4 * [num_blocks], num_channels=num_channels)
210
+
211
+ if self.ocr_on:
212
+ last_inp_channels = np.int(np.sum(pre_stage_channels))
213
+ ocr_mid_channels = 2 * ocr_width
214
+ ocr_key_channels = ocr_width
215
+
216
+ self.conv3x3_ocr = nn.Sequential(
217
+ nn.Conv2d(last_inp_channels, ocr_mid_channels,
218
+ kernel_size=3, stride=1, padding=1),
219
+ norm_layer(ocr_mid_channels),
220
+ nn.ReLU(inplace=relu_inplace),
221
+ )
222
+ self.ocr_gather_head = SpatialGather_Module(num_classes)
223
+
224
+ self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels,
225
+ key_channels=ocr_key_channels,
226
+ out_channels=ocr_mid_channels,
227
+ scale=1,
228
+ dropout=0.05,
229
+ norm_layer=norm_layer,
230
+ align_corners=align_corners, opt=opt)
231
+
232
+ def _make_transition_layer(
233
+ self, num_channels_pre_layer, num_channels_cur_layer):
234
+ num_branches_cur = len(num_channels_cur_layer)
235
+ num_branches_pre = len(num_channels_pre_layer)
236
+
237
+ transition_layers = []
238
+ for i in range(num_branches_cur):
239
+ if i < num_branches_pre:
240
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
241
+ transition_layers.append(nn.Sequential(
242
+ nn.Conv2d(num_channels_pre_layer[i],
243
+ num_channels_cur_layer[i],
244
+ kernel_size=3,
245
+ stride=1,
246
+ padding=1,
247
+ bias=False),
248
+ self.norm_layer(num_channels_cur_layer[i]),
249
+ nn.ReLU(inplace=relu_inplace)))
250
+ else:
251
+ transition_layers.append(None)
252
+ else:
253
+ conv3x3s = []
254
+ for j in range(i + 1 - num_branches_pre):
255
+ inchannels = num_channels_pre_layer[-1]
256
+ outchannels = num_channels_cur_layer[i] \
257
+ if j == i - num_branches_pre else inchannels
258
+ conv3x3s.append(nn.Sequential(
259
+ nn.Conv2d(inchannels, outchannels,
260
+ kernel_size=3, stride=2, padding=1, bias=False),
261
+ self.norm_layer(outchannels),
262
+ nn.ReLU(inplace=relu_inplace)))
263
+ transition_layers.append(nn.Sequential(*conv3x3s))
264
+
265
+ return nn.ModuleList(transition_layers)
266
+
267
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
268
+ downsample = None
269
+ if stride != 1 or inplanes != planes * block.expansion:
270
+ downsample = nn.Sequential(
271
+ nn.Conv2d(inplanes, planes * block.expansion,
272
+ kernel_size=1, stride=stride, bias=False),
273
+ self.norm_layer(planes * block.expansion),
274
+ )
275
+
276
+ layers = []
277
+ layers.append(block(inplanes, planes, stride,
278
+ downsample=downsample, norm_layer=self.norm_layer))
279
+ inplanes = planes * block.expansion
280
+ for i in range(1, blocks):
281
+ layers.append(block(inplanes, planes, norm_layer=self.norm_layer))
282
+
283
+ return nn.Sequential(*layers)
284
+
285
+ def _make_stage(self, block, num_inchannels,
286
+ num_modules, num_branches, num_blocks, num_channels,
287
+ fuse_method='SUM',
288
+ multi_scale_output=True):
289
+ modules = []
290
+ for i in range(num_modules):
291
+ # multi_scale_output is only used last module
292
+ if not multi_scale_output and i == num_modules - 1:
293
+ reset_multi_scale_output = False
294
+ else:
295
+ reset_multi_scale_output = True
296
+ modules.append(
297
+ HighResolutionModule(num_branches,
298
+ block,
299
+ num_blocks,
300
+ num_inchannels,
301
+ num_channels,
302
+ fuse_method,
303
+ reset_multi_scale_output,
304
+ norm_layer=self.norm_layer,
305
+ align_corners=self.align_corners)
306
+ )
307
+ num_inchannels = modules[-1].get_num_inchannels()
308
+
309
+ return nn.Sequential(*modules), num_inchannels
310
+
311
+ def forward(self, x, mask=None, additional_features=None):
312
+ hrnet_feats = self.compute_hrnet_feats(x, additional_features)
313
+ if not self.ocr_on:
314
+ return hrnet_feats,
315
+
316
+ ocr_feats = self.conv3x3_ocr(hrnet_feats)
317
+ mask = nn.functional.interpolate(mask, size=ocr_feats.size()[2:], mode='bilinear', align_corners=True)
318
+ context = self.ocr_gather_head(ocr_feats, mask)
319
+ ocr_feats = self.ocr_distri_head(ocr_feats, context)
320
+ return ocr_feats,
321
+
322
+ def compute_hrnet_feats(self, x, additional_features, return_list=False):
323
+ x = self.compute_pre_stage_features(x, additional_features)
324
+ x = self.layer1(x)
325
+
326
+ x_list = []
327
+ for i in range(self.stage2_num_branches):
328
+ if self.transition1[i] is not None:
329
+ x_list.append(self.transition1[i](x))
330
+ else:
331
+ x_list.append(x)
332
+ y_list = self.stage2(x_list)
333
+
334
+ x_list = []
335
+ for i in range(self.stage3_num_branches):
336
+ if self.transition2[i] is not None:
337
+ if i < self.stage2_num_branches:
338
+ x_list.append(self.transition2[i](y_list[i]))
339
+ else:
340
+ x_list.append(self.transition2[i](y_list[-1]))
341
+ else:
342
+ x_list.append(y_list[i])
343
+ y_list = self.stage3(x_list)
344
+
345
+ x_list = []
346
+ for i in range(self.stage4_num_branches):
347
+ if self.transition3[i] is not None:
348
+ if i < self.stage3_num_branches:
349
+ x_list.append(self.transition3[i](y_list[i]))
350
+ else:
351
+ x_list.append(self.transition3[i](y_list[-1]))
352
+ else:
353
+ x_list.append(y_list[i])
354
+ x = self.stage4(x_list)
355
+
356
+ if return_list:
357
+ return x
358
+
359
+ # Upsampling
360
+ x0_h, x0_w = x[0].size(2), x[0].size(3)
361
+ x1 = F.interpolate(x[1], size=(x0_h, x0_w),
362
+ mode='bilinear', align_corners=self.align_corners)
363
+ x2 = F.interpolate(x[2], size=(x0_h, x0_w),
364
+ mode='bilinear', align_corners=self.align_corners)
365
+ x3 = F.interpolate(x[3], size=(x0_h, x0_w),
366
+ mode='bilinear', align_corners=self.align_corners)
367
+
368
+ return torch.cat([x[0], x1, x2, x3], 1)
369
+
370
+ def compute_pre_stage_features(self, x, additional_features):
371
+ x = self.conv1(x)
372
+ x = self.bn1(x)
373
+ x = self.relu(x)
374
+ if additional_features is not None:
375
+ x = x + additional_features
376
+ x = self.conv2(x)
377
+ x = self.bn2(x)
378
+ return self.relu(x)
379
+
380
+ def load_pretrained_weights(self, pretrained_path=''):
381
+ model_dict = self.state_dict()
382
+
383
+ if not os.path.exists(pretrained_path):
384
+ print(f'\nFile "{pretrained_path}" does not exist.')
385
+ print('You need to specify the correct path to the pre-trained weights.\n'
386
+ 'You can download the weights for HRNet from the repository:\n'
387
+ 'https://github.com/HRNet/HRNet-Image-Classification')
388
+ exit(1)
389
+ pretrained_dict = torch.load(pretrained_path, map_location={'cuda:0': 'cpu'})
390
+ pretrained_dict = {k.replace('last_layer', 'aux_head').replace('model.', ''): v for k, v in
391
+ pretrained_dict.items()}
392
+ params_count = len(pretrained_dict)
393
+
394
+ pretrained_dict = {k: v for k, v in pretrained_dict.items()
395
+ if k in model_dict.keys()}
396
+
397
+ print(f'Loaded {len(pretrained_dict)} of {params_count} pretrained parameters for HRNet')
398
+
399
+ model_dict.update(pretrained_dict)
400
+ self.load_state_dict(model_dict)
model/hrnetv2/modifiers.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class LRMult(object):
4
+ def __init__(self, lr_mult=1.):
5
+ self.lr_mult = lr_mult
6
+
7
+ def __call__(self, m):
8
+ if getattr(m, 'weight', None) is not None:
9
+ m.weight.lr_mult = self.lr_mult
10
+ if getattr(m, 'bias', None) is not None:
11
+ m.bias.lr_mult = self.lr_mult
model/hrnetv2/ocr.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch._utils
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class SpatialGather_Module(nn.Module):
8
+ """
9
+ Aggregate the context features according to the initial
10
+ predicted probability distribution.
11
+ Employ the soft-weighted method to aggregate the context.
12
+ """
13
+
14
+ def __init__(self, cls_num=0, scale=1):
15
+ super(SpatialGather_Module, self).__init__()
16
+ self.cls_num = cls_num
17
+ self.scale = scale
18
+
19
+ def forward(self, feats, probs):
20
+ batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3)
21
+ probs = probs.view(batch_size, c, -1)
22
+ feats = feats.view(batch_size, feats.size(1), -1)
23
+ feats = feats.permute(0, 2, 1) # batch x hw x c
24
+ probs = F.softmax(self.scale * probs, dim=2) # batch x k x hw
25
+ ocr_context = torch.matmul(probs, feats) \
26
+ .permute(0, 2, 1).unsqueeze(3).contiguous() # batch x k x c
27
+ return ocr_context
28
+
29
+
30
+ class SpatialOCR_Module(nn.Module):
31
+ """
32
+ Implementation of the OCR module:
33
+ We aggregate the global object representation to update the representation for each pixel.
34
+ """
35
+
36
+ def __init__(self,
37
+ in_channels,
38
+ key_channels,
39
+ out_channels,
40
+ scale=1,
41
+ dropout=0.1,
42
+ norm_layer=nn.BatchNorm2d,
43
+ align_corners=True, opt=None):
44
+ super(SpatialOCR_Module, self).__init__()
45
+ self.object_context_block = ObjectAttentionBlock2D(in_channels, key_channels, scale,
46
+ norm_layer, align_corners)
47
+ _in_channels = 2 * in_channels
48
+ self.conv_bn_dropout = nn.Sequential(
49
+ nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False),
50
+ nn.Sequential(norm_layer(out_channels), nn.ReLU(inplace=True)),
51
+ nn.Dropout2d(dropout)
52
+ )
53
+
54
+ def forward(self, feats, proxy_feats):
55
+ context = self.object_context_block(feats, proxy_feats)
56
+
57
+ output = self.conv_bn_dropout(torch.cat([context, feats], 1))
58
+
59
+ return output
60
+
61
+
62
+ class ObjectAttentionBlock2D(nn.Module):
63
+ '''
64
+ The basic implementation for object context block
65
+ Input:
66
+ N X C X H X W
67
+ Parameters:
68
+ in_channels : the dimension of the input feature map
69
+ key_channels : the dimension after the key/query transform
70
+ scale : choose the scale to downsample the input feature maps (save memory cost)
71
+ bn_type : specify the bn type
72
+ Return:
73
+ N X C X H X W
74
+ '''
75
+
76
+ def __init__(self,
77
+ in_channels,
78
+ key_channels,
79
+ scale=1,
80
+ norm_layer=nn.BatchNorm2d,
81
+ align_corners=True):
82
+ super(ObjectAttentionBlock2D, self).__init__()
83
+ self.scale = scale
84
+ self.in_channels = in_channels
85
+ self.key_channels = key_channels
86
+ self.align_corners = align_corners
87
+
88
+ self.pool = nn.MaxPool2d(kernel_size=(scale, scale))
89
+ self.f_pixel = nn.Sequential(
90
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
91
+ kernel_size=1, stride=1, padding=0, bias=False),
92
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
93
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
94
+ kernel_size=1, stride=1, padding=0, bias=False),
95
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
96
+ )
97
+ self.f_object = nn.Sequential(
98
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
99
+ kernel_size=1, stride=1, padding=0, bias=False),
100
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True)),
101
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels,
102
+ kernel_size=1, stride=1, padding=0, bias=False),
103
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
104
+ )
105
+ self.f_down = nn.Sequential(
106
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels,
107
+ kernel_size=1, stride=1, padding=0, bias=False),
108
+ nn.Sequential(norm_layer(self.key_channels), nn.ReLU(inplace=True))
109
+ )
110
+ self.f_up = nn.Sequential(
111
+ nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels,
112
+ kernel_size=1, stride=1, padding=0, bias=False),
113
+ nn.Sequential(norm_layer(self.in_channels), nn.ReLU(inplace=True))
114
+ )
115
+
116
+ def forward(self, x, proxy):
117
+ batch_size, h, w = x.size(0), x.size(2), x.size(3)
118
+ if self.scale > 1:
119
+ x = self.pool(x)
120
+
121
+ query = self.f_pixel(x).view(batch_size, self.key_channels, -1)
122
+ query = query.permute(0, 2, 1)
123
+ key = self.f_object(proxy).view(batch_size, self.key_channels, -1)
124
+ value = self.f_down(proxy).view(batch_size, self.key_channels, -1)
125
+ value = value.permute(0, 2, 1)
126
+
127
+ sim_map = torch.matmul(query, key)
128
+ sim_map = (self.key_channels ** -.5) * sim_map
129
+ sim_map = F.softmax(sim_map, dim=-1)
130
+
131
+ # add bg context ...
132
+ context = torch.matmul(sim_map, value)
133
+ context = context.permute(0, 2, 1).contiguous()
134
+ context = context.view(batch_size, self.key_channels, *x.size()[2:])
135
+ context = self.f_up(context)
136
+ if self.scale > 1:
137
+ context = F.interpolate(input=context, size=(h, w),
138
+ mode='bilinear', align_corners=self.align_corners)
139
+
140
+ return context
model/hrnetv2/resnetv1b.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ GLUON_RESNET_TORCH_HUB = 'rwightman/pytorch-pretrained-gluonresnet'
4
+
5
+
6
+ class BasicBlockV1b(nn.Module):
7
+ expansion = 1
8
+
9
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
10
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
11
+ super(BasicBlockV1b, self).__init__()
12
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride,
13
+ padding=dilation, dilation=dilation, bias=False)
14
+ self.bn1 = norm_layer(planes)
15
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1,
16
+ padding=previous_dilation, dilation=previous_dilation, bias=False)
17
+ self.bn2 = norm_layer(planes)
18
+
19
+ self.relu = nn.ReLU(inplace=True)
20
+ self.downsample = downsample
21
+ self.stride = stride
22
+
23
+ def forward(self, x):
24
+ residual = x
25
+
26
+ out = self.conv1(x)
27
+ out = self.bn1(out)
28
+ out = self.relu(out)
29
+
30
+ out = self.conv2(out)
31
+ out = self.bn2(out)
32
+
33
+ if self.downsample is not None:
34
+ residual = self.downsample(x)
35
+
36
+ out = out + residual
37
+ out = self.relu(out)
38
+
39
+ return out
40
+
41
+
42
+ class BottleneckV1b(nn.Module):
43
+ expansion = 4
44
+
45
+ def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None,
46
+ previous_dilation=1, norm_layer=nn.BatchNorm2d):
47
+ super(BottleneckV1b, self).__init__()
48
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49
+ self.bn1 = norm_layer(planes)
50
+
51
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
52
+ padding=dilation, dilation=dilation, bias=False)
53
+ self.bn2 = norm_layer(planes)
54
+
55
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
56
+ self.bn3 = norm_layer(planes * self.expansion)
57
+
58
+ self.relu = nn.ReLU(inplace=True)
59
+ self.downsample = downsample
60
+ self.stride = stride
61
+
62
+ def forward(self, x):
63
+ residual = x
64
+
65
+ out = self.conv1(x)
66
+ out = self.bn1(out)
67
+ out = self.relu(out)
68
+
69
+ out = self.conv2(out)
70
+ out = self.bn2(out)
71
+ out = self.relu(out)
72
+
73
+ out = self.conv3(out)
74
+ out = self.bn3(out)
75
+
76
+ if self.downsample is not None:
77
+ residual = self.downsample(x)
78
+
79
+ out = out + residual
80
+ out = self.relu(out)
81
+
82
+ return out
83
+
84
+
85
+ class ResNetV1b(nn.Module):
86
+ """ Pre-trained ResNetV1b Model, which produces the strides of 8 featuremaps at conv5.
87
+
88
+ Parameters
89
+ ----------
90
+ block : Block
91
+ Class for the residual block. Options are BasicBlockV1, BottleneckV1.
92
+ layers : list of int
93
+ Numbers of layers in each block
94
+ classes : int, default 1000
95
+ Number of classification classes.
96
+ dilated : bool, default False
97
+ Applying dilation strategy to pretrained ResNet yielding a stride-8 model,
98
+ typically used in Semantic Segmentation.
99
+ norm_layer : object
100
+ Normalization layer used (default: :class:`nn.BatchNorm2d`)
101
+ deep_stem : bool, default False
102
+ Whether to replace the 7x7 conv1 with 3 3x3 convolution layers.
103
+ avg_down : bool, default False
104
+ Whether to use average pooling for projection skip connection between stages/downsample.
105
+ final_drop : float, default 0.0
106
+ Dropout ratio before the final classification layer.
107
+
108
+ Reference:
109
+ - He, Kaiming, et al. "Deep residual learning for image recognition."
110
+ Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
111
+
112
+ - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions."
113
+ """
114
+ def __init__(self, block, layers, classes=1000, dilated=True, deep_stem=False, stem_width=32,
115
+ avg_down=False, final_drop=0.0, norm_layer=nn.BatchNorm2d):
116
+ self.inplanes = stem_width*2 if deep_stem else 64
117
+ super(ResNetV1b, self).__init__()
118
+ if not deep_stem:
119
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
120
+ else:
121
+ self.conv1 = nn.Sequential(
122
+ nn.Conv2d(3, stem_width, kernel_size=3, stride=2, padding=1, bias=False),
123
+ norm_layer(stem_width),
124
+ nn.ReLU(True),
125
+ nn.Conv2d(stem_width, stem_width, kernel_size=3, stride=1, padding=1, bias=False),
126
+ norm_layer(stem_width),
127
+ nn.ReLU(True),
128
+ nn.Conv2d(stem_width, 2*stem_width, kernel_size=3, stride=1, padding=1, bias=False)
129
+ )
130
+ self.bn1 = norm_layer(self.inplanes)
131
+ self.relu = nn.ReLU(True)
132
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
133
+ self.layer1 = self._make_layer(block, 64, layers[0], avg_down=avg_down,
134
+ norm_layer=norm_layer)
135
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2, avg_down=avg_down,
136
+ norm_layer=norm_layer)
137
+ if dilated:
138
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=1, dilation=2,
139
+ avg_down=avg_down, norm_layer=norm_layer)
140
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=1, dilation=4,
141
+ avg_down=avg_down, norm_layer=norm_layer)
142
+ else:
143
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
144
+ avg_down=avg_down, norm_layer=norm_layer)
145
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
146
+ avg_down=avg_down, norm_layer=norm_layer)
147
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
148
+ self.drop = None
149
+ if final_drop > 0.0:
150
+ self.drop = nn.Dropout(final_drop)
151
+ self.fc = nn.Linear(512 * block.expansion, classes)
152
+
153
+ def _make_layer(self, block, planes, blocks, stride=1, dilation=1,
154
+ avg_down=False, norm_layer=nn.BatchNorm2d):
155
+ downsample = None
156
+ if stride != 1 or self.inplanes != planes * block.expansion:
157
+ downsample = []
158
+ if avg_down:
159
+ if dilation == 1:
160
+ downsample.append(
161
+ nn.AvgPool2d(kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)
162
+ )
163
+ else:
164
+ downsample.append(
165
+ nn.AvgPool2d(kernel_size=1, stride=1, ceil_mode=True, count_include_pad=False)
166
+ )
167
+ downsample.extend([
168
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
169
+ kernel_size=1, stride=1, bias=False),
170
+ norm_layer(planes * block.expansion)
171
+ ])
172
+ downsample = nn.Sequential(*downsample)
173
+ else:
174
+ downsample = nn.Sequential(
175
+ nn.Conv2d(self.inplanes, out_channels=planes * block.expansion,
176
+ kernel_size=1, stride=stride, bias=False),
177
+ norm_layer(planes * block.expansion)
178
+ )
179
+
180
+ layers = []
181
+ if dilation in (1, 2):
182
+ layers.append(block(self.inplanes, planes, stride, dilation=1, downsample=downsample,
183
+ previous_dilation=dilation, norm_layer=norm_layer))
184
+ elif dilation == 4:
185
+ layers.append(block(self.inplanes, planes, stride, dilation=2, downsample=downsample,
186
+ previous_dilation=dilation, norm_layer=norm_layer))
187
+ else:
188
+ raise RuntimeError("=> unknown dilation size: {}".format(dilation))
189
+
190
+ self.inplanes = planes * block.expansion
191
+ for _ in range(1, blocks):
192
+ layers.append(block(self.inplanes, planes, dilation=dilation,
193
+ previous_dilation=dilation, norm_layer=norm_layer))
194
+
195
+ return nn.Sequential(*layers)
196
+
197
+ def forward(self, x):
198
+ x = self.conv1(x)
199
+ x = self.bn1(x)
200
+ x = self.relu(x)
201
+ x = self.maxpool(x)
202
+
203
+ x = self.layer1(x)
204
+ x = self.layer2(x)
205
+ x = self.layer3(x)
206
+ x = self.layer4(x)
207
+
208
+ x = self.avgpool(x)
209
+ x = x.view(x.size(0), -1)
210
+ if self.drop is not None:
211
+ x = self.drop(x)
212
+ x = self.fc(x)
213
+
214
+ return x
215
+
216
+
217
+ def _safe_state_dict_filtering(orig_dict, model_dict_keys):
218
+ filtered_orig_dict = {}
219
+ for k, v in orig_dict.items():
220
+ if k in model_dict_keys:
221
+ filtered_orig_dict[k] = v
222
+ else:
223
+ print(f"[ERROR] Failed to load <{k}> in backbone")
224
+ return filtered_orig_dict
225
+
226
+
227
+ def resnet34_v1b(pretrained=False, **kwargs):
228
+ model = ResNetV1b(BasicBlockV1b, [3, 4, 6, 3], **kwargs)
229
+ if pretrained:
230
+ model_dict = model.state_dict()
231
+ filtered_orig_dict = _safe_state_dict_filtering(
232
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet34_v1b', pretrained=True).state_dict(),
233
+ model_dict.keys()
234
+ )
235
+ model_dict.update(filtered_orig_dict)
236
+ model.load_state_dict(model_dict)
237
+ return model
238
+
239
+
240
+ def resnet50_v1s(pretrained=False, **kwargs):
241
+ model = ResNetV1b(BottleneckV1b, [3, 4, 6, 3], deep_stem=True, stem_width=64, **kwargs)
242
+ if pretrained:
243
+ model_dict = model.state_dict()
244
+ filtered_orig_dict = _safe_state_dict_filtering(
245
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet50_v1s', pretrained=True).state_dict(),
246
+ model_dict.keys()
247
+ )
248
+ model_dict.update(filtered_orig_dict)
249
+ model.load_state_dict(model_dict)
250
+ return model
251
+
252
+
253
+ def resnet101_v1s(pretrained=False, **kwargs):
254
+ model = ResNetV1b(BottleneckV1b, [3, 4, 23, 3], deep_stem=True, stem_width=64, **kwargs)
255
+ if pretrained:
256
+ model_dict = model.state_dict()
257
+ filtered_orig_dict = _safe_state_dict_filtering(
258
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet101_v1s', pretrained=True).state_dict(),
259
+ model_dict.keys()
260
+ )
261
+ model_dict.update(filtered_orig_dict)
262
+ model.load_state_dict(model_dict)
263
+ return model
264
+
265
+
266
+ def resnet152_v1s(pretrained=False, **kwargs):
267
+ model = ResNetV1b(BottleneckV1b, [3, 8, 36, 3], deep_stem=True, stem_width=64, **kwargs)
268
+ if pretrained:
269
+ model_dict = model.state_dict()
270
+ filtered_orig_dict = _safe_state_dict_filtering(
271
+ torch.hub.load(GLUON_RESNET_TORCH_HUB, 'gluon_resnet152_v1s', pretrained=True).state_dict(),
272
+ model_dict.keys()
273
+ )
274
+ model_dict.update(filtered_orig_dict)
275
+ model.load_state_dict(model_dict)
276
+ return model
model/lut_transformation_net.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from utils.misc import normalize
6
+
7
+
8
+ class build_lut_transform(nn.Module):
9
+
10
+ def __init__(self, input_dim, lut_dim, input_resolution, opt):
11
+ super().__init__()
12
+
13
+ self.lut_dim = lut_dim
14
+ self.opt = opt
15
+
16
+ # self.compress_layer = nn.Linear(input_resolution, 1)
17
+
18
+ self.transform_layers = nn.Sequential(
19
+ nn.Linear(input_dim, 3 * lut_dim ** 3, bias=True),
20
+ # nn.BatchNorm1d(3 * lut_dim ** 3, affine=False),
21
+ nn.ReLU(inplace=True),
22
+ nn.Linear(3 * lut_dim ** 3, 3 * lut_dim ** 3, bias=True),
23
+ )
24
+ self.transform_layers[-1].apply(lambda m: hyper_weight_init(m))
25
+
26
+ def forward(self, composite_image, fg_appearance_features, bg_appearance_features):
27
+ composite_image = normalize(composite_image, self.opt, 'inv')
28
+
29
+ features = fg_appearance_features
30
+
31
+ lut_params = self.transform_layers(features)
32
+
33
+ fit_3DLUT = lut_params.view(lut_params.shape[0], 3, self.lut_dim, self.lut_dim, self.lut_dim)
34
+
35
+ lut_transform_image = torch.stack(
36
+ [TrilinearInterpolation(lut, image)[0] for lut, image in zip(fit_3DLUT, composite_image)], dim=0)
37
+
38
+ return fit_3DLUT, normalize(lut_transform_image, self.opt)
39
+
40
+
41
+ def TrilinearInterpolation(LUT, img):
42
+ img = (img - 0.5) * 2.
43
+
44
+ img = img.unsqueeze(0).permute(0, 2, 3, 1)[:, None].flip(-1)
45
+
46
+ # Note that the coordinates in the grid_sample are inverse to LUT DHW, i.e., xyz is to WHD not DHW.
47
+ LUT = LUT[None]
48
+
49
+ # grid sample
50
+ result = F.grid_sample(LUT, img, mode='bilinear', padding_mode='border', align_corners=True)
51
+
52
+ # drop added dimensions and permute back
53
+ result = result[:, :, 0]
54
+
55
+ return result
56
+
57
+
58
+ def hyper_weight_init(m):
59
+ if hasattr(m, 'weight'):
60
+ nn.init.kaiming_normal_(m.weight, a=0.0, nonlinearity='relu', mode='fan_in')
61
+ m.weight.data = m.weight.data / 1.e2
62
+
63
+ if hasattr(m, 'bias'):
64
+ with torch.no_grad():
65
+ m.bias.uniform_(0., 1.)
pretrained_models/Resolution_1024_HAdobe5K.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4917e99cc20c2530b6d248d530368929c1784113d20365085b96bbb10860a2f8
3
+ size 477235439