Spaces:
Running
on
Zero
Running
on
Zero
# Multi-HMR | |
# Copyright (c) 2024-present NAVER Corp. | |
# CC BY-NC-SA 4.0 license | |
import os | |
import sys | |
sys.path.append("./") | |
sys.path.append("./engine") | |
sys.path.append("./engine/pose_estimation") | |
import copy | |
import einops | |
import numpy as np | |
import roma | |
import torch | |
import torch.nn as nn | |
from blocks import ( | |
Dinov2Backbone, | |
FourierPositionEncoding, | |
SMPL_Layer, | |
TransformerDecoder, | |
) | |
from pose_utils import ( | |
inverse_perspective_projection, | |
pad_to_max, | |
rebatch, | |
rot6d_to_rotmat, | |
undo_focal_length_normalization, | |
undo_log_depth, | |
unpatch, | |
) | |
from torch import nn | |
def unravel_index(index, shape): | |
out = [] | |
for dim in reversed(shape): | |
out.append(index % dim) | |
index = index // dim | |
return tuple(reversed(out)) | |
class Model(nn.Module): | |
"""A ViT backbone followed by a "HPH" head (stack of cross attention layers with queries corresponding to detected humans.)""" | |
def __init__( | |
self, | |
backbone="dinov2_vitb14", | |
pretrained_backbone=False, | |
img_size=896, | |
camera_embedding="geometric", # geometric encodes viewing directions with fourrier encoding | |
camera_embedding_num_bands=16, # increase the size of the camera embedding | |
camera_embedding_max_resolution=64, # does not increase the size of the camera embedding | |
nearness=True, # regress log(1/z) | |
xat_depth=2, # number of cross attention block (SA, CA, MLP) in the HPH head. | |
xat_num_heads=8, # Number of attention heads | |
dict_smpl_layer=None, | |
person_center="head", | |
clip_dist=True, | |
num_betas=10, | |
smplx_dir=None, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
# Save options | |
self.img_size = img_size | |
self.nearness = nearness | |
self.clip_dist = (clip_dist,) | |
self.xat_depth = xat_depth | |
self.xat_num_heads = xat_num_heads | |
self.num_betas = num_betas | |
self.output_mesh = True | |
# Setup backbone | |
self.backbone = Dinov2Backbone(backbone, pretrained=pretrained_backbone) | |
self.embed_dim = self.backbone.embed_dim | |
self.patch_size = self.backbone.patch_size | |
assert self.img_size % self.patch_size == 0, "Invalid img size" | |
# Camera instrinsics | |
self.fovn = 60 | |
self.camera_embedding = camera_embedding | |
self.camera_embed_dim = 0 | |
if self.camera_embedding is not None: | |
if not self.camera_embedding == "geometric": | |
raise NotImplementedError( | |
"Only geometric camera embedding is implemented" | |
) | |
self.camera = FourierPositionEncoding( | |
n=3, | |
num_bands=camera_embedding_num_bands, | |
max_resolution=camera_embedding_max_resolution, | |
) | |
# import pdb | |
# pdb.set_trace() | |
self.camera_embed_dim = self.camera.channels | |
# Heads - Detection | |
self.mlp_classif = regression_mlp( | |
[self.embed_dim, self.embed_dim, 1] | |
) # bg or human | |
# Heads - Human properties | |
self.mlp_offset = regression_mlp([self.embed_dim, self.embed_dim, 2]) # offset | |
# SMPL Layers | |
self.nrot = 53 | |
dict_smpl_layer = { | |
"neutral": { | |
10: SMPL_Layer( | |
smplx_dir, | |
type="smplx", | |
gender="neutral", | |
num_betas=10, | |
kid=False, | |
person_center=person_center, | |
), | |
11: SMPL_Layer( | |
smplx_dir, | |
type="smplx", | |
gender="neutral", | |
num_betas=11, | |
kid=False, | |
person_center=person_center, | |
), | |
} | |
} | |
_moduleDict = [] | |
for k, _smpl_layer in dict_smpl_layer.items(): | |
for x, y in _smpl_layer.items(): | |
_moduleDict.append([f"{k}_{x}", copy.deepcopy(y)]) | |
self.smpl_layer = nn.ModuleDict(_moduleDict) | |
self.x_attention_head = HPH( | |
num_body_joints=self.nrot - 1, # 23, | |
context_dim=self.embed_dim + self.camera_embed_dim, | |
dim=1024, | |
depth=self.xat_depth, | |
heads=self.xat_num_heads, | |
mlp_dim=1024, | |
dim_head=32, | |
dropout=0.0, | |
emb_dropout=0.0, | |
at_token_res=self.img_size // self.patch_size, | |
num_betas=self.num_betas, | |
smplx_dir=smplx_dir, | |
) | |
print(f"person center is {person_center}") | |
# set whether do filter | |
def set_filter(self, apply_filter): | |
self.apply_filter = apply_filter | |
def detection( | |
self, | |
z, | |
nms_kernel_size, | |
det_thresh, | |
N, | |
idx=None, | |
max_dist=None, | |
is_training=False, | |
): | |
"""Detection score on the entire low res image""" | |
scores = _sigmoid(self.mlp_classif(z)) # per token detection score. | |
# Restore Height and Width dimensions. | |
scores = unpatch( | |
scores, patch_size=1, c=scores.shape[2], img_size=int(np.sqrt(N)) | |
) | |
pseudo_idx = idx | |
if not is_training: | |
if ( | |
nms_kernel_size > 1 | |
): # Easy nms: supress adjacent high scores with max pooling. | |
scores = _nms(scores, kernel=nms_kernel_size) | |
_scores = torch.permute(scores, (0, 2, 3, 1)) | |
# Binary decision (keep confident detections) | |
idx = apply_threshold(det_thresh, _scores) | |
if pseudo_idx is not None: | |
max_dist = 4 if max_dist is None else max_dist | |
mask = (torch.abs(idx[1] - pseudo_idx[1]) <= max_dist) & ( | |
torch.abs(idx[2] - pseudo_idx[2]) <= max_dist | |
) | |
idx_num = torch.sum(mask) | |
if idx_num < 1: | |
top = torch.clamp( | |
pseudo_idx[1] - max_dist, min=0, max=_scores.shape[1] - 1 | |
) | |
bottom = torch.clamp( | |
pseudo_idx[1] + max_dist, min=0, max=_scores.shape[1] | |
) | |
left = torch.clamp( | |
pseudo_idx[2] - max_dist, min=0, max=_scores.shape[2] - 1 | |
) | |
right = torch.clamp( | |
pseudo_idx[2] + max_dist, min=0, max=_scores.shape[2] | |
) | |
neigborhoods = _scores[:, top:bottom, left:right, :] | |
idx = torch.argmax(neigborhoods) | |
try: | |
idx = unravel_index(idx, neigborhoods.shape) | |
except Exception as e: | |
print(pseudo_idx) | |
raise e | |
idx = ( | |
pseudo_idx[0], | |
idx[1] + pseudo_idx[1] - max_dist, | |
idx[2] + pseudo_idx[2] - max_dist, | |
pseudo_idx[3], | |
) | |
elif idx_num > 1: # TODO | |
idx = (idx[0][mask], idx[1][mask], idx[2][mask], idx[3][mask]) | |
else: | |
idx = (idx[0][mask], idx[1][mask], idx[2][mask], idx[3][mask]) | |
# elif bbox is not None: | |
# mask = (idx[1] >= bbox[1]) & (idx[1] >= bbox[3]) & (idx[2] >= bbox[0]) & (idx[2] <= bbox[2]) | |
# idx_num = torch.sum(mask) | |
# if idx_num < 1: | |
# top = torch.clamp(bbox[1], min=0, max=_scores.shape[1]-1) | |
# bottom = torch.clamp(bbox[3], min=0, max=_scores.shape[1]-1) | |
# left = torch.clamp(bbox[0], min=0, max=_scores.shape[2]-1) | |
# right = torch.clamp(bbox[2], min=0, max=_scores.shape[2]-1) | |
# neigborhoods = _scores[:, top:bottom, left:right, :] | |
# idx = torch.argmax(neigborhoods) | |
# try: | |
# idx = unravel_index(idx, neigborhoods.shape) | |
# except Exception as e: | |
# print(pseudo_idx) | |
# raise e | |
# idx = (idx[0], idx[1] + top, idx[2] + left, idx[3]) | |
# else: | |
# idx = (idx[0][mask], idx[1][mask], idx[2][mask], idx[3][mask]) | |
else: | |
assert idx is not None # training time | |
# Scores | |
scores_detected = scores[ | |
idx[0], idx[3], idx[1], idx[2] | |
] # scores of the detected humans only | |
scores = torch.permute(scores, (0, 2, 3, 1)) | |
return scores, scores_detected, idx | |
def embedd_camera(self, K, z): | |
"""Embed viewing directions using fourrier encoding.""" | |
bs = z.shape[0] | |
_h, _w = list(z.shape[-2:]) | |
points = ( | |
torch.stack( | |
[ | |
torch.arange(0, _h, 1).reshape(-1, 1).repeat(1, _w), | |
torch.arange(0, _w, 1).reshape(1, -1).repeat(_h, 1), | |
], | |
-1, | |
) | |
.to(z.device) | |
.float() | |
) # [h,w,2] | |
points = ( | |
points * self.patch_size + self.patch_size // 2 | |
) # move to pixel space - we give the pixel center of each token | |
points = points.reshape(1, -1, 2).repeat(bs, 1, 1) # (bs, N, 2): 2D points | |
distance = torch.ones(bs, points.shape[1], 1).to( | |
K.device | |
) # (bs, N, 1): distance in the 3D world | |
rays = inverse_perspective_projection(points, K, distance) # (bs, N, 3) | |
rays_embeddings = self.camera(pos=rays) | |
# Repeat for each element of the batch | |
z_K = rays_embeddings.reshape(bs, _h, _w, self.camera_embed_dim) # [bs,h,w,D] | |
return z_K | |
def to_euclidean_dist(self, x, dist, _K): | |
# Focal length normalization | |
focal = _K[:, [0], [0]] | |
dist = undo_focal_length_normalization( | |
dist, focal, fovn=self.fovn, img_size=x.shape[-1] | |
) | |
# log space | |
if self.nearness: | |
dist = undo_log_depth(dist) | |
# Clamping | |
if self.clip_dist: | |
dist = torch.clamp(dist, 0, 50) | |
return dist | |
def get_smpl(self): | |
return self.smpl_layer[f"neutral_{self.num_betas}"] | |
def generate_meshes(self, out): | |
""" | |
Generates meshes for each person detected in the image. | |
This function processes the output of the detection model, which includes rotation vectors, | |
shapes, locations, distances, expressions, and other information related to SMPL-X parameters. | |
Parameters: | |
out (dict): A dictionary containing detection results and SMPL-X related parameters. | |
Returns: | |
list: A list of dictionaries, each containing information about a detected person's mesh. | |
""" | |
# Neutral | |
persons = [] | |
rotvec, shape, loc, dist, expression, K_det = ( | |
out["rotvec"], | |
out["shape"], | |
out["loc"], | |
out["dist"], | |
out["expression"], | |
out["K_det"], | |
) | |
scores_det = out["scores_det"] | |
idx = out["idx"] | |
smpl_out = self.smpl_layer[f"neutral_{self.num_betas}"]( | |
rotvec, shape, loc, dist, None, K=K_det, expression=expression | |
) | |
out.update(smpl_out) | |
for i in range(idx[0].shape[0]): | |
person = { | |
# Detection | |
"scores": scores_det[i], # detection scores | |
"loc": out["loc"][i], # 2d pixel location of the primary keypoints | |
# SMPL-X params | |
"transl": out["transl"][i], # from the primary keypoint i.e. the head | |
"transl_pelvis": out["transl_pelvis"][i], # of the pelvis joint | |
"rotvec": out["rotvec"][i], | |
"expression": out["expression"][i], | |
"shape": out["shape"][i], | |
# SMPL-X meshs | |
"v3d": out["v3d"][i], | |
"j3d": out["j3d"][i], | |
"j2d": out["j2d"][i], | |
} | |
persons.append(person) | |
return persons | |
def forward( | |
self, | |
x, | |
idx=None, | |
max_dist=None, | |
det_thresh=0.3, | |
nms_kernel_size=3, | |
K=None, | |
is_training=False, | |
*args, | |
**kwargs, | |
): | |
""" | |
Forward pass of the model and compute the loss according to the groundtruth | |
Args: | |
- x: RGB image - [bs,3,224,224] | |
- idx: GT location of persons - tuple of 3 tensor of shape [p] | |
- idx_j2d: GT location of 2d-kpts for each detected humans - tensor of shape [bs',14,2] - location in pixel space | |
Return: | |
- y: [bs,D,16,16] | |
""" | |
persons = [] | |
out = {} | |
# Feature extraction | |
z = self.backbone(x) | |
B, N, C = z.size() # [bs,256,768] | |
# Detection | |
scores, scores_det, idx = self.detection( | |
z, | |
nms_kernel_size=nms_kernel_size, | |
det_thresh=det_thresh, | |
N=N, | |
idx=idx, | |
max_dist=max_dist, | |
is_training=is_training, | |
) | |
if torch.any(scores_det < 0.1): | |
return persons | |
if len(idx[1]) == 0 and not is_training: | |
# no humans detected in the frame | |
return persons | |
# Map of Dense Feature | |
z = unpatch( | |
z, patch_size=1, c=z.shape[2], img_size=int(np.sqrt(N)) | |
) # [bs,D,16,16] | |
z_all = z | |
# Extract the 'central' features | |
z = torch.reshape( | |
z, (z.shape[0], 1, z.shape[1] // 1, z.shape[2], z.shape[3]) | |
) # [bs,stack_K,D,16,16] | |
z_central = z[idx[0], idx[3], :, idx[1], idx[2]] # dense vectors | |
# 2D offset regression | |
offset = self.mlp_offset(z_central) | |
# Camera instrincs | |
K_det = K[idx[0]] # cameras for detected person | |
z_K = self.embedd_camera(K, z) # Embed viewing directions. | |
z_central = torch.cat( | |
[z_central, z_K[idx[0], idx[1], idx[2]]], 1 | |
) # Add to query tokens. | |
z_all = torch.cat( | |
[z_all, z_K.permute(0, 3, 1, 2)], 1 | |
) # for the cross-attention only | |
z = torch.cat([z, z_K.permute(0, 3, 1, 2).unsqueeze(1)], 2) | |
# Distance for estimating the 3D location in 3D space | |
loc = torch.stack([idx[2], idx[1]]).permute( | |
1, 0 | |
) # Moving from higher resolution the location of the pelvis | |
loc = (loc + 0.5 + offset) * self.patch_size | |
# SMPL parameter regression | |
kv = z_all[ | |
idx[0] | |
] # retrieving dense features associated to each central vector | |
pred_smpl_params, pred_cam = self.x_attention_head( | |
z_central, kv, idx_0=idx[0], idx_det=idx | |
) | |
# Get outputs from the SMPL layer. | |
shape = pred_smpl_params["betas"] | |
rotmat = torch.cat( | |
[pred_smpl_params["global_orient"], pred_smpl_params["body_pose"]], 1 | |
) | |
expression = pred_smpl_params["expression"] | |
rotvec = roma.rotmat_to_rotvec(rotmat) | |
# Distance | |
dist = pred_cam[:, 0][:, None] | |
out["dist_postprocessed"] = ( | |
dist # before applying any post-processing such as focal length normalization, inverse or log | |
) | |
dist = self.to_euclidean_dist(x, dist, K_det) | |
# Populate output dictionnary | |
out.update( | |
{ | |
"scores": scores, | |
"offset": offset, | |
"dist": dist, | |
"expression": expression, | |
"rotmat": rotmat, | |
"shape": shape, | |
"rotvec": rotvec, | |
"loc": loc, | |
} | |
) | |
assert ( | |
rotvec.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0] | |
), "Incoherent shapes" | |
if not self.output_mesh: | |
out.update( | |
{ | |
"K_det": K_det, | |
"scores_det": scores_det, | |
"idx": idx, | |
} | |
) | |
return out | |
# Neutral | |
smpl_out = self.smpl_layer[f"neutral_{self.num_betas}"]( | |
rotvec, shape, loc, dist, None, K=K_det, expression=expression | |
) | |
out.update(smpl_out) | |
# Return | |
if is_training: | |
return out | |
else: | |
# Populate a dictionnary for each person | |
for i in range(idx[0].shape[0]): | |
person = { | |
# Detection | |
"scores": scores_det[i], # detection scores | |
"loc": out["loc"][i], # 2d pixel location of the primary keypoints | |
# SMPL-X params | |
"transl": out["transl"][ | |
i | |
], # from the primary keypoint i.e. the head | |
"transl_pelvis": out["transl_pelvis"][i], # of the pelvis joint | |
"rotvec": out["rotvec"][i], | |
"expression": out["expression"][i], | |
"shape": out["shape"][i], | |
# SMPL-X meshs | |
"v3d": out["v3d"][i], | |
"j3d": out["j3d"][i], | |
"j2d": out["j2d"][i], | |
"dist": out["dist"][i], | |
"offset": out["offset"][i], | |
} | |
persons.append(person) | |
return persons | |
class HPH(nn.Module): | |
"""Cross-attention based SMPL Transformer decoder | |
Code modified from: | |
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/heads/smpl_head.py#L17 | |
https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L301 | |
""" | |
def __init__( | |
self, | |
num_body_joints=52, | |
context_dim=1280, | |
dim=1024, | |
depth=2, | |
heads=8, | |
mlp_dim=1024, | |
dim_head=64, | |
dropout=0.0, | |
emb_dropout=0.0, | |
at_token_res=32, | |
num_betas=10, | |
smplx_dir=None, | |
): | |
super().__init__() | |
self.joint_rep_type, self.joint_rep_dim = "6d", 6 | |
self.num_body_joints = num_body_joints | |
self.nrot = self.num_body_joints + 1 | |
npose = self.joint_rep_dim * (self.num_body_joints + 1) | |
self.npose = npose | |
self.depth = (depth,) | |
self.heads = (heads,) | |
self.res = at_token_res | |
self.input_is_mean_shape = True | |
_context_dim = context_dim # for the central features | |
self.num_betas = num_betas | |
assert num_betas in [10, 11] | |
# Transformer Decoder setup. | |
# Based on https://github.com/shubham-goel/4D-Humans/blob/8830bb330558eea2395b7f57088ef0aae7f8fa22/hmr2/configs_hydra/experiment/hmr_vit_transformer.yaml#L35 | |
transformer_args = dict( | |
num_tokens=1, | |
token_dim=( | |
(npose + self.num_betas + 3 + _context_dim) | |
if self.input_is_mean_shape | |
else 1 | |
), | |
dim=dim, | |
depth=depth, | |
heads=heads, | |
mlp_dim=mlp_dim, | |
dim_head=dim_head, | |
dropout=dropout, | |
emb_dropout=emb_dropout, | |
context_dim=context_dim, | |
) | |
self.transformer = TransformerDecoder(**transformer_args) | |
dim = transformer_args["dim"] | |
# Final decoders to regress targets | |
self.decpose, self.decshape, self.deccam, self.decexpression = [ | |
nn.Linear(dim, od) for od in [npose, num_betas, 3, 10] | |
] | |
# Register bufffers for the smpl layer. | |
self.set_smpl_init(smplx_dir) | |
# Init learned embeddings for the cross attention queries | |
self.init_learned_queries(context_dim) | |
def init_learned_queries(self, context_dim, std=0.2): | |
"""Init learned embeddings for queries""" | |
self.cross_queries_x = nn.Parameter(torch.zeros(self.res, context_dim)) | |
torch.nn.init.normal_(self.cross_queries_x, std=std) | |
self.cross_queries_y = nn.Parameter(torch.zeros(self.res, context_dim)) | |
torch.nn.init.normal_(self.cross_queries_y, std=std) | |
self.cross_values_x = nn.Parameter(torch.zeros(self.res, context_dim)) | |
torch.nn.init.normal_(self.cross_values_x, std=std) | |
self.cross_values_y = nn.Parameter( | |
nn.Parameter(torch.zeros(self.res, context_dim)) | |
) | |
torch.nn.init.normal_(self.cross_values_y, std=std) | |
def set_smpl_init(self, smplx_dir): | |
"""Fetch saved SMPL parameters and register buffers.""" | |
mean_params = np.load(os.path.join(smplx_dir, "smpl_mean_params.npz")) | |
if self.nrot == 53: | |
init_body_pose = ( | |
torch.eye(3) | |
.reshape(1, 3, 3) | |
.repeat(self.nrot, 1, 1)[:, :, :2] | |
.flatten(1) | |
.reshape(1, -1) | |
) | |
init_body_pose[:, : 24 * 6] = torch.from_numpy( | |
mean_params["pose"][:] | |
).float() # global_orient+body_pose from SMPL | |
else: | |
init_body_pose = torch.from_numpy( | |
mean_params["pose"].astype(np.float32) | |
).unsqueeze(0) | |
init_betas = torch.from_numpy(mean_params["shape"].astype("float32")).unsqueeze( | |
0 | |
) | |
init_cam = torch.from_numpy(mean_params["cam"].astype(np.float32)).unsqueeze(0) | |
init_betas_kid = torch.cat( | |
[init_betas, torch.zeros_like(init_betas[:, [0]])], 1 | |
) | |
init_expression = 0.0 * torch.from_numpy( | |
mean_params["shape"].astype("float32") | |
).unsqueeze(0) | |
if self.num_betas == 11: | |
init_betas = torch.cat([init_betas, torch.zeros_like(init_betas[:, :1])], 1) | |
self.register_buffer("init_body_pose", init_body_pose) | |
self.register_buffer("init_betas", init_betas) | |
self.register_buffer("init_betas_kid", init_betas_kid) | |
self.register_buffer("init_cam", init_cam) | |
self.register_buffer("init_expression", init_expression) | |
def cross_attn_inputs(self, x, x_central, idx_0, idx_det): | |
"""Reshape and pad x_central to have the right shape for Cross-attention processing. | |
Inject learned embeddings to query and key inputs at the location of detected people. | |
""" | |
h, w = x.shape[2], x.shape[3] | |
x = einops.rearrange(x, "b c h w -> b (h w) c") | |
assert idx_0 is not None, "Learned cross queries only work with multicross" | |
if idx_0.shape[0] > 0: | |
# reconstruct the batch/nb_people dimensions: pad for images with fewer people than max. | |
counts, idx_det_0 = rebatch(idx_0, idx_det) | |
old_shape = x_central.shape | |
# Legacy check for old versions | |
assert idx_det is not None, "idx_det needed for learned_attention" | |
# xx is the tensor with all features | |
xx = einops.rearrange(x, "b (h w) c -> b c h w", h=h, w=w) | |
# Get learned embeddings for queries, at positions with detected people. | |
queries_xy = ( | |
self.cross_queries_x[idx_det[1]] + self.cross_queries_y[idx_det[2]] | |
) | |
# Add the embedding to the central features. | |
x_central = x_central + queries_xy | |
assert x_central.shape == old_shape, "Problem with shape" | |
# Make it a tensor of dim. [batch, max_ppl_along_batch, ...] | |
x_central, mask = pad_to_max(x_central, counts) | |
# xx = einops.rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) | |
xx = xx[torch.cumsum(counts, dim=0) - 1] | |
# Inject leared embeddings for key/values at detected locations. | |
values_xy = ( | |
self.cross_values_x[idx_det[1]] + self.cross_values_y[idx_det[2]] | |
) | |
xx[idx_det_0, :, idx_det[1], idx_det[2]] += values_xy | |
x = einops.rearrange(xx, "b c h w -> b (h w) c") | |
num_ppl = x_central.shape[1] | |
else: | |
mask = None | |
num_ppl = 1 | |
counts = None | |
return x, x_central, mask, num_ppl, counts | |
def forward(self, x_central, x, idx_0=None, idx_det=None, **kwargs): | |
""" " | |
Forward the HPH module. | |
""" | |
batch_size = x.shape[0] | |
# Reshape inputs for cross attention and inject learned embeddings for queries and values. | |
x, x_central, mask, num_ppl, counts = self.cross_attn_inputs( | |
x, x_central, idx_0, idx_det | |
) | |
# Add init (mean smpl params) to the query for each quantity being regressed. | |
bs = x_central.shape[0] if idx_0.shape[0] else batch_size | |
expand = lambda x: x.expand(bs, num_ppl, -1) | |
pred_body_pose, pred_betas, pred_cam, pred_expression = [ | |
expand(x) | |
for x in [ | |
self.init_body_pose, | |
self.init_betas, | |
self.init_cam, | |
self.init_expression, | |
] | |
] | |
token = torch.cat([x_central, pred_body_pose, pred_betas, pred_cam], dim=-1) | |
if len(token.shape) == 2: | |
token = token[:, None, :] | |
# Process query and inputs with the cross-attention module. | |
token_out = self.transformer(token, context=x, mask=mask) | |
# Reshape outputs from [batch_size, nmax_ppl, ...] to [total_ppl, ...] | |
if mask is not None: | |
# Stack along batch axis. | |
token_out_list = [token_out[i, :c, ...] for i, c in enumerate(counts)] | |
token_out = torch.concat(token_out_list, dim=0) | |
else: | |
token_out = token_out.squeeze(1) # (B, C) | |
# Decoded output token and add to init for each quantity to regress. | |
reshape = ( | |
(lambda x: x) | |
if idx_0.shape[0] == 0 | |
else (lambda x: x[0, 0, ...][None, ...]) | |
) | |
decoders = [self.decpose, self.decshape, self.deccam, self.decexpression] | |
inits = [pred_body_pose, pred_betas, pred_cam, pred_expression] | |
pred_body_pose, pred_betas, pred_cam, pred_expression = [ | |
d(token_out) + reshape(i) for d, i in zip(decoders, inits) | |
] | |
# Convert self.joint_rep_type -> rotmat | |
joint_conversion_fn = rot6d_to_rotmat | |
# conversion | |
pred_body_pose = joint_conversion_fn(pred_body_pose).view( | |
batch_size, self.num_body_joints + 1, 3, 3 | |
) | |
# Build the output dict | |
pred_smpl_params = { | |
"global_orient": pred_body_pose[:, [0]], | |
"body_pose": pred_body_pose[:, 1:], | |
"betas": pred_betas, | |
#'betas_kid': pred_betas_kid, | |
"expression": pred_expression, | |
} | |
return pred_smpl_params, pred_cam # , pred_smpl_params_list | |
def regression_mlp(layers_sizes): | |
""" | |
Return a fully connected network. | |
""" | |
assert len(layers_sizes) >= 2 | |
in_features = layers_sizes[0] | |
layers = [] | |
for i in range(1, len(layers_sizes) - 1): | |
out_features = layers_sizes[i] | |
layers.append(torch.nn.Linear(in_features, out_features)) | |
layers.append(torch.nn.ReLU()) | |
in_features = out_features | |
layers.append(torch.nn.Linear(in_features, layers_sizes[-1])) | |
return torch.nn.Sequential(*layers) | |
def apply_threshold(det_thresh, _scores): | |
"""Apply thresholding to detection scores; if stack_K is used and det_thresh is a list, apply to each channel separately""" | |
if isinstance(det_thresh, list): | |
det_thresh = det_thresh[0] | |
idx = torch.where(_scores >= det_thresh) | |
return idx | |
def _nms(heat, kernel=3): | |
"""easy non maximal supression (as in CenterNet)""" | |
if kernel not in [2, 4]: | |
pad = (kernel - 1) // 2 | |
else: | |
if kernel == 2: | |
pad = 1 | |
else: | |
pad = 2 | |
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad) | |
if hmax.shape[2] > heat.shape[2]: | |
hmax = hmax[:, :, : heat.shape[2], : heat.shape[3]] | |
keep = (hmax == heat).float() | |
return heat * keep | |
def _sigmoid(x): | |
y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) | |
return y | |
if __name__ == "__main__": | |
Model() | |