Andy1621 commited on
Commit
298eceb
·
1 Parent(s): c2b0268

Delete transform.py

Browse files
Files changed (1) hide show
  1. transform.py +0 -443
transform.py DELETED
@@ -1,443 +0,0 @@
1
- import torchvision
2
- import random
3
- from PIL import Image, ImageOps
4
- import numpy as np
5
- import numbers
6
- import math
7
- import torch
8
-
9
-
10
- class GroupRandomCrop(object):
11
- def __init__(self, size):
12
- if isinstance(size, numbers.Number):
13
- self.size = (int(size), int(size))
14
- else:
15
- self.size = size
16
-
17
- def __call__(self, img_group):
18
-
19
- w, h = img_group[0].size
20
- th, tw = self.size
21
-
22
- out_images = list()
23
-
24
- x1 = random.randint(0, w - tw)
25
- y1 = random.randint(0, h - th)
26
-
27
- for img in img_group:
28
- assert(img.size[0] == w and img.size[1] == h)
29
- if w == tw and h == th:
30
- out_images.append(img)
31
- else:
32
- out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
33
-
34
- return out_images
35
-
36
-
37
- class MultiGroupRandomCrop(object):
38
- def __init__(self, size, groups=1):
39
- if isinstance(size, numbers.Number):
40
- self.size = (int(size), int(size))
41
- else:
42
- self.size = size
43
- self.groups = groups
44
-
45
- def __call__(self, img_group):
46
-
47
- w, h = img_group[0].size
48
- th, tw = self.size
49
-
50
- out_images = list()
51
-
52
- for i in range(self.groups):
53
- x1 = random.randint(0, w - tw)
54
- y1 = random.randint(0, h - th)
55
-
56
- for img in img_group:
57
- assert(img.size[0] == w and img.size[1] == h)
58
- if w == tw and h == th:
59
- out_images.append(img)
60
- else:
61
- out_images.append(img.crop((x1, y1, x1 + tw, y1 + th)))
62
-
63
- return out_images
64
-
65
-
66
- class GroupCenterCrop(object):
67
- def __init__(self, size):
68
- self.worker = torchvision.transforms.CenterCrop(size)
69
-
70
- def __call__(self, img_group):
71
- return [self.worker(img) for img in img_group]
72
-
73
-
74
- class GroupRandomHorizontalFlip(object):
75
- """Randomly horizontally flips the given PIL.Image with a probability of 0.5
76
- """
77
-
78
- def __init__(self, is_flow=False):
79
- self.is_flow = is_flow
80
-
81
- def __call__(self, img_group, is_flow=False):
82
- v = random.random()
83
- if v < 0.5:
84
- ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
85
- if self.is_flow:
86
- for i in range(0, len(ret), 2):
87
- # invert flow pixel values when flipping
88
- ret[i] = ImageOps.invert(ret[i])
89
- return ret
90
- else:
91
- return img_group
92
-
93
-
94
- class GroupNormalize(object):
95
- def __init__(self, mean, std):
96
- self.mean = mean
97
- self.std = std
98
-
99
- def __call__(self, tensor):
100
- rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
101
- rep_std = self.std * (tensor.size()[0] // len(self.std))
102
-
103
- # TODO: make efficient
104
- for t, m, s in zip(tensor, rep_mean, rep_std):
105
- t.sub_(m).div_(s)
106
-
107
- return tensor
108
-
109
-
110
- class GroupScale(object):
111
- """ Rescales the input PIL.Image to the given 'size'.
112
- 'size' will be the size of the smaller edge.
113
- For example, if height > width, then image will be
114
- rescaled to (size * height / width, size)
115
- size: size of the smaller edge
116
- interpolation: Default: PIL.Image.BILINEAR
117
- """
118
-
119
- def __init__(self, size, interpolation=Image.BILINEAR):
120
- self.worker = torchvision.transforms.Resize(size, interpolation)
121
-
122
- def __call__(self, img_group):
123
- return [self.worker(img) for img in img_group]
124
-
125
-
126
- class GroupOverSample(object):
127
- def __init__(self, crop_size, scale_size=None, flip=True):
128
- self.crop_size = crop_size if not isinstance(
129
- crop_size, int) else (crop_size, crop_size)
130
-
131
- if scale_size is not None:
132
- self.scale_worker = GroupScale(scale_size)
133
- else:
134
- self.scale_worker = None
135
- self.flip = flip
136
-
137
- def __call__(self, img_group):
138
-
139
- if self.scale_worker is not None:
140
- img_group = self.scale_worker(img_group)
141
-
142
- image_w, image_h = img_group[0].size
143
- crop_w, crop_h = self.crop_size
144
-
145
- offsets = GroupMultiScaleCrop.fill_fix_offset(
146
- False, image_w, image_h, crop_w, crop_h)
147
- oversample_group = list()
148
- for o_w, o_h in offsets:
149
- normal_group = list()
150
- flip_group = list()
151
- for i, img in enumerate(img_group):
152
- crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
153
- normal_group.append(crop)
154
- flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
155
-
156
- if img.mode == 'L' and i % 2 == 0:
157
- flip_group.append(ImageOps.invert(flip_crop))
158
- else:
159
- flip_group.append(flip_crop)
160
-
161
- oversample_group.extend(normal_group)
162
- if self.flip:
163
- oversample_group.extend(flip_group)
164
- return oversample_group
165
-
166
-
167
- class GroupFullResSample(object):
168
- def __init__(self, crop_size, scale_size=None, flip=True):
169
- self.crop_size = crop_size if not isinstance(
170
- crop_size, int) else (crop_size, crop_size)
171
-
172
- if scale_size is not None:
173
- self.scale_worker = GroupScale(scale_size)
174
- else:
175
- self.scale_worker = None
176
- self.flip = flip
177
-
178
- def __call__(self, img_group):
179
-
180
- if self.scale_worker is not None:
181
- img_group = self.scale_worker(img_group)
182
-
183
- image_w, image_h = img_group[0].size
184
- crop_w, crop_h = self.crop_size
185
-
186
- w_step = (image_w - crop_w) // 4
187
- h_step = (image_h - crop_h) // 4
188
-
189
- offsets = list()
190
- offsets.append((0 * w_step, 2 * h_step)) # left
191
- offsets.append((4 * w_step, 2 * h_step)) # right
192
- offsets.append((2 * w_step, 2 * h_step)) # center
193
-
194
- oversample_group = list()
195
- for o_w, o_h in offsets:
196
- normal_group = list()
197
- flip_group = list()
198
- for i, img in enumerate(img_group):
199
- crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h))
200
- normal_group.append(crop)
201
- if self.flip:
202
- flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT)
203
-
204
- if img.mode == 'L' and i % 2 == 0:
205
- flip_group.append(ImageOps.invert(flip_crop))
206
- else:
207
- flip_group.append(flip_crop)
208
-
209
- oversample_group.extend(normal_group)
210
- oversample_group.extend(flip_group)
211
- return oversample_group
212
-
213
-
214
- class GroupMultiScaleCrop(object):
215
-
216
- def __init__(self, input_size, scales=None, max_distort=1,
217
- fix_crop=True, more_fix_crop=True):
218
- self.scales = scales if scales is not None else [1, .875, .75, .66]
219
- self.max_distort = max_distort
220
- self.fix_crop = fix_crop
221
- self.more_fix_crop = more_fix_crop
222
- self.input_size = input_size if not isinstance(input_size, int) else [
223
- input_size, input_size]
224
- self.interpolation = Image.BILINEAR
225
-
226
- def __call__(self, img_group):
227
-
228
- im_size = img_group[0].size
229
-
230
- crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size)
231
- crop_img_group = [
232
- img.crop(
233
- (offset_w,
234
- offset_h,
235
- offset_w +
236
- crop_w,
237
- offset_h +
238
- crop_h)) for img in img_group]
239
- ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation)
240
- for img in crop_img_group]
241
- return ret_img_group
242
-
243
- def _sample_crop_size(self, im_size):
244
- image_w, image_h = im_size[0], im_size[1]
245
-
246
- # find a crop size
247
- base_size = min(image_w, image_h)
248
- crop_sizes = [int(base_size * x) for x in self.scales]
249
- crop_h = [
250
- self.input_size[1] if abs(
251
- x - self.input_size[1]) < 3 else x for x in crop_sizes]
252
- crop_w = [
253
- self.input_size[0] if abs(
254
- x - self.input_size[0]) < 3 else x for x in crop_sizes]
255
-
256
- pairs = []
257
- for i, h in enumerate(crop_h):
258
- for j, w in enumerate(crop_w):
259
- if abs(i - j) <= self.max_distort:
260
- pairs.append((w, h))
261
-
262
- crop_pair = random.choice(pairs)
263
- if not self.fix_crop:
264
- w_offset = random.randint(0, image_w - crop_pair[0])
265
- h_offset = random.randint(0, image_h - crop_pair[1])
266
- else:
267
- w_offset, h_offset = self._sample_fix_offset(
268
- image_w, image_h, crop_pair[0], crop_pair[1])
269
-
270
- return crop_pair[0], crop_pair[1], w_offset, h_offset
271
-
272
- def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h):
273
- offsets = self.fill_fix_offset(
274
- self.more_fix_crop, image_w, image_h, crop_w, crop_h)
275
- return random.choice(offsets)
276
-
277
- @staticmethod
278
- def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h):
279
- w_step = (image_w - crop_w) // 4
280
- h_step = (image_h - crop_h) // 4
281
-
282
- ret = list()
283
- ret.append((0, 0)) # upper left
284
- ret.append((4 * w_step, 0)) # upper right
285
- ret.append((0, 4 * h_step)) # lower left
286
- ret.append((4 * w_step, 4 * h_step)) # lower right
287
- ret.append((2 * w_step, 2 * h_step)) # center
288
-
289
- if more_fix_crop:
290
- ret.append((0, 2 * h_step)) # center left
291
- ret.append((4 * w_step, 2 * h_step)) # center right
292
- ret.append((2 * w_step, 4 * h_step)) # lower center
293
- ret.append((2 * w_step, 0 * h_step)) # upper center
294
-
295
- ret.append((1 * w_step, 1 * h_step)) # upper left quarter
296
- ret.append((3 * w_step, 1 * h_step)) # upper right quarter
297
- ret.append((1 * w_step, 3 * h_step)) # lower left quarter
298
- ret.append((3 * w_step, 3 * h_step)) # lower righ quarter
299
-
300
- return ret
301
-
302
-
303
- class GroupRandomSizedCrop(object):
304
- """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
305
- and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
306
- This is popularly used to train the Inception networks
307
- size: size of the smaller edge
308
- interpolation: Default: PIL.Image.BILINEAR
309
- """
310
-
311
- def __init__(self, size, interpolation=Image.BILINEAR):
312
- self.size = size
313
- self.interpolation = interpolation
314
-
315
- def __call__(self, img_group):
316
- for attempt in range(10):
317
- area = img_group[0].size[0] * img_group[0].size[1]
318
- target_area = random.uniform(0.08, 1.0) * area
319
- aspect_ratio = random.uniform(3. / 4, 4. / 3)
320
-
321
- w = int(round(math.sqrt(target_area * aspect_ratio)))
322
- h = int(round(math.sqrt(target_area / aspect_ratio)))
323
-
324
- if random.random() < 0.5:
325
- w, h = h, w
326
-
327
- if w <= img_group[0].size[0] and h <= img_group[0].size[1]:
328
- x1 = random.randint(0, img_group[0].size[0] - w)
329
- y1 = random.randint(0, img_group[0].size[1] - h)
330
- found = True
331
- break
332
- else:
333
- found = False
334
- x1 = 0
335
- y1 = 0
336
-
337
- if found:
338
- out_group = list()
339
- for img in img_group:
340
- img = img.crop((x1, y1, x1 + w, y1 + h))
341
- assert(img.size == (w, h))
342
- out_group.append(
343
- img.resize(
344
- (self.size, self.size), self.interpolation))
345
- return out_group
346
- else:
347
- # Fallback
348
- scale = GroupScale(self.size, interpolation=self.interpolation)
349
- crop = GroupRandomCrop(self.size)
350
- return crop(scale(img_group))
351
-
352
-
353
- class ConvertDataFormat(object):
354
- def __init__(self, model_type):
355
- self.model_type = model_type
356
-
357
- def __call__(self, images):
358
- if self.model_type == '2D':
359
- return images
360
- tc, h, w = images.size()
361
- t = tc // 3
362
- images = images.view(t, 3, h, w)
363
- images = images.permute(1, 0, 2, 3)
364
- return images
365
-
366
-
367
- class Stack(object):
368
-
369
- def __init__(self, roll=False):
370
- self.roll = roll
371
-
372
- def __call__(self, img_group):
373
- if img_group[0].mode == 'L':
374
- return np.concatenate([np.expand_dims(x, 2)
375
- for x in img_group], axis=2)
376
- elif img_group[0].mode == 'RGB':
377
- if self.roll:
378
- return np.concatenate([np.array(x)[:, :, ::-1]
379
- for x in img_group], axis=2)
380
- else:
381
- #print(np.concatenate(img_group, axis=2).shape)
382
- # print(img_group[0].shape)
383
- return np.concatenate(img_group, axis=2)
384
-
385
-
386
- class ToTorchFormatTensor(object):
387
- """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
388
- to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
389
-
390
- def __init__(self, div=True):
391
- self.div = div
392
-
393
- def __call__(self, pic):
394
- if isinstance(pic, np.ndarray):
395
- # handle numpy array
396
- img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
397
- else:
398
- # handle PIL Image
399
- img = torch.ByteTensor(
400
- torch.ByteStorage.from_buffer(
401
- pic.tobytes()))
402
- img = img.view(pic.size[1], pic.size[0], len(pic.mode))
403
- # put it from HWC to CHW format
404
- # yikes, this transpose takes 80% of the loading time/CPU
405
- img = img.transpose(0, 1).transpose(0, 2).contiguous()
406
- return img.float().div(255) if self.div else img.float()
407
-
408
-
409
- class IdentityTransform(object):
410
-
411
- def __call__(self, data):
412
- return data
413
-
414
-
415
- if __name__ == "__main__":
416
- trans = torchvision.transforms.Compose([
417
- GroupScale(256),
418
- GroupRandomCrop(224),
419
- Stack(),
420
- ToTorchFormatTensor(),
421
- GroupNormalize(
422
- mean=[.485, .456, .406],
423
- std=[.229, .224, .225]
424
- )]
425
- )
426
-
427
- im = Image.open('../tensorflow-model-zoo.torch/lena_299.png')
428
-
429
- color_group = [im] * 3
430
- rst = trans(color_group)
431
-
432
- gray_group = [im.convert('L')] * 9
433
- gray_rst = trans(gray_group)
434
-
435
- trans2 = torchvision.transforms.Compose([
436
- GroupRandomSizedCrop(256),
437
- Stack(),
438
- ToTorchFormatTensor(),
439
- GroupNormalize(
440
- mean=[.485, .456, .406],
441
- std=[.229, .224, .225])
442
- ])
443
- print(trans2(color_group))