Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
from torch import nn as nn | |
import isegm.model.initializer as initializer | |
def select_activation_function(activation): | |
if isinstance(activation, str): | |
if activation.lower() == "relu": | |
return nn.ReLU | |
elif activation.lower() == "softplus": | |
return nn.Softplus | |
else: | |
raise ValueError(f"Unknown activation type {activation}") | |
elif isinstance(activation, nn.Module): | |
return activation | |
else: | |
raise ValueError(f"Unknown activation type {activation}") | |
class BilinearConvTranspose2d(nn.ConvTranspose2d): | |
def __init__(self, in_channels, out_channels, scale, groups=1): | |
kernel_size = 2 * scale - scale % 2 | |
self.scale = scale | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=scale, | |
padding=1, | |
groups=groups, | |
bias=False, | |
) | |
self.apply( | |
initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups) | |
) | |
class DistMaps(nn.Module): | |
def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): | |
super(DistMaps, self).__init__() | |
self.spatial_scale = spatial_scale | |
self.norm_radius = norm_radius | |
self.cpu_mode = cpu_mode | |
self.use_disks = use_disks | |
if self.cpu_mode: | |
from isegm.utils.cython import get_dist_maps | |
self._get_dist_maps = get_dist_maps | |
def get_coord_features(self, points, batchsize, rows, cols): | |
if self.cpu_mode: | |
coords = [] | |
for i in range(batchsize): | |
norm_delimeter = ( | |
1.0 if self.use_disks else self.spatial_scale * self.norm_radius | |
) | |
coords.append( | |
self._get_dist_maps( | |
points[i].cpu().float().numpy(), rows, cols, norm_delimeter | |
) | |
) | |
coords = ( | |
torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() | |
) | |
else: | |
num_points = points.shape[1] // 2 | |
points = points.view(-1, points.size(2)) | |
points, points_order = torch.split(points, [2, 1], dim=1) | |
invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 | |
row_array = torch.arange( | |
start=0, end=rows, step=1, dtype=torch.float32, device=points.device | |
) | |
col_array = torch.arange( | |
start=0, end=cols, step=1, dtype=torch.float32, device=points.device | |
) | |
coord_rows, coord_cols = torch.meshgrid(row_array, col_array) | |
coords = ( | |
torch.stack((coord_rows, coord_cols), dim=0) | |
.unsqueeze(0) | |
.repeat(points.size(0), 1, 1, 1) | |
) | |
add_xy = (points * self.spatial_scale).view( | |
points.size(0), points.size(1), 1, 1 | |
) | |
coords.add_(-add_xy) | |
if not self.use_disks: | |
coords.div_(self.norm_radius * self.spatial_scale) | |
coords.mul_(coords) | |
coords[:, 0] += coords[:, 1] | |
coords = coords[:, :1] | |
coords[invalid_points, :, :, :] = 1e6 | |
coords = coords.view(-1, num_points, 1, rows, cols) | |
coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w | |
coords = coords.view(-1, 2, rows, cols) | |
if self.use_disks: | |
coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() | |
else: | |
coords.sqrt_().mul_(2).tanh_() | |
return coords | |
def forward(self, x, coords): | |
return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) | |
class ScaleLayer(nn.Module): | |
def __init__(self, init_value=1.0, lr_mult=1): | |
super().__init__() | |
self.lr_mult = lr_mult | |
self.scale = nn.Parameter( | |
torch.full((1,), init_value / lr_mult, dtype=torch.float32) | |
) | |
def forward(self, x): | |
scale = torch.abs(self.scale * self.lr_mult) | |
return x * scale | |
class BatchImageNormalize: | |
def __init__(self, mean, std, dtype=torch.float): | |
self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] | |
self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] | |
def __call__(self, tensor): | |
tensor = tensor.clone() | |
tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) | |
return tensor | |