mingyuan's picture
initial commit
a0d91d3
from cv2 import norm
import torch
from torch import layer_norm, nn
from mmcv.runner import BaseModule
import numpy as np
from ..builder import SUBMODULES
from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding
import math
@SUBMODULES.register_module()
class ACTOREncoder(BaseModule):
def __init__(self,
max_seq_len=16,
njoints=None,
nfeats=None,
input_feats=None,
latent_dim=256,
output_dim=256,
condition_dim=None,
num_heads=4,
ff_size=1024,
num_layers=8,
activation='gelu',
dropout=0.1,
use_condition=False,
num_class=None,
use_final_proj=False,
output_var=False,
pos_embedding='sinusoidal',
init_cfg=None):
super().__init__(init_cfg=init_cfg)
self.njoints = njoints
self.nfeats = nfeats
if input_feats is None:
assert self.njoints is not None and self.nfeats is not None
self.input_feats = njoints * nfeats
else:
self.input_feats = input_feats
self.max_seq_len = max_seq_len
self.latent_dim = latent_dim
self.condition_dim = condition_dim
self.use_condition = use_condition
self.num_class = num_class
self.use_final_proj = use_final_proj
self.output_var = output_var
self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim)
if self.use_condition:
if num_class is None:
self.mu_layer = build_MLP(self.condition_dim, self.latent_dim)
if self.output_var:
self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim)
else:
self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim))
if self.output_var:
self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim))
else:
if self.output_var:
self.query = nn.Parameter(torch.randn(2, self.latent_dim))
else:
self.query = nn.Parameter(torch.randn(1, self.latent_dim))
if pos_embedding == 'sinusoidal':
self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout)
else:
self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2)
seqTransEncoderLayer = nn.TransformerEncoderLayer(
d_model=self.latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation)
self.seqTransEncoder = nn.TransformerEncoder(
seqTransEncoderLayer,
num_layers=num_layers)
def forward(self, motion, motion_mask=None, condition=None):
B, T = motion.shape[:2]
motion = motion.view(B, T, -1)
feature = self.skelEmbedding(motion)
if self.use_condition:
if self.output_var:
if self.num_class is None:
sigma_query = self.sigma_layer(condition).view(B, 1, -1)
else:
sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1)
feature = torch.cat((sigma_query, feature), dim=1)
if self.num_class is None:
mu_query = self.mu_layer(condition).view(B, 1, -1)
else:
mu_query = self.mu_layer[condition.long()].view(B, 1, -1)
feature = torch.cat((mu_query, feature), dim=1)
else:
query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1)
feature = torch.cat((query, feature), dim=1)
if self.output_var:
motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool()
else:
motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool()
feature = feature.permute(1, 0, 2).contiguous()
feature = self.pos_encoder(feature)
feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask)
if self.use_final_proj:
mu = self.final_mu(feature[0])
if self.output_var:
sigma = self.final_sigma(feature[1])
return mu, sigma
return mu
else:
if self.output_var:
return feature[0], feature[1]
else:
return feature[0]
@SUBMODULES.register_module()
class ACTORDecoder(BaseModule):
def __init__(self,
max_seq_len=16,
njoints=None,
nfeats=None,
input_feats=None,
input_dim=256,
latent_dim=256,
condition_dim=None,
num_heads=4,
ff_size=1024,
num_layers=8,
activation='gelu',
dropout=0.1,
use_condition=False,
num_class=None,
pos_embedding='sinusoidal',
init_cfg=None):
super().__init__(init_cfg=init_cfg)
if input_dim != latent_dim:
self.linear = nn.Linear(input_dim, latent_dim)
else:
self.linear = nn.Identity()
self.njoints = njoints
self.nfeats = nfeats
if input_feats is None:
assert self.njoints is not None and self.nfeats is not None
self.input_feats = njoints * nfeats
else:
self.input_feats = input_feats
self.max_seq_len = max_seq_len
self.input_dim = input_dim
self.latent_dim = latent_dim
self.condition_dim = condition_dim
self.use_condition = use_condition
self.num_class = num_class
if self.use_condition:
if num_class is None:
self.condition_bias = build_MLP(condition_dim, latent_dim)
else:
self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim))
if pos_embedding == 'sinusoidal':
self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout)
else:
self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len)
seqTransDecoderLayer = nn.TransformerDecoderLayer(
d_model=self.latent_dim,
nhead=num_heads,
dim_feedforward=ff_size,
dropout=dropout,
activation=activation)
self.seqTransDecoder = nn.TransformerDecoder(
seqTransDecoderLayer,
num_layers=num_layers)
self.final = nn.Linear(self.latent_dim, self.input_feats)
def forward(self, input, motion_mask=None, condition=None):
B = input.shape[0]
T = self.max_seq_len
input = self.linear(input)
if self.use_condition:
if self.num_class is None:
condition = self.condition_bias(condition)
else:
condition = self.condition_bias[condition.long()].squeeze(1)
input = input + condition
query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1)
input = input.view(1, B, -1)
feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool())
pose = self.final(feature).permute(1, 0, 2).contiguous()
return pose