File size: 3,476 Bytes
a930e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch 
import torch.nn as nn
from einops.einops import rearrange
import torch.nn.functional as F

def generate_random_masks(batch, patch_size, mask_ratio, generator=None, margins=[0,0,0,0]):
    mae_mask0 = _gen_random_mask(batch['image0'], patch_size, mask_ratio, generator, margins=margins)
    mae_mask1 = _gen_random_mask(batch['image1'], patch_size, mask_ratio, generator, margins=margins)
    batch.update({"mae_mask0" : mae_mask0, "mae_mask1": mae_mask1})

def _gen_random_mask(image, patch_size, mask_ratio, generator=None, margins=[0, 0, 0, 0]): 
    """ Random mask generator

    Args:

        image (torch.Tensor): [N, C, H, W]

        patch_size (int)

        mask_ratio (float)

        generator (torch.Generator): RNG to create the same random masks for validation  

        margins [float, float, float, float]: unused part for masking (up bottom left right)

    Returns:

        mask (torch.Tensor): (N, L)

    """ 
    N = image.shape[0]
    l = (image.shape[2] // patch_size)
    L = l ** 2
    len_keep = int(L *  (1 - mask_ratio * (1 - sum(margins))))

    margins = [int(margin * l) for margin in margins]

    noise = torch.rand(N, l, l, device=image.device, generator=generator)
    if margins[0] > 0 : noise[:,:margins[0],:] = 0
    if margins[1] > 0 : noise[:,-margins[1]:,:] = 0
    if margins[2] > 0 : noise[:,:,:margins[2]] = 0
    if margins[3] > 0 : noise[:,:,-margins[3]:] = 0
    noise = noise.flatten(1)

    # sort noise for each sample
    ids_shuffle = torch.argsort(noise, dim=1)
    ids_restore = torch.argsort(ids_shuffle, dim=1)

    # generate the binary mask: 0 is keep 1 is remove
    mask = torch.ones([N, L], device=image.device)
    mask[:, :len_keep] = 0
    # unshuffle to get the binary mask
    mask = torch.gather(mask, dim=1, index=ids_restore)
    return mask

def patchify(data):
    """ Split images into small overlapped patches

    Args:

        data (dict):{

            'image0_norm' (torch.Tensor): [N, C, H, W] normalized image,

            'image1_norm' (torch.Tensor): [N, C, H, W] normalized image,

    Returns:

        image0 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)

        image1 (torch.Tensor): [N, K, W_f**2, -1] (K: num of windows)

    """
    stride = data['hw0_i'][0] // data['hw0_c'][0]
    scale = data['hw0_i'][0] // data['hw0_f'][0]
    W_f = data["W_f"]
    kernel_size = [int(W_f*scale), int(W_f*scale)]
    padding = kernel_size[0]//2 -1 if kernel_size[0] % 2 == 0 else kernel_size[0]//2

    image0 = data["image0_norm"] if "image0_norm" in data else data["image0"]
    image1 = data["image1_norm"] if "image1_norm" in data else data["image1"]
    
    image0 = F.unfold(image0, kernel_size=kernel_size, stride=stride, padding=padding)
    image0 = rearrange(image0, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
    image0 = image0.flatten(4)
    image0 = image0.reshape(*image0.shape[:2], W_f**2, -1)
    
    image1 = F.unfold(image1, kernel_size=kernel_size, stride=stride, padding=padding)
    image1 = rearrange(image1, 'n (c h p w q) l -> n l h w p q c', h=W_f, w=W_f, p=scale, q=scale)
    image1 = image1.flatten(4)
    image1 = image1.reshape(*image1.shape[:2], W_f**2, -1)

    return image0, image1

def get_target(data):
    """Create target patches for mae"""
    target0, target1 = patchify(data)
    data.update({"target0":target0, "target1":target1})