import torch from torch import nn from torch.nn.utils import spectral_norm from modeling.base import BaseNetwork from layers.blocks import DestyleResBlock, Destyler, ResBlock class IFRNet(BaseNetwork): def __init__(self, base_n_channels, destyler_n_channels): super(IFRNet, self).__init__() self.destyler = Destyler(in_features=32768, num_features=destyler_n_channels) # from vgg features self.ds_fc1 = nn.Linear(destyler_n_channels, base_n_channels * 2) self.ds_res1 = DestyleResBlock(channels_in=3, channels_out=base_n_channels, kernel_size=5, stride=1, padding=2) self.ds_fc2 = nn.Linear(destyler_n_channels, base_n_channels * 4) self.ds_res2 = DestyleResBlock(channels_in=base_n_channels, channels_out=base_n_channels * 2, kernel_size=3, stride=2, padding=1) self.ds_fc3 = nn.Linear(destyler_n_channels, base_n_channels * 4) self.ds_res3 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.ds_fc4 = nn.Linear(destyler_n_channels, base_n_channels * 8) self.ds_res4 = DestyleResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 4, kernel_size=3, stride=2, padding=1) self.ds_fc5 = nn.Linear(destyler_n_channels, base_n_channels * 8) self.ds_res5 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.ds_fc6 = nn.Linear(destyler_n_channels, base_n_channels * 16) self.ds_res6 = DestyleResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 8, kernel_size=3, stride=2, padding=1) self.upsample = nn.UpsamplingNearest2d(scale_factor=2.0) self.res1 = ResBlock(channels_in=base_n_channels * 8, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.res2 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 4, kernel_size=3, stride=1, padding=1) self.res3 = ResBlock(channels_in=base_n_channels * 4, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.res4 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels * 2, kernel_size=3, stride=1, padding=1) self.res5 = ResBlock(channels_in=base_n_channels * 2, channels_out=base_n_channels, kernel_size=3, stride=1, padding=1) self.conv1 = nn.Conv2d(base_n_channels, 3, kernel_size=3, stride=1, padding=1) self.init_weights(init_type="normal", gain=0.02) def forward(self, x, vgg_feat): b_size, ch, h, w = vgg_feat.size() vgg_feat = vgg_feat.view(b_size, ch * h * w) vgg_feat = self.destyler(vgg_feat) out = self.ds_res1(x, self.ds_fc1(vgg_feat)) out = self.ds_res2(out, self.ds_fc2(vgg_feat)) out = self.ds_res3(out, self.ds_fc3(vgg_feat)) out = self.ds_res4(out, self.ds_fc4(vgg_feat)) out = self.ds_res5(out, self.ds_fc5(vgg_feat)) aux = self.ds_res6(out, self.ds_fc6(vgg_feat)) out = self.upsample(aux) out = self.res1(out) out = self.res2(out) out = self.upsample(out) out = self.res3(out) out = self.res4(out) out = self.upsample(out) out = self.res5(out) out = self.conv1(out) return out, aux class CIFR_Encoder(IFRNet): def __init__(self, base_n_channels, destyler_n_channels): super(CIFR_Encoder, self).__init__(base_n_channels, destyler_n_channels) def forward(self, x, vgg_feat): b_size, ch, h, w = vgg_feat.size() vgg_feat = vgg_feat.view(b_size, ch * h * w) vgg_feat = self.destyler(vgg_feat) feat1 = self.ds_res1(x, self.ds_fc1(vgg_feat)) feat2 = self.ds_res2(feat1, self.ds_fc2(vgg_feat)) feat3 = self.ds_res3(feat2, self.ds_fc3(vgg_feat)) feat4 = self.ds_res4(feat3, self.ds_fc4(vgg_feat)) feat5 = self.ds_res5(feat4, self.ds_fc5(vgg_feat)) feat6 = self.ds_res6(feat5, self.ds_fc6(vgg_feat)) feats = [feat1, feat2, feat3, feat4, feat5, feat6] out = self.upsample(feat6) out = self.res1(out) out = self.res2(out) out = self.upsample(out) out = self.res3(out) out = self.res4(out) out = self.upsample(out) out = self.res5(out) out = self.conv1(out) return out, feats class Normalize(nn.Module): def __init__(self, power=2): super(Normalize, self).__init__() self.power = power def forward(self, x): norm = x.pow(self.power).sum(1, keepdim=True).pow(1. / self.power) out = x.div(norm + 1e-7) return out class PatchSampleF(BaseNetwork): def __init__(self, base_n_channels, style_or_content, use_mlp=False, nc=256): # potential issues: currently, we use the same patch_ids for multiple images in the batch super(PatchSampleF, self).__init__() self.is_content = True if style_or_content == "content" else False self.l2norm = Normalize(2) self.use_mlp = use_mlp self.nc = nc # hard-coded self.mlp_0 = nn.Sequential(*[nn.Linear(base_n_channels, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.mlp_1 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.mlp_2 = nn.Sequential(*[nn.Linear(base_n_channels * 2, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.mlp_3 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.mlp_4 = nn.Sequential(*[nn.Linear(base_n_channels * 4, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.mlp_5 = nn.Sequential(*[nn.Linear(base_n_channels * 8, self.nc), nn.ReLU(), nn.Linear(self.nc, self.nc)]).cuda() self.init_weights(init_type="normal", gain=0.02) @staticmethod def gram_matrix(x): # a, b, c, d = x.size() # a=batch size(=1) a, b = x.size() # b=number of feature maps # (c,d)=dimensions of a f. map (N=c*d) # features = x.view(a * b, c * d) # resise F_XL into \hat F_XL G = torch.mm(x, x.t()) # compute the gram product # we 'normalize' the values of the gram matrix # by dividing by the number of element in each feature maps. return G.div(a * b) def forward(self, feats, num_patches=64, patch_ids=None): return_ids = [] return_feats = [] for feat_id, feat in enumerate(feats): B, C, H, W = feat.shape feat_reshape = feat.permute(0, 2, 3, 1).flatten(1, 2) if num_patches > 0: if patch_ids is not None: patch_id = patch_ids[feat_id] else: patch_id = torch.randperm(feat_reshape.shape[1], device=feats[0].device) patch_id = patch_id[:int(min(num_patches, patch_id.shape[0]))] # .to(patch_ids.device) x_sample = feat_reshape[:, patch_id, :].flatten(0, 1) # reshape(-1, x.shape[1]) else: x_sample = feat_reshape patch_id = [] if self.use_mlp: mlp = getattr(self, 'mlp_%d' % feat_id) x_sample = mlp(x_sample) if not self.is_content: x_sample = self.gram_matrix(x_sample) return_ids.append(patch_id) x_sample = self.l2norm(x_sample) if num_patches == 0: x_sample = x_sample.permute(0, 2, 1).reshape([B, x_sample.shape[-1], H, W]) return_feats.append(x_sample) return return_feats, return_ids class MLP(nn.Module): def __init__(self, base_n_channels, out_features=14): super(MLP, self).__init__() self.aux_classifier = nn.Sequential( nn.Conv2d(base_n_channels * 8, base_n_channels * 4, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(2), nn.Conv2d(base_n_channels * 4, base_n_channels * 2, kernel_size=3, stride=1, padding=1), nn.MaxPool2d(2), # nn.Conv2d(base_n_channels * 2, base_n_channels * 1, kernel_size=3, stride=1, padding=1), # nn.MaxPool2d(2), Flatten(), nn.Linear(base_n_channels * 8 * 8 * 2, out_features), # nn.Softmax(dim=-1) ) def forward(self, x): return self.aux_classifier(x) class Flatten(nn.Module): def forward(self, input): """ Note that input.size(0) is usually the batch size. So what it does is that given any input with input.size(0) # of batches, will flatten to be 1 * nb_elements. """ batch_size = input.size(0) out = input.view(batch_size, -1) return out # (batch_size, *size) class Discriminator(BaseNetwork): def __init__(self, base_n_channels): """ img_size : (int, int, int) Height and width must be powers of 2. E.g. (32, 32, 1) or (64, 128, 3). Last number indicates number of channels, e.g. 1 for grayscale or 3 for RGB """ super(Discriminator, self).__init__() self.image_to_features = nn.Sequential( spectral_norm(nn.Conv2d(3, base_n_channels, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(base_n_channels, 2 * base_n_channels, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(2 * base_n_channels, 2 * base_n_channels, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(2 * base_n_channels, 4 * base_n_channels, 5, 2, 2)), nn.LeakyReLU(0.2, inplace=True), # spectral_norm(nn.Conv2d(4 * base_n_channels, 4 * base_n_channels, 5, 2, 2)), # nn.LeakyReLU(0.2, inplace=True), spectral_norm(nn.Conv2d(4 * base_n_channels, 8 * base_n_channels, 5, 1, 1)), nn.LeakyReLU(0.2, inplace=True), ) output_size = 8 * base_n_channels * 3 * 3 self.features_to_prob = nn.Sequential( spectral_norm(nn.Conv2d(8 * base_n_channels, 2 * base_n_channels, 5, 2, 1)), Flatten(), nn.Linear(output_size, 1) ) self.init_weights(init_type="normal", gain=0.02) def forward(self, input_data): x = self.image_to_features(input_data) return self.features_to_prob(x) class PatchDiscriminator(Discriminator): def __init__(self, base_n_channels): super(PatchDiscriminator, self).__init__(base_n_channels) self.features_to_prob = nn.Sequential( spectral_norm(nn.Conv2d(8 * base_n_channels, 1, 1)), Flatten() ) def forward(self, input_data): x = self.image_to_features(input_data) return self.features_to_prob(x) if __name__ == '__main__': import torchvision ifrnet = CIFR_Encoder(32, 128).cuda() x = torch.rand((2, 3, 256, 256)).cuda() vgg16 = torchvision.models.vgg16(pretrained=True).features.eval().cuda() with torch.no_grad(): vgg_feat = vgg16(x) output, feats = ifrnet(x, vgg_feat) print(output.size()) for i, feat in enumerate(feats): print(i, feat.size()) disc = Discriminator(32).cuda() d_out = disc(output) print(d_out.size()) patch_disc = PatchDiscriminator(32).cuda() p_d_out = patch_disc(output) print(p_d_out.size())