Spaces:
Build error
Build error
import random | |
import torch | |
import torch.nn.functional as F | |
import numpy as np | |
from scipy import ndimage | |
class Point: | |
def __init__(self, cfg, is_train=True): | |
self.max_points = cfg['STROKE_SAMPLER']['POINT']['NUM_POINTS'] | |
self.max_eval = cfg['STROKE_SAMPLER']['EVAL']['MAX_ITER'] | |
self.is_train = is_train | |
def draw(self, mask=None, box=None): | |
if mask.sum() < 10: | |
return torch.zeros(mask.shape).bool() # if mask is empty | |
if not self.is_train: | |
return self.draw_eval(mask=mask, box=box) | |
max_points = min(self.max_points, mask.sum().item()) # max number of points no more than total mask number | |
num_points = random.randint(1, max_points) # get a random number of points | |
h,w = mask.shape | |
view_mask = mask.view(-1) | |
non_zero_idx = view_mask.nonzero()[:,0] # get non-zero index of mask | |
selected_idx = torch.randperm(len(non_zero_idx))[:num_points] # select id | |
non_zero_idx = non_zero_idx[selected_idx] # select non-zero index | |
rand_mask = torch.zeros(view_mask.shape).bool() # init rand mask | |
rand_mask[non_zero_idx] = True # get non zero place to zero | |
# dilate | |
# struct = ndimage.generate_binary_structure(2, 2) | |
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) | |
# return rand_mask | |
return rand_mask.reshape(h, w) | |
def draw_eval(self, mask=None, box=None): | |
background = ~mask | |
neg_num = min(self.max_eval // 2, background.sum().item()) | |
pos_num = min(self.max_eval - neg_num, mask.sum().item()-1) + 1 | |
h,w = mask.shape | |
view_mask = mask.view(-1) | |
non_zero_idx_pos = view_mask.nonzero()[:,0] # get non-zero index of mask | |
selected_idx_pos = torch.randperm(len(non_zero_idx_pos))[:pos_num] # select id | |
non_zero_idx_pos = non_zero_idx_pos[selected_idx_pos] # select non-zero index | |
pos_idx = torch.ones(non_zero_idx_pos.shape) | |
view_background = background.view(-1) | |
non_zero_idx_neg = view_background.nonzero()[:,0] # get non-zero index of mask | |
selected_idx_neg = torch.randperm(len(non_zero_idx_neg))[:neg_num] # select id | |
non_zero_idx_neg = non_zero_idx_neg[selected_idx_neg] # select non-zero index | |
neg_idx = torch.ones(non_zero_idx_neg.shape) * -1 | |
non_zero_idx = torch.cat([non_zero_idx_pos, non_zero_idx_neg]) | |
idx = torch.cat([pos_idx, neg_idx]) | |
rand_idx = torch.cat([torch.zeros(1), torch.randperm(len(non_zero_idx)-1) + 1]).long() | |
non_zero_idx = non_zero_idx[rand_idx] | |
idx = idx[rand_idx] | |
rand_masks = [] | |
for i in range(0, len(non_zero_idx)): | |
rand_mask = torch.zeros(view_mask.shape) # init rand mask | |
rand_mask[non_zero_idx[0:i+1]] = idx[0:i+1] # get non zero place to zero | |
# struct = ndimage.generate_binary_structure(2, 2) | |
# rand_mask = torch.from_numpy((ndimage.binary_dilation(rand_mask.reshape(h, w).numpy(), structure=struct, iterations=5).astype(rand_mask.numpy().dtype))) | |
rand_masks += [rand_mask.reshape(h, w)] | |
# kernel_size = 3 | |
rand_masks = torch.stack(rand_masks) | |
# rand_masks = F.conv2d(rand_masks[:,None], torch.ones(1,1,kernel_size,kernel_size), padding=kernel_size//2)[:,0] | |
# rand_masks[rand_masks>0] = 1 | |
# rand_masks[rand_masks<0] = -1 | |
return rand_masks | |
def __repr__(self,): | |
return 'point' |