File size: 3,021 Bytes
6dfcb0f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
import numpy as np
import torch
def get_tubes(masks_per_frame, tube_length):
rp = torch.randperm(len(masks_per_frame))
masks_per_frame = masks_per_frame[rp]
tubes = [masks_per_frame]
for x in range(tube_length - 1):
masks_per_frame = masks_per_frame.clone()
rp = torch.randperm(len(masks_per_frame))
masks_per_frame = masks_per_frame[rp]
tubes.append(masks_per_frame)
tubes = torch.vstack(tubes)
return tubes
class RotatedTableMaskingGenerator:
def __init__(self,
input_size,
mask_ratio,
tube_length,
batch_size,
mask_type='rotated_table',
seed=None,
randomize_num_visible=False):
self.batch_size = batch_size
self.mask_ratio = mask_ratio
self.tube_length = tube_length
self.frames, self.height, self.width = input_size
self.num_patches_per_frame = self.height * self.width
self.total_patches = self.frames * self.num_patches_per_frame
self.seed = seed
self.randomize_num_visible = randomize_num_visible
self.mask_type = mask_type
def __repr__(self):
repr_str = "Inverted Table Mask: total patches {}, tube length {}, randomize num visible? {}, seed {}".format(
self.total_patches, self.tube_length, self.randomize_num_visible, self.seed
)
return repr_str
def __call__(self, m=None):
if self.mask_type == 'rotated_table_magvit':
self.mask_ratio = np.random.uniform(low=0.0, high=1)
self.mask_ratio = np.cos(self.mask_ratio * np.pi / 2)
elif self.mask_type == 'rotated_table_maskvit':
self.mask_ratio = np.random.uniform(low=0.5, high=1)
all_masks = []
for b in range(self.batch_size):
self.num_masks_per_frame = max(0, int(self.mask_ratio * self.num_patches_per_frame))
self.total_masks = self.tube_length * self.num_masks_per_frame
num_masks = self.num_masks_per_frame
if self.randomize_num_visible:
assert "Randomize num visible Not implemented"
num_masks = self.rng.randint(low=num_masks, high=(self.num_patches_per_frame + 1))
if self.mask_ratio == 0:
mask_per_frame = torch.hstack([
torch.zeros(self.num_patches_per_frame - num_masks),
])
else:
mask_per_frame = torch.hstack([
torch.zeros(self.num_patches_per_frame - num_masks),
torch.ones(num_masks),
])
tubes = get_tubes(mask_per_frame, self.tube_length)
top = torch.zeros(self.height * self.width).to(tubes.dtype)
top = torch.tile(top, (self.frames - self.tube_length, 1))
mask = torch.cat([top, tubes])
mask = mask.flatten()
all_masks.append(mask)
return torch.stack(all_masks)
|