File size: 5,273 Bytes
918e8a0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import torch
import torch.nn as nn
import torch.nn.functional as F
from models import iresnet
from lpips.lpips import LPIPS
from pytorch_msssim import SSIM


def adopt_weight(weight, global_step, threshold=0, value=0.):
    if global_step < threshold:
        weight = value
    return weight


def hinge_d_loss(logits_real, logits_fake):
    loss_real = torch.mean(F.relu(1. - logits_real))
    loss_fake = torch.mean(F.relu(1. + logits_fake))
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def mse_d_loss(logits_real, logits_fake):
    loss_real = torch.mean((logits_real - 1.) ** 2)
    loss_fake = torch.mean(logits_fake ** 2)
    d_loss = 0.5 * (loss_real + loss_fake)
    return d_loss


def vanilla_d_loss(logits_real, logits_fake):
    d_loss = 0.5 * (
            torch.mean(torch.nn.functional.softplus(-logits_real)) +
            torch.mean(torch.nn.functional.softplus(logits_fake)))
    return d_loss


def create_fr_model(model_path, depth="100"):
    model = iresnet(depth)
    model.load_state_dict(torch.load(model_path))
    # model.half()
    return model


def downscale(img: torch.tensor):
    half_size = img.shape[-1] // 8
    img = F.interpolate(img, size=(half_size, half_size), mode='bicubic', align_corners=False)
    return img


class VQLPIPSWithDiscriminator(nn.Module):
    def __init__(self, disc_start=1000, disc_factor=1.0, disc_weight=1.0,

                 disc_conditional=False, disc_loss="mse", id_loss="mse",

                 fr_model="./models/arcface-r100-glint360k.pth"):
        super().__init__()
        assert disc_loss in ["hinge", "vanilla", "mse", "smooth"]
        self.loss_name = disc_loss
        self.perceptual_loss = LPIPS().eval()
        self.discriminator_iter_start = disc_start
        if disc_loss == "hinge": 
            self.disc_loss = hinge_d_loss
        elif disc_loss == "vanilla":
            self.disc_loss = vanilla_d_loss
        elif disc_loss == "mse":
            self.disc_loss = mse_d_loss
        else:
            raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
        print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
        self.fr_model = create_fr_model(fr_model).eval()
        if id_loss == "mse":
            self.feature_loss = nn.MSELoss()
        elif id_loss == "cosine":
            self.feature_loss = nn.CosineSimilarity()
        self.disc_factor = disc_factor
        self.discriminator_weight = disc_weight
        self.disc_conditional = disc_conditional
        self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)

    def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
        if last_layer is not None:
            nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
        else:
            nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
            g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]

        d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
        d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
        d_weight = d_weight * self.discriminator_weight
        return d_weight

    def forward(self, im_features, gt_indices, logits, gt_img, image, discriminator, emb_loss,

                epoch, last_layer=None, cond=None, mask=None):
        rec_loss = (image - gt_img) ** 2

        if epoch >= 0:
            gen_feature = self.fr_model(image)
            feature_loss = torch.mean(1 - torch.cosine_similarity(im_features, gen_feature))
        else:
            feature_loss = 0

        p_loss = self.perceptual_loss(image, gt_img) * 2

        with torch.cuda.amp.autocast(enabled=False):
            ssim_loss = 1 - self.ssim_loss((image.float() + 1) / 2, (gt_img + 1) / 2)
        logits_fake = discriminator(image)
        logits_real_d = discriminator(gt_img.detach())
        logits_fake_d = discriminator(image.detach())

        if mask is None:
            token_loss = (logits[:, 1:, :] - gt_indices[:, 1:, :])
            token_loss = torch.mean(token_loss ** 2)
        else:
            token_loss = torch.abs((logits[:, 1:, :] - gt_indices[:, 1:, :])) * mask[:, 1:, None]
            token_loss = token_loss.sum() / mask[:, 1:].sum()
        # token_loss = 0
        nll_loss = torch.mean(rec_loss + p_loss) + \
                   ssim_loss + \
                   token_loss + feature_loss + emb_loss
        # generator update
        g_loss = -torch.mean(logits_fake)

        d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
        disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
        ae_loss = nll_loss + d_weight * disc_factor * g_loss

        # second pass for discriminator update
        disc_factor = adopt_weight(self.disc_factor, epoch, threshold=self.discriminator_iter_start)
        d_loss = disc_factor * self.disc_loss(logits_real_d, logits_fake_d)
        return ae_loss, d_loss, token_loss, rec_loss, ssim_loss, p_loss, feature_loss