import torch import torch.nn.functional as F from torchvision import transforms from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide, SigmoidForPred) class BasePredictor(object): def __init__( self, model, device, net_clicks_limit=None, with_flip=False, zoom_in=None, max_size=None, **kwargs ): self.with_flip = with_flip self.net_clicks_limit = net_clicks_limit self.original_image = None self.device = device self.zoom_in = zoom_in self.prev_prediction = None self.model_indx = 0 self.click_models = None self.net_state_dict = None if isinstance(model, tuple): self.net, self.click_models = model else: self.net = model self.to_tensor = transforms.ToTensor() self.transforms = [zoom_in] if zoom_in is not None else [] if max_size is not None: self.transforms.append(LimitLongestSide(max_size=max_size)) self.transforms.append(SigmoidForPred()) if with_flip: self.transforms.append(AddHorizontalFlip()) def set_input_image(self, image): image_nd = self.to_tensor(image) for transform in self.transforms: transform.reset() self.original_image = image_nd.to(self.device) if len(self.original_image.shape) == 3: self.original_image = self.original_image.unsqueeze(0) self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) def get_prediction(self, clicker, prev_mask=None): clicks_list = clicker.get_clicks() if self.click_models is not None: model_indx = ( min( clicker.click_indx_offset + len(clicks_list), len(self.click_models) ) - 1 ) if model_indx != self.model_indx: self.model_indx = model_indx self.net = self.click_models[model_indx] input_image = self.original_image if prev_mask is None: prev_mask = self.prev_prediction if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask: input_image = torch.cat((input_image, prev_mask), dim=1) image_nd, clicks_lists, is_image_changed = self.apply_transforms( input_image, [clicks_list] ) pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) prediction = F.interpolate( pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:] ) for t in reversed(self.transforms): prediction = t.inv_transform(prediction) if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): return self.get_prediction(clicker) self.prev_prediction = prediction return prediction.cpu().numpy()[0, 0] def _get_prediction(self, image_nd, clicks_lists, is_image_changed): points_nd = self.get_points_nd(clicks_lists) return self.net(image_nd, points_nd)["instances"] def _get_transform_states(self): return [x.get_state() for x in self.transforms] def _set_transform_states(self, states): assert len(states) == len(self.transforms) for state, transform in zip(states, self.transforms): transform.set_state(state) def apply_transforms(self, image_nd, clicks_lists): is_image_changed = False for t in self.transforms: image_nd, clicks_lists = t.transform(image_nd, clicks_lists) is_image_changed |= t.image_changed return image_nd, clicks_lists, is_image_changed def get_points_nd(self, clicks_lists): total_clicks = [] num_pos_clicks = [ sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists ] num_neg_clicks = [ len(clicks_list) - num_pos for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks) ] num_max_points = max(num_pos_clicks + num_neg_clicks) if self.net_clicks_limit is not None: num_max_points = min(self.net_clicks_limit, num_max_points) num_max_points = max(1, num_max_points) for clicks_list in clicks_lists: clicks_list = clicks_list[: self.net_clicks_limit] pos_clicks = [ click.coords_and_indx for click in clicks_list if click.is_positive ] pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [ (-1, -1, -1) ] neg_clicks = [ click.coords_and_indx for click in clicks_list if not click.is_positive ] neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [ (-1, -1, -1) ] total_clicks.append(pos_clicks + neg_clicks) return torch.tensor(total_clicks, device=self.device) def get_states(self): return { "transform_states": self._get_transform_states(), "prev_prediction": self.prev_prediction.clone(), } def set_states(self, states): self._set_transform_states(states["transform_states"]) self.prev_prediction = states["prev_prediction"]