Feng Wang commited on
Commit
1b23f0f
·
1 Parent(s): 67f8016

refactor(data): refactor dataset logic.

Browse files
yolox/data/datasets/mosaicdetection.py CHANGED
@@ -13,22 +13,51 @@ from ..data_augment import box_candidates, random_perspective
13
  from .datasets_wrapper import Dataset
14
 
15
 
16
- class MosaicDetection(Dataset):
17
- """Detection dataset wrapper that performs mixup for normal dataset.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- Parameters
20
- ----------
21
- dataset : Pytorch Dataset
22
- Gluon dataset object.
23
- *args : list
24
- Additional arguments for mixup random sampler.
25
- """
26
 
27
  def __init__(
28
  self, dataset, img_size, mosaic=True, preproc=None,
29
  degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
30
  shear=2.0, perspective=0.0, enable_mixup=True, *args
31
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  super().__init__(img_size, mosaic=mosaic)
33
  self._dataset = dataset
34
  self.preproc = preproc
@@ -38,7 +67,7 @@ class MosaicDetection(Dataset):
38
  self.shear = shear
39
  self.perspective = perspective
40
  self.mixup_scale = mscale
41
- self._mosaic = mosaic
42
  self.enable_mixup = enable_mixup
43
 
44
  def __len__(self):
@@ -46,79 +75,71 @@ class MosaicDetection(Dataset):
46
 
47
  @Dataset.resize_getitem
48
  def __getitem__(self, idx):
49
- if self._mosaic:
50
- labels4 = []
51
  input_dim = self._dataset.input_dim
 
 
52
  # yc, xc = s, s # mosaic center x, y
53
- yc = int(random.uniform(0.5 * input_dim[0], 1.5 * input_dim[0]))
54
- xc = int(random.uniform(0.5 * input_dim[1], 1.5 * input_dim[1]))
55
 
56
  # 3 additional image indices
57
  indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
58
 
59
- for i, index in enumerate(indices):
60
  img, _labels, _, _ = self._dataset.pull_item(index)
61
  h0, w0 = img.shape[:2] # orig hw
62
- scale = min(1. * input_dim[0] / h0, 1. * input_dim[1] / w0)
63
- interp = cv2.INTER_LINEAR
64
- img = cv2.resize(img, (int(w0 * scale), int(h0 * scale)), interpolation=interp)
65
- (h, w) = img.shape[:2]
66
-
67
- if i == 0: # top left
68
- # base image with 4 tiles
69
- img4 = np.full(
70
- (input_dim[0] * 2, input_dim[1] * 2, img.shape[2]), 114, dtype=np.uint8
71
- )
72
- # xmin, ymin, xmax, ymax (large image)
73
- x1a, y1a, x2a, y2a = (max(xc - w, 0), max(yc - h, 0), xc, yc,)
74
- # xmin, ymin, xmax, ymax (small image)
75
- x1b, y1b, x2b, y2b = (w - (x2a - x1a), h - (y2a - y1a), w, h,)
76
- elif i == 1: # top right
77
- x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, input_dim[1] * 2), yc
78
- x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
79
- elif i == 2: # bottom left
80
- x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(input_dim[0] * 2, yc + h)
81
- x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
82
- elif i == 3: # bottom right
83
- x1a, y1a, x2a, y2a = xc, yc, min(xc + w, input_dim[1] * 2), min(input_dim[0] * 2, yc + h) # noqa
84
- x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
85
-
86
- img4[y1a:y2a, x1a:x2a] = img[y1b:y2b, x1b:x2b] # img4[ymin:ymax, xmin:xmax]
87
- padw = x1a - x1b
88
- padh = y1a - y1b
89
-
90
- labels = _labels.copy() # [[xmin, ymin, xmax, ymax, label_ind], ... ]
91
- if _labels.size > 0: # Normalized xywh to pixel xyxy format
92
  labels[:, 0] = scale * _labels[:, 0] + padw
93
  labels[:, 1] = scale * _labels[:, 1] + padh
94
  labels[:, 2] = scale * _labels[:, 2] + padw
95
  labels[:, 3] = scale * _labels[:, 3] + padh
96
- labels4.append(labels)
97
-
98
- if len(labels4):
99
- labels4 = np.concatenate(labels4, 0)
100
- np.clip(labels4[:, 0], 0, 2 * input_dim[1], out=labels4[:, 0])
101
- np.clip(labels4[:, 1], 0, 2 * input_dim[0], out=labels4[:, 1])
102
- np.clip(labels4[:, 2], 0, 2 * input_dim[1], out=labels4[:, 2])
103
- np.clip(labels4[:, 3], 0, 2 * input_dim[0], out=labels4[:, 3])
104
-
105
- img4, labels4 = random_perspective(
106
- img4,
107
- labels4,
108
  degrees=self.degrees,
109
  translate=self.translate,
110
  scale=self.scale,
111
  shear=self.shear,
112
  perspective=self.perspective,
113
- border=[-input_dim[0] // 2, -input_dim[1] // 2],
114
  ) # border to remove
115
 
116
  # -----------------------------------------------------------------
117
  # CopyPaste: https://arxiv.org/abs/2012.07177
118
  # -----------------------------------------------------------------
119
- if self.enable_mixup and not len(labels4) == 0:
120
- img4, labels4 = self.mixup(img4, labels4, self.input_dim)
121
- mix_img, padded_labels = self.preproc(img4, labels4, self.input_dim)
122
  img_info = (mix_img.shape[1], mix_img.shape[0])
123
 
124
  return mix_img, padded_labels, img_info, int(idx)
 
13
  from .datasets_wrapper import Dataset
14
 
15
 
16
+ def get_mosaic_coordinate(mosaic_image, mosaic_index, xc, yc, w, h, input_h, input_w):
17
+ # TODO update doc
18
+ # index0 to top left part of image
19
+ if mosaic_index == 0:
20
+ x1, y1, x2, y2 = max(xc - w, 0), max(yc - h, 0), xc, yc
21
+ small_coord = w - (x2 - x1), h - (y2 - y1), w, h
22
+ # index1 to top right part of image
23
+ elif mosaic_index == 1:
24
+ x1, y1, x2, y2 = xc, max(yc - h, 0), min(xc + w, input_w * 2), yc
25
+ small_coord = 0, h - (y2 - y1), min(w, x2 - x1), h
26
+ # index2 to bottom left part of image
27
+ elif mosaic_index == 2:
28
+ x1, y1, x2, y2 = max(xc - w, 0), yc, xc, min(input_h * 2, yc + h)
29
+ small_coord = w - (x2 - x1), 0, w, min(y2 - y1, h)
30
+ # index2 to bottom right part of image
31
+ elif mosaic_index == 3:
32
+ x1, y1, x2, y2 = xc, yc, min(xc + w, input_w * 2), min(input_h * 2, yc + h) # noqa
33
+ small_coord = 0, 0, min(w, x2 - x1), min(y2 - y1, h)
34
+ return (x1, y1, x2, y2), small_coord
35
+
36
 
37
+ class MosaicDetection(Dataset):
38
+ """Detection dataset wrapper that performs mixup for normal dataset."""
 
 
 
 
 
39
 
40
  def __init__(
41
  self, dataset, img_size, mosaic=True, preproc=None,
42
  degrees=10.0, translate=0.1, scale=(0.5, 1.5), mscale=(0.5, 1.5),
43
  shear=2.0, perspective=0.0, enable_mixup=True, *args
44
  ):
45
+ """
46
+
47
+ Args:
48
+ dataset(Dataset) : Pytorch dataset object.
49
+ img_size (tuple):
50
+ mosaic (bool): enable mosaic augmentation or not.
51
+ preproc (func):
52
+ degrees (float):
53
+ translate (float):
54
+ scale (tuple):
55
+ mscale (tuple):
56
+ shear (float):
57
+ perspective (float):
58
+ enable_mixup (bool):
59
+ *args(tuple) : Additional arguments for mixup random sampler.
60
+ """
61
  super().__init__(img_size, mosaic=mosaic)
62
  self._dataset = dataset
63
  self.preproc = preproc
 
67
  self.shear = shear
68
  self.perspective = perspective
69
  self.mixup_scale = mscale
70
+ self.enable_mosaic = mosaic
71
  self.enable_mixup = enable_mixup
72
 
73
  def __len__(self):
 
75
 
76
  @Dataset.resize_getitem
77
  def __getitem__(self, idx):
78
+ if self.enable_mosaic:
79
+ mosaic_labels = []
80
  input_dim = self._dataset.input_dim
81
+ input_h, input_w = input_dim[0], input_dim[1]
82
+
83
  # yc, xc = s, s # mosaic center x, y
84
+ yc = int(random.uniform(0.5 * input_h, 1.5 * input_h))
85
+ xc = int(random.uniform(0.5 * input_w, 1.5 * input_w))
86
 
87
  # 3 additional image indices
88
  indices = [idx] + [random.randint(0, len(self._dataset) - 1) for _ in range(3)]
89
 
90
+ for i_mosaic, index in enumerate(indices):
91
  img, _labels, _, _ = self._dataset.pull_item(index)
92
  h0, w0 = img.shape[:2] # orig hw
93
+ scale = min(1. * input_h / h0, 1. * input_w / w0)
94
+ img = cv2.resize(
95
+ img, (int(w0 * scale), int(h0 * scale)), interpolation=cv2.INTER_LINEAR
96
+ )
97
+ # generate output mosaic image
98
+ (h, w, c) = img.shape[:3]
99
+ if i_mosaic == 0:
100
+ mosaic_img = np.full((input_h * 2, input_w * 2, c), 114, dtype=np.uint8)
101
+
102
+ # suffix l means large image, while s means small image in mosaic aug.
103
+ (l_x1, l_y1, l_x2, l_y2), (s_x1, s_y1, s_x2, s_y2) = get_mosaic_coordinate(
104
+ mosaic_img, i_mosaic, xc, yc, w, h, input_h, input_w
105
+ )
106
+
107
+ mosaic_img[l_y1:l_y2, l_x1:l_x2] = img[s_y1:s_y2, s_x1:s_x2]
108
+ padw, padh = l_x1 - s_x1, l_y1 - s_y1
109
+
110
+ labels = _labels.copy()
111
+ # Normalized xywh to pixel xyxy format
112
+ if _labels.size > 0:
 
 
 
 
 
 
 
 
 
 
113
  labels[:, 0] = scale * _labels[:, 0] + padw
114
  labels[:, 1] = scale * _labels[:, 1] + padh
115
  labels[:, 2] = scale * _labels[:, 2] + padw
116
  labels[:, 3] = scale * _labels[:, 3] + padh
117
+ mosaic_labels.append(labels)
118
+
119
+ if len(mosaic_labels):
120
+ mosaic_labels = np.concatenate(mosaic_labels, 0)
121
+ np.clip(mosaic_labels[:, 0], 0, 2 * input_w, out=mosaic_labels[:, 0])
122
+ np.clip(mosaic_labels[:, 1], 0, 2 * input_h, out=mosaic_labels[:, 1])
123
+ np.clip(mosaic_labels[:, 2], 0, 2 * input_w, out=mosaic_labels[:, 2])
124
+ np.clip(mosaic_labels[:, 3], 0, 2 * input_h, out=mosaic_labels[:, 3])
125
+
126
+ mosaic_img, mosaic_labels = random_perspective(
127
+ mosaic_img,
128
+ mosaic_labels,
129
  degrees=self.degrees,
130
  translate=self.translate,
131
  scale=self.scale,
132
  shear=self.shear,
133
  perspective=self.perspective,
134
+ border=[-input_h // 2, -input_w // 2],
135
  ) # border to remove
136
 
137
  # -----------------------------------------------------------------
138
  # CopyPaste: https://arxiv.org/abs/2012.07177
139
  # -----------------------------------------------------------------
140
+ if self.enable_mixup and not len(mosaic_labels) == 0:
141
+ mosaic_img, mosaic_labels = self.mixup(mosaic_img, mosaic_labels, self.input_dim)
142
+ mix_img, padded_labels = self.preproc(mosaic_img, mosaic_labels, self.input_dim)
143
  img_info = (mix_img.shape[1], mix_img.shape[0])
144
 
145
  return mix_img, padded_labels, img_info, int(idx)