|
|
|
import numpy as np |
|
from typing import Tuple |
|
import torch |
|
from PIL import Image |
|
from torch.nn import functional as F |
|
|
|
__all__ = ["paste_masks_in_image"] |
|
|
|
|
|
BYTES_PER_FLOAT = 4 |
|
|
|
|
|
GPU_MEM_LIMIT = 1024**3 |
|
|
|
|
|
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True): |
|
""" |
|
Args: |
|
masks: N, 1, H, W |
|
boxes: N, 4 |
|
img_h, img_w (int): |
|
skip_empty (bool): only paste masks within the region that |
|
tightly bound all boxes, and returns the results this region only. |
|
An important optimization for CPU. |
|
|
|
Returns: |
|
if skip_empty == False, a mask of shape (N, img_h, img_w) |
|
if skip_empty == True, a mask of shape (N, h', w'), and the slice |
|
object for the corresponding region. |
|
""" |
|
|
|
|
|
|
|
|
|
device = masks.device |
|
|
|
if skip_empty and not torch.jit.is_scripting(): |
|
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to( |
|
dtype=torch.int32 |
|
) |
|
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32) |
|
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32) |
|
else: |
|
x0_int, y0_int = 0, 0 |
|
x1_int, y1_int = img_w, img_h |
|
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) |
|
|
|
N = masks.shape[0] |
|
|
|
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5 |
|
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5 |
|
img_y = (img_y - y0) / (y1 - y0) * 2 - 1 |
|
img_x = (img_x - x0) / (x1 - x0) * 2 - 1 |
|
|
|
|
|
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1)) |
|
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1)) |
|
grid = torch.stack([gx, gy], dim=3) |
|
|
|
if not torch.jit.is_scripting(): |
|
if not masks.dtype.is_floating_point: |
|
masks = masks.float() |
|
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False) |
|
|
|
if skip_empty and not torch.jit.is_scripting(): |
|
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int)) |
|
else: |
|
return img_masks[:, 0], () |
|
|
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def paste_masks_in_image( |
|
masks: torch.Tensor, boxes: torch.Tensor, image_shape: Tuple[int, int], threshold: float = 0.5 |
|
): |
|
""" |
|
Paste a set of masks that are of a fixed resolution (e.g., 28 x 28) into an image. |
|
The location, height, and width for pasting each mask is determined by their |
|
corresponding bounding boxes in boxes. |
|
|
|
Note: |
|
This is a complicated but more accurate implementation. In actual deployment, it is |
|
often enough to use a faster but less accurate implementation. |
|
See :func:`paste_mask_in_image_old` in this file for an alternative implementation. |
|
|
|
Args: |
|
masks (tensor): Tensor of shape (Bimg, Hmask, Wmask), where Bimg is the number of |
|
detected object instances in the image and Hmask, Wmask are the mask width and mask |
|
height of the predicted mask (e.g., Hmask = Wmask = 28). Values are in [0, 1]. |
|
boxes (Boxes or Tensor): A Boxes of length Bimg or Tensor of shape (Bimg, 4). |
|
boxes[i] and masks[i] correspond to the same object instance. |
|
image_shape (tuple): height, width |
|
threshold (float): A threshold in [0, 1] for converting the (soft) masks to |
|
binary masks. |
|
|
|
Returns: |
|
img_masks (Tensor): A tensor of shape (Bimg, Himage, Wimage), where Bimg is the |
|
number of detected object instances and Himage, Wimage are the image width |
|
and height. img_masks[i] is a binary mask for object instance i. |
|
""" |
|
|
|
assert masks.shape[-1] == masks.shape[-2], "Only square mask predictions are supported" |
|
N = len(masks) |
|
if N == 0: |
|
return masks.new_empty((0,) + image_shape, dtype=torch.uint8) |
|
if not isinstance(boxes, torch.Tensor): |
|
boxes = boxes.tensor |
|
device = boxes.device |
|
assert len(boxes) == N, boxes.shape |
|
|
|
img_h, img_w = image_shape |
|
|
|
|
|
|
|
if device.type == "cpu" or torch.jit.is_scripting(): |
|
|
|
|
|
num_chunks = N |
|
else: |
|
|
|
|
|
num_chunks = int(np.ceil(N * int(img_h) * int(img_w) * BYTES_PER_FLOAT / GPU_MEM_LIMIT)) |
|
assert ( |
|
num_chunks <= N |
|
), "Default GPU_MEM_LIMIT in mask_ops.py is too small; try increasing it" |
|
chunks = torch.chunk(torch.arange(N, device=device), num_chunks) |
|
|
|
img_masks = torch.zeros( |
|
N, img_h, img_w, device=device, dtype=torch.bool if threshold >= 0 else torch.uint8 |
|
) |
|
for inds in chunks: |
|
masks_chunk, spatial_inds = _do_paste_mask( |
|
masks[inds, None, :, :], boxes[inds], img_h, img_w, skip_empty=device.type == "cpu" |
|
) |
|
|
|
if threshold >= 0: |
|
masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool) |
|
else: |
|
|
|
masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8) |
|
|
|
if torch.jit.is_scripting(): |
|
img_masks[inds] = masks_chunk |
|
else: |
|
img_masks[(inds,) + spatial_inds] = masks_chunk |
|
return img_masks |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def paste_mask_in_image_old(mask, box, img_h, img_w, threshold): |
|
""" |
|
Paste a single mask in an image. |
|
This is a per-box implementation of :func:`paste_masks_in_image`. |
|
This function has larger quantization error due to incorrect pixel |
|
modeling and is not used any more. |
|
|
|
Args: |
|
mask (Tensor): A tensor of shape (Hmask, Wmask) storing the mask of a single |
|
object instance. Values are in [0, 1]. |
|
box (Tensor): A tensor of shape (4, ) storing the x0, y0, x1, y1 box corners |
|
of the object instance. |
|
img_h, img_w (int): Image height and width. |
|
threshold (float): Mask binarization threshold in [0, 1]. |
|
|
|
Returns: |
|
im_mask (Tensor): |
|
The resized and binarized object mask pasted into the original |
|
image plane (a tensor of shape (img_h, img_w)). |
|
""" |
|
|
|
|
|
|
|
box = box.to(dtype=torch.int32) |
|
|
|
|
|
|
|
samples_w = box[2] - box[0] + 1 |
|
samples_h = box[3] - box[1] + 1 |
|
|
|
|
|
mask = Image.fromarray(mask.cpu().numpy()) |
|
mask = mask.resize((samples_w, samples_h), resample=Image.BILINEAR) |
|
mask = np.array(mask, copy=False) |
|
|
|
if threshold >= 0: |
|
mask = np.array(mask > threshold, dtype=np.uint8) |
|
mask = torch.from_numpy(mask) |
|
else: |
|
|
|
|
|
mask = torch.from_numpy(mask * 255).to(torch.uint8) |
|
|
|
im_mask = torch.zeros((img_h, img_w), dtype=torch.uint8) |
|
x_0 = max(box[0], 0) |
|
x_1 = min(box[2] + 1, img_w) |
|
y_0 = max(box[1], 0) |
|
y_1 = min(box[3] + 1, img_h) |
|
|
|
im_mask[y_0:y_1, x_0:x_1] = mask[ |
|
(y_0 - box[1]) : (y_1 - box[1]), (x_0 - box[0]) : (x_1 - box[0]) |
|
] |
|
return im_mask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_masks(masks, padding): |
|
""" |
|
Args: |
|
masks (tensor): A tensor of shape (B, M, M) representing B masks. |
|
padding (int): Number of cells to pad on all sides. |
|
|
|
Returns: |
|
The padded masks and the scale factor of the padding size / original size. |
|
""" |
|
B = masks.shape[0] |
|
M = masks.shape[-1] |
|
pad2 = 2 * padding |
|
scale = float(M + pad2) / M |
|
padded_masks = masks.new_zeros((B, M + pad2, M + pad2)) |
|
padded_masks[:, padding:-padding, padding:-padding] = masks |
|
return padded_masks, scale |
|
|
|
|
|
def scale_boxes(boxes, scale): |
|
""" |
|
Args: |
|
boxes (tensor): A tensor of shape (B, 4) representing B boxes with 4 |
|
coords representing the corners x0, y0, x1, y1, |
|
scale (float): The box scaling factor. |
|
|
|
Returns: |
|
Scaled boxes. |
|
""" |
|
w_half = (boxes[:, 2] - boxes[:, 0]) * 0.5 |
|
h_half = (boxes[:, 3] - boxes[:, 1]) * 0.5 |
|
x_c = (boxes[:, 2] + boxes[:, 0]) * 0.5 |
|
y_c = (boxes[:, 3] + boxes[:, 1]) * 0.5 |
|
|
|
w_half *= scale |
|
h_half *= scale |
|
|
|
scaled_boxes = torch.zeros_like(boxes) |
|
scaled_boxes[:, 0] = x_c - w_half |
|
scaled_boxes[:, 2] = x_c + w_half |
|
scaled_boxes[:, 1] = y_c - h_half |
|
scaled_boxes[:, 3] = y_c + h_half |
|
return scaled_boxes |
|
|
|
|
|
@torch.jit.script_if_tracing |
|
def _paste_masks_tensor_shape( |
|
masks: torch.Tensor, |
|
boxes: torch.Tensor, |
|
image_shape: Tuple[torch.Tensor, torch.Tensor], |
|
threshold: float = 0.5, |
|
): |
|
""" |
|
A wrapper of paste_masks_in_image where image_shape is Tensor. |
|
During tracing, shapes might be tensors instead of ints. The Tensor->int |
|
conversion should be scripted rather than traced. |
|
""" |
|
return paste_masks_in_image(masks, boxes, (int(image_shape[0]), int(image_shape[1])), threshold) |
|
|