subatomicseer's picture
Initial Commit
7999e5a
import torch
import torch.nn as nn
from Network import vgg19, decoder
from utils import adaptive_instance_normalization
class AdaINNet(nn.Module):
"""
AdaIN Style Transfer Network
Args:
vgg_weight: pretrained vgg19 weight
"""
def __init__(self, vgg_weight):
super().__init__()
self.encoder = vgg19(vgg_weight)
# drop layers after 4_1
self.encoder = nn.Sequential(*list(self.encoder.children())[:22])
# No optimization for encoder
for parameter in self.encoder.parameters():
parameter.requires_grad = False
self.decoder = decoder()
self.mseloss = nn.MSELoss()
"""
Computes style loss of two images
Args:
x (torch.FloatTensor): content image tensor
y (torch.FloatTensor): style image tensor
Return:
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std
"""
def _style_loss(self, x, y):
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3]))
def forward(self, content, style, alpha=1.0):
# Generate image features
content_enc = self.encoder(content)
style_enc = self.encoder(style)
# Perform style transfer on feature space
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
# Generate outptu image
out = self.decoder(transfer_enc)
# vgg19 layer relu1_1
style_relu11 = self.encoder[:3](style)
out_relu11 = self.encoder[:3](out)
# vgg19 layer relu2_1
style_relu21 = self.encoder[3:8](style_relu11)
out_relu21 = self.encoder[3:8](out_relu11)
# vgg19 layer relu3_1
style_relu31 = self.encoder[8:13](style_relu21)
out_relu31 = self.encoder[8:13](out_relu21)
# vgg19 layer relu4_1
out_enc = self.encoder[13:](out_relu31)
# Calculate loss
content_loss = self.mseloss(out_enc, transfer_enc)
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc)
return content_loss, style_loss