hg / StyleTransferModel_128.py
victorgg's picture
Upload folder using huggingface_hub
742d952 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
class StyleTransferModel(nn.Module):
def __init__(self):
super(StyleTransferModel, self).__init__()
# self.pad = nn.ReflectionPad2d(3)
# Encoder for target face
self.target_encoder = nn.Sequential(
# self.pad,
nn.Conv2d(3, 128, kernel_size=7, stride=1, padding=0),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(512, 1024, kernel_size=3, stride=2, padding=1),
nn.LeakyReLU(0.2),
)
# for style_block in self.target_encoder:
# for param in style_block.parameters():
# param.requires_grad = False
# Style blocks
self.style_blocks = nn.ModuleList([
StyleBlock(1024, 1024, blockIndex) for blockIndex in range(6)
])
# Decoder (upsampling)
self.decoder = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2)
)
self.decoderPart1 = nn.Sequential(
nn.Conv2d(512, 256, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 128, kernel_size=3, stride=1, padding=1),
nn.LeakyReLU(0.2)
)
self.decoderPart2 = nn.Sequential(
# self.pad,
nn.Conv2d(128, 3, kernel_size=7, stride=1, padding=0),
nn.Tanh()
)
def forward(self, target, source):
# Encode target face
target = F.pad(target, pad=(3, 3, 3, 3), mode='reflect')
target_features = self.target_encoder(target)
# Apply style blocks
x = target_features
for style_block in self.style_blocks:
x = style_block(x, source)
# Decode
# x = F.interpolate(x, scale_factor=2, mode='linear')
x = F.upsample(
x,
scale_factor=2, # specify the desired height and width
mode='bilinear', # 'linear' in 2D is called 'bilinear'
align_corners=False # this is typically False for ONNX compatibility
)
output = self.decoder(x)
output = F.upsample(
output,
scale_factor=2, # specify the desired height and width
mode='bilinear', # 'linear' in 2D is called 'bilinear'
align_corners=False # this is typically False for ONNX compatibility
)
output = self.decoderPart1(output)
output = F.pad(output, pad=(3, 3, 3, 3), mode='reflect')
output = self.decoderPart2(output)
return (output + 1) / 2
class StyleBlock(nn.Module):
def __init__(self, in_channels, out_channels, blockIndex):
super(StyleBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=0)
self.style1 = nn.Linear(512, 2048)
self.style2 = nn.Linear(512, 2048)
self.style = [self.style1, self.style2]
self.blockIndex = blockIndex
def normalizeConvRMS(self, conv):
x = conv - torch.mean(conv, dim=[2, 3], keepdim=True) # centeredConv
squareX = x * x
meanSquaredX = torch.mean(squareX, dim=[2, 3], keepdim=True)
rms = torch.sqrt(meanSquaredX + 0.00000001)
return (1 / rms) * x
def forward(self, residual, style):
# print(f'Forward: {self.blockIndex}')
style1024 = []
for index in range(2):
style1 = self.style[index](style)
style1 = torch.unsqueeze(style1, 2)
style1 = torch.unsqueeze(style1, 3)
first_half = style1[:, :1024, :, :]
second_half = style1[:, 1024:, :, :]
style1024.append([first_half, second_half])
conv1 = self.normalizeConvRMS(self.conv1(F.pad(residual, pad=(1, 1, 1, 1), mode='reflect')))
out = F.relu(conv1 * style1024[0][0] + style1024[0][1])
out = F.pad(out, pad=(1, 1, 1, 1), mode='reflect')
conv2 = self.normalizeConvRMS(self.conv2(out))
out = conv2 * style1024[1][0] + style1024[1][1]
return residual + out