curt-park's picture
Refactor code
1615d09
import torch
import torch.nn.functional as F
from torchvision import transforms
from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide,
SigmoidForPred)
class BasePredictor(object):
def __init__(
self,
model,
device,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
**kwargs
):
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.device = device
self.zoom_in = zoom_in
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
if isinstance(model, tuple):
self.net, self.click_models = model
else:
self.net = model
self.to_tensor = transforms.ToTensor()
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
def set_input_image(self, image):
image_nd = self.to_tensor(image)
for transform in self.transforms:
transform.reset()
self.original_image = image_nd.to(self.device)
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
def get_prediction(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks()
if self.click_models is not None:
model_indx = (
min(
clicker.click_indx_offset + len(clicks_list), len(self.click_models)
)
- 1
)
if model_indx != self.model_indx:
self.model_indx = model_indx
self.net = self.click_models[model_indx]
input_image = self.original_image
if prev_mask is None:
prev_mask = self.prev_prediction
if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask:
input_image = torch.cat((input_image, prev_mask), dim=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list]
)
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
prediction = F.interpolate(
pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
)
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
return self.get_prediction(clicker)
self.prev_prediction = prediction
return prediction.cpu().numpy()[0, 0]
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
return self.net(image_nd, points_nd)["instances"]
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[: self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [
click.coords_and_indx for click in clicks_list if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
def get_states(self):
return {
"transform_states": self._get_transform_states(),
"prev_prediction": self.prev_prediction.clone(),
}
def set_states(self, states):
self._set_transform_states(states["transform_states"])
self.prev_prediction = states["prev_prediction"]