import numpy as np import torch import torch.nn as nn from isegm.model.modifiers import LRMult from isegm.model.ops import BatchImageNormalize, DistMaps, ScaleLayer class ISModel(nn.Module): def __init__( self, use_rgb_conv=True, with_aux_output=False, norm_radius=260, use_disks=False, cpu_dist_maps=False, clicks_groups=None, with_prev_mask=False, use_leaky_relu=False, binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d, norm_mean_std=([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ): super().__init__() self.with_aux_output = with_aux_output self.clicks_groups = clicks_groups self.with_prev_mask = with_prev_mask self.binary_prev_mask = binary_prev_mask self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1]) self.coord_feature_ch = 2 if clicks_groups is not None: self.coord_feature_ch *= len(clicks_groups) if self.with_prev_mask: self.coord_feature_ch += 1 if use_rgb_conv: rgb_conv_layers = [ nn.Conv2d( in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1, ), norm_layer(6 + self.coord_feature_ch), nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), nn.Conv2d( in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1 ), ] self.rgb_conv = nn.Sequential(*rgb_conv_layers) elif conv_extend: self.rgb_conv = None self.maps_transform = nn.Conv2d( in_channels=self.coord_feature_ch, out_channels=64, kernel_size=3, stride=2, padding=1, ) self.maps_transform.apply(LRMult(0.1)) else: self.rgb_conv = None mt_layers = [ nn.Conv2d( in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1 ), nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True), nn.Conv2d( in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1 ), ScaleLayer(init_value=0.05, lr_mult=1), ] self.maps_transform = nn.Sequential(*mt_layers) if self.clicks_groups is not None: self.dist_maps = nn.ModuleList() for click_radius in self.clicks_groups: self.dist_maps.append( DistMaps( norm_radius=click_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps, use_disks=use_disks, ) ) else: self.dist_maps = DistMaps( norm_radius=norm_radius, spatial_scale=1.0, cpu_mode=cpu_dist_maps, use_disks=use_disks, ) def forward(self, image, points): image, prev_mask = self.prepare_input(image) coord_features = self.get_coord_features(image, prev_mask, points) if self.rgb_conv is not None: x = self.rgb_conv(torch.cat((image, coord_features), dim=1)) outputs = self.backbone_forward(x) else: coord_features = self.maps_transform(coord_features) outputs = self.backbone_forward(image, coord_features) outputs["instances"] = nn.functional.interpolate( outputs["instances"], size=image.size()[2:], mode="bilinear", align_corners=True, ) if self.with_aux_output: outputs["instances_aux"] = nn.functional.interpolate( outputs["instances_aux"], size=image.size()[2:], mode="bilinear", align_corners=True, ) return outputs def prepare_input(self, image): prev_mask = None if self.with_prev_mask: prev_mask = image[:, 3:, :, :] image = image[:, :3, :, :] if self.binary_prev_mask: prev_mask = (prev_mask > 0.5).float() image = self.normalization(image) return image, prev_mask def backbone_forward(self, image, coord_features=None): raise NotImplementedError def get_coord_features(self, image, prev_mask, points): if self.clicks_groups is not None: points_groups = split_points_by_order( points, groups=(2,) + (1,) * (len(self.clicks_groups) - 2) + (-1,) ) coord_features = [ dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups) ] coord_features = torch.cat(coord_features, dim=1) else: coord_features = self.dist_maps(image, points) if prev_mask is not None: coord_features = torch.cat((prev_mask, coord_features), dim=1) return coord_features def split_points_by_order(tpoints: torch.Tensor, groups): points = tpoints.cpu().numpy() num_groups = len(groups) bs = points.shape[0] num_points = points.shape[1] // 2 groups = [x if x > 0 else num_points for x in groups] group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32) for x in groups] last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int) for group_indx, group_size in enumerate(groups): last_point_indx_group[:, group_indx, 1] = group_size for bindx in range(bs): for pindx in range(2 * num_points): point = points[bindx, pindx, :] group_id = int(point[2]) if group_id < 0: continue is_negative = int(pindx >= num_points) if group_id >= num_groups or ( group_id == 0 and is_negative ): # disable negative first click group_id = num_groups - 1 new_point_indx = last_point_indx_group[bindx, group_id, is_negative] last_point_indx_group[bindx, group_id, is_negative] += 1 group_points[group_id][bindx, new_point_indx, :] = point group_points = [ torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device) for x in group_points ] return group_points