Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from Network import vgg19, decoder | |
from utils import adaptive_instance_normalization | |
class AdaINNet(nn.Module): | |
""" | |
AdaIN Style Transfer Network | |
Args: | |
vgg_weight: pretrained vgg19 weight | |
""" | |
def __init__(self, vgg_weight): | |
super().__init__() | |
self.encoder = vgg19(vgg_weight) | |
# drop layers after 4_1 | |
self.encoder = nn.Sequential(*list(self.encoder.children())[:22]) | |
# No optimization for encoder | |
for parameter in self.encoder.parameters(): | |
parameter.requires_grad = False | |
self.decoder = decoder() | |
self.mseloss = nn.MSELoss() | |
""" | |
Computes style loss of two images | |
Args: | |
x (torch.FloatTensor): content image tensor | |
y (torch.FloatTensor): style image tensor | |
Return: | |
Mean Squared Error between x.mean, y.mean and MSE between x.std, y.std | |
""" | |
def _style_loss(self, x, y): | |
return self.mseloss(torch.mean(x, dim=[2, 3]), torch.mean(y, dim=[2, 3])) + \ | |
self.mseloss(torch.std(x, dim=[2, 3]), torch.std(y, dim=[2, 3])) | |
def forward(self, content, style, alpha=1.0): | |
# Generate image features | |
content_enc = self.encoder(content) | |
style_enc = self.encoder(style) | |
# Perform style transfer on feature space | |
transfer_enc = adaptive_instance_normalization(content_enc, style_enc) | |
# Generate outptu image | |
out = self.decoder(transfer_enc) | |
# vgg19 layer relu1_1 | |
style_relu11 = self.encoder[:3](style) | |
out_relu11 = self.encoder[:3](out) | |
# vgg19 layer relu2_1 | |
style_relu21 = self.encoder[3:8](style_relu11) | |
out_relu21 = self.encoder[3:8](out_relu11) | |
# vgg19 layer relu3_1 | |
style_relu31 = self.encoder[8:13](style_relu21) | |
out_relu31 = self.encoder[8:13](out_relu21) | |
# vgg19 layer relu4_1 | |
out_enc = self.encoder[13:](out_relu31) | |
# Calculate loss | |
content_loss = self.mseloss(out_enc, transfer_enc) | |
style_loss = self._style_loss(out_relu11, style_relu11) + self._style_loss(out_relu21, style_relu21) + \ | |
self._style_loss(out_relu31, style_relu31) + self._style_loss(out_enc, style_enc) | |
return content_loss, style_loss |