Spaces:
Sleeping
Sleeping
""" | |
Copyright (c) 2024-present Naver Cloud Corp. | |
This source code is based on code from the Segment Anything Model (SAM) | |
(https://github.com/facebookresearch/segment-anything). | |
This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from typing import Any, Dict, List | |
def gaussian(sigma=6): | |
""" | |
2D Gaussian Kernel Generation. | |
""" | |
size = 6 * sigma + 3 | |
x = torch.arange(0, size, 1) | |
y = x[:, None] | |
x0, y0 = 3 * sigma + 1, 3 * sigma + 1 | |
g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2)) | |
return g | |
class Zim(nn.Module): | |
def __init__( | |
self, | |
encoder, | |
decoder, | |
*, | |
image_size: int = 1024, | |
pixel_mean: List[float] = [123.675, 116.28, 103.53], | |
pixel_std: List[float] = [58.395, 57.12, 57.375], | |
) -> None: | |
""" | |
SAM predicts object masks from an image and input prompts. | |
Arguments: | |
encoder : The backbone used to encode the | |
image into image embeddings that allow for efficient mask prediction. | |
decoder : Predicts masks from the image embeddings and given prompts. | |
pixel_mean (list(float)): Mean values for normalizing pixels in the input image. | |
pixel_std (list(float)): Std values for normalizing pixels in the input image. | |
""" | |
super().__init__() | |
self.encoder = encoder | |
self.decoder = decoder | |
self.output_activation = nn.Sigmoid() | |
self.image_size = image_size | |
self.register_buffer( | |
"pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False | |
) | |
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False) | |
self.mask_threshold: float = 0.5 | |
self.image_format: str = "RGB" | |
self.num_mask_tokens = decoder.num_mask_tokens | |
self.encode_stride = 16 | |
self.encode_kernel = 21 | |
self.attn_mask_size = 64 | |
self.g = gaussian(self.encode_kernel) | |
self.output_conv = nn.Conv2d( | |
self.num_mask_tokens, | |
self.num_mask_tokens, | |
kernel_size=1, stride=1, padding=0, | |
) | |
def device(self) -> Any: | |
return self.pixel_mean.device | |
def cuda(self, device_id=None): | |
if type(device_id) == torch.device: | |
device_id = device_id.index | |
if device_id is None: | |
device_id = 0 | |
device = torch.device(f"cuda:{device_id}") | |
super(Zim, self).cuda(device) | |
self.encoder.cuda(device_id) | |
self.decoder.cuda(device_id) | |
return self | |
def postprocess_masks( | |
self, masks: torch.Tensor, input_size: List[int], original_size: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
Remove padding and upscale masks to the original image size. | |
Arguments: | |
masks (torch.Tensor): Batched masks from the decoder, | |
in BxCxHxW format. | |
input_size (tuple(int, int)): The size of the image input to the | |
model, in (H, W) format. Used to remove padding. | |
original_size (tuple(int, int)): The original size of the image | |
before resizing for input to the model, in (H, W) format. | |
Returns: | |
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W) | |
is given by original_size. | |
""" | |
masks = F.interpolate( | |
masks, | |
(self.image_size, self.image_size), | |
mode="bilinear", | |
align_corners=False, | |
) | |
masks = masks[..., : input_size[0], : input_size[1]] | |
masks = F.interpolate( | |
masks, original_size, mode="bilinear", align_corners=False | |
) | |
return masks | |
def preprocess(self, x: torch.Tensor) -> torch.Tensor: | |
"""Normalize pixel values and pad to a square input.""" | |
# Normalize colors | |
x = (x - self.pixel_mean) / self.pixel_std | |
# Pad | |
h, w = x.shape[-2:] | |
padh = self.image_size - h | |
padw = self.image_size - w | |
x = F.pad(x, (0, padw, 0, padh)) | |
return x | |
def bbox_attn_mask(self, boxes): | |
"""Prompt-aware Masked Attention: box prompt (binary attn mask) """ | |
bs = boxes.shape[0] | |
attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=boxes.device) | |
# attn_weight = attn_weight.masked_fill(m.logical_not(), -1e4) | |
for n in range(bs): | |
xmin, ymin, xmax, ymax = boxes[n] | |
xmin, xmax = min(xmin, xmax), max(xmin, xmax) | |
ymin, ymax = min(ymin, ymax), max(ymin, ymax) | |
xmin, xmax = int(xmin / self.encode_stride), int(xmax / self.encode_stride) | |
ymin, ymax = int(ymin / self.encode_stride), int(ymax / self.encode_stride) | |
xmin, ymin = max(0, xmin), max(0, ymin) | |
xmax = min(self.attn_mask_size, xmax+1) | |
ymax = min(self.attn_mask_size, ymax+1) | |
attn_mask[n, ymin:ymax, xmin:xmax] = 1 | |
return attn_mask | |
def point_attn_mask(self, point_coords): | |
"""Prompt-aware Masked Attention: point prompt (soft attn mask) """ | |
bs = point_coords.shape[0] | |
attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=point_coords.device) | |
if self.g.device != point_coords.device: | |
self.g = self.g.to(point_coords.device) | |
for n in range(bs): | |
for point in point_coords[n]: | |
x, y = int(point[0] / self.encode_stride), int(point[1].item() / self.encode_stride) | |
# outside image boundary | |
if x < 0 or y < 0 or x >= self.attn_mask_size or y >= self.attn_mask_size: | |
continue | |
# upper left | |
ul = int(round(x - 3 * self.encode_kernel - 1)), int(round(y - 3 * self.encode_kernel - 1)) | |
# bottom right | |
br = int(round(x + 3 * self.encode_kernel + 2)), int(round(y + 3 * self.encode_kernel + 2)) | |
c, d = int(max(0, -ul[0])), int(min(br[0], self.attn_mask_size) - ul[0]) | |
a, b = int(max(0, -ul[1])), int(min(br[1], self.attn_mask_size) - ul[1]) | |
cc, dd = int(max(0, ul[0])), int(min(br[0], self.attn_mask_size)) | |
aa, bb = int(max(0, ul[1])), int(min(br[1], self.attn_mask_size)) | |
attn_mask[n, aa:bb, cc:dd] = torch.maximum( | |
attn_mask[n, aa:bb, cc:dd], self.g[a:b, c:d] | |
) | |
return attn_mask | |