File size: 7,793 Bytes
7ca9b42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 |
import torch
import torch.nn.functional as F
import numpy as np
import imageio
from math import pi
from tqdm import tqdm
from lib.data import get_dataloader, get_meanpose
from lib.util.general import get_config
from lib.util.visualization import motion2video_np, hex2rgb
import os
eps = 1e-16
def localize_motion_torch(motion):
"""
:param motion: B x J x D x T
:return:
"""
B, J, D, T = motion.size()
# subtract centers to local coordinates
centers = motion[:, 8:9, :, :] # B x 1 x D x (T-1)
motion = motion - centers
# adding velocity
translation = centers[:, :, :, 1:] - centers[:, :, :, :-1] # B x 1 x D x (T-1)
velocity = F.pad(translation, [1, 0], "constant", 0.) # B x 1 x D x T
motion = torch.cat([motion[:, :8], motion[:, 9:], velocity], dim=1)
return motion
def normalize_motion_torch(motion, meanpose, stdpose):
"""
:param motion: (B, J, D, T)
:param meanpose: (J, D)
:param stdpose: (J, D)
:return:
"""
B, J, D, T = motion.size()
if D == 2 and meanpose.size(1) == 3:
meanpose = meanpose[:, [0, 2]]
if D == 2 and stdpose.size(1) == 3:
stdpose = stdpose[:, [0, 2]]
return (motion - meanpose.view(1, J, D, 1)) / stdpose.view(1, J, D, 1)
def normalize_motion_inv_torch(motion, meanpose, stdpose):
"""
:param motion: (B, J, D, T)
:param meanpose: (J, D)
:param stdpose: (J, D)
:return:
"""
B, J, D, T = motion.size()
if D == 2 and meanpose.size(1) == 3:
meanpose = meanpose[:, [0, 2]]
if D == 2 and stdpose.size(1) == 3:
stdpose = stdpose[:, [0, 2]]
return motion * stdpose.view(1, J, D, 1) + meanpose.view(1, J, D, 1)
def globalize_motion_torch(motion):
"""
:param motion: B x J x D x T
:return:
"""
B, J, D, T = motion.size()
motion_inv = torch.zeros_like(motion)
motion_inv[:, :8] = motion[:, :8]
motion_inv[:, 9:] = motion[:, 8:-1]
velocity = motion[:, -1:, :, :]
centers = torch.zeros_like(velocity)
displacement = torch.zeros_like(velocity[:, :, :, 0])
for t in range(T):
displacement += velocity[:, :, :, t]
centers[:, :, :, t] = displacement
motion_inv = motion_inv + centers
return motion_inv
def restore_world_space(motion, meanpose, stdpose, n_joints=15):
B, C, T = motion.size()
motion = motion.view(B, n_joints, C // n_joints, T)
motion = normalize_motion_inv_torch(motion, meanpose, stdpose)
motion = globalize_motion_torch(motion)
return motion
def convert_to_learning_space(motion, meanpose, stdpose):
B, J, D, T = motion.size()
motion = localize_motion_torch(motion)
motion = normalize_motion_torch(motion, meanpose, stdpose)
motion = motion.view(B, J*D, T)
return motion
# tensor operations for rotating and projecting 3d skeleton sequence
def get_body_basis(motion_3d):
"""
Get the unit vectors for vector rectangular coordinates for given 3D motion
:param motion_3d: 3D motion from 3D joints positions, shape (B, n_joints, 3, seq_len).
:param angles: (K, 3), Rotation angles around each axis.
:return: unit vectors for vector rectangular coordinates's , shape (B, 3, 3).
"""
B = motion_3d.size(0)
# 2 RightArm 5 LeftArm 9 RightUpLeg 12 LeftUpLeg
horizontal = (motion_3d[:, 2] - motion_3d[:, 5] + motion_3d[:, 9] - motion_3d[:, 12]) / 2 # [B, 3, seq_len]
horizontal = horizontal.mean(dim=-1) # [B, 3]
horizontal = horizontal / horizontal.norm(dim=-1).unsqueeze(-1) # [B, 3]
vector_z = torch.tensor([0., 0., 1.], device=motion_3d.device, dtype=motion_3d.dtype).unsqueeze(0).repeat(B, 1) # [B, 3]
vector_y = torch.cross(horizontal, vector_z) # [B, 3]
vector_y = vector_y / vector_y.norm(dim=-1).unsqueeze(-1)
vector_x = torch.cross(vector_y, vector_z)
vectors = torch.stack([vector_x, vector_y, vector_z], dim=2) # [B, 3, 3]
vectors = vectors.detach()
return vectors
def rotate_basis_euler(basis_vectors, angles):
"""
Rotate vector rectangular coordinates from given angles.
:param basis_vectors: [B, 3, 3]
:param angles: [B, K, T, 3] Rotation angles around each axis.
:return: [B, K, T, 3, 3]
"""
B, K, T, _ = angles.size()
cos, sin = torch.cos(angles), torch.sin(angles)
cx, cy, cz = cos[:, :, :, 0], cos[:, :, :, 1], cos[:, :, :, 2] # [B, K, T]
sx, sy, sz = sin[:, :, :, 0], sin[:, :, :, 1], sin[:, :, :, 2] # [B, K, T]
x = basis_vectors[:, 0, :] # [B, 3]
o = torch.zeros_like(x[:, 0]) # [B]
x_cpm_0 = torch.stack([o, -x[:, 2], x[:, 1]], dim=1) # [B, 3]
x_cpm_1 = torch.stack([x[:, 2], o, -x[:, 0]], dim=1) # [B, 3]
x_cpm_2 = torch.stack([-x[:, 1], x[:, 0], o], dim=1) # [B, 3]
x_cpm = torch.stack([x_cpm_0, x_cpm_1, x_cpm_2], dim=1) # [B, 3, 3]
x_cpm = x_cpm.unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
x = x.unsqueeze(-1) # [B, 3, 1]
xx = torch.matmul(x, x.transpose(-1, -2)).unsqueeze(1).unsqueeze(2) # [B, 1, 1, 3, 3]
eye = torch.eye(n=3, dtype=basis_vectors.dtype, device=basis_vectors.device)
eye = eye.unsqueeze(0).unsqueeze(0).unsqueeze(0) # [1, 1, 1, 3, 3]
mat33_x = cx.unsqueeze(-1).unsqueeze(-1) * eye \
+ sx.unsqueeze(-1).unsqueeze(-1) * x_cpm \
+ (1. - cx).unsqueeze(-1).unsqueeze(-1) * xx # [B, K, T, 3, 3]
o = torch.zeros_like(cz)
i = torch.ones_like(cz)
mat33_z_0 = torch.stack([cz, sz, o], dim=3) # [B, K, T, 3]
mat33_z_1 = torch.stack([-sz, cz, o], dim=3) # [B, K, T, 3]
mat33_z_2 = torch.stack([o, o, i], dim=3) # [B, K, T, 3]
mat33_z = torch.stack([mat33_z_0, mat33_z_1, mat33_z_2], dim=3) # [B, K, T, 3, 3]
basis_vectors = basis_vectors.unsqueeze(1).unsqueeze(2)
basis_vectors = basis_vectors @ mat33_x.transpose(-1, -2) @ mat33_z
return basis_vectors
def change_of_basis(motion_3d, basis_vectors=None, project_2d=False):
# motion_3d: (B, n_joints, 3, seq_len)
# basis_vectors: (B, K, T, 3, 3)
if basis_vectors is None:
motion_proj = motion_3d[:, :, [0, 2], :] # [B, n_joints, 2, seq_len]
else:
if project_2d: basis_vectors = basis_vectors[:, :, :, [0, 2], :]
_, K, seq_len, _, _ = basis_vectors.size()
motion_3d = motion_3d.unsqueeze(1).repeat(1, K, 1, 1, 1)
motion_3d = motion_3d.permute([0, 1, 4, 3, 2]) # [B, K, J, 3, T] -> [B, K, T, 3, J]
motion_proj = basis_vectors @ motion_3d # [B, K, T, 2, 3] @ [B, K, T, 3, J] -> [B, K, T, 2, J]
motion_proj = motion_proj.permute([0, 1, 4, 3, 2]) # [B, K, T, 3, J] -> [B, K, J, 3, T]
return motion_proj
def rotate_and_maybe_project_world(X, angles=None, body_reference=True, project_2d=False):
out_dim = 2 if project_2d else 3
batch_size, n_joints, _, seq_len = X.size()
if angles is not None:
K = angles.size(1)
basis_vectors = get_body_basis(X) if body_reference else \
torch.eye(3, device=X.device).unsqueeze(0).repeat(batch_size, 1, 1)
basis_vectors = rotate_basis_euler(basis_vectors, angles)
X_trans = change_of_basis(X, basis_vectors, project_2d=project_2d)
X_trans = X_trans.reshape(batch_size * K, n_joints, out_dim, seq_len)
else:
X_trans = change_of_basis(X, project_2d=project_2d)
X_trans = X_trans.reshape(batch_size, n_joints, out_dim, seq_len)
return X_trans
def rotate_and_maybe_project_learning(X, meanpose, stdpose, angles=None, body_reference=True, project_2d=False):
batch_size, channels, seq_len = X.size()
n_joints = channels // 3
X = restore_world_space(X, meanpose, stdpose, n_joints)
X = rotate_and_maybe_project_world(X, angles, body_reference, project_2d)
X = convert_to_learning_space(X, meanpose, stdpose)
return X
|