|
""" |
|
Copyright (c) 2024-present Naver Cloud Corp. |
|
|
|
This source code is licensed under the license found in the |
|
LICENSE file in the root directory of this source tree. |
|
""" |
|
|
|
import numpy as np |
|
import torch |
|
from torch.nn import functional as F |
|
from torchvision.transforms.functional import resize, to_pil_image, InterpolationMode |
|
from copy import deepcopy |
|
from typing import Optional, Tuple, List |
|
|
|
class ResizeLongestSide: |
|
""" |
|
Resizes images to the longest side 'target_length', as well as provides |
|
methods for resizing coordinates and boxes. Provides methods for |
|
transforming both numpy array and batched torch tensors. |
|
""" |
|
|
|
def __init__(self, target_length: int) -> None: |
|
self.target_length = target_length |
|
|
|
def apply_image(self, image: np.ndarray) -> np.ndarray: |
|
""" |
|
Expects a numpy array with shape HxWxC in uint8 format. |
|
""" |
|
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) |
|
return np.array(resize(to_pil_image(image), target_size)) |
|
|
|
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: |
|
""" |
|
Expects a numpy array of length 2 in the final dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], self.target_length |
|
) |
|
coords = deepcopy(coords).astype(float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray: |
|
""" |
|
Expects a numpy array shape Bx4. Requires the original image size |
|
in (H, W) format. |
|
""" |
|
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size) |
|
return boxes.reshape(-1, 4) |
|
|
|
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Expects batched images with shape BxCxHxW and float format. This |
|
transformation may not exactly match apply_image. apply_image is |
|
the transformation expected by the model. |
|
""" |
|
|
|
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length) |
|
return F.interpolate( |
|
image, target_size, mode="bilinear", align_corners=False, antialias=True |
|
) |
|
|
|
def apply_coords_torch( |
|
self, coords: torch.Tensor, original_size: Tuple[int, ...] |
|
) -> torch.Tensor: |
|
""" |
|
Expects a torch tensor with length 2 in the last dimension. Requires the |
|
original image size in (H, W) format. |
|
""" |
|
old_h, old_w = original_size |
|
new_h, new_w = self.get_preprocess_shape( |
|
original_size[0], original_size[1], self.target_length |
|
) |
|
coords = deepcopy(coords).to(torch.float) |
|
coords[..., 0] = coords[..., 0] * (new_w / old_w) |
|
coords[..., 1] = coords[..., 1] * (new_h / old_h) |
|
return coords |
|
|
|
def apply_boxes_torch( |
|
self, boxes: torch.Tensor, original_size: Tuple[int, ...] |
|
) -> torch.Tensor: |
|
""" |
|
Expects a torch tensor with shape Bx4. Requires the original image |
|
size in (H, W) format. |
|
""" |
|
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size) |
|
return boxes.reshape(-1, 4) |
|
|
|
def apply_mask(self, image: np.ndarray) -> np.ndarray: |
|
""" |
|
Expects a numpy array with shape HxWxC in uint8 format. |
|
""" |
|
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length) |
|
return np.array(resize(to_pil_image(image), target_size, interpolation=InterpolationMode.NEAREST)) |
|
|
|
@staticmethod |
|
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]: |
|
""" |
|
Compute the output size given input size and target long side length. |
|
""" |
|
scale = long_side_length * 1.0 / max(oldh, oldw) |
|
newh, neww = oldh * scale, oldw * scale |
|
neww = int(neww + 0.5) |
|
newh = int(newh + 0.5) |
|
return (newh, neww) |
|
|
|
|
|
def remove_prefix(text, prefix): |
|
if text.startswith(prefix): |
|
return text[len(prefix) :] |
|
return text |
|
|
|
class AverageMeter(object): |
|
"""Computes and stores the average and current value""" |
|
|
|
def __init__(self, is_ddp): |
|
self.is_ddp = is_ddp |
|
self.reset() |
|
|
|
def reset(self): |
|
self.val = 0.0 |
|
self.avg = 0.0 |
|
self.sum = 0.0 |
|
self.count = 0.0 |
|
|
|
def update(self, val, n=1): |
|
self.val = val |
|
self.sum += val * n |
|
self.count += n |
|
self.avg = self.sum / (self.count + 1e-5) |
|
|
|
def synch(self, device): |
|
if self.is_ddp is False: |
|
return |
|
|
|
_sum = torch.tensor(self.sum).to(device) |
|
_count = torch.tensor(self.count).to(device) |
|
|
|
torch.distributed.reduce(_sum, dst=0) |
|
torch.distributed.reduce(_count, dst=0) |
|
|
|
if torch.distributed.get_rank() == 0: |
|
self.sum = _sum.item() |
|
self.count = _count.item() |
|
self.avg = self.sum / (self.count + 1e-5) |
|
|