import numpy as np import torch import torch.nn.functional as F from scipy.optimize import fmin_l_bfgs_b from .base import BasePredictor class BRSBasePredictor(BasePredictor): def __init__(self, model, device, opt_functor, optimize_after_n_clicks=1, **kwargs): super().__init__(model, device, **kwargs) self.optimize_after_n_clicks = optimize_after_n_clicks self.opt_functor = opt_functor self.opt_data = None self.input_data = None def set_input_image(self, image): super().set_input_image(image) self.opt_data = None self.input_data = None def _get_clicks_maps_nd(self, clicks_lists, image_shape, radius=1): pos_clicks_map = np.zeros( (len(clicks_lists), 1) + image_shape, dtype=np.float32 ) neg_clicks_map = np.zeros( (len(clicks_lists), 1) + image_shape, dtype=np.float32 ) for list_indx, clicks_list in enumerate(clicks_lists): for click in clicks_list: y, x = click.coords y, x = int(round(y)), int(round(x)) y1, x1 = y - radius, x - radius y2, x2 = y + radius + 1, x + radius + 1 if click.is_positive: pos_clicks_map[list_indx, 0, y1:y2, x1:x2] = True else: neg_clicks_map[list_indx, 0, y1:y2, x1:x2] = True with torch.no_grad(): pos_clicks_map = torch.from_numpy(pos_clicks_map).to(self.device) neg_clicks_map = torch.from_numpy(neg_clicks_map).to(self.device) return pos_clicks_map, neg_clicks_map def get_states(self): return { "transform_states": self._get_transform_states(), "opt_data": self.opt_data, } def set_states(self, states): self._set_transform_states(states["transform_states"]) self.opt_data = states["opt_data"] class FeatureBRSPredictor(BRSBasePredictor): def __init__( self, model, device, opt_functor, insertion_mode="after_deeplab", **kwargs ): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None if self.insertion_mode == "after_deeplab": self.num_channels = model.feature_extractor.ch elif self.insertion_mode == "after_c4": self.num_channels = model.feature_extractor.aspp_in_channels elif self.insertion_mode == "after_aspp": self.num_channels = model.feature_extractor.ch + 32 else: raise NotImplementedError def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] if ( self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs ): self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) if ( num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None ): self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): scale = scale.view(bs, -1, 1, 1) bias = bias.view(bs, -1, 1, 1) if self.with_flip: scale = scale.repeat(2, 1, 1, 1) bias = bias.repeat(2, 1, 1, 1) scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias if self.insertion_mode == "after_c4": x = self.net.feature_extractor.aspp(scaled_backbone_features) x = F.interpolate( x, mode="bilinear", size=self._c1_features.size()[2:], align_corners=True, ) x = torch.cat((x, self._c1_features), dim=1) scaled_backbone_features = self.net.feature_extractor.head(x) elif self.insertion_mode == "after_aspp": scaled_backbone_features = self.net.feature_extractor.head( scaled_backbone_features ) pred_logits = self.net.head(scaled_backbone_features) pred_logits = F.interpolate( pred_logits, size=image_nd.size()[2:], mode="bilinear", align_corners=True, ) return pred_logits self.opt_functor.init_click( get_prediction_logits, pos_mask, neg_mask, self.device ) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b( func=self.opt_functor, x0=self.opt_data, **self.opt_functor.optimizer_params ) self.opt_data = opt_result[0] with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits def _get_head_input(self, image_nd, points): with torch.no_grad(): image_nd, prev_mask = self.net.prepare_input(image_nd) coord_features = self.net.get_coord_features(image_nd, prev_mask, points) if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) additional_features = None elif hasattr(self.net, "maps_transform"): x = image_nd additional_features = self.net.maps_transform(coord_features) if self.insertion_mode == "after_c4" or self.insertion_mode == "after_aspp": c1, _, c3, c4 = self.net.feature_extractor.backbone( x, additional_features ) c1 = self.net.feature_extractor.skip_project(c1) if self.insertion_mode == "after_aspp": x = self.net.feature_extractor.aspp(c4) x = F.interpolate( x, size=c1.size()[2:], mode="bilinear", align_corners=True ) x = torch.cat((x, c1), dim=1) backbone_features = x else: backbone_features = c4 self._c1_features = c1 else: backbone_features = self.net.feature_extractor(x, additional_features)[ 0 ] return backbone_features class HRNetFeatureBRSPredictor(BRSBasePredictor): def __init__(self, model, device, opt_functor, insertion_mode="A", **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.insertion_mode = insertion_mode self._c1_features = None if self.insertion_mode == "A": self.num_channels = sum( k * model.feature_extractor.width for k in [1, 2, 4, 8] ) elif self.insertion_mode == "C": self.num_channels = 2 * model.feature_extractor.ocr_width else: raise NotImplementedError def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] if ( self.opt_data is None or self.opt_data.shape[0] // (2 * self.num_channels) != bs ): self.opt_data = np.zeros((bs * 2 * self.num_channels), dtype=np.float32) if ( num_clicks <= self.net_clicks_limit or is_image_changed or self.input_data is None ): self.input_data = self._get_head_input(image_nd, points_nd) def get_prediction_logits(scale, bias): scale = scale.view(bs, -1, 1, 1) bias = bias.view(bs, -1, 1, 1) if self.with_flip: scale = scale.repeat(2, 1, 1, 1) bias = bias.repeat(2, 1, 1, 1) scaled_backbone_features = self.input_data * scale scaled_backbone_features = scaled_backbone_features + bias if self.insertion_mode == "A": if self.net.feature_extractor.ocr_width > 0: out_aux = self.net.feature_extractor.aux_head( scaled_backbone_features ) feats = self.net.feature_extractor.conv3x3_ocr( scaled_backbone_features ) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) feats = self.net.feature_extractor.ocr_distri_head(feats, context) else: feats = scaled_backbone_features pred_logits = self.net.feature_extractor.cls_head(feats) elif self.insertion_mode == "C": pred_logits = self.net.feature_extractor.cls_head( scaled_backbone_features ) else: raise NotImplementedError pred_logits = F.interpolate( pred_logits, size=image_nd.size()[2:], mode="bilinear", align_corners=True, ) return pred_logits self.opt_functor.init_click( get_prediction_logits, pos_mask, neg_mask, self.device ) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b( func=self.opt_functor, x0=self.opt_data, **self.opt_functor.optimizer_params ) self.opt_data = opt_result[0] with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_data_nd = torch.from_numpy(self.opt_data).to(self.device) opt_vars, _ = self.opt_functor.unpack_opt_params(opt_data_nd) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits def _get_head_input(self, image_nd, points): with torch.no_grad(): image_nd, prev_mask = self.net.prepare_input(image_nd) coord_features = self.net.get_coord_features(image_nd, prev_mask, points) if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((image_nd, coord_features), dim=1)) additional_features = None elif hasattr(self.net, "maps_transform"): x = image_nd additional_features = self.net.maps_transform(coord_features) feats = self.net.feature_extractor.compute_hrnet_feats( x, additional_features ) if self.insertion_mode == "A": backbone_features = feats elif self.insertion_mode == "C": out_aux = self.net.feature_extractor.aux_head(feats) feats = self.net.feature_extractor.conv3x3_ocr(feats) context = self.net.feature_extractor.ocr_gather_head(feats, out_aux) backbone_features = self.net.feature_extractor.ocr_distri_head( feats, context ) else: raise NotImplementedError return backbone_features class InputBRSPredictor(BRSBasePredictor): def __init__(self, model, device, opt_functor, optimize_target="rgb", **kwargs): super().__init__(model, device, opt_functor=opt_functor, **kwargs) self.optimize_target = optimize_target def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) pos_mask, neg_mask = self._get_clicks_maps_nd(clicks_lists, image_nd.shape[2:]) num_clicks = len(clicks_lists[0]) if self.opt_data is None or is_image_changed: if self.optimize_target == "dmaps": opt_channels = ( self.net.coord_feature_ch - 1 if self.net.with_prev_mask else self.net.coord_feature_ch ) else: opt_channels = 3 bs = image_nd.shape[0] // 2 if self.with_flip else image_nd.shape[0] self.opt_data = torch.zeros( (bs, opt_channels, image_nd.shape[2], image_nd.shape[3]), device=self.device, dtype=torch.float32, ) def get_prediction_logits(opt_bias): input_image, prev_mask = self.net.prepare_input(image_nd) dmaps = self.net.get_coord_features(input_image, prev_mask, points_nd) if self.optimize_target == "rgb": input_image = input_image + opt_bias elif self.optimize_target == "dmaps": if self.net.with_prev_mask: dmaps[:, 1:, :, :] = dmaps[:, 1:, :, :] + opt_bias else: dmaps = dmaps + opt_bias if self.net.rgb_conv is not None: x = self.net.rgb_conv(torch.cat((input_image, dmaps), dim=1)) if self.optimize_target == "all": x = x + opt_bias coord_features = None elif hasattr(self.net, "maps_transform"): x = input_image coord_features = self.net.maps_transform(dmaps) pred_logits = self.net.backbone_forward(x, coord_features=coord_features)[ "instances" ] pred_logits = F.interpolate( pred_logits, size=image_nd.size()[2:], mode="bilinear", align_corners=True, ) return pred_logits self.opt_functor.init_click( get_prediction_logits, pos_mask, neg_mask, self.device, shape=self.opt_data.shape, ) if num_clicks > self.optimize_after_n_clicks: opt_result = fmin_l_bfgs_b( func=self.opt_functor, x0=self.opt_data.cpu().numpy().ravel(), **self.opt_functor.optimizer_params ) self.opt_data = ( torch.from_numpy(opt_result[0]) .view(self.opt_data.shape) .to(self.device) ) with torch.no_grad(): if self.opt_functor.best_prediction is not None: opt_pred_logits = self.opt_functor.best_prediction else: opt_vars, _ = self.opt_functor.unpack_opt_params(self.opt_data) opt_pred_logits = get_prediction_logits(*opt_vars) return opt_pred_logits