MINIMA / third_party /XoFTR /src /utils /pretrain_utils.py
lsxi77777's picture
commit message
a930e1f
import torch
import torch.nn as nn
from einops.einops import rearrange
import torch.nn.functional as F
def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]):
mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins)
mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins)
batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1})
def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]):
""" Random mask generator
Args:
image (torch.Tensor): [N, C, H, W]
patch_size (int)
mask_ratio (float)
generator (torch.Generator): RNG to create the same random masks for validation
margins [float, float, float, float]: unused part for masking (up bottom left right)
Returns:
mask (torch.Tensor): (N, L)
"""
N = image.shape[0]
l = (image.shape[2] // patch_size)
L = l ** 2
len_keep = int(L * (1 - mask_ratio * (1 - sum(margins))))
margins = [int(margin * l) for margin in margins]
noise = torch.rand(N, l, l, device=image.device, generator=generator)
if margins[0] > 0 : noise[:,:margins[0],:] = 0
if margins[1] > 0 : noise[:,-margins[1]:,:] = 0
if margins[2] > 0 : noise[:,:,:margins[2]] = 0
if margins[3] > 0 : noise[:,:,-margins[3]:] = 0
noise = noise.flatten(1)
# sort noise for each sample
ids_shuffle = torch.argsort(noise, dim=1)
ids_restore = torch.argsort(ids_shuffle, dim=1)
# generate the binary mask: 0 is keep 1 is remove
mask = torch.ones([N, L], device=image.device)
mask[:, :len_keep] = 0
# unshuffle to get the binary mask
mask = torch.gather(mask, dim=1, index=ids_restore)
return mask
def patchify(data):
""" Split images into small overlapped patches
Args:
data (dict):{
'image0_norm' (torch.Tensor): [N, C, H, W] normalized image,
'image1_norm' (torch.Tensor): [N, C, H, W] normalized image,
Returns:
image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)
"""
stride = data['hw0_i'][0] // data['hw0_c'][0]
scale = data['hw0_i'][0] // data['hw0_f'][0]
W_f = data["W_f"]
kernel_size = [int(W_f*scale), int(W_f*scale)]
padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2
image0 = data["image0_norm"] if "image0_norm" in data else data["image0"]
image1 = data["image1_norm"] if "image1_norm" in data else data["image1"]
image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding)
image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
image0 = image0.flatten(4)
image0 = image0.reshape(*image0.shape[:2], W_f**2, -1)
image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding)
image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
image1 = image1.flatten(4)
image1 = image1.reshape(*image1.shape[:2], W_f**2, -1)
return image0, image1
def get_target(data):
"""Create target patches for mae"""
target0, target1 = patchify(data)
data.update({"target0":target0, "target1":target1})