Spaces:
Sleeping
Sleeping
added files
Browse files- ldm/modules/losses/__init__.py +1 -0
- ldm/modules/losses/contperceptual.py +111 -0
- ldm/modules/losses/lpips.py +123 -0
- ldm/modules/losses/vqperceptual.py +185 -0
- ldm/modules/vqvae/quantize.py +136 -0
ldm/modules/losses/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
|
ldm/modules/losses/contperceptual.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
|
5 |
+
|
6 |
+
|
7 |
+
class LPIPSWithDiscriminator(nn.Module):
|
8 |
+
def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
|
9 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
10 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
11 |
+
disc_loss="hinge"):
|
12 |
+
|
13 |
+
super().__init__()
|
14 |
+
assert disc_loss in ["hinge", "vanilla"]
|
15 |
+
self.kl_weight = kl_weight
|
16 |
+
self.pixel_weight = pixelloss_weight
|
17 |
+
self.perceptual_loss = LPIPS().eval()
|
18 |
+
self.perceptual_weight = perceptual_weight
|
19 |
+
# output log variance
|
20 |
+
self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
|
21 |
+
|
22 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
23 |
+
n_layers=disc_num_layers,
|
24 |
+
use_actnorm=use_actnorm
|
25 |
+
).apply(weights_init)
|
26 |
+
self.discriminator_iter_start = disc_start
|
27 |
+
self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
|
28 |
+
self.disc_factor = disc_factor
|
29 |
+
self.discriminator_weight = disc_weight
|
30 |
+
self.disc_conditional = disc_conditional
|
31 |
+
|
32 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
33 |
+
if last_layer is not None:
|
34 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
35 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
36 |
+
else:
|
37 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
38 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
39 |
+
|
40 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
41 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
42 |
+
d_weight = d_weight * self.discriminator_weight
|
43 |
+
return d_weight
|
44 |
+
|
45 |
+
def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
|
46 |
+
global_step, last_layer=None, cond=None, split="train",
|
47 |
+
weights=None):
|
48 |
+
rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
49 |
+
if self.perceptual_weight > 0:
|
50 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
51 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
52 |
+
|
53 |
+
nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
|
54 |
+
weighted_nll_loss = nll_loss
|
55 |
+
if weights is not None:
|
56 |
+
weighted_nll_loss = weights*nll_loss
|
57 |
+
weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
|
58 |
+
nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
59 |
+
kl_loss = posteriors.kl()
|
60 |
+
kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
|
61 |
+
|
62 |
+
# now the GAN part
|
63 |
+
if optimizer_idx == 0:
|
64 |
+
# generator update
|
65 |
+
if cond is None:
|
66 |
+
assert not self.disc_conditional
|
67 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
68 |
+
else:
|
69 |
+
assert self.disc_conditional
|
70 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
71 |
+
g_loss = -torch.mean(logits_fake)
|
72 |
+
|
73 |
+
if self.disc_factor > 0.0:
|
74 |
+
try:
|
75 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
76 |
+
except RuntimeError:
|
77 |
+
assert not self.training
|
78 |
+
d_weight = torch.tensor(0.0)
|
79 |
+
else:
|
80 |
+
d_weight = torch.tensor(0.0)
|
81 |
+
|
82 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
83 |
+
loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
|
84 |
+
|
85 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
|
86 |
+
"{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
|
87 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
88 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
89 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
90 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
91 |
+
}
|
92 |
+
return loss, log
|
93 |
+
|
94 |
+
if optimizer_idx == 1:
|
95 |
+
# second pass for discriminator update
|
96 |
+
if cond is None:
|
97 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
98 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
99 |
+
else:
|
100 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
101 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
102 |
+
|
103 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
104 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
105 |
+
|
106 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
107 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
108 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
109 |
+
}
|
110 |
+
return d_loss, log
|
111 |
+
|
ldm/modules/losses/lpips.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models"""
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from torchvision import models
|
6 |
+
from collections import namedtuple
|
7 |
+
|
8 |
+
from ldm.util import get_ckpt_path
|
9 |
+
|
10 |
+
|
11 |
+
class LPIPS(nn.Module):
|
12 |
+
# Learned perceptual metric
|
13 |
+
def __init__(self, use_dropout=True):
|
14 |
+
super().__init__()
|
15 |
+
self.scaling_layer = ScalingLayer()
|
16 |
+
self.chns = [64, 128, 256, 512, 512] # vg16 features
|
17 |
+
self.net = vgg16(pretrained=True, requires_grad=False)
|
18 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
19 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
20 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
21 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
22 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
23 |
+
self.load_from_pretrained()
|
24 |
+
for param in self.parameters():
|
25 |
+
param.requires_grad = False
|
26 |
+
|
27 |
+
def load_from_pretrained(self, name="vgg_lpips"):
|
28 |
+
ckpt = get_ckpt_path(name, "ldm/modules/autoencoder/lpips")
|
29 |
+
self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
30 |
+
print("loaded pretrained LPIPS loss from {}".format(ckpt))
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_pretrained(cls, name="vgg_lpips"):
|
34 |
+
if name != "vgg_lpips":
|
35 |
+
raise NotImplementedError
|
36 |
+
model = cls()
|
37 |
+
ckpt = get_ckpt_path(name)
|
38 |
+
model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
|
39 |
+
return model
|
40 |
+
|
41 |
+
def forward(self, input, target):
|
42 |
+
in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
|
43 |
+
outs0, outs1 = self.net(in0_input), self.net(in1_input)
|
44 |
+
feats0, feats1, diffs = {}, {}, {}
|
45 |
+
lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
|
46 |
+
for kk in range(len(self.chns)):
|
47 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
48 |
+
diffs[kk] = (feats0[kk] - feats1[kk]) ** 2
|
49 |
+
|
50 |
+
res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
|
51 |
+
val = res[0]
|
52 |
+
for l in range(1, len(self.chns)):
|
53 |
+
val += res[l]
|
54 |
+
return val
|
55 |
+
|
56 |
+
|
57 |
+
class ScalingLayer(nn.Module):
|
58 |
+
def __init__(self):
|
59 |
+
super(ScalingLayer, self).__init__()
|
60 |
+
self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
|
61 |
+
self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])
|
62 |
+
|
63 |
+
def forward(self, inp):
|
64 |
+
return (inp - self.shift) / self.scale
|
65 |
+
|
66 |
+
|
67 |
+
class NetLinLayer(nn.Module):
|
68 |
+
""" A single linear layer which does a 1x1 conv """
|
69 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
70 |
+
super(NetLinLayer, self).__init__()
|
71 |
+
layers = [nn.Dropout(), ] if (use_dropout) else []
|
72 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
|
73 |
+
self.model = nn.Sequential(*layers)
|
74 |
+
|
75 |
+
|
76 |
+
class vgg16(torch.nn.Module):
|
77 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
78 |
+
super(vgg16, self).__init__()
|
79 |
+
vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
|
80 |
+
self.slice1 = torch.nn.Sequential()
|
81 |
+
self.slice2 = torch.nn.Sequential()
|
82 |
+
self.slice3 = torch.nn.Sequential()
|
83 |
+
self.slice4 = torch.nn.Sequential()
|
84 |
+
self.slice5 = torch.nn.Sequential()
|
85 |
+
self.N_slices = 5
|
86 |
+
for x in range(4):
|
87 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
88 |
+
for x in range(4, 9):
|
89 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
90 |
+
for x in range(9, 16):
|
91 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
92 |
+
for x in range(16, 23):
|
93 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
94 |
+
for x in range(23, 30):
|
95 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
96 |
+
if not requires_grad:
|
97 |
+
for param in self.parameters():
|
98 |
+
param.requires_grad = False
|
99 |
+
|
100 |
+
def forward(self, X):
|
101 |
+
h = self.slice1(X)
|
102 |
+
h_relu1_2 = h
|
103 |
+
h = self.slice2(h)
|
104 |
+
h_relu2_2 = h
|
105 |
+
h = self.slice3(h)
|
106 |
+
h_relu3_3 = h
|
107 |
+
h = self.slice4(h)
|
108 |
+
h_relu4_3 = h
|
109 |
+
h = self.slice5(h)
|
110 |
+
h_relu5_3 = h
|
111 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
112 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
113 |
+
return out
|
114 |
+
|
115 |
+
|
116 |
+
def normalize_tensor(x,eps=1e-10):
|
117 |
+
norm_factor = torch.sqrt(torch.sum(x**2,dim=1,keepdim=True))
|
118 |
+
return x/(norm_factor+eps)
|
119 |
+
|
120 |
+
|
121 |
+
def spatial_average(x, keepdim=True):
|
122 |
+
return x.mean([2,3],keepdim=keepdim)
|
123 |
+
|
ldm/modules/losses/vqperceptual.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from einops import repeat
|
5 |
+
|
6 |
+
from ldm.modules.discriminator.model import NLayerDiscriminator, weights_init
|
7 |
+
from ldm.modules.losses.lpips import LPIPS
|
8 |
+
# from scripts.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
|
9 |
+
|
10 |
+
|
11 |
+
def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
|
12 |
+
assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
|
13 |
+
loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
|
14 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
|
15 |
+
loss_real = (weights * loss_real).sum() / weights.sum()
|
16 |
+
loss_fake = (weights * loss_fake).sum() / weights.sum()
|
17 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
18 |
+
return d_loss
|
19 |
+
|
20 |
+
def hinge_d_loss(logits_real, logits_fake):
|
21 |
+
loss_real = torch.mean(F.relu(1. - logits_real))
|
22 |
+
loss_fake = torch.mean(F.relu(1. + logits_fake))
|
23 |
+
d_loss = 0.5 * (loss_real + loss_fake)
|
24 |
+
return d_loss
|
25 |
+
|
26 |
+
|
27 |
+
def vanilla_d_loss(logits_real, logits_fake):
|
28 |
+
d_loss = 0.5 * (
|
29 |
+
torch.mean(torch.nn.functional.softplus(-logits_real)) +
|
30 |
+
torch.mean(torch.nn.functional.softplus(logits_fake)))
|
31 |
+
return d_loss
|
32 |
+
|
33 |
+
def adopt_weight(weight, global_step, threshold=0, value=0.):
|
34 |
+
if global_step < threshold:
|
35 |
+
weight = value
|
36 |
+
return weight
|
37 |
+
|
38 |
+
|
39 |
+
def measure_perplexity(predicted_indices, n_embed):
|
40 |
+
# src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
|
41 |
+
# eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
|
42 |
+
encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
|
43 |
+
avg_probs = encodings.mean(0)
|
44 |
+
perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
|
45 |
+
cluster_use = torch.sum(avg_probs > 0)
|
46 |
+
return perplexity, cluster_use
|
47 |
+
|
48 |
+
def l1(x, y):
|
49 |
+
return torch.abs(x-y)
|
50 |
+
|
51 |
+
|
52 |
+
def l2(x, y):
|
53 |
+
return torch.pow((x-y), 2)
|
54 |
+
|
55 |
+
|
56 |
+
class VQLPIPSWithDiscriminator(nn.Module):
|
57 |
+
def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
|
58 |
+
disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
|
59 |
+
perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
|
60 |
+
disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
|
61 |
+
pixel_loss="l1"):
|
62 |
+
super().__init__()
|
63 |
+
assert disc_loss in ["hinge", "vanilla"]
|
64 |
+
assert perceptual_loss in ["lpips", "clips", "dists"]
|
65 |
+
assert pixel_loss in ["l1", "l2"]
|
66 |
+
self.codebook_weight = codebook_weight
|
67 |
+
self.pixel_weight = pixelloss_weight
|
68 |
+
if perceptual_loss == "lpips":
|
69 |
+
print(f"{self.__class__.__name__}: Running with LPIPS.")
|
70 |
+
self.perceptual_loss = LPIPS().eval()
|
71 |
+
else:
|
72 |
+
raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
|
73 |
+
self.perceptual_weight = perceptual_weight
|
74 |
+
|
75 |
+
if pixel_loss == "l1":
|
76 |
+
self.pixel_loss = l1
|
77 |
+
else:
|
78 |
+
self.pixel_loss = l2
|
79 |
+
|
80 |
+
self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
|
81 |
+
n_layers=disc_num_layers,
|
82 |
+
use_actnorm=use_actnorm,
|
83 |
+
ndf=disc_ndf
|
84 |
+
).apply(weights_init)
|
85 |
+
self.discriminator_iter_start = disc_start
|
86 |
+
if disc_loss == "hinge":
|
87 |
+
self.disc_loss = hinge_d_loss
|
88 |
+
elif disc_loss == "vanilla":
|
89 |
+
self.disc_loss = vanilla_d_loss
|
90 |
+
else:
|
91 |
+
raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
|
92 |
+
print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
|
93 |
+
self.disc_factor = disc_factor
|
94 |
+
self.discriminator_weight = disc_weight
|
95 |
+
self.disc_conditional = disc_conditional
|
96 |
+
self.n_classes = n_classes
|
97 |
+
|
98 |
+
def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
|
99 |
+
if last_layer is not None:
|
100 |
+
nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
|
101 |
+
g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
|
102 |
+
else:
|
103 |
+
nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
|
104 |
+
g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
|
105 |
+
|
106 |
+
d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
|
107 |
+
d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
|
108 |
+
d_weight = d_weight * self.discriminator_weight
|
109 |
+
return d_weight
|
110 |
+
|
111 |
+
def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
|
112 |
+
global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
|
113 |
+
if not torch.is_tensor(codebook_loss):
|
114 |
+
codebook_loss = torch.tensor([0.]).to(inputs.device)
|
115 |
+
|
116 |
+
|
117 |
+
# now the GAN part
|
118 |
+
if optimizer_idx == 0:
|
119 |
+
|
120 |
+
#rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
|
121 |
+
rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
|
122 |
+
if self.perceptual_weight > 0:
|
123 |
+
p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
|
124 |
+
rec_loss = rec_loss + self.perceptual_weight * p_loss
|
125 |
+
else:
|
126 |
+
p_loss = torch.tensor([0.0])
|
127 |
+
|
128 |
+
nll_loss = rec_loss
|
129 |
+
#nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
|
130 |
+
nll_loss = torch.mean(nll_loss)
|
131 |
+
|
132 |
+
# generator update
|
133 |
+
g_loss = torch.tensor(0.0)
|
134 |
+
if cond is None:
|
135 |
+
assert not self.disc_conditional
|
136 |
+
logits_fake = self.discriminator(reconstructions.contiguous())
|
137 |
+
else:
|
138 |
+
assert self.disc_conditional
|
139 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
|
140 |
+
g_loss = -torch.mean(logits_fake)
|
141 |
+
|
142 |
+
try:
|
143 |
+
# d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
|
144 |
+
d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) if self.disc_conditional else torch.tensor(0.0)
|
145 |
+
except RuntimeError:
|
146 |
+
assert not self.training
|
147 |
+
d_weight = torch.tensor(0.0)
|
148 |
+
|
149 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) if self.disc_conditional else torch.tensor(0.0)
|
150 |
+
loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
|
151 |
+
|
152 |
+
log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
|
153 |
+
"{}/quant_loss".format(split): codebook_loss.detach().mean(),
|
154 |
+
"{}/nll_loss".format(split): nll_loss.detach().mean(),
|
155 |
+
"{}/rec_loss".format(split): rec_loss.detach().mean(),
|
156 |
+
"{}/p_loss".format(split): p_loss.detach().mean(),
|
157 |
+
"{}/d_weight".format(split): d_weight.detach(),
|
158 |
+
"{}/disc_factor".format(split): torch.tensor(disc_factor),
|
159 |
+
"{}/g_loss".format(split): g_loss.detach().mean(),
|
160 |
+
}
|
161 |
+
if predicted_indices is not None:
|
162 |
+
assert self.n_classes is not None
|
163 |
+
with torch.no_grad():
|
164 |
+
perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
|
165 |
+
log[f"{split}/perplexity"] = perplexity
|
166 |
+
log[f"{split}/cluster_usage"] = cluster_usage
|
167 |
+
return loss, log
|
168 |
+
|
169 |
+
if optimizer_idx == 1:
|
170 |
+
# second pass for discriminator update
|
171 |
+
if cond is None:
|
172 |
+
logits_real = self.discriminator(inputs.contiguous().detach())
|
173 |
+
logits_fake = self.discriminator(reconstructions.contiguous().detach())
|
174 |
+
else:
|
175 |
+
logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
|
176 |
+
logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
|
177 |
+
|
178 |
+
disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
|
179 |
+
d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
|
180 |
+
|
181 |
+
log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
|
182 |
+
"{}/logits_real".format(split): logits_real.detach().mean(),
|
183 |
+
"{}/logits_fake".format(split): logits_fake.detach().mean()
|
184 |
+
}
|
185 |
+
return d_loss, log
|
ldm/modules/vqvae/quantize.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#based on https://github.com/CompVis/taming-transformers
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import numpy as np
|
7 |
+
from einops import rearrange
|
8 |
+
|
9 |
+
class VectorQuantizer2(nn.Module):
|
10 |
+
"""
|
11 |
+
Improved version over VectorQuantizer, can be used as a drop-in replacement. Mostly
|
12 |
+
avoids costly matrix multiplications and allows for post-hoc remapping of indices.
|
13 |
+
"""
|
14 |
+
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
15 |
+
# backwards compatibility we use the buggy version by default, but you can
|
16 |
+
# specify legacy=False to fix it.
|
17 |
+
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random",
|
18 |
+
sane_index_shape=False, legacy=True):
|
19 |
+
super().__init__()
|
20 |
+
self.n_e = n_e
|
21 |
+
self.e_dim = e_dim
|
22 |
+
self.beta = beta
|
23 |
+
self.legacy = legacy
|
24 |
+
|
25 |
+
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
26 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
27 |
+
|
28 |
+
self.remap = remap
|
29 |
+
if self.remap is not None:
|
30 |
+
self.register_buffer("used", torch.tensor(np.load(self.remap)))
|
31 |
+
self.re_embed = self.used.shape[0]
|
32 |
+
self.unknown_index = unknown_index # "random" or "extra" or integer
|
33 |
+
if self.unknown_index == "extra":
|
34 |
+
self.unknown_index = self.re_embed
|
35 |
+
self.re_embed = self.re_embed+1
|
36 |
+
print(f"Remapping {self.n_e} indices to {self.re_embed} indices. "
|
37 |
+
f"Using {self.unknown_index} for unknown indices.")
|
38 |
+
else:
|
39 |
+
self.re_embed = n_e
|
40 |
+
|
41 |
+
self.sane_index_shape = sane_index_shape
|
42 |
+
|
43 |
+
def remap_to_used(self, inds):
|
44 |
+
ishape = inds.shape
|
45 |
+
assert len(ishape)>1
|
46 |
+
inds = inds.reshape(ishape[0],-1)
|
47 |
+
used = self.used.to(inds)
|
48 |
+
match = (inds[:,:,None]==used[None,None,...]).long()
|
49 |
+
new = match.argmax(-1)
|
50 |
+
unknown = match.sum(2)<1
|
51 |
+
if self.unknown_index == "random":
|
52 |
+
new[unknown]=torch.randint(0,self.re_embed,size=new[unknown].shape).to(device=new.device)
|
53 |
+
else:
|
54 |
+
new[unknown] = self.unknown_index
|
55 |
+
return new.reshape(ishape)
|
56 |
+
|
57 |
+
def unmap_to_all(self, inds):
|
58 |
+
ishape = inds.shape
|
59 |
+
assert len(ishape)>1
|
60 |
+
inds = inds.reshape(ishape[0],-1)
|
61 |
+
used = self.used.to(inds)
|
62 |
+
if self.re_embed > self.used.shape[0]: # extra token
|
63 |
+
inds[inds>=self.used.shape[0]] = 0 # simply set to zero
|
64 |
+
back=torch.gather(used[None,:][inds.shape[0]*[0],:], 1, inds)
|
65 |
+
return back.reshape(ishape)
|
66 |
+
|
67 |
+
def forward(self, z, temp=None, rescale_logits=False, return_logits=False):
|
68 |
+
assert temp is None or temp==1.0, "Only for interface compatible with Gumbel"
|
69 |
+
assert rescale_logits==False, "Only for interface compatible with Gumbel"
|
70 |
+
assert return_logits==False, "Only for interface compatible with Gumbel"
|
71 |
+
# reshape z -> (batch, height, width, channel) and flatten
|
72 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous()
|
73 |
+
z_flattened = z.view(-1, self.e_dim)
|
74 |
+
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
75 |
+
|
76 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
77 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
78 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
79 |
+
|
80 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
81 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
82 |
+
perplexity = None
|
83 |
+
min_encodings = None
|
84 |
+
|
85 |
+
# compute loss for embedding
|
86 |
+
if not self.legacy:
|
87 |
+
loss = self.beta * torch.mean((z_q.detach()-z)**2) + \
|
88 |
+
torch.mean((z_q - z.detach()) ** 2)
|
89 |
+
else:
|
90 |
+
loss = torch.mean((z_q.detach()-z)**2) + self.beta * \
|
91 |
+
torch.mean((z_q - z.detach()) ** 2)
|
92 |
+
|
93 |
+
# preserve gradients
|
94 |
+
z_q = z + (z_q - z).detach()
|
95 |
+
|
96 |
+
# reshape back to match original input shape
|
97 |
+
z_q = rearrange(z_q, 'b h w c -> b c h w').contiguous()
|
98 |
+
|
99 |
+
if self.remap is not None:
|
100 |
+
min_encoding_indices = min_encoding_indices.reshape(z.shape[0],-1) # add batch axis
|
101 |
+
min_encoding_indices = self.remap_to_used(min_encoding_indices)
|
102 |
+
min_encoding_indices = min_encoding_indices.reshape(-1,1) # flatten
|
103 |
+
|
104 |
+
if self.sane_index_shape:
|
105 |
+
min_encoding_indices = min_encoding_indices.reshape(
|
106 |
+
z_q.shape[0], z_q.shape[2], z_q.shape[3])
|
107 |
+
|
108 |
+
return z_q, loss, (perplexity, min_encodings, min_encoding_indices)
|
109 |
+
|
110 |
+
def get_codebook_entry(self, indices, shape):
|
111 |
+
# shape specifying (batch, height, width, channel)
|
112 |
+
if self.remap is not None:
|
113 |
+
indices = indices.reshape(shape[0],-1) # add batch axis
|
114 |
+
indices = self.unmap_to_all(indices)
|
115 |
+
indices = indices.reshape(-1) # flatten again
|
116 |
+
|
117 |
+
# get quantized latent vectors
|
118 |
+
z_q = self.embedding(indices)
|
119 |
+
|
120 |
+
if shape is not None:
|
121 |
+
z_q = z_q.view(shape)
|
122 |
+
# reshape back to match original input shape
|
123 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
124 |
+
|
125 |
+
return z_q
|
126 |
+
|
127 |
+
def get_codebook_entry_index(self, entry):
|
128 |
+
codebook_shape = self.embedding.weight.data.shape
|
129 |
+
|
130 |
+
assert entry.shape[1]==codebook_shape[1]
|
131 |
+
distance = torch.norm(self.embedding.weight.data - entry, dim=1)
|
132 |
+
|
133 |
+
nearest = torch.argmin(distance)
|
134 |
+
nearest_distance = torch.min(distance)
|
135 |
+
|
136 |
+
return nearest, nearest_distance
|