Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from einops.einops import rearrange | |
import torch.nn.functional as F | |
def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]): | |
mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins) | |
mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins) | |
batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1}) | |
def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]): | |
""" Random mask generator | |
Args: | |
image (torch.Tensor): [N, C, H, W] | |
patch_size (int) | |
mask_ratio (float) | |
generator (torch.Generator): RNG to create the same random masks for validation | |
margins [float, float, float, float]: unused part for masking (up bottom left right) | |
Returns: | |
mask (torch.Tensor): (N, L) | |
""" | |
N = image.shape[0] | |
l = (image.shape[2] // patch_size) | |
L = l ** 2 | |
len_keep = int(L * (1 - mask_ratio * (1 - sum(margins)))) | |
margins = [int(margin * l) for margin in margins] | |
noise = torch.rand(N, l, l, device=image.device, generator=generator) | |
if margins[0] > 0 : noise[:,:margins[0],:] = 0 | |
if margins[1] > 0 : noise[:,-margins[1]:,:] = 0 | |
if margins[2] > 0 : noise[:,:,:margins[2]] = 0 | |
if margins[3] > 0 : noise[:,:,-margins[3]:] = 0 | |
noise = noise.flatten(1) | |
# sort noise for each sample | |
ids_shuffle = torch.argsort(noise, dim=1) | |
ids_restore = torch.argsort(ids_shuffle, dim=1) | |
# generate the binary mask: 0 is keep 1 is remove | |
mask = torch.ones([N, L], device=image.device) | |
mask[:, :len_keep] = 0 | |
# unshuffle to get the binary mask | |
mask = torch.gather(mask, dim=1, index=ids_restore) | |
return mask | |
def patchify(data): | |
""" Split images into small overlapped patches | |
Args: | |
data (dict):{ | |
'image0_norm' (torch.Tensor): [N, C, H, W] normalized image, | |
'image1_norm' (torch.Tensor): [N, C, H, W] normalized image, | |
Returns: | |
image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows) | |
image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows) | |
""" | |
stride = data['hw0_i'][0] // data['hw0_c'][0] | |
scale = data['hw0_i'][0] // data['hw0_f'][0] | |
W_f = data["W_f"] | |
kernel_size = [int(W_f*scale), int(W_f*scale)] | |
padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2 | |
image0 = data["image0_norm"] if "image0_norm" in data else data["image0"] | |
image1 = data["image1_norm"] if "image1_norm" in data else data["image1"] | |
image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding) | |
image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale) | |
image0 = image0.flatten(4) | |
image0 = image0.reshape(*image0.shape[:2], W_f**2, -1) | |
image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding) | |
image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale) | |
image1 = image1.flatten(4) | |
image1 = image1.reshape(*image1.shape[:2], W_f**2, -1) | |
return image0, image1 | |
def get_target(data): | |
"""Create target patches for mae""" | |
target0, target1 = patchify(data) | |
data.update({"target0":target0, "target1":target1}) | |