Spaces:
Runtime error
Runtime error
File size: 5,430 Bytes
2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 2cdd41c 1615d09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn.functional as F
from torchvision import transforms
from isegm.inference.transforms import (AddHorizontalFlip, LimitLongestSide,
SigmoidForPred)
class BasePredictor(object):
def __init__(
self,
model,
device,
net_clicks_limit=None,
with_flip=False,
zoom_in=None,
max_size=None,
**kwargs
):
self.with_flip = with_flip
self.net_clicks_limit = net_clicks_limit
self.original_image = None
self.device = device
self.zoom_in = zoom_in
self.prev_prediction = None
self.model_indx = 0
self.click_models = None
self.net_state_dict = None
if isinstance(model, tuple):
self.net, self.click_models = model
else:
self.net = model
self.to_tensor = transforms.ToTensor()
self.transforms = [zoom_in] if zoom_in is not None else []
if max_size is not None:
self.transforms.append(LimitLongestSide(max_size=max_size))
self.transforms.append(SigmoidForPred())
if with_flip:
self.transforms.append(AddHorizontalFlip())
def set_input_image(self, image):
image_nd = self.to_tensor(image)
for transform in self.transforms:
transform.reset()
self.original_image = image_nd.to(self.device)
if len(self.original_image.shape) == 3:
self.original_image = self.original_image.unsqueeze(0)
self.prev_prediction = torch.zeros_like(self.original_image[:, :1, :, :])
def get_prediction(self, clicker, prev_mask=None):
clicks_list = clicker.get_clicks()
if self.click_models is not None:
model_indx = (
min(
clicker.click_indx_offset + len(clicks_list), len(self.click_models)
)
- 1
)
if model_indx != self.model_indx:
self.model_indx = model_indx
self.net = self.click_models[model_indx]
input_image = self.original_image
if prev_mask is None:
prev_mask = self.prev_prediction
if hasattr(self.net, "with_prev_mask") and self.net.with_prev_mask:
input_image = torch.cat((input_image, prev_mask), dim=1)
image_nd, clicks_lists, is_image_changed = self.apply_transforms(
input_image, [clicks_list]
)
pred_logits = self._get_prediction(image_nd, clicks_lists, is_image_changed)
prediction = F.interpolate(
pred_logits, mode="bilinear", align_corners=True, size=image_nd.size()[2:]
)
for t in reversed(self.transforms):
prediction = t.inv_transform(prediction)
if self.zoom_in is not None and self.zoom_in.check_possible_recalculation():
return self.get_prediction(clicker)
self.prev_prediction = prediction
return prediction.cpu().numpy()[0, 0]
def _get_prediction(self, image_nd, clicks_lists, is_image_changed):
points_nd = self.get_points_nd(clicks_lists)
return self.net(image_nd, points_nd)["instances"]
def _get_transform_states(self):
return [x.get_state() for x in self.transforms]
def _set_transform_states(self, states):
assert len(states) == len(self.transforms)
for state, transform in zip(states, self.transforms):
transform.set_state(state)
def apply_transforms(self, image_nd, clicks_lists):
is_image_changed = False
for t in self.transforms:
image_nd, clicks_lists = t.transform(image_nd, clicks_lists)
is_image_changed |= t.image_changed
return image_nd, clicks_lists, is_image_changed
def get_points_nd(self, clicks_lists):
total_clicks = []
num_pos_clicks = [
sum(x.is_positive for x in clicks_list) for clicks_list in clicks_lists
]
num_neg_clicks = [
len(clicks_list) - num_pos
for clicks_list, num_pos in zip(clicks_lists, num_pos_clicks)
]
num_max_points = max(num_pos_clicks + num_neg_clicks)
if self.net_clicks_limit is not None:
num_max_points = min(self.net_clicks_limit, num_max_points)
num_max_points = max(1, num_max_points)
for clicks_list in clicks_lists:
clicks_list = clicks_list[: self.net_clicks_limit]
pos_clicks = [
click.coords_and_indx for click in clicks_list if click.is_positive
]
pos_clicks = pos_clicks + (num_max_points - len(pos_clicks)) * [
(-1, -1, -1)
]
neg_clicks = [
click.coords_and_indx for click in clicks_list if not click.is_positive
]
neg_clicks = neg_clicks + (num_max_points - len(neg_clicks)) * [
(-1, -1, -1)
]
total_clicks.append(pos_clicks + neg_clicks)
return torch.tensor(total_clicks, device=self.device)
def get_states(self):
return {
"transform_states": self._get_transform_states(),
"prev_prediction": self.prev_prediction.clone(),
}
def set_states(self, states):
self._set_transform_states(states["transform_states"])
self.prev_prediction = states["prev_prediction"]
|