diff --git a/.gitattributes b/.gitattributes index 8d5d340f353fb590c02bfb6133ce27f0159de99c..845b09d218bb235ace57889110bbf2b039d71095 100644 --- a/.gitattributes +++ b/.gitattributes @@ -34,3 +34,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text temp/temp/temp.wav filter=lfs diff=lfs merge=lfs -text +checkpoints/BFM/01_MorphableModel.mat filter=lfs diff=lfs merge=lfs -text +checkpoints/BFM/BFM_model_front.mat filter=lfs diff=lfs merge=lfs -text +checkpoints/shape_predictor_68_face_landmarks.dat filter=lfs diff=lfs merge=lfs -text diff --git a/checkpoints/30_net_gen.pth b/checkpoints/30_net_gen.pth new file mode 100644 index 0000000000000000000000000000000000000000..a08303a88d8cfe1288d97c4af9256075a724ca3e --- /dev/null +++ b/checkpoints/30_net_gen.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4db83e1727128e2c5de27bc80d2929586535e04a709af45016a63e7cf7c46b0c +size 33877439 diff --git a/checkpoints/BFM/.gitkeep b/checkpoints/BFM/.gitkeep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/checkpoints/BFM/01_MorphableModel.mat b/checkpoints/BFM/01_MorphableModel.mat new file mode 100644 index 0000000000000000000000000000000000000000..f251485b55d35adac0ad4f1622a47d7a39a1502c --- /dev/null +++ b/checkpoints/BFM/01_MorphableModel.mat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:37b1f0742db356a3b1568a8365a06f5b0fe0ab687ac1c3068c803666cbd4d8e2 +size 240875364 diff --git a/checkpoints/BFM/BFM_exp_idx.mat b/checkpoints/BFM/BFM_exp_idx.mat new file mode 100644 index 0000000000000000000000000000000000000000..1146e4e9c3bef303a497383aa7974c014fe945c7 Binary files /dev/null and b/checkpoints/BFM/BFM_exp_idx.mat differ diff --git a/checkpoints/BFM/BFM_front_idx.mat b/checkpoints/BFM/BFM_front_idx.mat new file mode 100644 index 0000000000000000000000000000000000000000..b9d7b0953dd1dc5b1e28144610485409ac321f9b Binary files /dev/null and b/checkpoints/BFM/BFM_front_idx.mat differ diff --git a/checkpoints/BFM/BFM_model_front.mat b/checkpoints/BFM/BFM_model_front.mat new file mode 100644 index 0000000000000000000000000000000000000000..2926e5f317244023be2421b17dbb0e97d97ce9e6 --- /dev/null +++ b/checkpoints/BFM/BFM_model_front.mat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ae3ff544aba3246c5f2c117f2be76fa44a7b76145326aae0bbfbfb564d4f82af +size 127170280 diff --git a/checkpoints/BFM/Exp_Pca.bin b/checkpoints/BFM/Exp_Pca.bin new file mode 100644 index 0000000000000000000000000000000000000000..3c1785e6abc52b13e54a573f9f3ebc099915b1e0 --- /dev/null +++ b/checkpoints/BFM/Exp_Pca.bin @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e7f31380e6cbdaf2aeec698db220bac4f221946e4d551d88c092d47ec49b1726 +size 51086404 diff --git a/checkpoints/BFM/facemodel_info.mat b/checkpoints/BFM/facemodel_info.mat new file mode 100644 index 0000000000000000000000000000000000000000..3e516ec7297fa3248098f49ecea10579f4831c0a Binary files /dev/null and b/checkpoints/BFM/facemodel_info.mat differ diff --git a/checkpoints/BFM/select_vertex_id.mat b/checkpoints/BFM/select_vertex_id.mat new file mode 100644 index 0000000000000000000000000000000000000000..5b8b220093d93b133acc94ffed159f31a74854cd Binary files /dev/null and b/checkpoints/BFM/select_vertex_id.mat differ diff --git a/checkpoints/BFM/similarity_Lm3D_all.mat b/checkpoints/BFM/similarity_Lm3D_all.mat new file mode 100644 index 0000000000000000000000000000000000000000..a0e23588302bc71fc899eef53ff06df5f4df4c1d Binary files /dev/null and b/checkpoints/BFM/similarity_Lm3D_all.mat differ diff --git a/checkpoints/BFM/std_exp.txt b/checkpoints/BFM/std_exp.txt new file mode 100644 index 0000000000000000000000000000000000000000..767b8de4ea1ca78b6f22b98ff2dee4fa345500bb --- /dev/null +++ b/checkpoints/BFM/std_exp.txt @@ -0,0 +1 @@ +453980 257264 263068 211890 135873 184721 47055.6 72732 62787.4 106226 56708.5 51439.8 34887.1 44378.7 51813.4 31030.7 23354.9 23128.1 19400 21827.6 22767.7 22057.4 19894.3 16172.8 17142.7 10035.3 14727.5 12972.5 10763.8 8953.93 8682.62 8941.81 6342.3 5205.3 7065.65 6083.35 6678.88 4666.63 5082.89 5134.76 4908.16 3964.93 3739.95 3180.09 2470.45 1866.62 1624.71 2423.74 1668.53 1471.65 1194.52 782.102 815.044 835.782 834.937 744.496 575.146 633.76 705.685 753.409 620.306 673.326 766.189 619.866 559.93 357.264 396.472 556.849 455.048 460.592 400.735 326.702 279.428 291.535 326.584 305.664 287.816 283.642 276.19 \ No newline at end of file diff --git a/checkpoints/DNet.pt b/checkpoints/DNet.pt new file mode 100644 index 0000000000000000000000000000000000000000..f5258b8314f176fb9d5646d9c2a955e08180610a --- /dev/null +++ b/checkpoints/DNet.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41220d2973c0ba2eab6e8f17ed00711aef5a0d76d19808f885dc0e3251df2e80 +size 180424655 diff --git a/checkpoints/ENet.pth b/checkpoints/ENet.pth new file mode 100644 index 0000000000000000000000000000000000000000..783f421cd2ebc35ca938493c12744018b83f4033 --- /dev/null +++ b/checkpoints/ENet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:967ee3ed857619cedd92b6407dc8a124cbfe763cc11cad58316fe21271a8928f +size 573261168 diff --git a/checkpoints/GFPGANv1.3.pth b/checkpoints/GFPGANv1.3.pth new file mode 100644 index 0000000000000000000000000000000000000000..1da748a3ef84ff85dd2c77c836f222aae22b007e --- /dev/null +++ b/checkpoints/GFPGANv1.3.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c953a88f2727c85c3d9ae72e2bd4846bbaf59fe6972ad94130e23e7017524a70 +size 348632874 diff --git a/checkpoints/GPEN-BFR-512.pth b/checkpoints/GPEN-BFR-512.pth new file mode 100644 index 0000000000000000000000000000000000000000..2287dbb4a09d881a933fcda63ef61f42da9eb5ba --- /dev/null +++ b/checkpoints/GPEN-BFR-512.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1002c41add95b0decad69604d80455576f7187dd99ca16bd611bcfd44c10b51 +size 284085738 diff --git a/checkpoints/LNet.pth b/checkpoints/LNet.pth new file mode 100644 index 0000000000000000000000000000000000000000..63d1c81336b6c997e59ce2cf18a40140e92910d1 --- /dev/null +++ b/checkpoints/LNet.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7ae06fef0454c421b828cc53e8d4b9c92d990867a858ea7bb9661ab6cf6ab774 +size 1534697728 diff --git a/checkpoints/ParseNet-latest.pth b/checkpoints/ParseNet-latest.pth new file mode 100644 index 0000000000000000000000000000000000000000..1ac2efc50360a79c9905dbac57d9d99cbfbe863c --- /dev/null +++ b/checkpoints/ParseNet-latest.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3d558d8d0e42c20224f13cf5a29c79eba2d59913419f945545d8cf7b72920de2 +size 85331193 diff --git a/checkpoints/RetinaFace-R50.pth b/checkpoints/RetinaFace-R50.pth new file mode 100644 index 0000000000000000000000000000000000000000..16546738ce0a00a9fd47585e0fc52744d31cc117 --- /dev/null +++ b/checkpoints/RetinaFace-R50.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d1de9c2944f2ccddca5f5e010ea5ae64a39845a86311af6fdf30841b0a5a16d +size 109497761 diff --git a/checkpoints/expression.mat b/checkpoints/expression.mat new file mode 100644 index 0000000000000000000000000000000000000000..a337fa190e0da473d9c67580042779bdce352e94 Binary files /dev/null and b/checkpoints/expression.mat differ diff --git a/checkpoints/face3d_pretrain_epoch_20.pth b/checkpoints/face3d_pretrain_epoch_20.pth new file mode 100644 index 0000000000000000000000000000000000000000..97ebd6753f7ca4bcd39d3b82e7109b66a2dbc1fb --- /dev/null +++ b/checkpoints/face3d_pretrain_epoch_20.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d17a6b23457b521801baae583cb6a58f7238fe6721fc3d65d76407460e9149b +size 288860037 diff --git a/checkpoints/shape_predictor_68_face_landmarks.dat b/checkpoints/shape_predictor_68_face_landmarks.dat new file mode 100644 index 0000000000000000000000000000000000000000..1e5da4f9a556bec8582e6c55b89b3e6bfdd60021 --- /dev/null +++ b/checkpoints/shape_predictor_68_face_landmarks.dat @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fbdc2cb80eb9aa7a758672cbfdda32ba6300efe9b6e6c7a299ff7e736b11b92f +size 99693937 diff --git a/models/DNet.py b/models/DNet.py new file mode 100644 index 0000000000000000000000000000000000000000..085b2dcd59deb699af198b180d77ff80a81746d6 --- /dev/null +++ b/models/DNet.py @@ -0,0 +1,118 @@ +# TODO +import functools +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from utils import flow_util +from models.base_blocks import LayerNorm2d, ADAINHourglass, FineEncoder, FineDecoder + +# DNet +class DNet(nn.Module): + def __init__(self): + super(DNet, self).__init__() + self.mapping_net = MappingNet() + self.warpping_net = WarpingNet() + self.editing_net = EditingNet() + + def forward(self, input_image, driving_source, stage=None): + if stage == 'warp': + descriptor = self.mapping_net(driving_source) + output = self.warpping_net(input_image, descriptor) + else: + descriptor = self.mapping_net(driving_source) + output = self.warpping_net(input_image, descriptor) + output['fake_image'] = self.editing_net(input_image, output['warp_image'], descriptor) + return output + +class MappingNet(nn.Module): + def __init__(self, coeff_nc=73, descriptor_nc=256, layer=3): + super( MappingNet, self).__init__() + + self.layer = layer + nonlinearity = nn.LeakyReLU(0.1) + + self.first = nn.Sequential( + torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True)) + + for i in range(layer): + net = nn.Sequential(nonlinearity, + torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3)) + setattr(self, 'encoder' + str(i), net) + + self.pooling = nn.AdaptiveAvgPool1d(1) + self.output_nc = descriptor_nc + + def forward(self, input_3dmm): + out = self.first(input_3dmm) + for i in range(self.layer): + model = getattr(self, 'encoder' + str(i)) + out = model(out) + out[:,:,3:-3] + out = self.pooling(out) + return out + +class WarpingNet(nn.Module): + def __init__( + self, + image_nc=3, + descriptor_nc=256, + base_nc=32, + max_nc=256, + encoder_layer=5, + decoder_layer=3, + use_spect=False + ): + super( WarpingNet, self).__init__() + + nonlinearity = nn.LeakyReLU(0.1) + norm_layer = functools.partial(LayerNorm2d, affine=True) + kwargs = {'nonlinearity':nonlinearity, 'use_spect':use_spect} + + self.descriptor_nc = descriptor_nc + self.hourglass = ADAINHourglass(image_nc, self.descriptor_nc, base_nc, + max_nc, encoder_layer, decoder_layer, **kwargs) + + self.flow_out = nn.Sequential(norm_layer(self.hourglass.output_nc), + nonlinearity, + nn.Conv2d(self.hourglass.output_nc, 2, kernel_size=7, stride=1, padding=3)) + + self.pool = nn.AdaptiveAvgPool2d(1) + + def forward(self, input_image, descriptor): + final_output={} + output = self.hourglass(input_image, descriptor) + final_output['flow_field'] = self.flow_out(output) + + deformation = flow_util.convert_flow_to_deformation(final_output['flow_field']) + final_output['warp_image'] = flow_util.warp_image(input_image, deformation) + return final_output + + +class EditingNet(nn.Module): + def __init__( + self, + image_nc=3, + descriptor_nc=256, + layer=3, + base_nc=64, + max_nc=256, + num_res_blocks=2, + use_spect=False): + super(EditingNet, self).__init__() + + nonlinearity = nn.LeakyReLU(0.1) + norm_layer = functools.partial(LayerNorm2d, affine=True) + kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} + self.descriptor_nc = descriptor_nc + + # encoder part + self.encoder = FineEncoder(image_nc*2, base_nc, max_nc, layer, **kwargs) + self.decoder = FineDecoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) + + def forward(self, input_image, warp_image, descriptor): + x = torch.cat([input_image, warp_image], 1) + x = self.encoder(x) + gen_image = self.decoder(x, descriptor) + return gen_image diff --git a/models/ENet.py b/models/ENet.py new file mode 100644 index 0000000000000000000000000000000000000000..4df10d662122f6acb20ecfabe3b0d069144b8d18 --- /dev/null +++ b/models/ENet.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from models.base_blocks import ResBlock, StyleConv, ToRGB + + +class ENet(nn.Module): + def __init__( + self, + num_style_feat=512, + lnet=None, + concat=False + ): + super(ENet, self).__init__() + + self.low_res = lnet + for param in self.low_res.parameters(): + param.requires_grad = False + + channel_multiplier, narrow = 2, 1 + channels = { + '4': int(512 * narrow), + '8': int(512 * narrow), + '16': int(512 * narrow), + '32': int(512 * narrow), + '64': int(256 * channel_multiplier * narrow), + '128': int(128 * channel_multiplier * narrow), + '256': int(64 * channel_multiplier * narrow), + '512': int(32 * channel_multiplier * narrow), + '1024': int(16 * channel_multiplier * narrow) + } + + self.log_size = 8 + first_out_size = 128 + self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128 + + # downsample + in_channels = channels[f'{first_out_size}'] + self.conv_body_down = nn.ModuleList() + for i in range(8, 2, -1): + out_channels = channels[f'{2**(i - 1)}'] + self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down')) + in_channels = out_channels + + self.num_style_feat = num_style_feat + linear_out_channel = num_style_feat + self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel) + self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1) + + self.style_convs = nn.ModuleList() + self.to_rgbs = nn.ModuleList() + self.noises = nn.Module() + + self.concat = concat + if concat: + in_channels = 3 + 32 # channels['64'] + else: + in_channels = 3 + + for i in range(7, 9): # 128, 256 + out_channels = channels[f'{2**i}'] # + self.style_convs.append( + StyleConv( + in_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode='upsample')) + self.style_convs.append( + StyleConv( + out_channels, + out_channels, + kernel_size=3, + num_style_feat=num_style_feat, + demodulate=True, + sample_mode=None)) + self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True)) + in_channels = out_channels + + def forward(self, audio_sequences, face_sequences, gt_sequences): + B = audio_sequences.size(0) + input_dim_size = len(face_sequences.size()) + inp, ref = torch.split(face_sequences,3,dim=1) + + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0) + ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0) + gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0) + + # get the global style + feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2) + for i in range(self.log_size - 2): + feat = self.conv_body_down[i](feat) + feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2) + + # style code + style_code = self.final_linear(feat.reshape(feat.size(0), -1)) + style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat) + + LNet_input = torch.cat([inp, gt_sequences], dim=1) + LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear') + + if self.concat: + low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input) + low_res_img.detach() + low_res_feat.detach() + out = torch.cat([low_res_img, low_res_feat], dim=1) + + else: + low_res_img = self.low_res(audio_sequences, LNet_input) + low_res_img.detach() + # 96 x 96 + out = low_res_img + + p2d = (2,2,2,2) + out = F.pad(out, p2d, "reflect", 0) + skip = out + + for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs): + out = conv1(out, style_code) # 96, 192, 384 + out = conv2(out, style_code) + skip = to_rgb(out, style_code, skip) + _outputs = skip + + # remove padding + _outputs = _outputs[:,:,8:-8,8:-8] + + if input_dim_size > 4: + _outputs = torch.split(_outputs, B, dim=0) + outputs = torch.stack(_outputs, dim=2) + low_res_img = F.interpolate(low_res_img, outputs.size()[3:]) + low_res_img = torch.split(low_res_img, B, dim=0) + low_res_img = torch.stack(low_res_img, dim=2) + else: + outputs = _outputs + return outputs, low_res_img \ No newline at end of file diff --git a/models/LNet.py b/models/LNet.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b36d764d9f5a10c6834868dec6f11fc5bb1d3c --- /dev/null +++ b/models/LNet.py @@ -0,0 +1,139 @@ +import functools +import torch +import torch.nn as nn + +from models.transformer import RETURNX, Transformer +from models.base_blocks import Conv2d, LayerNorm2d, FirstBlock2d, DownBlock2d, UpBlock2d, \ + FFCADAINResBlocks, Jump, FinalBlock2d + + +class Visual_Encoder(nn.Module): + def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(Visual_Encoder, self).__init__() + self.layers = layers + self.first_inp = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) + self.first_ref = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) + for i in range(layers): + in_channels = min(ngf*(2**i), img_f) + out_channels = min(ngf*(2**(i+1)), img_f) + model_ref = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) + model_inp = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) + if i < 2: + ca_layer = RETURNX() + else: + ca_layer = Transformer(2**(i+1) * ngf,2,4,ngf,ngf*4) + setattr(self, 'ca' + str(i), ca_layer) + setattr(self, 'ref_down' + str(i), model_ref) + setattr(self, 'inp_down' + str(i), model_inp) + self.output_nc = out_channels * 2 + + def forward(self, maskGT, ref): + x_maskGT, x_ref = self.first_inp(maskGT), self.first_ref(ref) + out=[x_maskGT] + for i in range(self.layers): + model_ref = getattr(self, 'ref_down'+str(i)) + model_inp = getattr(self, 'inp_down'+str(i)) + ca_layer = getattr(self, 'ca'+str(i)) + x_maskGT, x_ref = model_inp(x_maskGT), model_ref(x_ref) + x_maskGT = ca_layer(x_maskGT, x_ref) + if i < self.layers - 1: + out.append(x_maskGT) + else: + out.append(torch.cat([x_maskGT, x_ref], dim=1)) # concat ref features ! + return out + + +class Decoder(nn.Module): + def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(Decoder, self).__init__() + self.layers = layers + for i in range(layers)[::-1]: + if i == layers-1: + in_channels = ngf*(2**(i+1)) * 2 + else: + in_channels = min(ngf*(2**(i+1)), img_f) + out_channels = min(ngf*(2**i), img_f) + up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) + res = FFCADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect) + jump = Jump(out_channels, norm_layer, nonlinearity, use_spect) + + setattr(self, 'up' + str(i), up) + setattr(self, 'res' + str(i), res) + setattr(self, 'jump' + str(i), jump) + + self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'sigmoid') + self.output_nc = out_channels + + def forward(self, x, z): + out = x.pop() + for i in range(self.layers)[::-1]: + res_model = getattr(self, 'res' + str(i)) + up_model = getattr(self, 'up' + str(i)) + jump_model = getattr(self, 'jump' + str(i)) + out = res_model(out, z) + out = up_model(out) + out = jump_model(x.pop()) + out + out_image = self.final(out) + return out_image + + +class LNet(nn.Module): + def __init__( + self, + image_nc=3, + descriptor_nc=512, + layer=3, + base_nc=64, + max_nc=512, + num_res_blocks=9, + use_spect=True, + encoder=Visual_Encoder, + decoder=Decoder + ): + super(LNet, self).__init__() + + nonlinearity = nn.LeakyReLU(0.1) + norm_layer = functools.partial(LayerNorm2d, affine=True) + kwargs = {'norm_layer':norm_layer, 'nonlinearity':nonlinearity, 'use_spect':use_spect} + self.descriptor_nc = descriptor_nc + + self.encoder = encoder(image_nc, base_nc, max_nc, layer, **kwargs) + self.decoder = decoder(image_nc, self.descriptor_nc, base_nc, max_nc, layer, num_res_blocks, **kwargs) + self.audio_encoder = nn.Sequential( + Conv2d(1, 32, kernel_size=3, stride=1, padding=1), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(64, 128, kernel_size=3, stride=3, padding=1), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1), + Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True), + + Conv2d(256, 512, kernel_size=3, stride=1, padding=0), + Conv2d(512, descriptor_nc, kernel_size=1, stride=1, padding=0), + ) + + def forward(self, audio_sequences, face_sequences): + B = audio_sequences.size(0) + input_dim_size = len(face_sequences.size()) + if input_dim_size > 4: + audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0) + face_sequences = torch.cat([face_sequences[:, :, i] for i in range(face_sequences.size(2))], dim=0) + cropped, ref = torch.split(face_sequences, 3, dim=1) + + vis_feat = self.encoder(cropped, ref) + audio_feat = self.audio_encoder(audio_sequences) + _outputs = self.decoder(vis_feat, audio_feat) + + if input_dim_size > 4: + _outputs = torch.split(_outputs, B, dim=0) + outputs = torch.stack(_outputs, dim=2) + else: + outputs = _outputs + return outputs \ No newline at end of file diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eff2fce862335d2d340a2b7e9f27c4192bcf4df4 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,37 @@ +import torch +from models.DNet import DNet +from models.LNet import LNet +from models.ENet import ENet + + +def _load(checkpoint_path): + map_location=None if torch.cuda.is_available() else torch.device('cpu') + checkpoint = torch.load(checkpoint_path, map_location=map_location) + return checkpoint + +def load_checkpoint(path, model): + print("Load checkpoint from: {}".format(path)) + checkpoint = _load(path) + s = checkpoint["state_dict"] if 'arcface' not in path else checkpoint + new_s = {} + for k, v in s.items(): + if 'low_res' in k: + continue + else: + new_s[k.replace('module.', '')] = v + model.load_state_dict(new_s, strict=False) + return model + +def load_network(args): + L_net = LNet() + L_net = load_checkpoint(args.LNet_path, L_net) + E_net = ENet(lnet=L_net) + model = load_checkpoint(args.ENet_path, E_net) + return model.eval() + +def load_DNet(args): + D_Net = DNet() + print("Load checkpoint from: {}".format(args.DNet_path)) + checkpoint = torch.load(args.DNet_path, map_location=lambda storage, loc: storage) + D_Net.load_state_dict(checkpoint['net_G_ema'], strict=False) + return D_Net.eval() \ No newline at end of file diff --git a/models/__pycache__/DNet.cpython-37.pyc b/models/__pycache__/DNet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c194b139627660ec53bdbe4dee4efe5846b116 Binary files /dev/null and b/models/__pycache__/DNet.cpython-37.pyc differ diff --git a/models/__pycache__/DNet.cpython-38.pyc b/models/__pycache__/DNet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1f9d4b5c4a60435e91ac3a852106516c42a290c Binary files /dev/null and b/models/__pycache__/DNet.cpython-38.pyc differ diff --git a/models/__pycache__/DNet.cpython-39.pyc b/models/__pycache__/DNet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd23fddf709d6347a86306b05d50d3f5b4cbc045 Binary files /dev/null and b/models/__pycache__/DNet.cpython-39.pyc differ diff --git a/models/__pycache__/ENet.cpython-37.pyc b/models/__pycache__/ENet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..882c4b25046106f2d769ddc9b9c172ee04eb8094 Binary files /dev/null and b/models/__pycache__/ENet.cpython-37.pyc differ diff --git a/models/__pycache__/ENet.cpython-38.pyc b/models/__pycache__/ENet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..292e109d77b21df2f3c3ed5731e2539dac8dab41 Binary files /dev/null and b/models/__pycache__/ENet.cpython-38.pyc differ diff --git a/models/__pycache__/ENet.cpython-39.pyc b/models/__pycache__/ENet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13a8058ab52ea30d80e40794322b5a4c48ff60bf Binary files /dev/null and b/models/__pycache__/ENet.cpython-39.pyc differ diff --git a/models/__pycache__/LNet.cpython-37.pyc b/models/__pycache__/LNet.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3aa0fcfdefd3bdb2cbec28f40b9f90d2ffe3e7f7 Binary files /dev/null and b/models/__pycache__/LNet.cpython-37.pyc differ diff --git a/models/__pycache__/LNet.cpython-38.pyc b/models/__pycache__/LNet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8cdc662c1f28f8da45e9f375dd9dec9d2ec7e792 Binary files /dev/null and b/models/__pycache__/LNet.cpython-38.pyc differ diff --git a/models/__pycache__/LNet.cpython-39.pyc b/models/__pycache__/LNet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9cb704b199562e06bd35a02013c3bf6ee9b954b9 Binary files /dev/null and b/models/__pycache__/LNet.cpython-39.pyc differ diff --git a/models/__pycache__/__init__.cpython-37.pyc b/models/__pycache__/__init__.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..606f7e79e9a10e2e6835d8c6073c8ee5c6261172 Binary files /dev/null and b/models/__pycache__/__init__.cpython-37.pyc differ diff --git a/models/__pycache__/__init__.cpython-38.pyc b/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a68d6d50861b1777c180542e473b85b6204b803 Binary files /dev/null and b/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/models/__pycache__/__init__.cpython-39.pyc b/models/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..817187f95dc9fadb6bba7ad1534e43719c0c14e6 Binary files /dev/null and b/models/__pycache__/__init__.cpython-39.pyc differ diff --git a/models/__pycache__/base_blocks.cpython-37.pyc b/models/__pycache__/base_blocks.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0519acc29e2bb75456c2992b3e4cacf36d8dda30 Binary files /dev/null and b/models/__pycache__/base_blocks.cpython-37.pyc differ diff --git a/models/__pycache__/base_blocks.cpython-38.pyc b/models/__pycache__/base_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55bd058d571fa33eeffbf9104faa9811d7ec94b8 Binary files /dev/null and b/models/__pycache__/base_blocks.cpython-38.pyc differ diff --git a/models/__pycache__/base_blocks.cpython-39.pyc b/models/__pycache__/base_blocks.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98d4195714636c9feb0806cb08e59e9b25cd472b Binary files /dev/null and b/models/__pycache__/base_blocks.cpython-39.pyc differ diff --git a/models/__pycache__/ffc.cpython-37.pyc b/models/__pycache__/ffc.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8d39f875acb80794e558da939e8d630357c4c3fb Binary files /dev/null and b/models/__pycache__/ffc.cpython-37.pyc differ diff --git a/models/__pycache__/ffc.cpython-38.pyc b/models/__pycache__/ffc.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6af09b56f3b0f13e31b9104dd9717c7a56f32da6 Binary files /dev/null and b/models/__pycache__/ffc.cpython-38.pyc differ diff --git a/models/__pycache__/ffc.cpython-39.pyc b/models/__pycache__/ffc.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1afa426dbac94ce9af1a70b2324ab2ec32498217 Binary files /dev/null and b/models/__pycache__/ffc.cpython-39.pyc differ diff --git a/models/__pycache__/transformer.cpython-37.pyc b/models/__pycache__/transformer.cpython-37.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28a93671b418118939c7868a3bc502b2c51976a0 Binary files /dev/null and b/models/__pycache__/transformer.cpython-37.pyc differ diff --git a/models/__pycache__/transformer.cpython-38.pyc b/models/__pycache__/transformer.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b24ecbb14cf7a5269cd8dce05fa2b14cee2ee98 Binary files /dev/null and b/models/__pycache__/transformer.cpython-38.pyc differ diff --git a/models/__pycache__/transformer.cpython-39.pyc b/models/__pycache__/transformer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d17d953352d8d4c2f0d47a0569c81991a48028b Binary files /dev/null and b/models/__pycache__/transformer.cpython-39.pyc differ diff --git a/models/base_blocks.py b/models/base_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..2d453540b5b9701b23574aea175890de74f06b51 --- /dev/null +++ b/models/base_blocks.py @@ -0,0 +1,554 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.utils.spectral_norm import spectral_norm as SpectralNorm + +from models.ffc import FFC +from basicsr.archs.arch_util import default_init_weights + + +class Conv2d(nn.Module): + def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_block = nn.Sequential( + nn.Conv2d(cin, cout, kernel_size, stride, padding), + nn.BatchNorm2d(cout) + ) + self.act = nn.ReLU() + self.residual = residual + + def forward(self, x): + out = self.conv_block(x) + if self.residual: + out += x + return self.act(out) + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, mode='down'): + super(ResBlock, self).__init__() + self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1) + self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1) + self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False) + if mode == 'down': + self.scale_factor = 0.5 + elif mode == 'up': + self.scale_factor = 2 + + def forward(self, x): + out = F.leaky_relu_(self.conv1(x), negative_slope=0.2) + # upsample/downsample + out = F.interpolate(out, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + out = F.leaky_relu_(self.conv2(out), negative_slope=0.2) + # skip + x = F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False) + skip = self.skip(x) + out = out + skip + return out + + +class LayerNorm2d(nn.Module): + def __init__(self, n_out, affine=True): + super(LayerNorm2d, self).__init__() + self.n_out = n_out + self.affine = affine + + if self.affine: + self.weight = nn.Parameter(torch.ones(n_out, 1, 1)) + self.bias = nn.Parameter(torch.zeros(n_out, 1, 1)) + + def forward(self, x): + normalized_shape = x.size()[1:] + if self.affine: + return F.layer_norm(x, normalized_shape, \ + self.weight.expand(normalized_shape), + self.bias.expand(normalized_shape)) + else: + return F.layer_norm(x, normalized_shape) + + +def spectral_norm(module, use_spect=True): + if use_spect: + return SpectralNorm(module) + else: + return module + + +class FirstBlock2d(nn.Module): + def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FirstBlock2d, self).__init__() + kwargs = {'kernel_size': 7, 'stride': 1, 'padding': 3} + conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) + + if type(norm_layer) == type(None): + self.model = nn.Sequential(conv, nonlinearity) + else: + self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) + + def forward(self, x): + out = self.model(x) + return out + + +class DownBlock2d(nn.Module): + def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(DownBlock2d, self).__init__() + kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} + conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) + pool = nn.AvgPool2d(kernel_size=(2, 2)) + + if type(norm_layer) == type(None): + self.model = nn.Sequential(conv, nonlinearity, pool) + else: + self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity, pool) + + def forward(self, x): + out = self.model(x) + return out + + +class UpBlock2d(nn.Module): + def __init__(self, input_nc, output_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(UpBlock2d, self).__init__() + kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} + conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) + if type(norm_layer) == type(None): + self.model = nn.Sequential(conv, nonlinearity) + else: + self.model = nn.Sequential(conv, norm_layer(output_nc), nonlinearity) + + def forward(self, x): + out = self.model(F.interpolate(x, scale_factor=2)) + return out + + +class ADAIN(nn.Module): + def __init__(self, norm_nc, feature_nc): + super().__init__() + + self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) + + nhidden = 128 + use_bias=True + + self.mlp_shared = nn.Sequential( + nn.Linear(feature_nc, nhidden, bias=use_bias), + nn.ReLU() + ) + self.mlp_gamma = nn.Linear(nhidden, norm_nc, bias=use_bias) + self.mlp_beta = nn.Linear(nhidden, norm_nc, bias=use_bias) + + def forward(self, x, feature): + + # Part 1. generate parameter-free normalized activations + normalized = self.param_free_norm(x) + # Part 2. produce scaling and bias conditioned on feature + feature = feature.view(feature.size(0), -1) + actv = self.mlp_shared(feature) + gamma = self.mlp_gamma(actv) + beta = self.mlp_beta(actv) + + # apply scale and bias + gamma = gamma.view(*gamma.size()[:2], 1,1) + beta = beta.view(*beta.size()[:2], 1,1) + out = normalized * (1 + gamma) + beta + return out + + +class FineADAINResBlock2d(nn.Module): + """ + Define an Residual block for different types + """ + def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FineADAINResBlock2d, self).__init__() + kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} + self.conv1 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) + self.conv2 = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) + self.norm1 = ADAIN(input_nc, feature_nc) + self.norm2 = ADAIN(input_nc, feature_nc) + self.actvn = nonlinearity + + def forward(self, x, z): + dx = self.actvn(self.norm1(self.conv1(x), z)) + dx = self.norm2(self.conv2(x), z) + out = dx + x + return out + + +class FineADAINResBlocks(nn.Module): + def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FineADAINResBlocks, self).__init__() + self.num_block = num_block + for i in range(num_block): + model = FineADAINResBlock2d(input_nc, feature_nc, norm_layer, nonlinearity, use_spect) + setattr(self, 'res'+str(i), model) + + def forward(self, x, z): + for i in range(self.num_block): + model = getattr(self, 'res'+str(i)) + x = model(x, z) + return x + + +class ADAINEncoderBlock(nn.Module): + def __init__(self, input_nc, output_nc, feature_nc, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(ADAINEncoderBlock, self).__init__() + kwargs_down = {'kernel_size': 4, 'stride': 2, 'padding': 1} + kwargs_fine = {'kernel_size': 3, 'stride': 1, 'padding': 1} + + self.conv_0 = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_down), use_spect) + self.conv_1 = spectral_norm(nn.Conv2d(output_nc, output_nc, **kwargs_fine), use_spect) + + + self.norm_0 = ADAIN(input_nc, feature_nc) + self.norm_1 = ADAIN(output_nc, feature_nc) + self.actvn = nonlinearity + + def forward(self, x, z): + x = self.conv_0(self.actvn(self.norm_0(x, z))) + x = self.conv_1(self.actvn(self.norm_1(x, z))) + return x + + +class ADAINDecoderBlock(nn.Module): + def __init__(self, input_nc, output_nc, hidden_nc, feature_nc, use_transpose=True, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(ADAINDecoderBlock, self).__init__() + # Attributes + self.actvn = nonlinearity + hidden_nc = min(input_nc, output_nc) if hidden_nc is None else hidden_nc + + kwargs_fine = {'kernel_size':3, 'stride':1, 'padding':1} + if use_transpose: + kwargs_up = {'kernel_size':3, 'stride':2, 'padding':1, 'output_padding':1} + else: + kwargs_up = {'kernel_size':3, 'stride':1, 'padding':1} + + # create conv layers + self.conv_0 = spectral_norm(nn.Conv2d(input_nc, hidden_nc, **kwargs_fine), use_spect) + if use_transpose: + self.conv_1 = spectral_norm(nn.ConvTranspose2d(hidden_nc, output_nc, **kwargs_up), use_spect) + self.conv_s = spectral_norm(nn.ConvTranspose2d(input_nc, output_nc, **kwargs_up), use_spect) + else: + self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(hidden_nc, output_nc, **kwargs_up), use_spect), + nn.Upsample(scale_factor=2)) + self.conv_s = nn.Sequential(spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs_up), use_spect), + nn.Upsample(scale_factor=2)) + # define normalization layers + self.norm_0 = ADAIN(input_nc, feature_nc) + self.norm_1 = ADAIN(hidden_nc, feature_nc) + self.norm_s = ADAIN(input_nc, feature_nc) + + def forward(self, x, z): + x_s = self.shortcut(x, z) + dx = self.conv_0(self.actvn(self.norm_0(x, z))) + dx = self.conv_1(self.actvn(self.norm_1(dx, z))) + out = x_s + dx + return out + + def shortcut(self, x, z): + x_s = self.conv_s(self.actvn(self.norm_s(x, z))) + return x_s + + +class FineEncoder(nn.Module): + """docstring for Encoder""" + def __init__(self, image_nc, ngf, img_f, layers, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FineEncoder, self).__init__() + self.layers = layers + self.first = FirstBlock2d(image_nc, ngf, norm_layer, nonlinearity, use_spect) + for i in range(layers): + in_channels = min(ngf*(2**i), img_f) + out_channels = min(ngf*(2**(i+1)), img_f) + model = DownBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) + setattr(self, 'down' + str(i), model) + self.output_nc = out_channels + + def forward(self, x): + x = self.first(x) + out=[x] + for i in range(self.layers): + model = getattr(self, 'down'+str(i)) + x = model(x) + out.append(x) + return out + + +class FineDecoder(nn.Module): + """docstring for FineDecoder""" + def __init__(self, image_nc, feature_nc, ngf, img_f, layers, num_block, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FineDecoder, self).__init__() + self.layers = layers + for i in range(layers)[::-1]: + in_channels = min(ngf*(2**(i+1)), img_f) + out_channels = min(ngf*(2**i), img_f) + up = UpBlock2d(in_channels, out_channels, norm_layer, nonlinearity, use_spect) + res = FineADAINResBlocks(num_block, in_channels, feature_nc, norm_layer, nonlinearity, use_spect) + jump = Jump(out_channels, norm_layer, nonlinearity, use_spect) + setattr(self, 'up' + str(i), up) + setattr(self, 'res' + str(i), res) + setattr(self, 'jump' + str(i), jump) + self.final = FinalBlock2d(out_channels, image_nc, use_spect, 'tanh') + self.output_nc = out_channels + + def forward(self, x, z): + out = x.pop() + for i in range(self.layers)[::-1]: + res_model = getattr(self, 'res' + str(i)) + up_model = getattr(self, 'up' + str(i)) + jump_model = getattr(self, 'jump' + str(i)) + out = res_model(out, z) + out = up_model(out) + out = jump_model(x.pop()) + out + out_image = self.final(out) + return out_image + + +class ADAINEncoder(nn.Module): + def __init__(self, image_nc, pose_nc, ngf, img_f, layers, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(ADAINEncoder, self).__init__() + self.layers = layers + self.input_layer = nn.Conv2d(image_nc, ngf, kernel_size=7, stride=1, padding=3) + for i in range(layers): + in_channels = min(ngf * (2**i), img_f) + out_channels = min(ngf *(2**(i+1)), img_f) + model = ADAINEncoderBlock(in_channels, out_channels, pose_nc, nonlinearity, use_spect) + setattr(self, 'encoder' + str(i), model) + self.output_nc = out_channels + + def forward(self, x, z): + out = self.input_layer(x) + out_list = [out] + for i in range(self.layers): + model = getattr(self, 'encoder' + str(i)) + out = model(out, z) + out_list.append(out) + return out_list + + +class ADAINDecoder(nn.Module): + """docstring for ADAINDecoder""" + def __init__(self, pose_nc, ngf, img_f, encoder_layers, decoder_layers, skip_connect=True, + nonlinearity=nn.LeakyReLU(), use_spect=False): + + super(ADAINDecoder, self).__init__() + self.encoder_layers = encoder_layers + self.decoder_layers = decoder_layers + self.skip_connect = skip_connect + use_transpose = True + for i in range(encoder_layers-decoder_layers, encoder_layers)[::-1]: + in_channels = min(ngf * (2**(i+1)), img_f) + in_channels = in_channels*2 if i != (encoder_layers-1) and self.skip_connect else in_channels + out_channels = min(ngf * (2**i), img_f) + model = ADAINDecoderBlock(in_channels, out_channels, out_channels, pose_nc, use_transpose, nonlinearity, use_spect) + setattr(self, 'decoder' + str(i), model) + self.output_nc = out_channels*2 if self.skip_connect else out_channels + + def forward(self, x, z): + out = x.pop() if self.skip_connect else x + for i in range(self.encoder_layers-self.decoder_layers, self.encoder_layers)[::-1]: + model = getattr(self, 'decoder' + str(i)) + out = model(out, z) + out = torch.cat([out, x.pop()], 1) if self.skip_connect else out + return out + + +class ADAINHourglass(nn.Module): + def __init__(self, image_nc, pose_nc, ngf, img_f, encoder_layers, decoder_layers, nonlinearity, use_spect): + super(ADAINHourglass, self).__init__() + self.encoder = ADAINEncoder(image_nc, pose_nc, ngf, img_f, encoder_layers, nonlinearity, use_spect) + self.decoder = ADAINDecoder(pose_nc, ngf, img_f, encoder_layers, decoder_layers, True, nonlinearity, use_spect) + self.output_nc = self.decoder.output_nc + + def forward(self, x, z): + return self.decoder(self.encoder(x, z), z) + + +class FineADAINLama(nn.Module): + def __init__(self, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FineADAINLama, self).__init__() + kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} + self.actvn = nonlinearity + ratio_gin = 0.75 + ratio_gout = 0.75 + self.ffc = FFC(input_nc, input_nc, 3, + ratio_gin, ratio_gout, 1, 1, 1, + 1, False, False, padding_type='reflect') + global_channels = int(input_nc * ratio_gout) + self.bn_l = ADAIN(input_nc - global_channels, feature_nc) + self.bn_g = ADAIN(global_channels, feature_nc) + + def forward(self, x, z): + x_l, x_g = self.ffc(x) + x_l = self.actvn(self.bn_l(x_l,z)) + x_g = self.actvn(self.bn_g(x_g,z)) + return x_l, x_g + + +class FFCResnetBlock(nn.Module): + def __init__(self, dim, feature_dim, padding_type='reflect', norm_layer=BatchNorm2d, activation_layer=nn.ReLU, dilation=1, + spatial_transform_kwargs=None, inline=False, **conv_kwargs): + super().__init__() + self.conv1 = FineADAINLama(dim, feature_dim, **conv_kwargs) + self.conv2 = FineADAINLama(dim, feature_dim, **conv_kwargs) + self.inline = True + + def forward(self, x, z): + if self.inline: + x_l, x_g = x[:, :-self.conv1.ffc.global_in_num], x[:, -self.conv1.ffc.global_in_num:] + else: + x_l, x_g = x if type(x) is tuple else (x, 0) + + id_l, id_g = x_l, x_g + x_l, x_g = self.conv1((x_l, x_g), z) + x_l, x_g = self.conv2((x_l, x_g), z) + + x_l, x_g = id_l + x_l, id_g + x_g + out = x_l, x_g + if self.inline: + out = torch.cat(out, dim=1) + return out + + +class FFCADAINResBlocks(nn.Module): + def __init__(self, num_block, input_nc, feature_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(FFCADAINResBlocks, self).__init__() + self.num_block = num_block + for i in range(num_block): + model = FFCResnetBlock(input_nc, feature_nc, norm_layer, nonlinearity, use_spect) + setattr(self, 'res'+str(i), model) + + def forward(self, x, z): + for i in range(self.num_block): + model = getattr(self, 'res'+str(i)) + x = model(x, z) + return x + + +class Jump(nn.Module): + def __init__(self, input_nc, norm_layer=nn.BatchNorm2d, nonlinearity=nn.LeakyReLU(), use_spect=False): + super(Jump, self).__init__() + kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1} + conv = spectral_norm(nn.Conv2d(input_nc, input_nc, **kwargs), use_spect) + if type(norm_layer) == type(None): + self.model = nn.Sequential(conv, nonlinearity) + else: + self.model = nn.Sequential(conv, norm_layer(input_nc), nonlinearity) + + def forward(self, x): + out = self.model(x) + return out + + +class FinalBlock2d(nn.Module): + def __init__(self, input_nc, output_nc, use_spect=False, tanh_or_sigmoid='tanh'): + super(FinalBlock2d, self).__init__() + kwargs = {'kernel_size': 7, 'stride': 1, 'padding':3} + conv = spectral_norm(nn.Conv2d(input_nc, output_nc, **kwargs), use_spect) + if tanh_or_sigmoid == 'sigmoid': + out_nonlinearity = nn.Sigmoid() + else: + out_nonlinearity = nn.Tanh() + self.model = nn.Sequential(conv, out_nonlinearity) + + def forward(self, x): + out = self.model(x) + return out + + +class ModulatedConv2d(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + num_style_feat, + demodulate=True, + sample_mode=None, + eps=1e-8): + super(ModulatedConv2d, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.demodulate = demodulate + self.sample_mode = sample_mode + self.eps = eps + + # modulation inside each modulated conv + self.modulation = nn.Linear(num_style_feat, in_channels, bias=True) + # initialization + default_init_weights(self.modulation, scale=1, bias_fill=1, a=0, mode='fan_in', nonlinearity='linear') + + self.weight = nn.Parameter( + torch.randn(1, out_channels, in_channels, kernel_size, kernel_size) / + math.sqrt(in_channels * kernel_size**2)) + self.padding = kernel_size // 2 + + def forward(self, x, style): + b, c, h, w = x.shape + style = self.modulation(style).view(b, 1, c, 1, 1) + weight = self.weight * style + + if self.demodulate: + demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps) + weight = weight * demod.view(b, self.out_channels, 1, 1, 1) + + weight = weight.view(b * self.out_channels, c, self.kernel_size, self.kernel_size) + + # upsample or downsample if necessary + if self.sample_mode == 'upsample': + x = F.interpolate(x, scale_factor=2, mode='bilinear', align_corners=False) + elif self.sample_mode == 'downsample': + x = F.interpolate(x, scale_factor=0.5, mode='bilinear', align_corners=False) + + b, c, h, w = x.shape + x = x.view(1, b * c, h, w) + out = F.conv2d(x, weight, padding=self.padding, groups=b) + out = out.view(b, self.out_channels, *out.shape[2:4]) + return out + + def __repr__(self): + return (f'{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, ' + f'kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})') + + +class StyleConv(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, num_style_feat, demodulate=True, sample_mode=None): + super(StyleConv, self).__init__() + self.modulated_conv = ModulatedConv2d( + in_channels, out_channels, kernel_size, num_style_feat, demodulate=demodulate, sample_mode=sample_mode) + self.weight = nn.Parameter(torch.zeros(1)) # for noise injection + self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1)) + self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True) + + def forward(self, x, style, noise=None): + # modulate + out = self.modulated_conv(x, style) * 2**0.5 # for conversion + # noise injection + if noise is None: + b, _, h, w = out.shape + noise = out.new_empty(b, 1, h, w).normal_() + out = out + self.weight * noise + # add bias + out = out + self.bias + # activation + out = self.activate(out) + return out + + +class ToRGB(nn.Module): + def __init__(self, in_channels, num_style_feat, upsample=True): + super(ToRGB, self).__init__() + self.upsample = upsample + self.modulated_conv = ModulatedConv2d( + in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None) + self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1)) + + def forward(self, x, style, skip=None): + out = self.modulated_conv(x, style) + out = out + self.bias + if skip is not None: + if self.upsample: + skip = F.interpolate(skip, scale_factor=2, mode='bilinear', align_corners=False) + out = out + skip + return out \ No newline at end of file diff --git a/models/ffc.py b/models/ffc.py new file mode 100644 index 0000000000000000000000000000000000000000..89a5c4c09dc5f3e739a3ee9446225a738e0de97a --- /dev/null +++ b/models/ffc.py @@ -0,0 +1,233 @@ +# Fast Fourier Convolution NeurIPS 2020 +# original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py +# paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +# from models.modules.squeeze_excitation import SELayer +import torch.fft + +class SELayer(nn.Module): + def __init__(self, channel, reduction=16): + super(SELayer, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.fc = nn.Sequential( + nn.Linear(channel, channel // reduction, bias=False), + nn.ReLU(inplace=True), + nn.Linear(channel // reduction, channel, bias=False), + nn.Sigmoid() + ) + + def forward(self, x): + b, c, _, _ = x.size() + y = self.avg_pool(x).view(b, c) + y = self.fc(y).view(b, c, 1, 1) + res = x * y.expand_as(x) + return res + + +class FFCSE_block(nn.Module): + def __init__(self, channels, ratio_g): + super(FFCSE_block, self).__init__() + in_cg = int(channels * ratio_g) + in_cl = channels - in_cg + r = 16 + + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.conv1 = nn.Conv2d(channels, channels // r, + kernel_size=1, bias=True) + self.relu1 = nn.ReLU(inplace=True) + self.conv_a2l = None if in_cl == 0 else nn.Conv2d( + channels // r, in_cl, kernel_size=1, bias=True) + self.conv_a2g = None if in_cg == 0 else nn.Conv2d( + channels // r, in_cg, kernel_size=1, bias=True) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + x = x if type(x) is tuple else (x, 0) + id_l, id_g = x + + x = id_l if type(id_g) is int else torch.cat([id_l, id_g], dim=1) + x = self.avgpool(x) + x = self.relu1(self.conv1(x)) + + x_l = 0 if self.conv_a2l is None else id_l * \ + self.sigmoid(self.conv_a2l(x)) + x_g = 0 if self.conv_a2g is None else id_g * \ + self.sigmoid(self.conv_a2g(x)) + return x_l, x_g + + +class FourierUnit(nn.Module): + + def __init__(self, in_channels, out_channels, groups=1, spatial_scale_factor=None, spatial_scale_mode='bilinear', + spectral_pos_encoding=False, use_se=False, se_kwargs=None, ffc3d=False, fft_norm='ortho'): + # bn_layer not used + super(FourierUnit, self).__init__() + self.groups = groups + + self.conv_layer = torch.nn.Conv2d(in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0), + out_channels=out_channels * 2, + kernel_size=1, stride=1, padding=0, groups=self.groups, bias=False) + self.bn = torch.nn.BatchNorm2d(out_channels * 2) + self.relu = torch.nn.ReLU(inplace=True) + + # squeeze and excitation block + self.use_se = use_se + if use_se: + if se_kwargs is None: + se_kwargs = {} + self.se = SELayer(self.conv_layer.in_channels, **se_kwargs) + + self.spatial_scale_factor = spatial_scale_factor + self.spatial_scale_mode = spatial_scale_mode + self.spectral_pos_encoding = spectral_pos_encoding + self.ffc3d = ffc3d + self.fft_norm = fft_norm + + def forward(self, x): + batch = x.shape[0] + + if self.spatial_scale_factor is not None: + orig_size = x.shape[-2:] + x = F.interpolate(x, scale_factor=self.spatial_scale_factor, mode=self.spatial_scale_mode, align_corners=False) + + r_size = x.size() + # (batch, c, h, w/2+1, 2) + fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1) + ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm) + ffted = torch.stack((ffted.real, ffted.imag), dim=-1) + ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) + ffted = ffted.view((batch, -1,) + ffted.size()[3:]) + + if self.spectral_pos_encoding: + height, width = ffted.shape[-2:] + coords_vert = torch.linspace(0, 1, height)[None, None, :, None].expand(batch, 1, height, width).to(ffted) + coords_hor = torch.linspace(0, 1, width)[None, None, None, :].expand(batch, 1, height, width).to(ffted) + ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1) + + if self.use_se: + ffted = self.se(ffted) + + ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) + ffted = self.relu(self.bn(ffted)) + + ffted = ffted.view((batch, -1, 2,) + ffted.size()[2:]).permute( + 0, 1, 3, 4, 2).contiguous() # (batch,c, t, h, w/2+1, 2) + ffted = torch.complex(ffted[..., 0], ffted[..., 1]) + + ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:] + output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm) + + if self.spatial_scale_factor is not None: + output = F.interpolate(output, size=orig_size, mode=self.spatial_scale_mode, align_corners=False) + + return output + + +class SpectralTransform(nn.Module): + def __init__(self, in_channels, out_channels, stride=1, groups=1, enable_lfu=True, **fu_kwargs): + # bn_layer not used + super(SpectralTransform, self).__init__() + self.enable_lfu = enable_lfu + if stride == 2: + self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) + else: + self.downsample = nn.Identity() + + self.stride = stride + self.conv1 = nn.Sequential( + nn.Conv2d(in_channels, out_channels // + 2, kernel_size=1, groups=groups, bias=False), + nn.BatchNorm2d(out_channels // 2), + nn.ReLU(inplace=True) + ) + self.fu = FourierUnit( + out_channels // 2, out_channels // 2, groups, **fu_kwargs) + if self.enable_lfu: + self.lfu = FourierUnit( + out_channels // 2, out_channels // 2, groups) + self.conv2 = torch.nn.Conv2d( + out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) + + def forward(self, x): + x = self.downsample(x) + x = self.conv1(x) + output = self.fu(x) + + if self.enable_lfu: + n, c, h, w = x.shape + split_no = 2 + split_s = h // split_no + xs = torch.cat(torch.split( + x[:, :c // 4], split_s, dim=-2), dim=1).contiguous() + xs = torch.cat(torch.split(xs, split_s, dim=-1), + dim=1).contiguous() + xs = self.lfu(xs) + xs = xs.repeat(1, 1, split_no, split_no).contiguous() + else: + xs = 0 + + output = self.conv2(x + output + xs) + return output + + +class FFC(nn.Module): + + def __init__(self, in_channels, out_channels, kernel_size, + ratio_gin, ratio_gout, stride=1, padding=0, + dilation=1, groups=1, bias=False, enable_lfu=True, + padding_type='reflect', gated=False, **spectral_kwargs): + super(FFC, self).__init__() + + assert stride == 1 or stride == 2, "Stride should be 1 or 2." + self.stride = stride + + in_cg = int(in_channels * ratio_gin) + in_cl = in_channels - in_cg + out_cg = int(out_channels * ratio_gout) + out_cl = out_channels - out_cg + + self.ratio_gin = ratio_gin + self.ratio_gout = ratio_gout + self.global_in_num = in_cg + + module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d + self.convl2l = module(in_cl, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d + self.convl2g = module(in_cl, out_cg, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d + self.convg2l = module(in_cg, out_cl, kernel_size, + stride, padding, dilation, groups, bias, padding_mode=padding_type) + module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform + self.convg2g = module( + in_cg, out_cg, stride, 1 if groups == 1 else groups // 2, enable_lfu, **spectral_kwargs) + + self.gated = gated + module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d + self.gate = module(in_channels, 2, 1) + + def forward(self, x): + x_l, x_g = x if type(x) is tuple else (x, 0) + out_xl, out_xg = 0, 0 + + if self.gated: + total_input_parts = [x_l] + if torch.is_tensor(x_g): + total_input_parts.append(x_g) + total_input = torch.cat(total_input_parts, dim=1) + + gates = torch.sigmoid(self.gate(total_input)) + g2l_gate, l2g_gate = gates.chunk(2, dim=1) + else: + g2l_gate, l2g_gate = 1, 1 + + if self.ratio_gout != 1: + out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate + if self.ratio_gout != 0: + out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) + + return out_xl, out_xg \ No newline at end of file diff --git a/models/transformer.py b/models/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..cf582c760432b6ad6cf8988e5155814b28b33107 --- /dev/null +++ b/models/transformer.py @@ -0,0 +1,119 @@ +import torch +from torch import nn + +from einops import rearrange + +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class GELU(nn.Module): + def __init__(self): + super(GELU, self).__init__() + def forward(self, x): + return 0.5*x*(1+F.tanh(np.sqrt(2/np.pi)*(x+0.044715*torch.pow(x,3)))) + +# helpers + +def pair(t): + return t if isinstance(t, tuple) else (t, t) + +# classes + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + +class DualPreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.normx = nn.LayerNorm(dim) + self.normy = nn.LayerNorm(dim) + self.fn = fn + def forward(self, x, y, **kwargs): + return self.fn(self.normx(x), self.normy(y), **kwargs) + +class FeedForward(nn.Module): + def __init__(self, dim, hidden_dim, dropout = 0.): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + +class Attention(nn.Module): + def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head ** -0.5 + + self.attend = nn.Softmax(dim = -1) + + self.to_q = nn.Linear(dim, inner_dim, bias = False) + self.to_k = nn.Linear(dim, inner_dim, bias = False) + self.to_v = nn.Linear(dim, inner_dim, bias = False) + + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) if project_out else nn.Identity() + + def forward(self, x, y): + # qk = self.to_qk(x).chunk(2, dim = -1) # + q = rearrange(self.to_q(x), 'b n (h d) -> b h n d', h = self.heads) # q,k from the zero feature + k = rearrange(self.to_k(x), 'b n (h d) -> b h n d', h = self.heads) # v from the reference features + v = rearrange(self.to_v(y), 'b n (h d) -> b h n d', h = self.heads) + + dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale + + attn = self.attend(dots) + + out = torch.matmul(attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + return self.to_out(out) + +class Transformer(nn.Module): + def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): + super().__init__() + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append(nn.ModuleList([ + DualPreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)) + ])) + + + def forward(self, x, y): # x is the cropped, y is the foreign reference + bs,c,h,w = x.size() + + # img to embedding + x = x.view(bs,c,-1).permute(0,2,1) + y = y.view(bs,c,-1).permute(0,2,1) + + for attn, ff in self.layers: + x = attn(x, y) + x + x = ff(x) + x + + x = x.view(bs,h,w,c).permute(0,3,1,2) + return x + +class RETURNX(nn.Module): + def __init__(self,): + super().__init__() + + def forward(self, x, y): # x is the cropped, y is the foreign reference + return x \ No newline at end of file diff --git a/results/1.mp4 b/results/1.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..95b894173405e00e97fc82a8e878589bdb3437d3 Binary files /dev/null and b/results/1.mp4 differ