''' not exactly the same as the official repo but the results are good ''' import sys import os sys.path.append(os.getcwd()) import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import math from nets.layers import SeqEncoder1D, SeqTranslator1D """ from https://github.com/ai4r/Gesture-Generation-from-Trimodal-Context.git """ class Conv2d_tf(nn.Conv2d): """ Conv2d with the padding behavior from TF from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py """ def __init__(self, *args, **kwargs): super(Conv2d_tf, self).__init__(*args, **kwargs) self.padding = kwargs.get("padding", "SAME") def _compute_padding(self, input, dim): input_size = input.size(dim + 2) filter_size = self.weight.size(dim + 2) effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] total_padding = max( 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size ) additional_padding = int(total_padding % 2 != 0) return additional_padding, total_padding def forward(self, input): if self.padding == "VALID": return F.conv2d( input, self.weight, self.bias, self.stride, padding=0, dilation=self.dilation, groups=self.groups, ) rows_odd, padding_rows = self._compute_padding(input, dim=0) cols_odd, padding_cols = self._compute_padding(input, dim=1) if rows_odd or cols_odd: input = F.pad(input, [0, cols_odd, 0, rows_odd]) return F.conv2d( input, self.weight, self.bias, self.stride, padding=(padding_rows // 2, padding_cols // 2), dilation=self.dilation, groups=self.groups, ) class Conv1d_tf(nn.Conv1d): """ Conv1d with the padding behavior from TF modified from https://github.com/mlperf/inference/blob/482f6a3beb7af2fb0bd2d91d6185d5e71c22c55f/others/edge/object_detection/ssd_mobilenet/pytorch/utils.py """ def __init__(self, *args, **kwargs): super(Conv1d_tf, self).__init__(*args, **kwargs) self.padding = kwargs.get("padding") def _compute_padding(self, input, dim): input_size = input.size(dim + 2) filter_size = self.weight.size(dim + 2) effective_filter_size = (filter_size - 1) * self.dilation[dim] + 1 out_size = (input_size + self.stride[dim] - 1) // self.stride[dim] total_padding = max( 0, (out_size - 1) * self.stride[dim] + effective_filter_size - input_size ) additional_padding = int(total_padding % 2 != 0) return additional_padding, total_padding def forward(self, input): # if self.padding == "valid": # return F.conv1d( # input, # self.weight, # self.bias, # self.stride, # padding=0, # dilation=self.dilation, # groups=self.groups, # ) rows_odd, padding_rows = self._compute_padding(input, dim=0) if rows_odd: input = F.pad(input, [0, rows_odd]) return F.conv1d( input, self.weight, self.bias, self.stride, padding=(padding_rows // 2), dilation=self.dilation, groups=self.groups, ) def ConvNormRelu(in_channels, out_channels, type='1d', downsample=False, k=None, s=None, padding='valid', groups=1, nonlinear='lrelu', bn='bn'): if k is None and s is None: if not downsample: k = 3 s = 1 padding = 'same' else: k = 4 s = 2 padding = 'valid' if type == '1d': conv_block = Conv1d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) norm_block = nn.BatchNorm1d(out_channels) elif type == '2d': conv_block = Conv2d_tf(in_channels, out_channels, kernel_size=k, stride=s, padding=padding, groups=groups) norm_block = nn.BatchNorm2d(out_channels) else: assert False if bn != 'bn': if bn == 'gn': norm_block = nn.GroupNorm(1, out_channels) elif bn == 'ln': norm_block = nn.LayerNorm(out_channels) else: norm_block = nn.Identity() if nonlinear == 'lrelu': nlinear = nn.LeakyReLU(0.2, True) elif nonlinear == 'tanh': nlinear = nn.Tanh() elif nonlinear == 'none': nlinear = nn.Identity() return nn.Sequential( conv_block, norm_block, nlinear ) class UnetUp(nn.Module): def __init__(self, in_ch, out_ch): super(UnetUp, self).__init__() self.conv = ConvNormRelu(in_ch, out_ch) def forward(self, x1, x2): # x1 = torch.repeat_interleave(x1, 2, dim=2) # x1 = x1[:, :, :x2.shape[2]] x1 = torch.nn.functional.interpolate(x1, size=x2.shape[2], mode='linear') x = x1 + x2 x = self.conv(x) return x class UNet(nn.Module): def __init__(self, input_dim, dim): super(UNet, self).__init__() # dim = 512 self.down1 = nn.Sequential( ConvNormRelu(input_dim, input_dim, '1d', False), ConvNormRelu(input_dim, dim, '1d', False), ConvNormRelu(dim, dim, '1d', False) ) self.gru = nn.GRU(dim, dim, 1, batch_first=True) self.down2 = ConvNormRelu(dim, dim, '1d', True) self.down3 = ConvNormRelu(dim, dim, '1d', True) self.down4 = ConvNormRelu(dim, dim, '1d', True) self.down5 = ConvNormRelu(dim, dim, '1d', True) self.down6 = ConvNormRelu(dim, dim, '1d', True) self.up1 = UnetUp(dim, dim) self.up2 = UnetUp(dim, dim) self.up3 = UnetUp(dim, dim) self.up4 = UnetUp(dim, dim) self.up5 = UnetUp(dim, dim) def forward(self, x1, pre_pose=None, w_pre=False): x2_0 = self.down1(x1) if w_pre: i = 1 x2_pre = self.gru(x2_0[:,:,0:i].permute(0,2,1), pre_pose[:,:,-1:].permute(2,0,1).contiguous())[0].permute(0,2,1) x2 = torch.cat([x2_pre, x2_0[:,:,i:]], dim=-1) # x2 = torch.cat([pre_pose, x2_0], dim=2) # [B, 512, 15] else: # x2 = self.gru(x2_0.transpose(1, 2))[0].transpose(1,2) x2 = x2_0 x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x6 = self.down5(x5) x7 = self.down6(x6) x = self.up1(x7, x6) x = self.up2(x, x5) x = self.up3(x, x4) x = self.up4(x, x3) x = self.up5(x, x2) # [B, 512, 15] return x, x2_0 class AudioEncoder(nn.Module): def __init__(self, n_frames, template_length, pose=False, common_dim=512): super().__init__() self.n_frames = n_frames self.pose = pose self.step = 0 self.weight = 0 if self.pose: # self.first_net = nn.Sequential( # ConvNormRelu(1, 64, '2d', False), # ConvNormRelu(64, 64, '2d', True), # ConvNormRelu(64, 128, '2d', False), # ConvNormRelu(128, 128, '2d', True), # ConvNormRelu(128, 256, '2d', False), # ConvNormRelu(256, 256, '2d', True), # ConvNormRelu(256, 256, '2d', False), # ConvNormRelu(256, 256, '2d', False, padding='VALID') # ) # decoder_layer = nn.TransformerDecoderLayer(d_model=args.feature_dim, nhead=4, # dim_feedforward=2 * args.feature_dim, batch_first=True) # a = nn.TransformerDecoder self.first_net = SeqTranslator1D(256, 256, min_layers_num=4, residual=True ) self.dropout_0 = nn.Dropout(0.1) self.mu_fc = nn.Conv1d(256, 128, 1, 1) self.var_fc = nn.Conv1d(256, 128, 1, 1) self.trans_motion = SeqTranslator1D(common_dim, common_dim, kernel_size=1, stride=1, min_layers_num=3, residual=True ) # self.att = nn.MultiheadAttention(64 + template_length, 4, dropout=0.1) self.unet = UNet(128 + template_length, common_dim) else: self.first_net = SeqTranslator1D(256, 256, min_layers_num=4, residual=True ) self.dropout_0 = nn.Dropout(0.1) # self.att = nn.MultiheadAttention(256, 4, dropout=0.1) self.unet = UNet(256, 256) self.dropout_1 = nn.Dropout(0.0) def forward(self, spectrogram, time_steps=None, template=None, pre_pose=None, w_pre=False): self.step = self.step + 1 if self.pose: spect = spectrogram.transpose(1, 2) if w_pre: spect = spect[:, :, :] out = self.first_net(spect) out = self.dropout_0(out) mu = self.mu_fc(out) var = self.var_fc(out) audio = self.__reparam(mu, var) # audio = out # template = self.trans_motion(template) x1 = torch.cat([audio, template], dim=1)#.permute(2,0,1) # x1 = out #x1, _ = self.att(x1, x1, x1) #x1 = x1.permute(1,2,0) x1, x2_0 = self.unet(x1, pre_pose=pre_pose, w_pre=w_pre) else: spectrogram = spectrogram.transpose(1, 2) x1 = self.first_net(spectrogram)#.permute(2,0,1) #out, _ = self.att(out, out, out) #out = out.permute(1, 2, 0) x1 = self.dropout_0(x1) x1, x2_0 = self.unet(x1) x1 = self.dropout_1(x1) mu = None var = None return x1, (mu, var), x2_0 def __reparam(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std, device='cuda') z = eps * std + mu return z class Generator(nn.Module): def __init__(self, n_poses, pose_dim, pose, n_pre_poses, each_dim: list, dim_list: list, use_template=False, template_length=0, training=False, device=None, separate=False, expression=False ): super().__init__() self.use_template = use_template self.template_length = template_length self.training = training self.device = device self.separate = separate self.pose = pose self.decoderf = True self.expression = expression common_dim = 256 if self.use_template: assert template_length > 0 # self.KLLoss = KLLoss(kl_tolerance=self.config.Train.weights.kl_tolerance).to(self.device) # self.pose_encoder = SeqEncoder1D( # C_in=pose_dim, # C_out=512, # T_in=n_poses, # min_layer_nums=6 # # ) self.pose_encoder = SeqTranslator1D(pose_dim - 50, common_dim, # kernel_size=1, # stride=1, min_layers_num=3, residual=True ) self.mu_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) self.var_fc = nn.Conv1d(common_dim, template_length, kernel_size=1, stride=1) else: self.template_length = 0 self.gen_length = n_poses self.audio_encoder = AudioEncoder(n_poses, template_length, True, common_dim) self.speech_encoder = AudioEncoder(n_poses, template_length, False) # self.pre_pose_encoder = SeqEncoder1D( # C_in=pose_dim, # C_out=128, # T_in=15, # min_layer_nums=3 # # ) # self.pmu_fc = nn.Linear(128, 64) # self.pvar_fc = nn.Linear(128, 64) self.pre_pose_encoder = SeqTranslator1D(pose_dim-50, common_dim, min_layers_num=5, residual=True ) self.decoder_in = 256 + 64 self.dim_list = dim_list if self.separate: self.decoder = nn.ModuleList() self.final_out = nn.ModuleList() self.decoder.append(nn.Sequential( ConvNormRelu(256, 64), ConvNormRelu(64, 64), ConvNormRelu(64, 64), )) self.final_out.append(nn.Conv1d(64, each_dim[0], 1, 1)) self.decoder.append(nn.Sequential( ConvNormRelu(common_dim, common_dim), ConvNormRelu(common_dim, common_dim), ConvNormRelu(common_dim, common_dim), )) self.final_out.append(nn.Conv1d(common_dim, each_dim[1], 1, 1)) self.decoder.append(nn.Sequential( ConvNormRelu(common_dim, common_dim), ConvNormRelu(common_dim, common_dim), ConvNormRelu(common_dim, common_dim), )) self.final_out.append(nn.Conv1d(common_dim, each_dim[2], 1, 1)) if self.expression: self.decoder.append(nn.Sequential( ConvNormRelu(256, 256), ConvNormRelu(256, 256), ConvNormRelu(256, 256), )) self.final_out.append(nn.Conv1d(256, each_dim[3], 1, 1)) else: self.decoder = nn.Sequential( ConvNormRelu(self.decoder_in, 512), ConvNormRelu(512, 512), ConvNormRelu(512, 512), ConvNormRelu(512, 512), ConvNormRelu(512, 512), ConvNormRelu(512, 512), ) self.final_out = nn.Conv1d(512, pose_dim, 1, 1) def __reparam(self, mu, log_var): std = torch.exp(0.5 * log_var) eps = torch.randn_like(std, device=self.device) z = eps * std + mu return z def forward(self, in_spec, pre_poses, gt_poses, template=None, time_steps=None, w_pre=False, norm=True): if time_steps is not None: self.gen_length = time_steps if self.use_template: if self.training: if w_pre: in_spec = in_spec[:, 15:, :] pre_pose = self.pre_pose_encoder(gt_poses[:, 14:15, :-50].permute(0, 2, 1)) pose_enc = self.pose_encoder(gt_poses[:, 15:, :-50].permute(0, 2, 1)) mu = self.mu_fc(pose_enc) var = self.var_fc(pose_enc) template = self.__reparam(mu, var) else: pre_pose = None pose_enc = self.pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) mu = self.mu_fc(pose_enc) var = self.var_fc(pose_enc) template = self.__reparam(mu, var) elif pre_poses is not None: if w_pre: pre_pose = pre_poses[:, -1:, :-50] if norm: pre_pose = pre_pose.reshape(1, -1, 55, 5) pre_pose = torch.cat([F.normalize(pre_pose[..., :3], dim=-1), F.normalize(pre_pose[..., 3:5], dim=-1)], dim=-1).reshape(1, -1, 275) pre_pose = self.pre_pose_encoder(pre_pose.permute(0, 2, 1)) template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length ]).to( in_spec.device) else: pre_pose = None template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) elif gt_poses is not None: template = self.pre_pose_encoder(gt_poses[:, :, :-50].permute(0, 2, 1)) elif template is None: pre_pose = None template = torch.randn([in_spec.shape[0], self.template_length, self.gen_length]).to(in_spec.device) else: template = None mu = None var = None a_t_f, (mu2, var2), x2_0 = self.audio_encoder(in_spec, time_steps=time_steps, template=template, pre_pose=pre_pose, w_pre=w_pre) s_f, _, _ = self.speech_encoder(in_spec, time_steps=time_steps) out = [] if self.separate: for i in range(self.decoder.__len__()): if i == 0 or i == 3: mid = self.decoder[i](s_f) else: mid = self.decoder[i](a_t_f) mid = self.final_out[i](mid) out.append(mid) out = torch.cat(out, dim=1) else: out = self.decoder(a_t_f) out = self.final_out(out) out = out.transpose(1, 2) if self.training: if w_pre: return out, template, mu, var, (mu2, var2, x2_0, pre_pose) else: return out, template, mu, var, (mu2, var2, None, None) else: return out class Discriminator(nn.Module): def __init__(self, pose_dim, pose): super().__init__() self.net = nn.Sequential( Conv1d_tf(pose_dim, 64, kernel_size=4, stride=2, padding='SAME'), nn.LeakyReLU(0.2, True), ConvNormRelu(64, 128, '1d', True), ConvNormRelu(128, 256, '1d', k=4, s=1), Conv1d_tf(256, 1, kernel_size=4, stride=1, padding='SAME'), ) def forward(self, x): x = x.transpose(1, 2) out = self.net(x) return out def main(): d = Discriminator(275, 55) x = torch.randn([8, 60, 275]) result = d(x) if __name__ == "__main__": main()