Spaces:
Runtime error
Runtime error
import typing | |
import torch | |
from PIL import Image | |
import torchvision.transforms.functional as tvf | |
import torch.nn.functional as F | |
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image: | |
""" | |
Takes a `PIL.Image` and crops it `size` unless one | |
dimension is larger than the actual image. Padding | |
must be performed afterwards if so. | |
Args: | |
image (`PIL.Image`): | |
An image to perform cropping on | |
size (`tuple` of integers): | |
A size to crop to, should be in the form | |
of (width, height) | |
Returns: | |
An augmented `PIL.Image` | |
""" | |
top = (image.size[-2] - size[0]) // 2 | |
left = (image.size[-1] - size[1]) // 2 | |
top = max(top, 0) | |
left = max(left, 0) | |
height = min(top + size[0], image.size[-2]) | |
width = min(left + size[1], image.size[-1]) | |
return image.crop((top, left, height, width)) | |
def pad(image, size:typing.Tuple[int,int]) -> Image: | |
""" | |
Takes a `PIL.Image` and pads it to `size` with | |
zeros. | |
Args: | |
image (`PIL.Image`): | |
An image to perform padding on | |
size (`tuple` of integers): | |
A size to pad to, should be in the form | |
of (width, height) | |
Returns: | |
An augmented `PIL.Image` | |
""" | |
top = (image.size[-2] - size[0]) // 2 | |
left = (image.size[-1] - size[1]) // 2 | |
pad_top = max(-top, 0) | |
pad_left = max(-left, 0) | |
height, width = ( | |
max(size[1] - image.size[-2] + top, 0), | |
max(size[0] - image.size[-1] + left, 0) | |
) | |
return tvf.pad( | |
image, | |
[pad_top, pad_left, height, width], | |
padding_mode="constant" | |
) | |
def gpu_crop( | |
batch:torch.tensor, | |
size:typing.Tuple[int,int] | |
): | |
""" | |
Crops each image in `batch` to a particular `size`. | |
Args: | |
batch (array of `torch.Tensor`): | |
A batch of images, should be of shape `NxCxWxH` | |
size (`tuple` of integers): | |
A size to pad to, should be in the form | |
of (width, height) | |
Returns: | |
A batch of cropped images | |
""" | |
# Split into multiple lines for clarity | |
affine_matrix = torch.eye(3, device=batch.device).float() | |
affine_matrix = affine_matrix.unsqueeze(0) | |
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3) | |
affine_matrix = affine_matrix.contiguous()[:,:2] | |
coords = F.affine_grid( | |
affine_matrix, batch.shape[:2] + size, align_corners=True | |
) | |
top_range, bottom_range = coords.min(), coords.max() | |
zoom = 1/(bottom_range - top_range).item()*2 | |
resizing_limit = min( | |
batch.shape[-2]/coords.shape[-2], | |
batch.shape[-1]/coords.shape[-1] | |
)/2 | |
if resizing_limit > 1 and resizing_limit > zoom: | |
batch = F.interpolate( | |
batch, | |
scale_factor=1/resizing_limit, | |
mode='area', | |
recompute_scale_factor=True | |
) | |
return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True) |