Spaces:
Runtime error
Runtime error
File size: 3,848 Bytes
7fab858 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import re
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
import torch.nn.utils.spectral_norm as spectral_norm
def get_nonspade_norm_layer(opt, norm_type="instance"):
# helper function to get # output channels of the previous layer
def get_out_channel(layer):
if hasattr(layer, "out_channels"):
return getattr(layer, "out_channels")
return layer.weight.size(0)
# this function will be returned
def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith("spectral"):
layer = spectral_norm(layer)
subnorm_type = norm_type[len("spectral") :]
if subnorm_type == "none" or len(subnorm_type) == 0:
return layer
# remove bias in the previous layer, which is meaningless
# since it has no effect after normalization
if getattr(layer, "bias", None) is not None:
delattr(layer, "bias")
layer.register_parameter("bias", None)
if subnorm_type == "batch":
norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "sync_batch":
norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
elif subnorm_type == "instance":
norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=False)
else:
raise ValueError("normalization layer %s is not recognized" % subnorm_type)
return nn.Sequential(layer, norm_layer)
return add_norm_layer
class SPADE(nn.Module):
def __init__(self, config_text, norm_nc, label_nc, opt):
super().__init__()
assert config_text.startswith("spade")
parsed = re.search("spade(\D+)(\d)x\d", config_text)
param_free_norm_type = str(parsed.group(1))
ks = int(parsed.group(2))
self.opt = opt
if param_free_norm_type == "instance":
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "syncbatch":
self.param_free_norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
elif param_free_norm_type == "batch":
self.param_free_norm = nn.BatchNorm2d(norm_nc, affine=False)
else:
raise ValueError("%s is not a recognized param-free norm type in SPADE" % param_free_norm_type)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
if self.opt.no_parsing_map:
self.mlp_shared = nn.Sequential(nn.Conv2d(3, nhidden, kernel_size=ks, padding=pw), nn.ReLU())
else:
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc + 3, nhidden, kernel_size=ks, padding=pw), nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap, degraded_image):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
segmap = F.interpolate(segmap, size=x.size()[2:], mode="nearest")
degraded_face = F.interpolate(degraded_image, size=x.size()[2:], mode="bilinear")
if self.opt.no_parsing_map:
actv = self.mlp_shared(degraded_face)
else:
actv = self.mlp_shared(torch.cat((segmap, degraded_face), dim=1))
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
|