File size: 1,620 Bytes
14ce5a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn.functional as F


def add_perturbation(z, z_q, z_channels, codebook_norm, codebook, alpha, beta, delta):
    # reshape z -> (batch, height * width, channel) and flatten
    z = torch.einsum("b c h w -> b h w c", z).contiguous()
    z_flattened = z.view(-1, z_channels)

    if codebook_norm:
        z = F.normalize(z, p=2, dim=-1)
        z_flattened = F.normalize(z_flattened, p=2, dim=-1)
        embedding = F.normalize(codebook.weight, p=2, dim=-1)
    else:
        embedding = codebook.weight

    d = (
        torch.sum(z_flattened**2, dim=1, keepdim=True)
        + torch.sum(embedding**2, dim=1)
        - 2
        * torch.einsum("bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding))
    )

    _, min_encoding_indices = torch.topk(d, delta, dim=1, largest=False)
    random_prob = torch.rand(min_encoding_indices.shape[0], device=d.device)
    random_idx = torch.randint(0, delta, random_prob.shape, device=d.device)
    random_idx = torch.where(random_prob > alpha, 0, random_idx)
    min_encoding_indices = min_encoding_indices[
        torch.arange(min_encoding_indices.size(0)), random_idx
    ]

    perturbed_z_q = codebook(min_encoding_indices).view(z.shape)
    if codebook_norm:
        perturbed_z_q = F.normalize(perturbed_z_q, p=2, dim=-1)
    perturbed_z_q = z + (perturbed_z_q - z).detach()
    perturbed_z_q = torch.einsum("b h w c -> b c h w", perturbed_z_q)

    mask = torch.arange(z.shape[0], device=perturbed_z_q.device) < int(
        z.shape[0] * beta
    )
    mask = mask[:, None, None, None]

    return torch.where(mask, perturbed_z_q, z_q)