Spaces:
Runtime error
Runtime error
File size: 6,142 Bytes
2cdd41c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import torch
import torch.nn as nn
import numpy as np
from isegm.model.ops import DistMaps, ScaleLayer, BatchImageNormalize
from isegm.model.modifiers import LRMult
class ISModel(nn.Module):
def __init__(self, use_rgb_conv=True, with_aux_output=False,
norm_radius=260, use_disks=False, cpu_dist_maps=False,
clicks_groups=None, with_prev_mask=False, use_leaky_relu=False,
binary_prev_mask=False, conv_extend=False, norm_layer=nn.BatchNorm2d,
norm_mean_std=([.485, .456, .406], [.229, .224, .225])):
super().__init__()
self.with_aux_output = with_aux_output
self.clicks_groups = clicks_groups
self.with_prev_mask = with_prev_mask
self.binary_prev_mask = binary_prev_mask
self.normalization = BatchImageNormalize(norm_mean_std[0], norm_mean_std[1])
self.coord_feature_ch = 2
if clicks_groups is not None:
self.coord_feature_ch *= len(clicks_groups)
if self.with_prev_mask:
self.coord_feature_ch += 1
if use_rgb_conv:
rgb_conv_layers = [
nn.Conv2d(in_channels=3 + self.coord_feature_ch, out_channels=6 + self.coord_feature_ch, kernel_size=1),
norm_layer(6 + self.coord_feature_ch),
nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
nn.Conv2d(in_channels=6 + self.coord_feature_ch, out_channels=3, kernel_size=1)
]
self.rgb_conv = nn.Sequential(*rgb_conv_layers)
elif conv_extend:
self.rgb_conv = None
self.maps_transform = nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=64,
kernel_size=3, stride=2, padding=1)
self.maps_transform.apply(LRMult(0.1))
else:
self.rgb_conv = None
mt_layers = [
nn.Conv2d(in_channels=self.coord_feature_ch, out_channels=16, kernel_size=1),
nn.LeakyReLU(negative_slope=0.2) if use_leaky_relu else nn.ReLU(inplace=True),
nn.Conv2d(in_channels=16, out_channels=64, kernel_size=3, stride=2, padding=1),
ScaleLayer(init_value=0.05, lr_mult=1)
]
self.maps_transform = nn.Sequential(*mt_layers)
if self.clicks_groups is not None:
self.dist_maps = nn.ModuleList()
for click_radius in self.clicks_groups:
self.dist_maps.append(DistMaps(norm_radius=click_radius, spatial_scale=1.0,
cpu_mode=cpu_dist_maps, use_disks=use_disks))
else:
self.dist_maps = DistMaps(norm_radius=norm_radius, spatial_scale=1.0,
cpu_mode=cpu_dist_maps, use_disks=use_disks)
def forward(self, image, points):
image, prev_mask = self.prepare_input(image)
coord_features = self.get_coord_features(image, prev_mask, points)
if self.rgb_conv is not None:
x = self.rgb_conv(torch.cat((image, coord_features), dim=1))
outputs = self.backbone_forward(x)
else:
coord_features = self.maps_transform(coord_features)
outputs = self.backbone_forward(image, coord_features)
outputs['instances'] = nn.functional.interpolate(outputs['instances'], size=image.size()[2:],
mode='bilinear', align_corners=True)
if self.with_aux_output:
outputs['instances_aux'] = nn.functional.interpolate(outputs['instances_aux'], size=image.size()[2:],
mode='bilinear', align_corners=True)
return outputs
def prepare_input(self, image):
prev_mask = None
if self.with_prev_mask:
prev_mask = image[:, 3:, :, :]
image = image[:, :3, :, :]
if self.binary_prev_mask:
prev_mask = (prev_mask > 0.5).float()
image = self.normalization(image)
return image, prev_mask
def backbone_forward(self, image, coord_features=None):
raise NotImplementedError
def get_coord_features(self, image, prev_mask, points):
if self.clicks_groups is not None:
points_groups = split_points_by_order(points, groups=(2,) + (1, ) * (len(self.clicks_groups) - 2) + (-1,))
coord_features = [dist_map(image, pg) for dist_map, pg in zip(self.dist_maps, points_groups)]
coord_features = torch.cat(coord_features, dim=1)
else:
coord_features = self.dist_maps(image, points)
if prev_mask is not None:
coord_features = torch.cat((prev_mask, coord_features), dim=1)
return coord_features
def split_points_by_order(tpoints: torch.Tensor, groups):
points = tpoints.cpu().numpy()
num_groups = len(groups)
bs = points.shape[0]
num_points = points.shape[1] // 2
groups = [x if x > 0 else num_points for x in groups]
group_points = [np.full((bs, 2 * x, 3), -1, dtype=np.float32)
for x in groups]
last_point_indx_group = np.zeros((bs, num_groups, 2), dtype=np.int)
for group_indx, group_size in enumerate(groups):
last_point_indx_group[:, group_indx, 1] = group_size
for bindx in range(bs):
for pindx in range(2 * num_points):
point = points[bindx, pindx, :]
group_id = int(point[2])
if group_id < 0:
continue
is_negative = int(pindx >= num_points)
if group_id >= num_groups or (group_id == 0 and is_negative): # disable negative first click
group_id = num_groups - 1
new_point_indx = last_point_indx_group[bindx, group_id, is_negative]
last_point_indx_group[bindx, group_id, is_negative] += 1
group_points[group_id][bindx, new_point_indx, :] = point
group_points = [torch.tensor(x, dtype=tpoints.dtype, device=tpoints.device)
for x in group_points]
return group_points
|