mridulk commited on
Commit
02e3b25
·
1 Parent(s): d759c1a

added files

Browse files
ldm/modules/losses/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
ldm/modules/losses/contperceptual.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5
+
6
+
7
+ class LPIPSWithDiscriminator(nn.Module):
8
+ def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11
+ disc_loss="hinge"):
12
+
13
+ super().__init__()
14
+ assert disc_loss in ["hinge", "vanilla"]
15
+ self.kl_weight = kl_weight
16
+ self.pixel_weight = pixelloss_weight
17
+ self.perceptual_loss = LPIPS().eval()
18
+ self.perceptual_weight = perceptual_weight
19
+ # output log variance
20
+ self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21
+
22
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23
+ n_layers=disc_num_layers,
24
+ use_actnorm=use_actnorm
25
+ ).apply(weights_init)
26
+ self.discriminator_iter_start = disc_start
27
+ self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28
+ self.disc_factor = disc_factor
29
+ self.discriminator_weight = disc_weight
30
+ self.disc_conditional = disc_conditional
31
+
32
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33
+ if last_layer is not None:
34
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36
+ else:
37
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39
+
40
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42
+ d_weight = d_weight * self.discriminator_weight
43
+ return d_weight
44
+
45
+ def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46
+ global_step, last_layer=None, cond=None, split="train",
47
+ weights=None):
48
+ rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49
+ if self.perceptual_weight > 0:
50
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
52
+
53
+ nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54
+ weighted_nll_loss = nll_loss
55
+ if weights is not None:
56
+ weighted_nll_loss = weights*nll_loss
57
+ weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58
+ nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59
+ kl_loss = posteriors.kl()
60
+ kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61
+
62
+ # now the GAN part
63
+ if optimizer_idx == 0:
64
+ # generator update
65
+ if cond is None:
66
+ assert not self.disc_conditional
67
+ logits_fake = self.discriminator(reconstructions.contiguous())
68
+ else:
69
+ assert self.disc_conditional
70
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71
+ g_loss = -torch.mean(logits_fake)
72
+
73
+ if self.disc_factor > 0.0:
74
+ try:
75
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76
+ except RuntimeError:
77
+ assert not self.training
78
+ d_weight = torch.tensor(0.0)
79
+ else:
80
+ d_weight = torch.tensor(0.0)
81
+
82
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83
+ loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84
+
85
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86
+ "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
88
+ "{}/d_weight".format(split): d_weight.detach(),
89
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
90
+ "{}/g_loss".format(split): g_loss.detach().mean(),
91
+ }
92
+ return loss, log
93
+
94
+ if optimizer_idx == 1:
95
+ # second pass for discriminator update
96
+ if cond is None:
97
+ logits_real = self.discriminator(inputs.contiguous().detach())
98
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
99
+ else:
100
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102
+
103
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105
+
106
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107
+ "{}/logits_real".format(split): logits_real.detach().mean(),
108
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
109
+ }
110
+ return d_loss, log
111
+
ldm/modules/losses/lpips.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from torchvision import models
6
+ from collections import namedtuple
7
+
8
+ from ldm.util import get_ckpt_path
9
+
10
+
11
+ class LPIPS(nn.Module):
12
+ # Learned perceptual metric
13
+ def __init__(self, use_dropout=True):
14
+ super().__init__()
15
+ self.scaling_layer = ScalingLayer()
16
+ self.chns = [64, 128, 256, 512, 512] # vg16 features
17
+ self.net = vgg16(pretrained=True, requires_grad=False)
18
+ self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
19
+ self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
20
+ self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
21
+ self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
22
+ self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
23
+ self.load_from_pretrained()
24
+ for param in self.parameters():
25
+ param.requires_grad = False
26
+
27
+ def load_from_pretrained(self, name="vgg_lpips"):
28
+ ckpt = get_ckpt_path(name, "ldm/modules/autoencoder/lpips")
29
+ self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
30
+ print("loaded pretrained LPIPS loss from {}".format(ckpt))
31
+
32
+ @classmethod
33
+ def from_pretrained(cls, name="vgg_lpips"):
34
+ if name != "vgg_lpips":
35
+ raise NotImplementedError
36
+ model = cls()
37
+ ckpt = get_ckpt_path(name)
38
+ model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
39
+ return model
40
+
41
+ def forward(self, input, target):
42
+ in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
43
+ outs0, outs1 = self.net(in0_input), self.net(in1_input)
44
+ feats0, feats1, diffs = {}, {}, {}
45
+ lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
46
+ for kk in range(len(self.chns)):
47
+ feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
48
+ diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
49
+
50
+ res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
51
+ val = res[0]
52
+ for l in range(1, len(self.chns)):
53
+ val += res[l]
54
+ return val
55
+
56
+
57
+ class ScalingLayer(nn.Module):
58
+ def __init__(self):
59
+ super(ScalingLayer, self).__init__()
60
+ self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
61
+ self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
62
+
63
+ def forward(self, inp):
64
+ return (inp - self.shift) / self.scale
65
+
66
+
67
+ class NetLinLayer(nn.Module):
68
+ """ A single linear layer which does a 1x1 conv """
69
+ def __init__(self, chn_in, chn_out=1, use_dropout=False):
70
+ super(NetLinLayer, self).__init__()
71
+ layers = [nn.Dropout(), ] if (use_dropout) else []
72
+ layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
73
+ self.model = nn.Sequential(*layers)
74
+
75
+
76
+ class vgg16(torch.nn.Module):
77
+ def __init__(self, requires_grad=False, pretrained=True):
78
+ super(vgg16, self).__init__()
79
+ vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
80
+ self.slice1 = torch.nn.Sequential()
81
+ self.slice2 = torch.nn.Sequential()
82
+ self.slice3 = torch.nn.Sequential()
83
+ self.slice4 = torch.nn.Sequential()
84
+ self.slice5 = torch.nn.Sequential()
85
+ self.N_slices = 5
86
+ for x in range(4):
87
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
88
+ for x in range(4, 9):
89
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
90
+ for x in range(9, 16):
91
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
92
+ for x in range(16, 23):
93
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
94
+ for x in range(23, 30):
95
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
96
+ if not requires_grad:
97
+ for param in self.parameters():
98
+ param.requires_grad = False
99
+
100
+ def forward(self, X):
101
+ h = self.slice1(X)
102
+ h_relu1_2 = h
103
+ h = self.slice2(h)
104
+ h_relu2_2 = h
105
+ h = self.slice3(h)
106
+ h_relu3_3 = h
107
+ h = self.slice4(h)
108
+ h_relu4_3 = h
109
+ h = self.slice5(h)
110
+ h_relu5_3 = h
111
+ vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
112
+ out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
113
+ return out
114
+
115
+
116
+ def normalize_tensor(x,eps=1e-10):
117
+ norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
118
+ return x/(norm_factor+eps)
119
+
120
+
121
+ def spatial_average(x, keepdim=True):
122
+ return x.mean([2,3],keepdim=keepdim)
123
+
ldm/modules/losses/vqperceptual.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+
6
+ from ldm.modules.discriminator.model import NLayerDiscriminator, weights_init
7
+ from ldm.modules.losses.lpips import LPIPS
8
+ # from scripts.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9
+
10
+
11
+ def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12
+ assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13
+ loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14
+ loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15
+ loss_real = (weights * loss_real).sum() / weights.sum()
16
+ loss_fake = (weights * loss_fake).sum() / weights.sum()
17
+ d_loss = 0.5 * (loss_real + loss_fake)
18
+ return d_loss
19
+
20
+ def hinge_d_loss(logits_real, logits_fake):
21
+ loss_real = torch.mean(F.relu(1. - logits_real))
22
+ loss_fake = torch.mean(F.relu(1. + logits_fake))
23
+ d_loss = 0.5 * (loss_real + loss_fake)
24
+ return d_loss
25
+
26
+
27
+ def vanilla_d_loss(logits_real, logits_fake):
28
+ d_loss = 0.5 * (
29
+ torch.mean(torch.nn.functional.softplus(-logits_real)) +
30
+ torch.mean(torch.nn.functional.softplus(logits_fake)))
31
+ return d_loss
32
+
33
+ def adopt_weight(weight, global_step, threshold=0, value=0.):
34
+ if global_step < threshold:
35
+ weight = value
36
+ return weight
37
+
38
+
39
+ def measure_perplexity(predicted_indices, n_embed):
40
+ # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
41
+ # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
42
+ encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
43
+ avg_probs = encodings.mean(0)
44
+ perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
45
+ cluster_use = torch.sum(avg_probs > 0)
46
+ return perplexity, cluster_use
47
+
48
+ def l1(x, y):
49
+ return torch.abs(x-y)
50
+
51
+
52
+ def l2(x, y):
53
+ return torch.pow((x-y), 2)
54
+
55
+
56
+ class VQLPIPSWithDiscriminator(nn.Module):
57
+ def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
58
+ disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
59
+ perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
60
+ disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
61
+ pixel_loss="l1"):
62
+ super().__init__()
63
+ assert disc_loss in ["hinge", "vanilla"]
64
+ assert perceptual_loss in ["lpips", "clips", "dists"]
65
+ assert pixel_loss in ["l1", "l2"]
66
+ self.codebook_weight = codebook_weight
67
+ self.pixel_weight = pixelloss_weight
68
+ if perceptual_loss == "lpips":
69
+ print(f"{self.__class__.__name__}: Running with LPIPS.")
70
+ self.perceptual_loss = LPIPS().eval()
71
+ else:
72
+ raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
73
+ self.perceptual_weight = perceptual_weight
74
+
75
+ if pixel_loss == "l1":
76
+ self.pixel_loss = l1
77
+ else:
78
+ self.pixel_loss = l2
79
+
80
+ self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
81
+ n_layers=disc_num_layers,
82
+ use_actnorm=use_actnorm,
83
+ ndf=disc_ndf
84
+ ).apply(weights_init)
85
+ self.discriminator_iter_start = disc_start
86
+ if disc_loss == "hinge":
87
+ self.disc_loss = hinge_d_loss
88
+ elif disc_loss == "vanilla":
89
+ self.disc_loss = vanilla_d_loss
90
+ else:
91
+ raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
92
+ print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
93
+ self.disc_factor = disc_factor
94
+ self.discriminator_weight = disc_weight
95
+ self.disc_conditional = disc_conditional
96
+ self.n_classes = n_classes
97
+
98
+ def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
99
+ if last_layer is not None:
100
+ nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
101
+ g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
102
+ else:
103
+ nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
104
+ g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
105
+
106
+ d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
107
+ d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
108
+ d_weight = d_weight * self.discriminator_weight
109
+ return d_weight
110
+
111
+ def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
112
+ global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
113
+ if not torch.is_tensor(codebook_loss):
114
+ codebook_loss = torch.tensor([0.]).to(inputs.device)
115
+
116
+
117
+ # now the GAN part
118
+ if optimizer_idx == 0:
119
+
120
+ #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
121
+ rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
122
+ if self.perceptual_weight > 0:
123
+ p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
124
+ rec_loss = rec_loss + self.perceptual_weight * p_loss
125
+ else:
126
+ p_loss = torch.tensor([0.0])
127
+
128
+ nll_loss = rec_loss
129
+ #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
130
+ nll_loss = torch.mean(nll_loss)
131
+
132
+ # generator update
133
+ g_loss = torch.tensor(0.0)
134
+ if cond is None:
135
+ assert not self.disc_conditional
136
+ logits_fake = self.discriminator(reconstructions.contiguous())
137
+ else:
138
+ assert self.disc_conditional
139
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
140
+ g_loss = -torch.mean(logits_fake)
141
+
142
+ try:
143
+ # d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
144
+ d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) if self.disc_conditional else torch.tensor(0.0)
145
+ except RuntimeError:
146
+ assert not self.training
147
+ d_weight = torch.tensor(0.0)
148
+
149
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) if self.disc_conditional else torch.tensor(0.0)
150
+ loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
151
+
152
+ log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
153
+ "{}/quant_loss".format(split): codebook_loss.detach().mean(),
154
+ "{}/nll_loss".format(split): nll_loss.detach().mean(),
155
+ "{}/rec_loss".format(split): rec_loss.detach().mean(),
156
+ "{}/p_loss".format(split): p_loss.detach().mean(),
157
+ "{}/d_weight".format(split): d_weight.detach(),
158
+ "{}/disc_factor".format(split): torch.tensor(disc_factor),
159
+ "{}/g_loss".format(split): g_loss.detach().mean(),
160
+ }
161
+ if predicted_indices is not None:
162
+ assert self.n_classes is not None
163
+ with torch.no_grad():
164
+ perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
165
+ log[f"{split}/perplexity"] = perplexity
166
+ log[f"{split}/cluster_usage"] = cluster_usage
167
+ return loss, log
168
+
169
+ if optimizer_idx == 1:
170
+ # second pass for discriminator update
171
+ if cond is None:
172
+ logits_real = self.discriminator(inputs.contiguous().detach())
173
+ logits_fake = self.discriminator(reconstructions.contiguous().detach())
174
+ else:
175
+ logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
176
+ logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
177
+
178
+ disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
179
+ d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
180
+
181
+ log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
182
+ "{}/logits_real".format(split): logits_real.detach().mean(),
183
+ "{}/logits_fake".format(split): logits_fake.detach().mean()
184
+ }
185
+ return d_loss, log
ldm/modules/vqvae/quantize.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #based on https://github.com/CompVis/taming-transformers
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from einops import rearrange
8
+
9
+ class VectorQuantizer2(nn.Module):
10
+ """
11
+ Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
12
+ avoids costly matrix multiplications and allows for post-hoc remapping of indices.
13
+ """
14
+ # NOTE: due to a bug the beta term was applied to the wrong term. for
15
+ # backwards compatibility we use the buggy version by default, but you can
16
+ # specify legacy=False to fix it.
17
+ def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
18
+ sane_index_shape=False, legacy=True):
19
+ super().__init__()
20
+ self.n_e = n_e
21
+ self.e_dim = e_dim
22
+ self.beta = beta
23
+ self.legacy = legacy
24
+
25
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
27
+
28
+ self.remap = remap
29
+ if self.remap is not None:
30
+ self.register_buffer("used", torch.tensor(np.load(self.remap)))
31
+ self.re_embed = self.used.shape[0]
32
+ self.unknown_index = unknown_index # "random" or "extra" or integer
33
+ if self.unknown_index == "extra":
34
+ self.unknown_index = self.re_embed
35
+ self.re_embed = self.re_embed+1
36
+ print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
37
+ f"Using {self.unknown_index} for unknown indices.")
38
+ else:
39
+ self.re_embed = n_e
40
+
41
+ self.sane_index_shape = sane_index_shape
42
+
43
+ def remap_to_used(self, inds):
44
+ ishape = inds.shape
45
+ assert len(ishape)>1
46
+ inds = inds.reshape(ishape[0],-1)
47
+ used = self.used.to(inds)
48
+ match = (inds[:,:,None]==used[None,None,...]).long()
49
+ new = match.argmax(-1)
50
+ unknown = match.sum(2)<1
51
+ if self.unknown_index == "random":
52
+ new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
53
+ else:
54
+ new[unknown] = self.unknown_index
55
+ return new.reshape(ishape)
56
+
57
+ def unmap_to_all(self, inds):
58
+ ishape = inds.shape
59
+ assert len(ishape)>1
60
+ inds = inds.reshape(ishape[0],-1)
61
+ used = self.used.to(inds)
62
+ if self.re_embed > self.used.shape[0]: # extra token
63
+ inds[inds>=self.used.shape[0]] = 0 # simply set to zero
64
+ back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
65
+ return back.reshape(ishape)
66
+
67
+ def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
68
+ assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
69
+ assert rescale_logits==False, "Only for interface compatible with Gumbel"
70
+ assert return_logits==False, "Only for interface compatible with Gumbel"
71
+ # reshape z -> (batch, height, width, channel) and flatten
72
+ z = rearrange(z, 'b c h w -> b h w c').contiguous()
73
+ z_flattened = z.view(-1, self.e_dim)
74
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
75
+
76
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
77
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
78
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
79
+
80
+ min_encoding_indices = torch.argmin(d, dim=1)
81
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
82
+ perplexity = None
83
+ min_encodings = None
84
+
85
+ # compute loss for embedding
86
+ if not self.legacy:
87
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
88
+ torch.mean((z_q - z.detach()) ** 2)
89
+ else:
90
+ loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
91
+ torch.mean((z_q - z.detach()) ** 2)
92
+
93
+ # preserve gradients
94
+ z_q = z + (z_q - z).detach()
95
+
96
+ # reshape back to match original input shape
97
+ z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
98
+
99
+ if self.remap is not None:
100
+ min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
101
+ min_encoding_indices = self.remap_to_used(min_encoding_indices)
102
+ min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
103
+
104
+ if self.sane_index_shape:
105
+ min_encoding_indices = min_encoding_indices.reshape(
106
+ z_q.shape[0], z_q.shape[2], z_q.shape[3])
107
+
108
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
109
+
110
+ def get_codebook_entry(self, indices, shape):
111
+ # shape specifying (batch, height, width, channel)
112
+ if self.remap is not None:
113
+ indices = indices.reshape(shape[0],-1) # add batch axis
114
+ indices = self.unmap_to_all(indices)
115
+ indices = indices.reshape(-1) # flatten again
116
+
117
+ # get quantized latent vectors
118
+ z_q = self.embedding(indices)
119
+
120
+ if shape is not None:
121
+ z_q = z_q.view(shape)
122
+ # reshape back to match original input shape
123
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
124
+
125
+ return z_q
126
+
127
+ def get_codebook_entry_index(self, entry):
128
+ codebook_shape = self.embedding.weight.data.shape
129
+
130
+ assert entry.shape[1]==codebook_shape[1]
131
+ distance = torch.norm(self.embedding.weight.data - entry, dim=1)
132
+
133
+ nearest = torch.argmin(distance)
134
+ nearest_distance = torch.min(distance)
135
+
136
+ return nearest, nearest_distance