|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision.transforms as transforms |
|
|
|
def upsample_masks(masks, size, thresh=0.5): |
|
shape = masks.shape |
|
dtype = masks.dtype |
|
h, w = shape[-2:] |
|
H, W = size |
|
if (H == h) and (W == w): |
|
return masks |
|
elif (H < h) and (W < w): |
|
s = (h // H, w // W) |
|
return masks[..., ::s[0], ::s[1]] |
|
|
|
masks = masks.unsqueeze(-2).unsqueeze(-1) |
|
masks = masks.repeat(*([1] * (len(shape) - 2)), 1, H // h, 1, W // w) |
|
if ((H % h) == 0) and ((W % w) == 0): |
|
masks = masks.view(*shape[:-2], H, W) |
|
else: |
|
_H = np.prod(masks.shape[-4:-2]) |
|
_W = np.prod(masks.shape[-2:]) |
|
masks = transforms.Resize(size)(masks.view(-1, 1, _H, _W)) > thresh |
|
masks = masks.view(*shape[:2], H, W).to(masks.dtype) |
|
return masks |
|
|
|
|
|
|
|
|
|
def partition_masks(masks, num_samples=2, leave_one_out=False): |
|
B = masks.shape[0] |
|
S = num_samples |
|
masks = masks.view(B, -1) |
|
partitioned = [torch.ones_like(masks) for _ in range(S)] |
|
for b in range(B): |
|
vis_inds = torch.where(~masks[b])[0] |
|
vis_inds = vis_inds[torch.randperm(vis_inds.size(0))] |
|
if leave_one_out: |
|
for s in range(S): |
|
partitioned[s][b][vis_inds] = 0 |
|
partitioned[s][b][vis_inds[s::S]] = 1 |
|
else: |
|
for s in range(S): |
|
partitioned[s][b][vis_inds[s::S]] = 0 |
|
return partitioned |
|
|
|
|
|
class RectangularizeMasks(nn.Module): |
|
"""Make sure all masks in a batch have same number of 1s and 0s""" |
|
|
|
def __init__(self, truncation_mode='min'): |
|
super().__init__() |
|
self._mode = truncation_mode |
|
assert self._mode in ['min', 'max', 'mean', 'full', 'none', None], (self._mode) |
|
|
|
def set_mode(self, mode): |
|
self._mode = mode |
|
|
|
def __call__(self, masks): |
|
|
|
if self._mode in ['none', None]: |
|
return masks |
|
|
|
assert isinstance(masks, torch.Tensor), type(masks) |
|
if self._mode == 'full': |
|
return torch.ones_like(masks) |
|
|
|
shape = masks.shape |
|
masks = masks.flatten(1) |
|
B, N = masks.shape |
|
num_masked = masks.float().sum(-1) |
|
M = { |
|
'min': torch.amin, 'max': torch.amax, 'mean': torch.mean |
|
}[self._mode](num_masked).long() |
|
|
|
num_changes = num_masked.long() - M |
|
|
|
for b in range(B): |
|
nc = num_changes[b] |
|
if nc > 0: |
|
inds = torch.where(masks[b])[0] |
|
inds = inds[torch.randperm(inds.size(0))[:nc].to(inds.device)] |
|
masks[b, inds] = 0 |
|
elif nc < 0: |
|
inds = torch.where(~masks[b])[0] |
|
inds = inds[torch.randperm(inds.size(0))[:-nc].to(inds.device)] |
|
masks[b, inds] = 1 |
|
if list(masks.shape) != list(shape): |
|
masks = masks.view(*shape) |
|
|
|
return masks |
|
|
|
|
|
class UniformMaskingGenerator(object): |
|
def __init__(self, input_size, mask_ratio, seed=None, clumping_factor=1, randomize_num_visible=False): |
|
self.frames = None |
|
if len(input_size) == 3: |
|
self.frames, self.height, self.width = input_size |
|
elif len(input_size) == 2: |
|
self.height, self.width = input_size |
|
elif len(input_size) == 1 or isinstance(input_size, int): |
|
self.height = self.width = input_size |
|
|
|
self.clumping_factor = clumping_factor |
|
self.pad_h = self.height % self.c[0] |
|
self.pad_w = self.width % self.c[1] |
|
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) |
|
self.mask_ratio = mask_ratio |
|
|
|
self.rng = np.random.RandomState(seed=seed) |
|
self.randomize_num_visible = randomize_num_visible |
|
|
|
@property |
|
def num_masks_per_frame(self): |
|
if not hasattr(self, '_num_masks_per_frame'): |
|
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) |
|
return self._num_masks_per_frame |
|
|
|
@num_masks_per_frame.setter |
|
def num_masks_per_frame(self, val): |
|
self._num_masks_per_frame = val |
|
self._mask_ratio = (val / self.num_patches_per_frame) |
|
|
|
@property |
|
def c(self): |
|
if isinstance(self.clumping_factor, int): |
|
return (self.clumping_factor, self.clumping_factor) |
|
else: |
|
return self.clumping_factor[:2] |
|
|
|
@property |
|
def mask_ratio(self): |
|
return self._mask_ratio |
|
|
|
@mask_ratio.setter |
|
def mask_ratio(self, val): |
|
self._mask_ratio = val |
|
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) |
|
|
|
@property |
|
def num_visible(self): |
|
return self.num_patches_per_frame - self.num_masks_per_frame |
|
|
|
@num_visible.setter |
|
def num_visible(self, val): |
|
self.num_masks_per_frame = self.num_patches_per_frame - val |
|
|
|
def __repr__(self): |
|
repr_str = "Mask: total patches per frame {}, mask patches per frame {}, mask ratio {}, random num num visible? {}".format( |
|
self.num_patches_per_frame, self.num_masks_per_frame, self.mask_ratio, self.randomize_num_visible |
|
) |
|
return repr_str |
|
|
|
def sample_mask_per_frame(self): |
|
num_masks = self.num_masks_per_frame |
|
if self.randomize_num_visible: |
|
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) |
|
mask = np.hstack([ |
|
np.zeros(self.num_patches_per_frame - num_masks), |
|
np.ones(num_masks)]) |
|
self.rng.shuffle(mask) |
|
if max(*self.c) > 1: |
|
mask = mask.reshape(self.height // self.c[0], |
|
1, |
|
self.width // self.c[1], |
|
1) |
|
mask = np.tile(mask, (1, self.c[0], 1, self.c[1])) |
|
mask = mask.reshape((self.height - self.pad_h, self.width - self.pad_w)) |
|
_pad_h = self.rng.choice(range(self.pad_h + 1)) |
|
pad_h = (self.pad_h - _pad_h, _pad_h) |
|
_pad_w = self.rng.choice(range(self.pad_w + 1)) |
|
pad_w = (self.pad_w - _pad_w, _pad_w) |
|
mask = np.pad(mask, |
|
(pad_h, pad_w), |
|
constant_values=1 |
|
).reshape((self.height, self.width)) |
|
return mask |
|
|
|
def __call__(self, num_frames=None): |
|
num_frames = (num_frames or self.frames) or 1 |
|
masks = np.stack([self.sample_mask_per_frame() for _ in range(num_frames)]).flatten() |
|
return masks |
|
|
|
|
|
class TubeMaskingGenerator(UniformMaskingGenerator): |
|
|
|
def __call__(self, num_frames=None): |
|
num_frames = (num_frames or self.frames) or 1 |
|
masks = np.tile(self.sample_mask_per_frame(), (num_frames, 1)).flatten() |
|
return masks |
|
|
|
|
|
class RotatedTableMaskingGenerator(TubeMaskingGenerator): |
|
|
|
def __init__(self, tube_length=None, *args, **kwargs): |
|
super(RotatedTableMaskingGenerator, self).__init__(*args, **kwargs) |
|
self.tube_length = tube_length |
|
|
|
def __call__(self, num_frames=None): |
|
num_frames = (num_frames or self.frames) or 2 |
|
tube_length = self.tube_length or (num_frames - 1) |
|
table_thickness = num_frames - tube_length |
|
assert tube_length < num_frames, (tube_length, num_frames) |
|
|
|
tubes = super().__call__(num_frames=tube_length) |
|
top = np.zeros(table_thickness * self.height * self.width).astype(tubes.dtype).flatten() |
|
masks = np.concatenate([top, tubes], 0) |
|
return masks |
|
|
|
|
|
class PytorchMaskGeneratorWrapper(nn.Module): |
|
"""Pytorch wrapper for numpy masking generators""" |
|
|
|
def __init__(self, |
|
mask_generator=TubeMaskingGenerator, |
|
*args, **kwargs): |
|
super().__init__() |
|
self.mask_generator = mask_generator(*args, **kwargs) |
|
|
|
@property |
|
def mask_ratio(self): |
|
return self.mask_generator.mask_ratio |
|
|
|
@mask_ratio.setter |
|
def mask_ratio(self, value): |
|
self.mask_generator.mask_ratio = value |
|
|
|
def forward(self, device='cuda', dtype_out=torch.bool, **kwargs): |
|
masks = self.mask_generator(**kwargs) |
|
masks = torch.tensor(masks).to(device).to(dtype_out) |
|
return masks |
|
|
|
|
|
class MaskingGenerator(nn.Module): |
|
"""Pytorch base class for masking generators""" |
|
|
|
def __init__(self, |
|
input_size, |
|
mask_ratio, |
|
seed=0, |
|
visible_frames=0, |
|
clumping_factor=1, |
|
randomize_num_visible=False, |
|
create_on_cpu=True, |
|
always_batch=False): |
|
super().__init__() |
|
self.frames = None |
|
|
|
if len(input_size) == 3: |
|
self.frames, self.height, self.width = input_size |
|
elif len(input_size) == 2: |
|
self.height, self.width = input_size |
|
elif len(input_size) == 1 or isinstance(input_size, int): |
|
self.height = self.width = input_size |
|
|
|
self.clumping_factor = clumping_factor |
|
self.pad_h = self.height % self.c[0] |
|
self.pad_w = self.width % self.c[1] |
|
self.num_patches_per_frame = (self.height // self.c[0]) * (self.width // self.c[1]) |
|
|
|
self.mask_ratio = mask_ratio |
|
self.visible_frames = visible_frames |
|
self.always_batch = always_batch |
|
self.create_on_cpu = create_on_cpu |
|
|
|
self.rng = np.random.RandomState(seed=seed) |
|
self._set_torch_seed(seed) |
|
|
|
self.randomize_num_visible = randomize_num_visible |
|
|
|
@property |
|
def num_masks_per_frame(self): |
|
if not hasattr(self, '_num_masks_per_frame'): |
|
self._num_masks_per_frame = int(self.mask_ratio * self.num_patches_per_frame) |
|
return self._num_masks_per_frame |
|
|
|
@num_masks_per_frame.setter |
|
def num_masks_per_frame(self, val): |
|
self._num_masks_per_frame = val |
|
self._mask_ratio = (val / self.num_patches_per_frame) |
|
|
|
@property |
|
def c(self): |
|
if isinstance(self.clumping_factor, int): |
|
return (self.clumping_factor,) * 2 |
|
else: |
|
return self.clumping_factor[:2] |
|
|
|
@property |
|
def mask_ratio(self): |
|
return self._mask_ratio |
|
|
|
@mask_ratio.setter |
|
def mask_ratio(self, val): |
|
self._mask_ratio = val |
|
self._num_masks_per_frame = int(self._mask_ratio * self.num_patches_per_frame) |
|
|
|
@property |
|
def num_visible(self): |
|
return self.num_patches_per_frame - self.num_masks_per_frame |
|
|
|
@num_visible.setter |
|
def num_visible(self, val): |
|
self.num_masks_per_frame = self.num_patches_per_frame - val |
|
|
|
def _set_torch_seed(self, seed): |
|
self.seed = seed |
|
torch.manual_seed(self.seed) |
|
|
|
def __repr__(self): |
|
repr_str = ("Class: {}\nMask: total patches per mask {},\n" + \ |
|
"mask patches per mask {}, visible patches per mask {}, mask ratio {:0.3f}\n" + \ |
|
"randomize num visible? {}").format( |
|
type(self).__name__, self.num_patches_per_frame, |
|
self.num_masks_per_frame, self.num_visible, self.mask_ratio, |
|
self.randomize_num_visible |
|
) |
|
return repr_str |
|
|
|
def sample_mask_per_frame(self, *args, **kwargs): |
|
num_masks = self.num_masks_per_frame |
|
if self.randomize_num_visible: |
|
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1)) |
|
|
|
mask = torch.cat([ |
|
torch.zeros([self.num_patches_per_frame - num_masks]), |
|
torch.ones([num_masks])], 0).bool() |
|
inds = torch.randperm(mask.size(0)).long() |
|
mask = mask[inds] |
|
|
|
if max(*self.c) > 1: |
|
mask = mask.view(self.height // self.c[0], |
|
1, |
|
self.width // self.c[1], |
|
1) |
|
mask = torch.tile(mask, (1, self.c[0], 1, self.c[1])) |
|
mask = mask.reshape(self.height - self.pad_h, self.width - self.pad_w) |
|
_pad_h = self.rng.choice(range(self.pad_h + 1)) |
|
pad_h = (self.pad_h - _pad_h, _pad_h) |
|
_pad_w = self.rng.choice(range(self.pad_w + 1)) |
|
pad_w = (self.pad_w - _pad_w, _pad_w) |
|
mask = F.pad(mask, |
|
pad_w + pad_h, |
|
mode='constant', |
|
value=1) |
|
mask = mask.reshape(self.height, self.width) |
|
|
|
return mask |
|
|
|
def forward(self, x=None, num_frames=None): |
|
|
|
num_frames = (num_frames or self.frames) or 1 |
|
if isinstance(x, torch.Tensor): |
|
batch_size = x.size(0) |
|
masks = torch.stack([ |
|
torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() |
|
for b in range(batch_size)], 0) |
|
if not self.create_on_cpu: |
|
masks = masks.to(x.device) |
|
if batch_size == 1 and not self.always_batch: |
|
masks = masks.squeeze(0) |
|
else: |
|
batch_size = 1 |
|
masks = torch.cat([self.sample_mask_per_frame() for _ in range(num_frames)], 0).flatten() |
|
if self.always_batch: |
|
masks = masks[None] |
|
|
|
if self.visible_frames > 0: |
|
vis = torch.zeros((batch_size, 1, self.height, self.width), dtype=torch.bool) |
|
vis = vis.view(masks.shape).to(masks.device) |
|
masks = torch.cat(([vis] * self.visible_frames) + [masks], -1) |
|
|
|
return masks |
|
|