|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
"""
|
|
from tm2t
|
|
TM2T: Stochastical and Tokenized Modeling for the Reciprocal Generation of 3D Human Motions and Texts
|
|
https://github.com/EricGuo5513/TM2T
|
|
Licensed under: https://github.com/EricGuo5513/TM2T/blob/main/LICENSE
|
|
"""
|
|
|
|
from .quantizer import *
|
|
from .layer import *
|
|
|
|
|
|
class SCFormer(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQEncoderV3, self).__init__()
|
|
|
|
n_down = args.vae_layer
|
|
channels = [args.vae_length]
|
|
for i in range(n_down - 1):
|
|
channels.append(args.vae_length)
|
|
|
|
input_size = args.vae_test_dim
|
|
assert len(channels) == n_down
|
|
layers = [
|
|
nn.Conv1d(input_size, channels[0], 4, 2, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[0]),
|
|
]
|
|
|
|
for i in range(1, n_down):
|
|
layers += [
|
|
nn.Conv1d(channels[i - 1], channels[i], 4, 2, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[i]),
|
|
]
|
|
self.main = nn.Sequential(*layers)
|
|
|
|
self.main.apply(init_weight)
|
|
|
|
|
|
def forward(self, inputs):
|
|
"""
|
|
face 51 or 106
|
|
hand 30*(15)
|
|
upper body
|
|
lower body
|
|
global 1*3
|
|
max length around 180 --> 450
|
|
"""
|
|
bs, t, n = inputs.shape
|
|
inputs = inputs.reshape(bs * t, n)
|
|
inputs = self.spatial_transformer_encoder(inputs)
|
|
cs = inputs.shape[1]
|
|
inputs = inputs.reshape(bs, t, cs).permute(0, 2, 1).reshape(bs * cs, t)
|
|
inputs = self.temporal_cnn_encoder(inputs)
|
|
ct = inputs.shape[1]
|
|
outputs = inputs.reshape(bs, cs, ct).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQEncoderV3(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQEncoderV3, self).__init__()
|
|
n_down = args.vae_layer
|
|
channels = [args.vae_length]
|
|
for i in range(n_down - 1):
|
|
channels.append(args.vae_length)
|
|
|
|
input_size = args.vae_test_dim
|
|
assert len(channels) == n_down
|
|
layers = [
|
|
nn.Conv1d(input_size, channels[0], 4, 2, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[0]),
|
|
]
|
|
|
|
for i in range(1, n_down):
|
|
layers += [
|
|
nn.Conv1d(channels[i - 1], channels[i], 4, 2, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[i]),
|
|
]
|
|
self.main = nn.Sequential(*layers)
|
|
|
|
self.main.apply(init_weight)
|
|
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQEncoderV6(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQEncoderV6, self).__init__()
|
|
n_down = args.vae_layer
|
|
channels = [args.vae_length]
|
|
for i in range(n_down - 1):
|
|
channels.append(args.vae_length)
|
|
|
|
input_size = args.vae_test_dim
|
|
assert len(channels) == n_down
|
|
layers = [
|
|
nn.Conv1d(input_size, channels[0], 3, 1, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[0]),
|
|
]
|
|
|
|
for i in range(1, n_down):
|
|
layers += [
|
|
nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[i]),
|
|
]
|
|
self.main = nn.Sequential(*layers)
|
|
|
|
self.main.apply(init_weight)
|
|
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQEncoderV4(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQEncoderV4, self).__init__()
|
|
n_down = args.vae_layer
|
|
channels = [args.vae_length]
|
|
for i in range(n_down - 1):
|
|
channels.append(args.vae_length)
|
|
|
|
input_size = args.vae_test_dim
|
|
assert len(channels) == n_down
|
|
layers = [
|
|
nn.Conv1d(input_size, channels[0], 4, 2, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[0]),
|
|
]
|
|
|
|
for i in range(1, n_down):
|
|
layers += [
|
|
nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[i]),
|
|
]
|
|
self.main = nn.Sequential(*layers)
|
|
|
|
self.main.apply(init_weight)
|
|
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
|
|
return outputs
|
|
|
|
|
|
class VQEncoderV5(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQEncoderV5, self).__init__()
|
|
n_down = args.vae_layer
|
|
channels = [args.vae_length]
|
|
for i in range(n_down - 1):
|
|
channels.append(args.vae_length)
|
|
|
|
input_size = args.vae_test_dim
|
|
assert len(channels) == n_down
|
|
layers = [
|
|
nn.Conv1d(input_size, channels[0], 3, 1, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[0]),
|
|
]
|
|
|
|
for i in range(1, n_down):
|
|
layers += [
|
|
nn.Conv1d(channels[i - 1], channels[i], 3, 1, 1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
ResBlock(channels[i]),
|
|
]
|
|
self.main = nn.Sequential(*layers)
|
|
|
|
self.main.apply(init_weight)
|
|
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
|
|
return outputs
|
|
|
|
|
|
class VQDecoderV4(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQDecoderV4, self).__init__()
|
|
n_up = args.vae_layer
|
|
channels = []
|
|
for i in range(n_up - 1):
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_test_dim)
|
|
input_size = args.vae_length
|
|
n_resblk = 2
|
|
assert len(channels) == n_up + 1
|
|
if input_size == channels[0]:
|
|
layers = []
|
|
else:
|
|
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
|
|
|
for i in range(n_resblk):
|
|
layers += [ResBlock(channels[0])]
|
|
|
|
for i in range(n_up):
|
|
up_factor = 2 if i < n_up - 1 else 1
|
|
layers += [
|
|
nn.Upsample(scale_factor=up_factor, mode="nearest"),
|
|
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
]
|
|
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
|
self.main = nn.Sequential(*layers)
|
|
self.main.apply(init_weight)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQDecoderV5(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQDecoderV5, self).__init__()
|
|
n_up = args.vae_layer
|
|
channels = []
|
|
for i in range(n_up - 1):
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_test_dim)
|
|
input_size = args.vae_length
|
|
n_resblk = 2
|
|
assert len(channels) == n_up + 1
|
|
if input_size == channels[0]:
|
|
layers = []
|
|
else:
|
|
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
|
|
|
for i in range(n_resblk):
|
|
layers += [ResBlock(channels[0])]
|
|
|
|
for i in range(n_up):
|
|
up_factor = 2 if i < n_up - 1 else 1
|
|
layers += [
|
|
|
|
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
]
|
|
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
|
self.main = nn.Sequential(*layers)
|
|
self.main.apply(init_weight)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQDecoderV7(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQDecoderV7, self).__init__()
|
|
n_up = args.vae_layer
|
|
channels = []
|
|
for i in range(n_up - 1):
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_test_dim + 4)
|
|
input_size = args.vae_length
|
|
n_resblk = 2
|
|
assert len(channels) == n_up + 1
|
|
if input_size == channels[0]:
|
|
layers = []
|
|
else:
|
|
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
|
|
|
for i in range(n_resblk):
|
|
layers += [ResBlock(channels[0])]
|
|
|
|
for i in range(n_up):
|
|
up_factor = 2 if i < n_up - 1 else 1
|
|
layers += [
|
|
|
|
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
]
|
|
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
|
self.main = nn.Sequential(*layers)
|
|
self.main.apply(init_weight)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQDecoderV3(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQDecoderV3, self).__init__()
|
|
n_up = args.vae_layer
|
|
channels = []
|
|
for i in range(n_up - 1):
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_test_dim)
|
|
input_size = args.vae_length
|
|
n_resblk = 2
|
|
assert len(channels) == n_up + 1
|
|
if input_size == channels[0]:
|
|
layers = []
|
|
else:
|
|
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
|
|
|
for i in range(n_resblk):
|
|
layers += [ResBlock(channels[0])]
|
|
|
|
for i in range(n_up):
|
|
layers += [
|
|
nn.Upsample(scale_factor=2, mode="nearest"),
|
|
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
]
|
|
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
|
self.main = nn.Sequential(*layers)
|
|
self.main.apply(init_weight)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
class VQDecoderV6(nn.Module):
|
|
def __init__(self, args):
|
|
super(VQDecoderV6, self).__init__()
|
|
n_up = args.vae_layer
|
|
channels = []
|
|
for i in range(n_up - 1):
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_length)
|
|
channels.append(args.vae_test_dim)
|
|
input_size = args.vae_length * 2
|
|
n_resblk = 2
|
|
assert len(channels) == n_up + 1
|
|
if input_size == channels[0]:
|
|
layers = []
|
|
else:
|
|
layers = [nn.Conv1d(input_size, channels[0], kernel_size=3, stride=1, padding=1)]
|
|
|
|
for i in range(n_resblk):
|
|
layers += [ResBlock(channels[0])]
|
|
|
|
for i in range(n_up):
|
|
layers += [
|
|
|
|
nn.Conv1d(channels[i], channels[i + 1], kernel_size=3, stride=1, padding=1),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
]
|
|
layers += [nn.Conv1d(channels[-1], channels[-1], kernel_size=3, stride=1, padding=1)]
|
|
self.main = nn.Sequential(*layers)
|
|
self.main.apply(init_weight)
|
|
|
|
def forward(self, inputs):
|
|
inputs = inputs.permute(0, 2, 1)
|
|
outputs = self.main(inputs).permute(0, 2, 1)
|
|
return outputs
|
|
|
|
|
|
|
|
from .layer import reparameterize, ConvNormRelu, BasicBlock
|
|
|
|
"""
|
|
from Trimodal,
|
|
encoder:
|
|
bs, n, c_in --conv--> bs, n/k, c_out_0 --mlp--> bs, c_out_1, only support fixed length
|
|
decoder:
|
|
bs, c_out_1 --mlp--> bs, n/k*c_out_0 --> bs, n/k, c_out_0 --deconv--> bs, n, c_in
|
|
"""
|
|
|
|
|
|
class PoseEncoderConv(nn.Module):
|
|
def __init__(self, length, dim, feature_length=32):
|
|
super().__init__()
|
|
self.base = feature_length
|
|
self.net = nn.Sequential(
|
|
ConvNormRelu(dim, self.base, batchnorm=True),
|
|
ConvNormRelu(self.base, self.base * 2, batchnorm=True),
|
|
ConvNormRelu(self.base * 2, self.base * 2, True, batchnorm=True),
|
|
nn.Conv1d(self.base * 2, self.base, 3),
|
|
)
|
|
self.out_net = nn.Sequential(
|
|
nn.Linear(12 * self.base, self.base * 4),
|
|
nn.BatchNorm1d(self.base * 4),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.base * 4, self.base * 2),
|
|
nn.BatchNorm1d(self.base * 2),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.base * 2, self.base),
|
|
)
|
|
self.fc_mu = nn.Linear(self.base, self.base)
|
|
self.fc_logvar = nn.Linear(self.base, self.base)
|
|
|
|
def forward(self, poses, variational_encoding=None):
|
|
poses = poses.transpose(1, 2)
|
|
out = self.net(poses)
|
|
out = out.flatten(1)
|
|
out = self.out_net(out)
|
|
mu = self.fc_mu(out)
|
|
logvar = self.fc_logvar(out)
|
|
if variational_encoding:
|
|
z = reparameterize(mu, logvar)
|
|
else:
|
|
z = mu
|
|
return z, mu, logvar
|
|
|
|
|
|
class PoseDecoderFC(nn.Module):
|
|
def __init__(self, gen_length, pose_dim, use_pre_poses=False):
|
|
super().__init__()
|
|
self.gen_length = gen_length
|
|
self.pose_dim = pose_dim
|
|
self.use_pre_poses = use_pre_poses
|
|
|
|
in_size = 32
|
|
if use_pre_poses:
|
|
self.pre_pose_net = nn.Sequential(
|
|
nn.Linear(pose_dim * 4, 32),
|
|
nn.BatchNorm1d(32),
|
|
nn.ReLU(),
|
|
nn.Linear(32, 32),
|
|
)
|
|
in_size += 32
|
|
|
|
self.net = nn.Sequential(
|
|
nn.Linear(in_size, 128),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 128),
|
|
nn.BatchNorm1d(128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, 256),
|
|
nn.BatchNorm1d(256),
|
|
nn.ReLU(),
|
|
nn.Linear(256, 512),
|
|
nn.BatchNorm1d(512),
|
|
nn.ReLU(),
|
|
nn.Linear(512, gen_length * pose_dim),
|
|
)
|
|
|
|
def forward(self, latent_code, pre_poses=None):
|
|
if self.use_pre_poses:
|
|
pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1))
|
|
feat = torch.cat((pre_pose_feat, latent_code), dim=1)
|
|
else:
|
|
feat = latent_code
|
|
output = self.net(feat)
|
|
output = output.view(-1, self.gen_length, self.pose_dim)
|
|
return output
|
|
|
|
|
|
class PoseDecoderConv(nn.Module):
|
|
def __init__(self, length, dim, use_pre_poses=False, feature_length=32):
|
|
super().__init__()
|
|
self.use_pre_poses = use_pre_poses
|
|
self.feat_size = feature_length
|
|
|
|
if use_pre_poses:
|
|
self.pre_pose_net = nn.Sequential(
|
|
nn.Linear(dim * 4, 32),
|
|
nn.BatchNorm1d(32),
|
|
nn.ReLU(),
|
|
nn.Linear(32, 32),
|
|
)
|
|
self.feat_size += 32
|
|
|
|
if length == 64:
|
|
self.pre_net = nn.Sequential(
|
|
nn.Linear(self.feat_size, self.feat_size),
|
|
nn.BatchNorm1d(self.feat_size),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.feat_size, self.feat_size // 8 * 64),
|
|
)
|
|
elif length == 34:
|
|
self.pre_net = nn.Sequential(
|
|
nn.Linear(self.feat_size, self.feat_size * 2),
|
|
nn.BatchNorm1d(self.feat_size * 2),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.feat_size * 2, self.feat_size // 8 * 34),
|
|
)
|
|
elif length == 32:
|
|
self.pre_net = nn.Sequential(
|
|
nn.Linear(self.feat_size, self.feat_size * 2),
|
|
nn.BatchNorm1d(self.feat_size * 2),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.feat_size * 2, self.feat_size // 8 * 32),
|
|
)
|
|
else:
|
|
assert False
|
|
self.decoder_size = self.feat_size // 8
|
|
self.net = nn.Sequential(
|
|
nn.ConvTranspose1d(self.decoder_size, self.feat_size, 3),
|
|
nn.BatchNorm1d(self.feat_size),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.ConvTranspose1d(self.feat_size, self.feat_size, 3),
|
|
nn.BatchNorm1d(self.feat_size),
|
|
nn.LeakyReLU(0.2, True),
|
|
nn.Conv1d(self.feat_size, self.feat_size * 2, 3),
|
|
nn.Conv1d(self.feat_size * 2, dim, 3),
|
|
)
|
|
|
|
def forward(self, feat, pre_poses=None):
|
|
if self.use_pre_poses:
|
|
pre_pose_feat = self.pre_pose_net(pre_poses.reshape(pre_poses.shape[0], -1))
|
|
feat = torch.cat((pre_pose_feat, feat), dim=1)
|
|
|
|
out = self.pre_net(feat)
|
|
|
|
out = out.view(feat.shape[0], self.decoder_size, -1)
|
|
|
|
out = self.net(out)
|
|
out = out.transpose(1, 2)
|
|
return out
|
|
|
|
|
|
"""
|
|
Our CaMN Modification
|
|
"""
|
|
|
|
|
|
class PoseEncoderConvResNet(nn.Module):
|
|
def __init__(self, length, dim, feature_length=32):
|
|
super().__init__()
|
|
self.base = feature_length
|
|
self.conv1 = BasicBlock(dim, self.base, reduce_first=1, downsample=False, first_dilation=1)
|
|
self.conv2 = BasicBlock(
|
|
self.base,
|
|
self.base * 2,
|
|
downsample=False,
|
|
first_dilation=1,
|
|
)
|
|
self.conv3 = BasicBlock(self.base * 2, self.base * 2, first_dilation=1, downsample=True, stride=2)
|
|
self.conv4 = BasicBlock(self.base * 2, self.base, first_dilation=1, downsample=False)
|
|
|
|
self.out_net = nn.Sequential(
|
|
|
|
nn.Linear(17 * self.base, self.base * 4),
|
|
nn.BatchNorm1d(self.base * 4),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.base * 4, self.base * 2),
|
|
nn.BatchNorm1d(self.base * 2),
|
|
nn.LeakyReLU(True),
|
|
nn.Linear(self.base * 2, self.base),
|
|
)
|
|
|
|
self.fc_mu = nn.Linear(self.base, self.base)
|
|
self.fc_logvar = nn.Linear(self.base, self.base)
|
|
|
|
def forward(self, poses, variational_encoding=None):
|
|
poses = poses.transpose(1, 2)
|
|
out1 = self.conv1(poses)
|
|
out2 = self.conv2(out1)
|
|
out3 = self.conv3(out2)
|
|
out = self.conv4(out3)
|
|
out = out.flatten(1)
|
|
out = self.out_net(out)
|
|
mu = self.fc_mu(out)
|
|
logvar = self.fc_logvar(out)
|
|
if variational_encoding:
|
|
z = reparameterize(mu, logvar)
|
|
else:
|
|
z = mu
|
|
return z, mu, logvar
|
|
|
|
|
|
|
|
"""
|
|
bs, n, c_int --> bs, n, c_out or bs, 1 (hidden), c_out
|
|
"""
|
|
|
|
|
|
class AELSTM(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.motion_emb = nn.Linear(args.vae_test_dim, args.vae_length)
|
|
self.lstm = nn.LSTM(args.vae_length, hidden_size=args.vae_length, num_layers=4, batch_first=True, bidirectional=True, dropout=0.3)
|
|
self.out = nn.Sequential(
|
|
nn.Linear(args.vae_length, args.vae_length // 2), nn.LeakyReLU(0.2, True), nn.Linear(args.vae_length // 2, args.vae_test_dim)
|
|
)
|
|
self.hidden_size = args.vae_length
|
|
|
|
def forward(self, inputs):
|
|
poses = self.motion_emb(inputs)
|
|
out, _ = self.lstm(poses)
|
|
out = out[:, :, : self.hidden_size] + out[:, :, self.hidden_size :]
|
|
out_poses = self.out(out)
|
|
return {
|
|
"poses_feat": out,
|
|
"rec_pose": out_poses,
|
|
}
|
|
|
|
|
|
class PoseDecoderLSTM(nn.Module):
|
|
"""
|
|
input bs*n*64
|
|
"""
|
|
|
|
def __init__(self, pose_dim, feature_length):
|
|
super().__init__()
|
|
self.pose_dim = pose_dim
|
|
self.base = feature_length
|
|
self.hidden_size = 256
|
|
self.lstm_d = nn.LSTM(self.base, hidden_size=self.hidden_size, num_layers=4, batch_first=True, bidirectional=True, dropout=0.3)
|
|
self.out_d = nn.Sequential(
|
|
nn.Linear(self.hidden_size, self.hidden_size // 2), nn.LeakyReLU(True), nn.Linear(self.hidden_size // 2, self.pose_dim)
|
|
)
|
|
|
|
def forward(self, latent_code):
|
|
output, _ = self.lstm_d(latent_code)
|
|
output = output[:, :, : self.hidden_size] + output[:, :, self.hidden_size :]
|
|
|
|
output = self.out_d(output.reshape(-1, output.shape[2]))
|
|
output = output.view(latent_code.shape[0], latent_code.shape[1], -1)
|
|
|
|
return output
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, d_model, dropout=0.1, max_len=5000):
|
|
super(PositionalEncoding, self).__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
pe = torch.zeros(max_len, d_model)
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
|
|
pe[:, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 1::2] = torch.cos(position * div_term)
|
|
pe = pe.unsqueeze(0)
|
|
|
|
self.register_buffer("pe", pe)
|
|
|
|
def forward(self, x):
|
|
|
|
x = x + self.pe[:, : x.shape[1]]
|
|
return self.dropout(x)
|
|
|
|
|
|
class Encoder_TRANSFORMER(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.skelEmbedding = nn.Linear(args.vae_test_dim, args.vae_length)
|
|
self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3)
|
|
seqTransEncoderLayer = nn.TransformerEncoderLayer(
|
|
d_model=args.vae_length, nhead=4, dim_feedforward=1025, dropout=0.3, activation="gelu", batch_first=True
|
|
)
|
|
self.seqTransEncoder = nn.TransformerEncoder(seqTransEncoderLayer, num_layers=4)
|
|
|
|
def _generate_square_subsequent_mask(self, sz):
|
|
mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
|
|
mask = mask.float().masked_fill(mask == 0, float("-inf")).masked_fill(mask == 1, float(0.0))
|
|
return mask
|
|
|
|
def forward(self, inputs):
|
|
x = self.skelEmbedding(inputs)
|
|
|
|
xseq = self.sequence_pos_encoder(x)
|
|
device = xseq.device
|
|
|
|
final = self.seqTransEncoder(xseq)
|
|
|
|
mu = final[:, 0:1, :]
|
|
logvar = final[:, 1:2, :]
|
|
return final, mu, logvar
|
|
|
|
|
|
class Decoder_TRANSFORMER(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.vae_test_len = args.vae_test_len
|
|
self.vae_length = args.vae_length
|
|
self.sequence_pos_encoder = PositionalEncoding(args.vae_length, 0.3)
|
|
seqTransDecoderLayer = nn.TransformerDecoderLayer(
|
|
d_model=args.vae_length, nhead=4, dim_feedforward=1024, dropout=0.3, activation="gelu", batch_first=True
|
|
)
|
|
self.seqTransDecoder = nn.TransformerDecoder(seqTransDecoderLayer, num_layers=4)
|
|
self.finallayer = nn.Linear(args.vae_length, args.vae_test_dim)
|
|
|
|
def forward(self, inputs):
|
|
timequeries = torch.zeros(inputs.shape[0], self.vae_test_len, self.vae_length, device=inputs.device)
|
|
timequeries = self.sequence_pos_encoder(timequeries)
|
|
output = self.seqTransDecoder(tgt=timequeries, memory=inputs)
|
|
output = self.finallayer(output)
|
|
return output
|
|
|