File size: 2,566 Bytes
c583015
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from torch import nn
from torchvision.models import vgg19
import torchvision
from src.adain import AdaIN

class Model(nn.Module):
    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha
        
        self.encoder = nn.Sequential(*list(vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features)[:21])

        for param in self.encoder.parameters():
            param.requires_grad = False
        
        # set padding in conv layers to reflect
        # create dict for saving activations used in the style loss
        self.activations = {}
        for i, module in enumerate(self.encoder.children()):
            if isinstance(module, nn.Conv2d):
                module.padding_mode = 'reflect'

            if i in [1, 6, 11, 20]:
                module.register_forward_hook(self._save_activations(i))
        
        self.AdaIN = AdaIN()
        
        self.decoder = nn.Sequential(
            nn.Upsample(scale_factor=2.0, mode='nearest'),
            nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),

            nn.Upsample(scale_factor=2.0, mode='nearest'),
            nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),
            
            nn.Upsample(scale_factor=2.0, mode='nearest'),
            nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),
            
            nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.ReLU(),
            nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
            nn.Tanh()
        )

    # https://stackoverflow.com/a/68854535
    def _save_activations(self, name):
        def hook(module, input, output):
            self.activations[name] = output
        return hook

    def forward(self, content, style):
        enc_content = self.encoder(content)
        enc_style = self.encoder(style)
        
        self.t = self.AdaIN(enc_content, enc_style)
        self.t = (1.0 - self.alpha) * enc_content + self.alpha * self.t
        out = self.decoder(self.t)

        return out