File size: 1,570 Bytes
2e34814
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import List, Optional

import torch
from torch import nn
from torch.nn.functional import (
    smooth_l1_loss,
)


def flatten_CHW(im: torch.Tensor) -> torch.Tensor:
    """
    (B, C, H, W) -> (B, -1)
    """
    B = im.shape[0]
    return im.reshape(B, -1)


def stddev(x: torch.Tensor) -> torch.Tensor:
    """
    x: (B, -1), assume with mean normalized
    Retuens:
        stddev: (B)
    """
    return torch.sqrt(torch.mean(x * x, dim=-1))


def gram_matrix(input_):
    B, C = input_.shape[:2]
    features = input_.view(B, C, -1)
    N = features.shape[-1]
    G = torch.bmm(features, features.transpose(1, 2))  # C x C
    return G.div(C * N)


class ColorTransferLoss(nn.Module):
    """Penalize the gram matrix difference between StyleGAN2's ToRGB outputs"""
    def __init__(
        self,
        init_rgbs,
        scale_rgb: bool = False
    ):
        super().__init__()

        with torch.no_grad():
            init_feats = [x.detach() for x in init_rgbs]
            self.stds = [stddev(flatten_CHW(rgb)) if scale_rgb else 1 for rgb in init_feats]  # (B, 1, 1, 1) or scalar
            self.grams = [gram_matrix(rgb / std) for rgb, std in zip(init_feats, self.stds)]

    def forward(self, rgbs: List[torch.Tensor], level: int = None):
        if level is None:
            level = len(self.grams)

        feats = rgbs
        loss = 0
        for i, (rgb, std) in enumerate(zip(feats[:level], self.stds[:level])):
            G = gram_matrix(rgb / std)
            loss = loss + smooth_l1_loss(G, self.grams[i])

        return loss