curt-park's picture
Refactor code
1615d09
import numpy as np
import torch
import torch.nn as nn
from isegm.model.modifiers import LRMult
from isegm.model.ops import BatchImageNormalize, DistMaps, ScaleLayer
class ISModel(nn.Module):
def __init__(
self,
use_rgb_conv=True,
with_aux_output=False,
norm_radius=260,
use_disks=False,
cpu_dist_maps=False,
clicks_groups=None,
with_prev_mask=False,
use_leaky_relu=False,
binary_prev_mask=False,
conv_extend=False,
norm_layer=nn.BatchNorm2d,
norm_mean_std=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
):
super().__init__()
self.with_aux_output = with_aux_output
self.clicks_groups = clicks_groups
self.with_prev_mask = with_prev_mask
self.binary_prev_mask = binary_prev_mask
self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
self.coord_feature_ch = 2
if clicks_groups is not None:
self.coord_feature_ch *= len(clicks_groups)
if self.with_prev_mask:
self.coord_feature_ch += 1
if use_rgb_conv:
rgb_conv_layers = [
nn.Conv2d(
in_channels=3 + self.coord_feature_ch,
out_channels=6 + self.coord_feature_ch,
kernel_size=1,
),
norm_layer(6 + self.coord_feature_ch),
nn.LeakyReLU(negative_slope=0.2)
if use_leaky_relu
else nn.ReLU(inplace=True),
nn.Conv2d(
in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1
),
]
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
elif conv_extend:
self.rgb_conv = None
self.maps_transform = nn.Conv2d(
in_channels=self.coord_feature_ch,
out_channels=64,
kernel_size=3,
stride=2,
padding=1,
)
self.maps_transform.apply(LRMult(0.1))
else:
self.rgb_conv = None
mt_layers = [
nn.Conv2d(
in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1
),
nn.LeakyReLU(negative_slope=0.2)
if use_leaky_relu
else nn.ReLU(inplace=True),
nn.Conv2d(
in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1
),
ScaleLayer(init_value=0.05, lr_mult=1),
]
self.maps_transform = nn.Sequential(*mt_layers)
if self.clicks_groups is not None:
self.dist_maps = nn.ModuleList()
for click_radius in self.clicks_groups:
self.dist_maps.append(
DistMaps(
norm_radius=click_radius,
spatial_scale=1.0,
cpu_mode=cpu_dist_maps,
use_disks=use_disks,
)
)
else:
self.dist_maps = DistMaps(
norm_radius=norm_radius,
spatial_scale=1.0,
cpu_mode=cpu_dist_maps,
use_disks=use_disks,
)
def forward(self, image, points):
image, prev_mask = self.prepare_input(image)
coord_features = self.get_coord_features(image, prev_mask, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
outputs = self.backbone_forward(x)
else:
coord_features = self.maps_transform(coord_features)
outputs = self.backbone_forward(image, coord_features)
outputs["instances"] = nn.functional.interpolate(
outputs["instances"],
size=image.size()[2:],
mode="bilinear",
align_corners=True,
)
if self.with_aux_output:
outputs["instances_aux"] = nn.functional.interpolate(
outputs["instances_aux"],
size=image.size()[2:],
mode="bilinear",
align_corners=True,
)
return outputs
def prepare_input(self, image):
prev_mask = None
if self.with_prev_mask:
prev_mask = image[:, 3:, :, :]
image = image[:, :3, :, :]
if self.binary_prev_mask:
prev_mask = (prev_mask > 0.5).float()
image = self.normalization(image)
return image, prev_mask
def backbone_forward(self, image, coord_features=None):
raise NotImplementedError
def get_coord_features(self, image, prev_mask, points):
if self.clicks_groups is not None:
points_groups = split_points_by_order(
points, groups=(2,) + (1,) * (len(self.clicks_groups) - 2) + (-1,)
)
coord_features = [
dist_map(image, pg)
for dist_map, pg in zip(self.dist_maps, points_groups)
]
coord_features = torch.cat(coord_features, dim=1)
else:
coord_features = self.dist_maps(image, points)
if prev_mask is not None:
coord_features = torch.cat((prev_mask, coord_features), dim=1)
return coord_features
def split_points_by_order(tpoints: torch.Tensor, groups):
points = tpoints.cpu().numpy()
num_groups = len(groups)
bs = points.shape[0]
num_points = points.shape[1] // 2
groups = [x if x > 0 else num_points for x in groups]
group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups]
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
for group_indx, group_size in enumerate(groups):
last_point_indx_group[:, group_indx, 1] = group_size
for bindx in range(bs):
for pindx in range(2 * num_points):
point = points[bindx, pindx, :]
group_id = int(point[2])
if group_id < 0:
continue
is_negative = int(pindx >= num_points)
if group_id >= num_groups or (
group_id == 0 and is_negative
): # disable negative first click
group_id = num_groups - 1
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
last_point_indx_group[bindx, group_id, is_negative] += 1
group_points[group_id][bindx, new_point_indx, :] = point
group_points = [
torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
for x in group_points
]
return group_points