File size: 2,346 Bytes
7999e5a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch
import torch.nn as nn
from Network import vgg19, decoder
from utils import adaptive_instance_normalization

class AdaINNet(nn.Module):
    """
    AdaIN Style Transfer Network

    Args:
        vgg_weight: pretrained vgg19 weight
    """
    def __init__(self, vgg_weight):
        super().__init__()
        self.encoder = vgg19(vgg_weight)
        
        # drop layers after 4_1
        self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
        
        # No optimization for encoder
        for parameter in self.encoder.parameters():
            parameter.requires_grad = False
        
        self.decoder = decoder()
        
        self.mseloss = nn.MSELoss()

    """
    Computes style loss of two images

    Args:
        x (torch.FloatTensor): content image tensor
        y (torch.FloatTensor): style image tensor

    Return:
        Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
    """
    def _style_loss(self, x, y):
        return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
            self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))

    def forward(self, content, style, alpha=1.0):
        # Generate image features
        content_enc = self.encoder(content)
        style_enc = self.encoder(style)

        # Perform style transfer on feature space
        transfer_enc = adaptive_instance_normalization(content_enc, style_enc)

        # Generate outptu image
        out = self.decoder(transfer_enc)
        
        # vgg19 layer relu1_1
        style_relu11 = self.encoder[:3](style)
        out_relu11 = self.encoder[:3](out)

        # vgg19 layer relu2_1
        style_relu21 = self.encoder[3:8](style_relu11)
        out_relu21 = self.encoder[3:8](out_relu11)

        # vgg19 layer relu3_1
        style_relu31 = self.encoder[8:13](style_relu21)
        out_relu31 = self.encoder[8:13](out_relu21)

        # vgg19 layer relu4_1
        out_enc = self.encoder[13:](out_relu31)

        # Calculate loss
        content_loss = self.mseloss(out_enc, transfer_enc)
        style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
            self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)

        return content_loss, style_loss