Spaces:
Sleeping
Sleeping
""" | |
motion_in and motion_out are all (bs, t, c), not (bs, t, j, c//j) | |
input: | |
audio: (bs, audio_t) | |
speaker_id: (bs, 1) | |
seed_frames: int | |
seed_motion: (bs, t, j*6) # rot6d | |
output: | |
motion: (bs, t, j*6) # rot6d | |
motion_axis_angle: (bs, t, j*3) # axis-angle | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import PreTrainedModel | |
from .configuration_camn_audio import CamnAudioConfig | |
# ------------------ utils ---------------------- # | |
MASK_DICT = { | |
"local_upper": [ | |
False, False, False, True, False, False, True, False, False, True, | |
False, False, True, True, True, True, True, True, True, True, | |
True, True, False, False, False, True, True, True, True, True, | |
True, True, True, True, True, True, True, True, True, True, | |
True, True, True, True, True, True, True, True, True, True, | |
True, True, True, True, True | |
], | |
"local_full": [False] + [True]*54 | |
} | |
def _copysign(a, b): | |
signs_differ = (a < 0) != (b < 0) | |
return torch.where(signs_differ, -a, a) | |
def _sqrt_positive_part(x): | |
ret = torch.zeros_like(x) | |
positive_mask = x > 0 | |
ret[positive_mask] = torch.sqrt(x[positive_mask]) | |
return ret | |
def matrix_to_quaternion(matrix): | |
if matrix.size(-1) != 3 or matrix.size(-2) != 3: | |
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") | |
m00 = matrix[..., 0, 0] | |
m11 = matrix[..., 1, 1] | |
m22 = matrix[..., 2, 2] | |
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22) | |
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22) | |
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22) | |
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22) | |
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2]) | |
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0]) | |
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1]) | |
return torch.stack((o0, o1, o2, o3), -1) | |
def quaternion_to_axis_angle(quaternions): | |
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True) | |
half_angles = torch.atan2(norms, quaternions[..., :1]) | |
angles = 2 * half_angles | |
eps = 1e-6 | |
small_angles = angles.abs() < eps | |
sin_half_angles_over_angles = torch.empty_like(angles) | |
sin_half_angles_over_angles[~small_angles] = ( | |
torch.sin(half_angles[~small_angles]) / angles[~small_angles] | |
) | |
sin_half_angles_over_angles[small_angles] = ( | |
0.5 - (angles[small_angles] * angles[small_angles]) / 48 | |
) | |
return quaternions[..., 1:] / sin_half_angles_over_angles | |
def matrix_to_axis_angle(matrix): | |
return quaternion_to_axis_angle(matrix_to_quaternion(matrix)) | |
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor: | |
a1, a2 = d6[..., :3], d6[..., 3:] | |
b1 = F.normalize(a1, dim=-1) | |
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1 | |
b2 = F.normalize(b2, dim=-1) | |
b3 = torch.cross(b1, b2, dim=-1) | |
return torch.stack((b1, b2, b3), dim=-2) | |
def rotation_6d_to_axis_angle(rot6d): | |
return matrix_to_axis_angle(rotation_6d_to_matrix(rot6d)) | |
def recover_from_mask_ts(selected_motion: torch.Tensor, mask: list[bool]) -> torch.Tensor: | |
device = selected_motion.device | |
dtype = selected_motion.dtype | |
mask_arr = torch.tensor(mask, dtype=torch.bool, device=device) | |
j = len(mask_arr) | |
sum_mask = mask_arr.sum().item() | |
c_channels = selected_motion.shape[-1] // sum_mask | |
new_shape = selected_motion.shape[:-1] + (sum_mask, c_channels) | |
selected_motion = selected_motion.reshape(new_shape) | |
out_shape = list(selected_motion.shape[:-2]) + [j, c_channels] | |
recovered = torch.zeros(out_shape, dtype=dtype, device=device) | |
recovered[..., mask_arr, :] = selected_motion | |
final_shape = list(recovered.shape[:-2]) + [j * c_channels] | |
recovered = recovered.reshape(final_shape) | |
return recovered | |
# ------------------ network ---------------------- # | |
class BasicBlock(nn.Module): | |
"""Basic 1D residual block.""" | |
def __init__(self, inplanes, planes, ker_size, stride=1, first_dilation=None, norm_layer=nn.BatchNorm1d, act_layer=nn.LeakyReLU): | |
super().__init__() | |
self.conv1 = nn.Conv1d(inplanes, planes, kernel_size=ker_size, stride=stride, | |
padding=first_dilation, dilation=1, bias=True) | |
self.bn1 = norm_layer(planes) | |
self.act1 = act_layer(inplace=True) | |
self.conv2 = nn.Conv1d(planes, planes, kernel_size=ker_size, padding=ker_size//2, bias=True) | |
self.bn2 = norm_layer(planes) | |
self.act2 = act_layer(inplace=True) | |
self.downsample = None | |
if stride != 1 or inplanes != planes: | |
self.downsample = nn.Sequential( | |
nn.Conv1d(inplanes, planes, stride=stride, kernel_size=ker_size, padding=first_dilation, bias=True), | |
norm_layer(planes) | |
) | |
def forward(self, x): | |
shortcut = x | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.act1(x) | |
x = self.conv2(x) | |
x = self.bn2(x) | |
if self.downsample is not None: | |
shortcut = self.downsample(shortcut) | |
x += shortcut | |
x = self.act2(x) | |
return x | |
class WavEncoder(nn.Module): | |
"""Waveform encoder that uses stacked residual blocks.""" | |
def __init__(self, out_dim): | |
super().__init__() | |
self.feat_extractor = nn.Sequential( | |
BasicBlock(1, 32, 15, 5, first_dilation=1600), | |
BasicBlock(32,32,15,6,first_dilation=0), | |
BasicBlock(32,32,15,1,first_dilation=7), | |
BasicBlock(32,64,15,6,first_dilation=0), | |
BasicBlock(64,64,15,1,first_dilation=7), | |
BasicBlock(64,128,15,6,first_dilation=0), | |
) | |
def forward(self, wav_data): | |
wav_data = wav_data.unsqueeze(1) | |
out = self.feat_extractor(wav_data) | |
return out.transpose(1, 2) | |
class MLP(nn.Module): | |
"""A simple MLP for projection.""" | |
def __init__(self, in_dim, middle_dim, out_dim): | |
super().__init__() | |
self.fc1 = nn.Linear(in_dim, middle_dim) | |
self.fc2 = nn.Linear(middle_dim, out_dim) | |
self.act = nn.LeakyReLU(0.1, True) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.fc2(x) | |
return x | |
class Empty(nn.Module): | |
"""Empty module that returns input as is.""" | |
def forward(self, x): | |
return x | |
class CamnAudioPreTrainedModel(PreTrainedModel): | |
config_class = CamnAudioConfig | |
base_model_prefix = "camn_audio" | |
def _init_weights(self, module): | |
pass | |
class CamnAudioModel(CamnAudioPreTrainedModel): | |
"""CaMNAudio model for audio-driven motion generation. | |
This model assumes that the config (CamnAudioConfig) can be initialized from a dict-like object | |
or OmegaConf directly by passing them as kwargs. For example: | |
from omegaconf import OmegaConf | |
cfg = OmegaConf.load("configs/camn_audio.yaml") | |
config = CamnAudioConfig(config_obj=cfg.model) | |
This way all attributes from cfg.model become config attributes without having to manually map each one. | |
""" | |
def __init__(self, config: CamnAudioConfig): | |
super().__init__(config) | |
self.pose_rep = config.pose_rep | |
self.cfg = config | |
self.audio_encoder = WavEncoder(self.cfg.audio_f) | |
self.speaker_embedding = nn.Embedding(self.cfg.speaker_dims, self.cfg.speaker_f) if self.cfg.speaker_f > 0 else None | |
self.motion_encoder = Empty() | |
self.joint_mask = MASK_DICT[config.joint_mask] | |
input_dim_body = self.cfg.pose_dims+1+self.cfg.speaker_f+self.cfg.audio_f | |
self.body_motion_decoder = nn.LSTM( | |
input_dim_body, hidden_size=self.cfg.hidden_size, | |
num_layers=self.cfg.n_layer, batch_first=True, | |
bidirectional=True, dropout=self.cfg.dropout_prob | |
) | |
self.body_out = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.body_dims) | |
input_dim_hands = input_dim_body + self.cfg.body_dims | |
self.hands_motion_decoder = nn.LSTM( | |
input_dim_hands, hidden_size=self.cfg.hidden_size, | |
num_layers=self.cfg.n_layer, batch_first=True, | |
bidirectional=True, dropout=self.cfg.dropout_prob | |
) | |
self.hands_out = MLP(self.cfg.hidden_size, self.cfg.hidden_size, self.cfg.hands_dims) | |
if self.pose_rep == "bvh": | |
self.bvh_dims = self.cfg.body_dims + self.cfg.hands_dims | |
def recombine(self, body_out, hands_out): | |
bs, t, _ = body_out.shape | |
if self.pose_rep == "bvh": | |
recombine = torch.zeros((bs, t, self.bvh_dims), device=body_out.device) | |
recombine[:, :, :self.cfg.body_dims] = body_out | |
recombine[:, :, self.cfg.body_dims:] = hands_out | |
elif self.pose_rep == "smplx": | |
body_out = body_out.reshape(bs, t, self.cfg.body_dims//6, 6) | |
hands_out = hands_out.reshape(bs, t, self.cfg.hands_dims//6, 6) | |
recombine = torch.zeros((bs, t, self.cfg.pose_dims//6, 6), device=body_out.device) | |
recombine[:, :, :self.cfg.body_dims//6] = body_out | |
recombine[:, :, self.cfg.body_dims//6:] = hands_out | |
return recombine | |
def forward(self, audio, speaker_id, seed_frames=4, seed_motion=None, return_axis_angle=True): | |
audio_feat = self.audio_encoder(audio) | |
bs, t, _ = audio_feat.shape | |
if self.speaker_embedding is not None: | |
speaker_feat = self.speaker_embedding(speaker_id) | |
speaker_feat = speaker_feat.repeat(1, t, 1) | |
else: | |
speaker_feat = torch.zeros(bs, t, 0, device=audio.device) | |
if seed_motion is None: | |
seed_motion = torch.zeros(bs, t, self.cfg.pose_dims+1, device=audio.device) | |
seed_motion[:, :seed_frames, -1] = 1 | |
else: | |
_, t_m, _ = seed_motion.shape | |
seed_motion_pad = torch.zeros(bs, t_m, self.cfg.pose_dims+1, device=audio.device) | |
seed_motion_pad[:, :seed_frames, :-1] = seed_motion[:, :seed_frames] | |
seed_motion_pad[:, :seed_frames, -1] = 1 | |
seed_motion = seed_motion_pad | |
if t_m != t: | |
diff_length = t_m - t | |
if diff_length > 0: | |
seed_motion = seed_motion[:, :t, :] | |
else: | |
seed_motion = torch.cat((seed_motion, seed_motion[:, -diff_length:, :]), 1) | |
in_fea = torch.cat((audio_feat, speaker_feat, seed_motion), dim=2) | |
body_out, _ = self.body_motion_decoder(in_fea) | |
body_out = body_out[:, :, :self.cfg.hidden_size] + body_out[:, :, self.cfg.hidden_size:] | |
body_out = self.body_out(body_out) | |
in_fea_hands = torch.cat((in_fea, body_out), dim=2) | |
hands_out, _ = self.hands_motion_decoder(in_fea_hands) | |
hands_out = hands_out[:, :, :self.cfg.hidden_size] + hands_out[:, :, self.cfg.hidden_size:] | |
hands_out = self.hands_out(hands_out) | |
recombine = self.recombine(body_out, hands_out) | |
motion_axis_angle = None | |
if return_axis_angle: | |
motion_axis_angle = rotation_6d_to_axis_angle(recombine.reshape(-1, self.cfg.pose_dims//6, 6)).reshape(bs, t, -1) | |
motion_axis_angle = recover_from_mask_ts(motion_axis_angle, self.joint_mask) | |
return { | |
"motion": recombine, | |
"motion_axis_angle": motion_axis_angle, | |
} |