Spaces:
Runtime error
Runtime error
File size: 3,117 Bytes
3655067 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import typing
import torch
from PIL import Image
import torchvision.transforms.functional as tvf
import torch.nn.functional as F
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and crops it `size` unless one
dimension is larger than the actual image. Padding
must be performed afterwards if so.
Args:
image (`PIL.Image`):
An image to perform cropping on
size (`tuple` of integers):
A size to crop to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
top = (image.size[-2] - size[0]) // 2
left = (image.size[-1] - size[1]) // 2
top = max(top, 0)
left = max(left, 0)
height = min(top + size[0], image.size[-2])
width = min(left + size[1], image.size[-1])
return image.crop((top, left, height, width))
def pad(image, size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and pads it to `size` with
zeros.
Args:
image (`PIL.Image`):
An image to perform padding on
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
top = (image.size[-2] - size[0]) // 2
left = (image.size[-1] - size[1]) // 2
pad_top = max(-top, 0)
pad_left = max(-left, 0)
height, width = (
max(size[1] - image.size[-2] + top, 0),
max(size[0] - image.size[-1] + left, 0)
)
return tvf.pad(
image,
[pad_top, pad_left, height, width],
padding_mode="constant"
)
def gpu_crop(
batch:torch.tensor,
size:typing.Tuple[int,int]
):
"""
Crops each image in `batch` to a particular `size`.
Args:
batch (array of `torch.Tensor`):
A batch of images, should be of shape `NxCxWxH`
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
A batch of cropped images
"""
# Split into multiple lines for clarity
affine_matrix = torch.eye(3, device=batch.device).float()
affine_matrix = affine_matrix.unsqueeze(0)
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
affine_matrix = affine_matrix.contiguous()[:,:2]
coords = F.affine_grid(
affine_matrix, batch.shape[:2] + size, align_corners=True
)
top_range, bottom_range = coords.min(), coords.max()
zoom = 1/(bottom_range - top_range).item()*2
resizing_limit = min(
batch.shape[-2]/coords.shape[-2],
batch.shape[-1]/coords.shape[-1]
)/2
if resizing_limit > 1 and resizing_limit > zoom:
batch = F.interpolate(
batch,
scale_factor=1/resizing_limit,
mode='area',
recompute_scale_factor=True
)
return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True) |