Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
import re | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d | |
import torch.nn.utils.spectral_norm as spectral_norm | |
def get_nonspade_norm_layer(opt, norm_type="instance"): | |
# helper function to get # output channels of the previous layer | |
def get_out_channel(layer): | |
if hasattr(layer, "out_channels"): | |
return getattr(layer, "out_channels") | |
return layer.weight.size(0) | |
# this function will be returned | |
def add_norm_layer(layer): | |
nonlocal norm_type | |
if norm_type.startswith("spectral"): | |
layer = spectral_norm(layer) | |
subnorm_type = norm_type[len("spectral") :] | |
if subnorm_type == "none" or len(subnorm_type) == 0: | |
return layer | |
# remove bias in the previous layer, which is meaningless | |
# since it has no effect after normalization | |
if getattr(layer, "bias", None) is not None: | |
delattr(layer, "bias") | |
layer.register_parameter("bias", None) | |
if subnorm_type == "batch": | |
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == "sync_batch": | |
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) | |
elif subnorm_type == "instance": | |
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False) | |
else: | |
raise ValueError("normalization layer %s is not recognized" % subnorm_type) | |
return nn.Sequential(layer, norm_layer) | |
return add_norm_layer | |
class SPADE(nn.Module): | |
def __init__(self, config_text, norm_nc, label_nc, opt): | |
super().__init__() | |
assert config_text.startswith("spade") | |
parsed = re.search("spade(\D+)(\d)x\d", config_text) | |
param_free_norm_type = str(parsed.group(1)) | |
ks = int(parsed.group(2)) | |
self.opt = opt | |
if param_free_norm_type == "instance": | |
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) | |
elif param_free_norm_type == "syncbatch": | |
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False) | |
elif param_free_norm_type == "batch": | |
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False) | |
else: | |
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type) | |
# The dimension of the intermediate embedding space. Yes, hardcoded. | |
nhidden = 128 | |
pw = ks // 2 | |
if self.opt.no_parsing_map: | |
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()) | |
else: | |
self.mlp_shared = nn.Sequential( | |
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU() | |
) | |
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
def forward(self, x, segmap, degraded_image): | |
# Part 1. generate parameter-free normalized activations | |
normalized = self.param_free_norm(x) | |
# Part 2. produce scaling and bias conditioned on semantic map | |
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest") | |
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear") | |
if self.opt.no_parsing_map: | |
actv = self.mlp_shared(degraded_face) | |
else: | |
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1)) | |
gamma = self.mlp_gamma(actv) | |
beta = self.mlp_beta(actv) | |
# apply scale and bias | |
out = normalized * (1 + gamma) + beta | |
return out | |