from collections import OrderedDict import math import torch import torch.nn as nn import torch.nn.functional as F import swapae.util as util from swapae.models.networks import BaseNetwork from swapae.models.networks.stylegan2_layers import ConvLayer, ResBlock, EqualLinear class BasePatchDiscriminator(BaseNetwork): @staticmethod def modify_commandline_options(parser, is_train): parser.add_argument("--netPatchD_scale_capacity", default=4.0, type=float) parser.add_argument("--netPatchD_max_nc", default=256 + 128, type=int) parser.add_argument("--patch_size", default=128, type=int) parser.add_argument("--max_num_tiles", default=8, type=int) parser.add_argument("--patch_random_transformation", type=util.str2bool, nargs='?', const=True, default=False) return parser def __init__(self, opt): super().__init__(opt) #self.visdom = util.Visualizer(opt) def needs_regularization(self): return False def extract_features(self, patches): raise NotImplementedError() def discriminate_features(self, feature1, feature2): raise NotImplementedError() def apply_random_transformation(self, patches): B, ntiles, C, H, W = patches.size() patches = patches.view(B * ntiles, C, H, W) before = patches transformer = util.RandomSpatialTransformer(self.opt, B * ntiles) patches = transformer.forward_transform(patches, (self.opt.patch_size, self.opt.patch_size)) #self.visdom.display_current_results({'before': before, # 'after': patches}, 0, save_result=False) return patches.view(B, ntiles, C, H, W) def sample_patches_old(self, img, indices): B, C, H, W = img.size() s = self.opt.patch_size if H % s > 0 or W % s > 0: y_offset = torch.randint(H % s, (), device=img.device) x_offset = torch.randint(W % s, (), device=img.device) img = img[:, :, y_offset:y_offset + s * (H // s), x_offset:x_offset + s * (W // s)] img = img.view(B, C, H//s, s, W//s, s) ntiles = (H // s) * (W // s) tiles = img.permute(0, 2, 4, 1, 3, 5).reshape(B, ntiles, C, s, s) if indices is None: indices = torch.randperm(ntiles, device=img.device)[:self.opt.max_num_tiles] return self.apply_random_transformation(tiles[:, indices]), indices else: return self.apply_random_transformation(tiles[:, indices]) def forward(self, real, fake, fake_only=False): assert real is not None real_patches, patch_ids = self.sample_patches(real, None) if fake is None: real_patches.requires_grad_() real_feat = self.extract_features(real_patches) bs = real.size(0) if fake is None or not fake_only: pred_real = self.discriminate_features( real_feat, torch.roll(real_feat, 1, 1)) pred_real = pred_real.view(bs, -1) if fake is not None: fake_patches = self.sample_patches(fake, patch_ids) #self.visualizer.display_current_results({'real_A': real_patches[0], # 'real_B': torch.roll(fake_patches, 1, 1)[0]}, 0, False, max_num_images=16) fake_feat = self.extract_features(fake_patches) pred_fake = self.discriminate_features( real_feat, torch.roll(fake_feat, 1, 1)) pred_fake = pred_fake.view(bs, -1) if fake is None: return pred_real, real_patches elif fake_only: return pred_fake else: return pred_real, pred_fake class StyleGAN2PatchDiscriminator(BasePatchDiscriminator): @staticmethod def modify_commandline_options(parser, is_train): BasePatchDiscriminator.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super().__init__(opt) channel_multiplier = self.opt.netPatchD_scale_capacity size = self.opt.patch_size channels = { 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), 32: int(32 * channel_multiplier), 64: int(16 * channel_multiplier), 128: int(8 * channel_multiplier), 256: int(4 * channel_multiplier), } log_size = int(math.ceil(math.log(size, 2))) in_channel = channels[2 ** log_size] blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] convs = [('0', ConvLayer(3, in_channel, 3))] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) in_channel = out_channel convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) self.convs = nn.Sequential(OrderedDict(convs)) out_dim = 1 pairlinear1 = EqualLinear(channels[4] * 2 * 2, 2048, activation='fused_lrelu') pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') pairlinear4 = EqualLinear(1024, out_dim) self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) def extract_features(self, patches, aggregate=False): if patches.ndim == 5: B, T, C, H, W = patches.size() flattened_patches = patches.flatten(0, 1) else: B, C, H, W = patches.size() T = patches.size(1) flattened_patches = patches features = self.convs(flattened_patches) features = features.view(B, T, features.size(1), features.size(2), features.size(3)) if aggregate: features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) return features.flatten(0, 1) def extract_layerwise_features(self, image): feats = [image] for m in self.convs: feats.append(m(feats[-1])) return feats def discriminate_features(self, feature1): feature1 = feature1.flatten(1) #feature2 = feature2.flatten(1) #out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) out = self.pairlinear(feature1) return out """ def discriminate_features(self, feature1, feature2): feature1 = feature1.flatten(1) feature2 = feature2.flatten(1) out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) return out """ class StyleGAN2COGANPatchDiscriminator(BasePatchDiscriminator): @staticmethod def modify_commandline_options(parser, is_train): BasePatchDiscriminator.modify_commandline_options(parser, is_train) return parser def __init__(self, opt): super().__init__(opt) channel_multiplier = self.opt.netPatchD_scale_capacity size = self.opt.patch_size channels = { 4: min(self.opt.netPatchD_max_nc, int(256 * channel_multiplier)), 8: min(self.opt.netPatchD_max_nc, int(128 * channel_multiplier)), 16: min(self.opt.netPatchD_max_nc, int(64 * channel_multiplier)), 32: int(32 * channel_multiplier), 64: int(16 * channel_multiplier), 128: int(8 * channel_multiplier), 256: int(4 * channel_multiplier), } log_size = int(math.ceil(math.log(size, 2))) in_channel = channels[2 ** log_size] blur_kernel = [1, 3, 3, 1] if self.opt.use_antialias else [1] convs = [('0', ConvLayer(3, in_channel, 3))] for i in range(log_size, 2, -1): out_channel = channels[2 ** (i - 1)] layer_name = str(7 - i) if i <= 6 else "%dx%d" % (2 ** i, 2 ** i) convs.append((layer_name, ResBlock(in_channel, out_channel, blur_kernel))) in_channel = out_channel convs.append(('5', ResBlock(in_channel, self.opt.netPatchD_max_nc * 2, downsample=False))) convs.append(('6', ConvLayer(self.opt.netPatchD_max_nc * 2, self.opt.netPatchD_max_nc, 3, pad=0))) self.convs = nn.Sequential(OrderedDict(convs)) out_dim = 1 pairlinear1 = EqualLinear(channels[4] * 2 * 2 * 2, 2048, activation='fused_lrelu') pairlinear2 = EqualLinear(2048, 2048, activation='fused_lrelu') pairlinear3 = EqualLinear(2048, 1024, activation='fused_lrelu') pairlinear4 = EqualLinear(1024, out_dim) self.pairlinear = nn.Sequential(pairlinear1, pairlinear2, pairlinear3, pairlinear4) def extract_features(self, patches, aggregate=False): if patches.ndim == 5: B, T, C, H, W = patches.size() flattened_patches = patches.flatten(0, 1) else: B, C, H, W = patches.size() T = patches.size(1) flattened_patches = patches features = self.convs(flattened_patches) features = features.view(B, T, features.size(1), features.size(2), features.size(3)) if aggregate: features = features.mean(1, keepdim=True).expand(-1, T, -1, -1, -1) return features.flatten(0, 1) def extract_layerwise_features(self, image): feats = [image] for m in self.convs: feats.append(m(feats[-1])) return feats def discriminate_features(self, feature1, feature2): feature1 = feature1.flatten(1) feature2 = feature2.flatten(1) out = self.pairlinear(torch.cat([feature1, feature2], dim=1)) return out