Spaces:
Running
Running
from cv2 import norm | |
import torch | |
from torch import layer_norm, nn | |
from mmcv.runner import BaseModule | |
import numpy as np | |
from ..builder import SUBMODULES | |
from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding | |
import math | |
class ACTOREncoder(BaseModule): | |
def __init__(self, | |
max_seq_len=16, | |
njoints=None, | |
nfeats=None, | |
input_feats=None, | |
latent_dim=256, | |
output_dim=256, | |
condition_dim=None, | |
num_heads=4, | |
ff_size=1024, | |
num_layers=8, | |
activation='gelu', | |
dropout=0.1, | |
use_condition=False, | |
num_class=None, | |
use_final_proj=False, | |
output_var=False, | |
pos_embedding='sinusoidal', | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
self.njoints = njoints | |
self.nfeats = nfeats | |
if input_feats is None: | |
assert self.njoints is not None and self.nfeats is not None | |
self.input_feats = njoints * nfeats | |
else: | |
self.input_feats = input_feats | |
self.max_seq_len = max_seq_len | |
self.latent_dim = latent_dim | |
self.condition_dim = condition_dim | |
self.use_condition = use_condition | |
self.num_class = num_class | |
self.use_final_proj = use_final_proj | |
self.output_var = output_var | |
self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
if self.use_condition: | |
if num_class is None: | |
self.mu_layer = build_MLP(self.condition_dim, self.latent_dim) | |
if self.output_var: | |
self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim) | |
else: | |
self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) | |
if self.output_var: | |
self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) | |
else: | |
if self.output_var: | |
self.query = nn.Parameter(torch.randn(2, self.latent_dim)) | |
else: | |
self.query = nn.Parameter(torch.randn(1, self.latent_dim)) | |
if pos_embedding == 'sinusoidal': | |
self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) | |
else: | |
self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2) | |
seqTransEncoderLayer = nn.TransformerEncoderLayer( | |
d_model=self.latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation=activation) | |
self.seqTransEncoder = nn.TransformerEncoder( | |
seqTransEncoderLayer, | |
num_layers=num_layers) | |
def forward(self, motion, motion_mask=None, condition=None): | |
B, T = motion.shape[:2] | |
motion = motion.view(B, T, -1) | |
feature = self.skelEmbedding(motion) | |
if self.use_condition: | |
if self.output_var: | |
if self.num_class is None: | |
sigma_query = self.sigma_layer(condition).view(B, 1, -1) | |
else: | |
sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1) | |
feature = torch.cat((sigma_query, feature), dim=1) | |
if self.num_class is None: | |
mu_query = self.mu_layer(condition).view(B, 1, -1) | |
else: | |
mu_query = self.mu_layer[condition.long()].view(B, 1, -1) | |
feature = torch.cat((mu_query, feature), dim=1) | |
else: | |
query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1) | |
feature = torch.cat((query, feature), dim=1) | |
if self.output_var: | |
motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool() | |
else: | |
motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool() | |
feature = feature.permute(1, 0, 2).contiguous() | |
feature = self.pos_encoder(feature) | |
feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask) | |
if self.use_final_proj: | |
mu = self.final_mu(feature[0]) | |
if self.output_var: | |
sigma = self.final_sigma(feature[1]) | |
return mu, sigma | |
return mu | |
else: | |
if self.output_var: | |
return feature[0], feature[1] | |
else: | |
return feature[0] | |
class ACTORDecoder(BaseModule): | |
def __init__(self, | |
max_seq_len=16, | |
njoints=None, | |
nfeats=None, | |
input_feats=None, | |
input_dim=256, | |
latent_dim=256, | |
condition_dim=None, | |
num_heads=4, | |
ff_size=1024, | |
num_layers=8, | |
activation='gelu', | |
dropout=0.1, | |
use_condition=False, | |
num_class=None, | |
pos_embedding='sinusoidal', | |
init_cfg=None): | |
super().__init__(init_cfg=init_cfg) | |
if input_dim != latent_dim: | |
self.linear = nn.Linear(input_dim, latent_dim) | |
else: | |
self.linear = nn.Identity() | |
self.njoints = njoints | |
self.nfeats = nfeats | |
if input_feats is None: | |
assert self.njoints is not None and self.nfeats is not None | |
self.input_feats = njoints * nfeats | |
else: | |
self.input_feats = input_feats | |
self.max_seq_len = max_seq_len | |
self.input_dim = input_dim | |
self.latent_dim = latent_dim | |
self.condition_dim = condition_dim | |
self.use_condition = use_condition | |
self.num_class = num_class | |
if self.use_condition: | |
if num_class is None: | |
self.condition_bias = build_MLP(condition_dim, latent_dim) | |
else: | |
self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim)) | |
if pos_embedding == 'sinusoidal': | |
self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) | |
else: | |
self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len) | |
seqTransDecoderLayer = nn.TransformerDecoderLayer( | |
d_model=self.latent_dim, | |
nhead=num_heads, | |
dim_feedforward=ff_size, | |
dropout=dropout, | |
activation=activation) | |
self.seqTransDecoder = nn.TransformerDecoder( | |
seqTransDecoderLayer, | |
num_layers=num_layers) | |
self.final = nn.Linear(self.latent_dim, self.input_feats) | |
def forward(self, input, motion_mask=None, condition=None): | |
B = input.shape[0] | |
T = self.max_seq_len | |
input = self.linear(input) | |
if self.use_condition: | |
if self.num_class is None: | |
condition = self.condition_bias(condition) | |
else: | |
condition = self.condition_bias[condition.long()].squeeze(1) | |
input = input + condition | |
query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1) | |
input = input.view(1, B, -1) | |
feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool()) | |
pose = self.final(feature).permute(1, 0, 2).contiguous() | |
return pose | |