Spaces:
Sleeping
Sleeping
File size: 4,548 Bytes
c608946 9cde3b4 c608946 9cde3b4 c608946 9cde3b4 c608946 9cde3b4 c608946 9cde3b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
from typing import Dict
import numpy as np
import torch
import kornia.augmentation as K
from kornia.geometry.transform import warp_perspective
# Adapted from Kornia
class GeometricSequential:
def __init__(self, *transforms, align_corners=True) -> None:
self.transforms = transforms
self.align_corners = align_corners
def __call__(self, x, mode="bilinear"):
b, c, h, w = x.shape
M = torch.eye(3, device=x.device)[None].expand(b, 3, 3)
for t in self.transforms:
if np.random.rand() < t.p:
M = M.matmul(
t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
)
return (
warp_perspective(
x, M, dsize=(h, w), mode=mode, align_corners=self.align_corners
),
M,
)
def apply_transform(self, x, M, mode="bilinear"):
b, c, h, w = x.shape
return warp_perspective(
x, M, dsize=(h, w), align_corners=self.align_corners, mode=mode
)
class RandomPerspective(K.RandomPerspective):
def generate_parameters(self, batch_shape: torch.Size) -> Dict[str, torch.Tensor]:
distortion_scale = torch.as_tensor(
self.distortion_scale, device=self._device, dtype=self._dtype
)
return self.random_perspective_generator(
batch_shape[0],
batch_shape[-2],
batch_shape[-1],
distortion_scale,
self.same_on_batch,
self.device,
self.dtype,
)
def random_perspective_generator(
self,
batch_size: int,
height: int,
width: int,
distortion_scale: torch.Tensor,
same_on_batch: bool = False,
device: torch.device = torch.device("cpu"),
dtype: torch.dtype = torch.float32,
) -> Dict[str, torch.Tensor]:
r"""Get parameters for ``perspective`` for a random perspective transform.
Args:
batch_size (int): the tensor batch size.
height (int) : height of the image.
width (int): width of the image.
distortion_scale (torch.Tensor): it controls the degree of distortion and ranges from 0 to 1.
same_on_batch (bool): apply the same transformation across the batch. Default: False.
device (torch.device): the device on which the random numbers will be generated. Default: cpu.
dtype (torch.dtype): the data type of the generated random numbers. Default: float32.
Returns:
params Dict[str, torch.Tensor]: parameters to be passed for transformation.
- start_points (torch.Tensor): element-wise perspective source areas with a shape of (B, 4, 2).
- end_points (torch.Tensor): element-wise perspective target areas with a shape of (B, 4, 2).
Note:
The generated random numbers are not reproducible across different devices and dtypes.
"""
if not (distortion_scale.dim() == 0 and 0 <= distortion_scale <= 1):
raise AssertionError(
f"'distortion_scale' must be a scalar within [0, 1]. Got {distortion_scale}."
)
if not (
type(height) is int and height > 0 and type(width) is int and width > 0
):
raise AssertionError(
f"'height' and 'width' must be integers. Got {height}, {width}."
)
start_points: torch.Tensor = torch.tensor(
[[[0.0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]]],
device=distortion_scale.device,
dtype=distortion_scale.dtype,
).expand(batch_size, -1, -1)
# generate random offset not larger than half of the image
fx = distortion_scale * width / 2
fy = distortion_scale * height / 2
factor = torch.stack([fx, fy], dim=0).view(-1, 1, 2)
offset = (torch.rand_like(start_points) - 0.5) * 2
end_points = start_points + factor * offset
return dict(start_points=start_points, end_points=end_points)
class RandomErasing:
def __init__(self, p = 0., scale = 0.) -> None:
self.p = p
self.scale = scale
self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
def __call__(self, image, depth):
if self.p > 0:
image = self.random_eraser(image)
depth = self.random_eraser(depth, params=self.random_eraser._params)
return image, depth
|