VTBench / src /vqvaes /xqgan /latent_perturbation.py
huaweilin's picture
update
14ce5a9
import torch
import torch.nn.functional as F
def add_perturbation(z, z_q, z_channels, codebook_norm, codebook, alpha, beta, delta):
# reshape z -> (batch, height * width, channel) and flatten
z = torch.einsum("b c h w -> b h w c", z).contiguous()
z_flattened = z.view(-1, z_channels)
if codebook_norm:
z = F.normalize(z, p=2, dim=-1)
z_flattened = F.normalize(z_flattened, p=2, dim=-1)
embedding = F.normalize(codebook.weight, p=2, dim=-1)
else:
embedding = codebook.weight
d = (
torch.sum(z_flattened**2, dim=1, keepdim=True)
+ torch.sum(embedding**2, dim=1)
- 2
* torch.einsum("bd,dn->bn", z_flattened, torch.einsum("n d -> d n", embedding))
)
_, min_encoding_indices = torch.topk(d, delta, dim=1, largest=False)
random_prob = torch.rand(min_encoding_indices.shape[0], device=d.device)
random_idx = torch.randint(0, delta, random_prob.shape, device=d.device)
random_idx = torch.where(random_prob > alpha, 0, random_idx)
min_encoding_indices = min_encoding_indices[
torch.arange(min_encoding_indices.size(0)), random_idx
]
perturbed_z_q = codebook(min_encoding_indices).view(z.shape)
if codebook_norm:
perturbed_z_q = F.normalize(perturbed_z_q, p=2, dim=-1)
perturbed_z_q = z + (perturbed_z_q - z).detach()
perturbed_z_q = torch.einsum("b h w c -> b c h w", perturbed_z_q)
mask = torch.arange(z.shape[0], device=perturbed_z_q.device) < int(
z.shape[0] * beta
)
mask = mask[:, None, None, None]
return torch.where(mask, perturbed_z_q, z_q)