Spaces:
Running
Running
File size: 2,566 Bytes
c583015 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
from torch import nn
from torchvision.models import vgg19
import torchvision
from src.adain import AdaIN
class Model(nn.Module):
def __init__(self, alpha=1.0):
super().__init__()
self.alpha = alpha
self.encoder = nn.Sequential(*list(vgg19(weights=torchvision.models.VGG19_Weights.DEFAULT).features)[:21])
for param in self.encoder.parameters():
param.requires_grad = False
# set padding in conv layers to reflect
# create dict for saving activations used in the style loss
self.activations = {}
for i, module in enumerate(self.encoder.children()):
if isinstance(module, nn.Conv2d):
module.padding_mode = 'reflect'
if i in [1, 6, 11, 20]:
module.register_forward_hook(self._save_activations(i))
self.AdaIN = AdaIN()
self.decoder = nn.Sequential(
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(512, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Upsample(scale_factor=2.0, mode='nearest'),
nn.Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.ReLU(),
nn.Conv2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), padding_mode='reflect'),
nn.Tanh()
)
# https://stackoverflow.com/a/68854535
def _save_activations(self, name):
def hook(module, input, output):
self.activations[name] = output
return hook
def forward(self, content, style):
enc_content = self.encoder(content)
enc_style = self.encoder(style)
self.t = self.AdaIN(enc_content, enc_style)
self.t = (1.0 - self.alpha) * enc_content + self.alpha * self.t
out = self.decoder(self.t)
return out
|