import torch import torch.nn as nn import torch.nn.init as init import torchvision from deepfillv2.network_module import * def weights_init(net, init_type="kaiming", init_gain=0.02): """Initialize network weights. Parameters: net (network) -- network to be initialized init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal init_var (float) -- scaling factor for normal, xavier and orthogonal. """ def init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and classname.find("Conv") != -1: if init_type == "normal": init.normal_(m.weight.data, 0.0, init_gain) elif init_type == "xavier": init.xavier_normal_(m.weight.data, gain=init_gain) elif init_type == "kaiming": init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") elif init_type == "orthogonal": init.orthogonal_(m.weight.data, gain=init_gain) else: raise NotImplementedError( "initialization method [%s] is not implemented" % init_type ) elif classname.find("BatchNorm2d") != -1: init.normal_(m.weight.data, 1.0, 0.02) init.constant_(m.bias.data, 0.0) elif classname.find("Linear") != -1: init.normal_(m.weight, 0, 0.01) init.constant_(m.bias, 0) # Apply the initialization function net.apply(init_func) # ----------------------------------------------- # Generator # ----------------------------------------------- # Input: masked image + mask # Output: filled image class GatedGenerator(nn.Module): def __init__(self, opt): super(GatedGenerator, self).__init__() self.coarse = nn.Sequential( # encoder GatedConv2d( opt.in_channels, opt.latent_channels, 5, 1, 2, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels * 2, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 4, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), # Bottleneck GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 2, dilation=2, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 4, dilation=4, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 8, dilation=8, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 16, dilation=16, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), # decoder TransposeGatedConv2d( opt.latent_channels * 4, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), TransposeGatedConv2d( opt.latent_channels * 2, opt.latent_channels, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels // 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels // 2, opt.out_channels, 3, 1, 1, pad_type=opt.pad_type, activation="none", norm=opt.norm, ), nn.Tanh(), ) self.refine_conv = nn.Sequential( GatedConv2d( opt.in_channels, opt.latent_channels, 5, 1, 2, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 2, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 2, dilation=2, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 4, dilation=4, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 8, dilation=8, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 16, dilation=16, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), ) self.refine_atten_1 = nn.Sequential( GatedConv2d( opt.in_channels, opt.latent_channels, 5, 1, 2, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 4, 3, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation="relu", norm=opt.norm, ), ) self.refine_atten_2 = nn.Sequential( GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), ) self.refine_combine = nn.Sequential( GatedConv2d( opt.latent_channels * 8, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 4, opt.latent_channels * 4, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), TransposeGatedConv2d( opt.latent_channels * 4, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels * 2, opt.latent_channels * 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), TransposeGatedConv2d( opt.latent_channels * 2, opt.latent_channels, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels, opt.latent_channels // 2, 3, 1, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, ), GatedConv2d( opt.latent_channels // 2, opt.out_channels, 3, 1, 1, pad_type=opt.pad_type, activation="none", norm=opt.norm, ), nn.Tanh(), ) use_cuda = opt.use_cuda self.context_attention = ContextualAttention( ksize=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True, use_cuda=use_cuda, ) def forward(self, img, mask): # img: entire img # mask: 1 for mask region; 0 for unmask region # Coarse first_masked_img = img * (1 - mask) + mask first_in = torch.cat( (first_masked_img, mask), dim=1 ) # in: [B, 4, H, W] first_out = self.coarse(first_in) # out: [B, 3, H, W] first_out = nn.functional.interpolate( first_out, (img.shape[2], img.shape[3]), recompute_scale_factor=False, ) # Refinement second_masked_img = img * (1 - mask) + first_out * mask second_in = torch.cat([second_masked_img, mask], dim=1) refine_conv = self.refine_conv(second_in) refine_atten = self.refine_atten_1(second_in) mask_s = nn.functional.interpolate( mask, (refine_atten.shape[2], refine_atten.shape[3]), recompute_scale_factor=False, ) refine_atten = self.context_attention( refine_atten, refine_atten, mask_s ) refine_atten = self.refine_atten_2(refine_atten) second_out = torch.cat([refine_conv, refine_atten], dim=1) second_out = self.refine_combine(second_out) second_out = nn.functional.interpolate( second_out, (img.shape[2], img.shape[3]), recompute_scale_factor=False, ) return first_out, second_out # ----------------------------------------------- # Discriminator # ----------------------------------------------- # Input: generated image / ground truth and mask # Output: patch based region, we set 30 * 30 class PatchDiscriminator(nn.Module): def __init__(self, opt): super(PatchDiscriminator, self).__init__() # Down sampling self.block1 = Conv2dLayer( opt.in_channels, opt.latent_channels, 7, 1, 3, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, sn=True, ) self.block2 = Conv2dLayer( opt.latent_channels, opt.latent_channels * 2, 4, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, sn=True, ) self.block3 = Conv2dLayer( opt.latent_channels * 2, opt.latent_channels * 4, 4, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, sn=True, ) self.block4 = Conv2dLayer( opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, sn=True, ) self.block5 = Conv2dLayer( opt.latent_channels * 4, opt.latent_channels * 4, 4, 2, 1, pad_type=opt.pad_type, activation=opt.activation, norm=opt.norm, sn=True, ) self.block6 = Conv2dLayer( opt.latent_channels * 4, 1, 4, 2, 1, pad_type=opt.pad_type, activation="none", norm="none", sn=True, ) def forward(self, img, mask): # the input x should contain 4 channels because it is a combination of recon image and mask x = torch.cat((img, mask), 1) x = self.block1(x) # out: [B, 64, 256, 256] x = self.block2(x) # out: [B, 128, 128, 128] x = self.block3(x) # out: [B, 256, 64, 64] x = self.block4(x) # out: [B, 256, 32, 32] x = self.block5(x) # out: [B, 256, 16, 16] x = self.block6(x) # out: [B, 256, 8, 8] return x # ---------------------------------------- # Perceptual Network # ---------------------------------------- # VGG-16 conv4_3 features class PerceptualNet(nn.Module): def __init__(self): super(PerceptualNet, self).__init__() block = [ torchvision.models.vgg16(pretrained=True).features[:15].eval() ] for p in block[0]: p.requires_grad = False self.block = torch.nn.ModuleList(block) self.transform = torch.nn.functional.interpolate self.register_buffer( "mean", torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) ) self.register_buffer( "std", torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) ) def forward(self, x): x = (x - self.mean) / self.std x = self.transform( x, mode="bilinear", size=(224, 224), align_corners=False, recompute_scale_factor=False, ) for block in self.block: x = block(x) return x