|
from typing import List |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torchvision.transforms import functional as TF |
|
|
|
|
|
class AugmentationComposer: |
|
"""Composes several transforms together.""" |
|
|
|
def __init__(self, transforms, image_size: int = [640, 640], base_size: int = 640): |
|
self.transforms = transforms |
|
|
|
self.pad_resize = PadAndResize(image_size) |
|
self.base_size = base_size |
|
|
|
for transform in self.transforms: |
|
if hasattr(transform, "set_parent"): |
|
transform.set_parent(self) |
|
|
|
def __call__(self, image, boxes=torch.zeros(0, 5)): |
|
for transform in self.transforms: |
|
image, boxes = transform(image, boxes) |
|
image, boxes, rev_tensor = self.pad_resize(image, boxes) |
|
image = TF.to_tensor(image) |
|
return image, boxes, rev_tensor |
|
|
|
|
|
class RemoveOutliers: |
|
"""Removes outlier bounding boxes that are too small or have invalid dimensions.""" |
|
|
|
def __init__(self, min_box_area=1e-8): |
|
""" |
|
Args: |
|
min_box_area (float): Minimum area for a box to be kept, as a fraction of the image area. |
|
""" |
|
self.min_box_area = min_box_area |
|
|
|
def __call__(self, image, boxes): |
|
""" |
|
Args: |
|
image (PIL.Image): The cropped image. |
|
boxes (torch.Tensor): Bounding boxes in normalized coordinates (x_min, y_min, x_max, y_max). |
|
Returns: |
|
PIL.Image: The input image (unchanged). |
|
torch.Tensor: Filtered bounding boxes. |
|
""" |
|
box_areas = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 4] - boxes[:, 2]) |
|
|
|
valid_boxes = (box_areas > self.min_box_area) & (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 4] > boxes[:, 2]) |
|
|
|
return image, boxes[valid_boxes] |
|
|
|
|
|
class PadAndResize: |
|
def __init__(self, image_size, background_color=(114, 114, 114)): |
|
"""Initialize the object with the target image size.""" |
|
self.target_width, self.target_height = image_size |
|
self.background_color = background_color |
|
|
|
def set_size(self, image_size: List[int]): |
|
self.target_width, self.target_height = image_size |
|
|
|
def __call__(self, image: Image, boxes): |
|
img_width, img_height = image.size |
|
scale = min(self.target_width / img_width, self.target_height / img_height) |
|
new_width, new_height = int(img_width * scale), int(img_height * scale) |
|
|
|
resized_image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) |
|
|
|
pad_left = (self.target_width - new_width) // 2 |
|
pad_top = (self.target_height - new_height) // 2 |
|
padded_image = Image.new("RGB", (self.target_width, self.target_height), self.background_color) |
|
padded_image.paste(resized_image, (pad_left, pad_top)) |
|
|
|
boxes[:, [1, 3]] = (boxes[:, [1, 3]] * new_width + pad_left) / self.target_width |
|
boxes[:, [2, 4]] = (boxes[:, [2, 4]] * new_height + pad_top) / self.target_height |
|
|
|
transform_info = torch.tensor([scale, pad_left, pad_top, pad_left, pad_top]) |
|
return padded_image, boxes, transform_info |
|
|
|
|
|
class HorizontalFlip: |
|
"""Randomly horizontally flips the image along with the bounding boxes.""" |
|
|
|
def __init__(self, prob=0.5): |
|
self.prob = prob |
|
|
|
def __call__(self, image, boxes): |
|
if torch.rand(1) < self.prob: |
|
image = TF.hflip(image) |
|
boxes[:, [1, 3]] = 1 - boxes[:, [3, 1]] |
|
return image, boxes |
|
|
|
|
|
class VerticalFlip: |
|
"""Randomly vertically flips the image along with the bounding boxes.""" |
|
|
|
def __init__(self, prob=0.5): |
|
self.prob = prob |
|
|
|
def __call__(self, image, boxes): |
|
if torch.rand(1) < self.prob: |
|
image = TF.vflip(image) |
|
boxes[:, [2, 4]] = 1 - boxes[:, [4, 2]] |
|
return image, boxes |
|
|
|
|
|
class Mosaic: |
|
"""Applies the Mosaic augmentation to a batch of images and their corresponding boxes.""" |
|
|
|
def __init__(self, prob=0.5): |
|
self.prob = prob |
|
self.parent = None |
|
|
|
def set_parent(self, parent): |
|
self.parent = parent |
|
|
|
def __call__(self, image, boxes): |
|
if torch.rand(1) >= self.prob: |
|
return image, boxes |
|
|
|
assert self.parent is not None, "Parent is not set. Mosaic cannot retrieve image size." |
|
|
|
img_sz = self.parent.base_size |
|
more_data = self.parent.get_more_data(3) |
|
|
|
data = [(image, boxes)] + more_data |
|
mosaic_image = Image.new("RGB", (2 * img_sz, 2 * img_sz), (114, 114, 114)) |
|
vectors = np.array([(-1, -1), (0, -1), (-1, 0), (0, 0)]) |
|
center = np.array([img_sz, img_sz]) |
|
all_labels = [] |
|
|
|
for (image, boxes), vector in zip(data, vectors): |
|
this_w, this_h = image.size |
|
coord = tuple(center + vector * np.array([this_w, this_h])) |
|
|
|
mosaic_image.paste(image, coord) |
|
xmin, ymin, xmax, ymax = boxes[:, 1], boxes[:, 2], boxes[:, 3], boxes[:, 4] |
|
xmin = (xmin * this_w + coord[0]) / (2 * img_sz) |
|
xmax = (xmax * this_w + coord[0]) / (2 * img_sz) |
|
ymin = (ymin * this_h + coord[1]) / (2 * img_sz) |
|
ymax = (ymax * this_h + coord[1]) / (2 * img_sz) |
|
|
|
adjusted_boxes = torch.stack([boxes[:, 0], xmin, ymin, xmax, ymax], dim=1) |
|
all_labels.append(adjusted_boxes) |
|
|
|
all_labels = torch.cat(all_labels, dim=0) |
|
mosaic_image = mosaic_image.resize((img_sz, img_sz)) |
|
return mosaic_image, all_labels |
|
|
|
|
|
class MixUp: |
|
"""Applies the MixUp augmentation to a pair of images and their corresponding boxes.""" |
|
|
|
def __init__(self, prob=0.5, alpha=1.0): |
|
self.alpha = alpha |
|
self.prob = prob |
|
self.parent = None |
|
|
|
def set_parent(self, parent): |
|
"""Set the parent dataset object for accessing dataset methods.""" |
|
self.parent = parent |
|
|
|
def __call__(self, image, boxes): |
|
if torch.rand(1) >= self.prob: |
|
return image, boxes |
|
|
|
assert self.parent is not None, "Parent is not set. MixUp cannot retrieve additional data." |
|
|
|
|
|
image2, boxes2 = self.parent.get_more_data()[0] |
|
|
|
|
|
lam = np.random.beta(self.alpha, self.alpha) if self.alpha > 0 else 0.5 |
|
|
|
|
|
image1, image2 = TF.to_tensor(image), TF.to_tensor(image2) |
|
mixed_image = lam * image1 + (1 - lam) * image2 |
|
|
|
|
|
merged_boxes = torch.cat((boxes, boxes2)) |
|
|
|
return TF.to_pil_image(mixed_image), merged_boxes |
|
|
|
|
|
class RandomCrop: |
|
"""Randomly crops the image to half its size along with adjusting the bounding boxes.""" |
|
|
|
def __init__(self, prob=0.5): |
|
""" |
|
Args: |
|
prob (float): Probability of applying the crop. |
|
""" |
|
self.prob = prob |
|
|
|
def __call__(self, image, boxes): |
|
if torch.rand(1) < self.prob: |
|
original_width, original_height = image.size |
|
crop_height, crop_width = original_height // 2, original_width // 2 |
|
top = torch.randint(0, original_height - crop_height + 1, (1,)).item() |
|
left = torch.randint(0, original_width - crop_width + 1, (1,)).item() |
|
|
|
image = TF.crop(image, top, left, crop_height, crop_width) |
|
|
|
boxes[:, [1, 3]] = boxes[:, [1, 3]] * original_width - left |
|
boxes[:, [2, 4]] = boxes[:, [2, 4]] * original_height - top |
|
|
|
boxes[:, [1, 3]] = boxes[:, [1, 3]].clamp(0, crop_width) |
|
boxes[:, [2, 4]] = boxes[:, [2, 4]].clamp(0, crop_height) |
|
|
|
boxes[:, [1, 3]] /= crop_width |
|
boxes[:, [2, 4]] /= crop_height |
|
|
|
return image, boxes |
|
|