Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide, | |
SigmoidForPred) | |
class BasePredictor(object): | |
def __init__( | |
self, | |
model, | |
device, | |
net_clicks_limit=None, | |
with_flip=False, | |
zoom_in=None, | |
max_size=None, | |
**kwargs | |
): | |
self.with_flip = with_flip | |
self.net_clicks_limit = net_clicks_limit | |
self.original_image = None | |
self.device = device | |
self.zoom_in = zoom_in | |
self.prev_prediction = None | |
self.model_indx = 0 | |
self.click_models = None | |
self.net_state_dict = None | |
if isinstance(model, tuple): | |
self.net, self.click_models = model | |
else: | |
self.net = model | |
self.to_tensor = transforms.ToTensor() | |
self.transforms = [zoom_in] if zoom_in is not None else [] | |
if max_size is not None: | |
self.transforms.append(LimitLongestSide(max_size=max_size)) | |
self.transforms.append(SigmoidForPred()) | |
if with_flip: | |
self.transforms.append(AddHorizontalFlip()) | |
def set_input_image(self, image): | |
image_nd = self.to_tensor(image) | |
for transform in self.transforms: | |
transform.reset() | |
self.original_image = image_nd.to(self.device) | |
if len(self.original_image.shape) == 3: | |
self.original_image = self.original_image.unsqueeze(0) | |
self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :]) | |
def get_prediction(self, clicker, prev_mask=None): | |
clicks_list = clicker.get_clicks() | |
if self.click_models is not None: | |
model_indx = ( | |
min( | |
clicker.click_indx_offset + len(clicks_list), len(self.click_models) | |
) | |
- 1 | |
) | |
if model_indx != self.model_indx: | |
self.model_indx = model_indx | |
self.net = self.click_models[model_indx] | |
input_image = self.original_image | |
if prev_mask is None: | |
prev_mask = self.prev_prediction | |
if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask: | |
input_image = torch.cat((input_image, prev_mask), dim=1) | |
image_nd, clicks_lists, is_image_changed = self.apply_transforms( | |
input_image, [clicks_list] | |
) | |
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed) | |
prediction = F.interpolate( | |
pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:] | |
) | |
for t in reversed(self.transforms): | |
prediction = t.inv_transform(prediction) | |
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation(): | |
return self.get_prediction(clicker) | |
self.prev_prediction = prediction | |
return prediction.cpu().numpy()[0, 0] | |
def _get_prediction(self, image_nd, clicks_lists, is_image_changed): | |
points_nd = self.get_points_nd(clicks_lists) | |
return self.net(image_nd, points_nd)["instances"] | |
def _get_transform_states(self): | |
return [x.get_state() for x in self.transforms] | |
def _set_transform_states(self, states): | |
assert len(states) == len(self.transforms) | |
for state, transform in zip(states, self.transforms): | |
transform.set_state(state) | |
def apply_transforms(self, image_nd, clicks_lists): | |
is_image_changed = False | |
for t in self.transforms: | |
image_nd, clicks_lists = t.transform(image_nd, clicks_lists) | |
is_image_changed |= t.image_changed | |
return image_nd, clicks_lists, is_image_changed | |
def get_points_nd(self, clicks_lists): | |
total_clicks = [] | |
num_pos_clicks = [ | |
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists | |
] | |
num_neg_clicks = [ | |
len(clicks_list) - num_pos | |
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks) | |
] | |
num_max_points = max(num_pos_clicks + num_neg_clicks) | |
if self.net_clicks_limit is not None: | |
num_max_points = min(self.net_clicks_limit, num_max_points) | |
num_max_points = max(1, num_max_points) | |
for clicks_list in clicks_lists: | |
clicks_list = clicks_list[: self.net_clicks_limit] | |
pos_clicks = [ | |
click.coords_and_indx for click in clicks_list if click.is_positive | |
] | |
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [ | |
(-1, -1, -1) | |
] | |
neg_clicks = [ | |
click.coords_and_indx for click in clicks_list if not click.is_positive | |
] | |
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [ | |
(-1, -1, -1) | |
] | |
total_clicks.append(pos_clicks + neg_clicks) | |
return torch.tensor(total_clicks, device=self.device) | |
def get_states(self): | |
return { | |
"transform_states": self._get_transform_states(), | |
"prev_prediction": self.prev_prediction.clone(), | |
} | |
def set_states(self, states): | |
self._set_transform_states(states["transform_states"]) | |
self.prev_prediction = states["prev_prediction"] | |