Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,476 Bytes
a930e1f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 |
import torch
import torch.nn as nn
from einops.einops import rearrange
import torch.nn.functional as F
def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]):
mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins)
mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins)
batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1})
def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]):
""" Random mask generator
Args:
image (torch.Tensor): [N, C, H, W]
patch_size (int)
mask_ratio (float)
generator (torch.Generator): RNG to create the same random masks for validation
margins [float, float, float, float]: unused part for masking (up bottom left right)
Returns:
mask (torch.Tensor): (N, L)
"""
N = image.shape[0]
l = (image.shape[2] // patch_size)
L = l ** 2
len_keep = int(L * (1 - mask_ratio * (1 - sum(margins))))
margins = [int(margin * l) for margin in margins]
noise = torch.rand(N, l, l, device=image.device, generator=generator)
if margins[0] > 0 : noise[:,:margins[0],:] = 0
if margins[1] > 0 : noise[:,-margins[1]:,:] = 0
if margins[2] > 0 : noise[:,:,:margins[2]] = 0
if margins[3] > 0 : noise[:,:,-margins[3]:] = 0
noise = noise.flatten(1)
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# generate the binary mask: 0 is keep 1 is remove
mask = torch.ones([N, L], device=image.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return mask
def patchify(data):
""" Split images into small overlapped patches
Args:
data (dict):{
'image0_norm' (torch.Tensor): [N, C, H, W] normalized image,
'image1_norm' (torch.Tensor): [N, C, H, W] normalized image,
Returns:
image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
"""
stride = data['hw0_i'][0] // data['hw0_c'][0]
scale = data['hw0_i'][0] // data['hw0_f'][0]
W_f = data["W_f"]
kernel_size = [int(W_f*scale), int(W_f*scale)]
padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2
image0 = data["image0_norm"] if "image0_norm" in data else data["image0"]
image1 = data["image1_norm"] if "image1_norm" in data else data["image1"]
image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding)
image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
image0 = image0.flatten(4)
image0 = image0.reshape(*image0.shape[:2], W_f**2, -1)
image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding)
image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
image1 = image1.flatten(4)
image1 = image1.reshape(*image1.shape[:2], W_f**2, -1)
return image0, image1
def get_target(data):
"""Create target patches for mae"""
target0, target1 = patchify(data)
data.update({"target0":target0, "target1":target1})
|