# Multi-HMR # Copyright (c) 2024-present NAVER Corp. # CC BY-NC-SA 4.0 license from torch import nn import torch import numpy as np import roma import copy from utils import unpatch, inverse_perspective_projection, undo_focal_length_normalization, undo_log_depth from blocks import Dinov2Backbone, FourierPositionEncoding, TransformerDecoder, SMPL_Layer from utils import rot6d_to_rotmat, rebatch, pad_to_max import torch.nn as nn import numpy as np import einops from utils.constants import MEAN_PARAMS class Model(nn.Module): """ A ViT backbone followed by a "HPH" head (stack of cross attention layers with queries corresponding to detected humans.) """ def __init__(self, backbone='dinov2_vitb14', img_size=896, camera_embedding='geometric', # geometric encodes viewing directions with fourrier encoding camera_embedding_num_bands=16, # increase the size of the camera embedding camera_embedding_max_resolution=64, # does not increase the size of the camera embedding nearness=True, # regress log(1/z) xat_depth=2, # number of cross attention block (SA, CA, MLP) in the HPH head. xat_num_heads=8, # Number of attention heads dict_smpl_layer=None, person_center='head', clip_dist=True, *args, **kwargs): super().__init__() # Save options self.img_size = img_size self.nearness = nearness self.clip_dist = clip_dist, self.xat_depth = xat_depth self.xat_num_heads = xat_num_heads # Setup backbone self.backbone = Dinov2Backbone(backbone) self.embed_dim = self.backbone.embed_dim self.patch_size = self.backbone.patch_size assert self.img_size % self.patch_size == 0, "Invalid img size" # Camera instrinsics self.fovn = 60 self.camera_embedding = camera_embedding self.camera_embed_dim = 0 if self.camera_embedding is not None: if not self.camera_embedding == 'geometric': raise NotImplementedError("Only geometric camera embedding is implemented") self.camera = FourierPositionEncoding(n=3, num_bands=camera_embedding_num_bands,max_resolution=camera_embedding_max_resolution) # import pdb # pdb.set_trace() self.camera_embed_dim = self.camera.channels # Heads - Detection self.mlp_classif = regression_mlp([self.embed_dim, self.embed_dim, 1]) # bg or human # Heads - Human properties self.mlp_offset = regression_mlp([self.embed_dim, self.embed_dim, 2]) # offset # Dense vetcor idx self.nrot = 53 self.idx_score, self.idx_offset, self.idx_dist = [0], [1,2], [3] self.idx_pose = list(range(4,4+self.nrot*9)) self.idx_shape = list(range(4+self.nrot*9,4+self.nrot*9+11)) self.idx_expr = list(range(4+self.nrot*9+11,4+self.nrot*9+11+10)) # SMPL Layers dict_smpl_layer = {'neutral': {10: SMPL_Layer(type='smplx', gender='neutral', num_betas=10, kid=False, person_center=person_center)}} _moduleDict = [] for k, _smpl_layer in dict_smpl_layer.items(): _moduleDict.append([k, copy.deepcopy(_smpl_layer[10])]) self.smpl_layer = nn.ModuleDict(_moduleDict) self.x_attention_head = HPH( num_body_joints=self.nrot-1, #23, context_dim=self.embed_dim + self.camera_embed_dim, dim=1024, depth=self.xat_depth, heads=self.xat_num_heads, mlp_dim=1024, dim_head=32, dropout=0.0, emb_dropout=0.0, at_token_res=self.img_size // self.patch_size) def detection(self, z, nms_kernel_size, det_thresh, N): """ Detection score on the entire low res image """ scores = _sigmoid(self.mlp_classif(z)) # per token detection score. # Restore Height and Width dimensions. scores = unpatch(scores, patch_size=1, c=scores.shape[2], img_size=int(np.sqrt(N))) if nms_kernel_size > 1: # Easy nms: supress adjacent high scores with max pooling. scores = _nms(scores, kernel=nms_kernel_size) _scores = torch.permute(scores, (0, 2, 3, 1)) # Binary decision (keep confident detections) idx = apply_threshold(det_thresh, _scores) # Scores scores_detected = scores[idx[0], idx[3], idx[1],idx[2]] # scores of the detected humans only scores = torch.permute(scores, (0, 2, 3, 1)) return scores, scores_detected, idx def embedd_camera(self, K, z): """ Embed viewing directions using fourrier encoding.""" bs = z.shape[0] _h, _w = list(z.shape[-2:]) points = torch.stack([torch.arange(0,_h,1).reshape(-1,1).repeat(1,_w), torch.arange(0,_w,1).reshape(1,-1).repeat(_h,1)],-1).to(z.device).float() # [h,w,2] points = points * self.patch_size + self.patch_size // 2 # move to pixel space - we give the pixel center of each token points = points.reshape(1,-1,2).repeat(bs,1,1) # (bs, N, 2): 2D points distance = torch.ones(bs,points.shape[1],1).to(K.device) # (bs, N, 1): distance in the 3D world rays = inverse_perspective_projection(points, K, distance) # (bs, N, 3) rays_embeddings = self.camera(pos=rays) # Repeat for each element of the batch z_K = rays_embeddings.reshape(bs,_h,_w,self.camera_embed_dim) # [bs,h,w,D] return z_K def to_euclidean_dist(self, x, dist, _K): # Focal length normalization focal = _K[:,[0],[0]] dist = undo_focal_length_normalization(dist, focal, fovn=self.fovn, img_size=x.shape[-1]) # log space if self.nearness: dist = undo_log_depth(dist) # Clamping if self.clip_dist: dist = torch.clamp(dist, 0, 50) return dist def forward(self, x, idx=None, det_thresh=0.5, nms_kernel_size=3, K=None, *args, **kwargs): """ Forward pass of the model and compute the loss according to the groundtruth Args: - x: RGB image - [bs,3,224,224] - idx: GT location of persons - tuple of 3 tensor of shape [p] - idx_j2d: GT location of 2d-kpts for each detected humans - tensor of shape [bs',14,2] - location in pixel space Return: - y: [bs,D,16,16] """ persons = [] out = {} # Feature extraction z = self.backbone(x) B,N,C = z.size() # [bs,256,768] # Detection scores, scores_det, idx = self.detection(z, nms_kernel_size=nms_kernel_size, det_thresh=det_thresh, N=N) if len(idx[0]) == 0: # no humans detected in the frame return persons # Map of Dense Feature z = unpatch(z, patch_size=1, c=z.shape[2], img_size=int(np.sqrt(N))) # [bs,D,16,16] z_all = z # Extract the 'central' features z = torch.reshape(z, (z.shape[0], 1, z.shape[1]//1, z.shape[2], z.shape[3])) # [bs,stack_K,D,16,16] z_central = z[idx[0],idx[3],:,idx[1],idx[2]] # dense vectors # 2D offset regression offset = self.mlp_offset(z_central) # Camera instrincs K_det = K[idx[0]] # cameras for detected person z_K = self.embedd_camera(K, z) # Embed viewing directions. z_central = torch.cat([z_central, z_K[idx[0],idx[1], idx[2]]], 1) # Add to query tokens. z_all = torch.cat([z_all, z_K.permute(0,3,1,2)], 1) # for the cross-attention only z = torch.cat([z, z_K.permute(0,3,1,2).unsqueeze(1)],2) # Distance for estimating the 3D location in 3D space loc = torch.stack([idx[2],idx[1]]).permute(1,0) # Moving from higher resolution the location of the pelvis loc = (loc + 0.5 + offset ) * self.patch_size # SMPL parameter regression kv = z_all[idx[0]] # retrieving dense features associated to each central vector pred_smpl_params, pred_cam = self.x_attention_head(z_central, kv, idx_0=idx[0], idx_det=idx) # Get outputs from the SMPL layer. shape = pred_smpl_params['betas'] rotmat = torch.cat([pred_smpl_params['global_orient'],pred_smpl_params['body_pose']], 1) expression = pred_smpl_params['expression'] rotvec = roma.rotmat_to_rotvec(rotmat) # Distance dist = pred_cam[:, 0][:, None] out['dist_postprocessed'] = dist # before applying any post-processing such as focal length normalization, inverse or log dist = self.to_euclidean_dist(x, dist, K_det) # Populate output dictionnary out.update({'scores': scores, 'offset': offset, 'dist': dist, 'expression': expression, 'rotmat': rotmat, 'shape': shape, 'rotvec': rotvec, 'loc': loc}) assert rotvec.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0], "Incoherent shapes" # Neutral smpl_out = self.smpl_layer['neutral'](rotvec, shape, loc, dist, None, K=K_det, expression=expression) out.update(smpl_out) # Populate a dictionnary for each person for i in range(idx[0].shape[0]): person = { # Detection 'scores': scores_det[i], # detection scores 'loc': out['loc'][i], # 2d pixel location of the primary keypoints # SMPL-X params 'transl': out['transl'][i], # from the primary keypoint i.e. the head 'transl_pelvis': out['transl_pelvis'][i], # of the pelvis joint 'rotvec': out['rotvec'][i], 'expression': out['expression'][i], 'shape': out['shape'][i], # SMPL-X meshs 'verts_smplx': out['verts_smplx_cam'][i], 'j3d_smplx': out['j3d'][i], 'j2d_smplx': out['j2d'][i], } persons.append(person) return persons class HPH(nn.Module): """ Cross-attention based SMPL Transformer decoder Code modified from: https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/heads/smpl_head.py#L17 https://github.com/shubham-goel/4D-Humans/blob/a0def798c7eac811a63c8220fcc22d983b39785e/hmr2/models/components/pose_transformer.py#L301 """ def __init__(self, num_body_joints=52, context_dim=1280, dim=1024, depth=2, heads=8, mlp_dim=1024, dim_head=64, dropout=0.0, emb_dropout=0.0, at_token_res=32, ): super().__init__() self.joint_rep_type, self.joint_rep_dim = '6d', 6 self.num_body_joints = num_body_joints self.nrot = self.num_body_joints + 1 npose = self.joint_rep_dim * (self.num_body_joints + 1) self.npose = npose self.depth = depth, self.heads = heads, self.res = at_token_res self.input_is_mean_shape = True _context_dim = context_dim # for the central features # Transformer Decoder setup. # Based on https://github.com/shubham-goel/4D-Humans/blob/8830bb330558eea2395b7f57088ef0aae7f8fa22/hmr2/configs_hydra/experiment/hmr_vit_transformer.yaml#L35 transformer_args = dict( num_tokens=1, token_dim=(npose + 10 + 3 + _context_dim) if self.input_is_mean_shape else 1, dim=dim, depth=depth, heads=heads, mlp_dim=mlp_dim, dim_head=dim_head, dropout=dropout, emb_dropout=emb_dropout, context_dim=context_dim, ) self.transformer = TransformerDecoder(**transformer_args) dim = transformer_args['dim'] # Final decoders to regress targets self.decpose, self.decshape, self.deccam, self.decexpression = [nn.Linear(dim, od) for od in [npose, 10, 3, 10]] # Register bufffers for the smpl layer. self.set_smpl_init() # Init learned embeddings for the cross attention queries self.init_learned_queries(context_dim) def init_learned_queries(self, context_dim, std=0.2): """ Init learned embeddings for queries""" self.cross_queries_x = nn.Parameter(torch.zeros(self.res, context_dim)) torch.nn.init.normal_(self.cross_queries_x, std=std) self.cross_queries_y = nn.Parameter(torch.zeros(self.res, context_dim)) torch.nn.init.normal_(self.cross_queries_y, std=std) self.cross_values_x = nn.Parameter(torch.zeros(self.res, context_dim)) torch.nn.init.normal_(self.cross_values_x, std=std) self.cross_values_y = nn.Parameter(nn.Parameter(torch.zeros(self.res, context_dim))) torch.nn.init.normal_(self.cross_values_y, std=std) def set_smpl_init(self): """ Fetch saved SMPL parameters and register buffers.""" mean_params = np.load(MEAN_PARAMS) if self.nrot == 53: init_body_pose = torch.eye(3).reshape(1,3,3).repeat(self.nrot,1,1)[:,:,:2].flatten(1).reshape(1, -1) init_body_pose[:,:24*6] = torch.from_numpy(mean_params['pose'][:]).float() # global_orient+body_pose from SMPL else: init_body_pose = torch.from_numpy(mean_params['pose'].astype(np.float32)).unsqueeze(0) init_betas = torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) init_cam = torch.from_numpy(mean_params['cam'].astype(np.float32)).unsqueeze(0) init_betas_kid = torch.cat([init_betas, torch.zeros_like(init_betas[:,[0]])],1) init_expression = 0. * torch.from_numpy(mean_params['shape'].astype('float32')).unsqueeze(0) self.register_buffer('init_body_pose', init_body_pose) self.register_buffer('init_betas', init_betas) self.register_buffer('init_betas_kid', init_betas_kid) self.register_buffer('init_cam', init_cam) self.register_buffer('init_expression', init_expression) def cross_attn_inputs(self, x, x_central, idx_0, idx_det): """ Reshape and pad x_central to have the right shape for Cross-attention processing. Inject learned embeddings to query and key inputs at the location of detected people. """ h, w = x.shape[2], x.shape[3] x = einops.rearrange(x, 'b c h w -> b (h w) c') assert idx_0 is not None, "Learned cross queries only work with multicross" if idx_0.shape[0] > 0: # reconstruct the batch/nb_people dimensions: pad for images with fewer people than max. counts, idx_det_0 = rebatch(idx_0, idx_det) old_shape = x_central.shape # Legacy check for old versions assert idx_det is not None, 'idx_det needed for learned_attention' # xx is the tensor with all features xx = einops.rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) # Get learned embeddings for queries, at positions with detected people. queries_xy = self.cross_queries_x[idx_det[1]] + self.cross_queries_y[idx_det[2]] # Add the embedding to the central features. x_central = x_central + queries_xy assert x_central.shape == old_shape, "Problem with shape" # Make it a tensor of dim. [batch, max_ppl_along_batch, ...] x_central, mask = pad_to_max(x_central, counts) #xx = einops.rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) xx = xx[torch.cumsum(counts, dim=0)-1] # Inject leared embeddings for key/values at detected locations. values_xy = self.cross_values_x[idx_det[1]] + self.cross_values_y[idx_det[2]] xx[idx_det_0, :, idx_det[1], idx_det[2]] += values_xy x = einops.rearrange(xx, 'b c h w -> b (h w) c') num_ppl = x_central.shape[1] else: mask = None num_ppl = 1 counts = None return x, x_central, mask, num_ppl, counts def forward(self, x_central, x, idx_0=None, idx_det=None, **kwargs): """" Forward the HPH module. """ batch_size = x.shape[0] # Reshape inputs for cross attention and inject learned embeddings for queries and values. x, x_central, mask, num_ppl, counts = self.cross_attn_inputs(x, x_central, idx_0, idx_det) # Add init (mean smpl params) to the query for each quantity being regressed. bs = x_central.shape[0] if idx_0.shape[0] else batch_size expand = lambda x: x.expand(bs, num_ppl , -1) pred_body_pose, pred_betas, pred_cam, pred_expression = [expand(x) for x in [self.init_body_pose, self.init_betas, self.init_cam, self.init_expression]] token = torch.cat([x_central, pred_body_pose, pred_betas, pred_cam], dim=-1) if len(token.shape) == 2: token = token[:,None,:] # Process query and inputs with the cross-attention module. token_out = self.transformer(token, context=x, mask=mask) # Reshape outputs from [batch_size, nmax_ppl, ...] to [total_ppl, ...] if mask is not None: # Stack along batch axis. token_out_list = [token_out[i, :c, ...] for i, c in enumerate(counts)] token_out = torch.concat(token_out_list, dim=0) else: token_out = token_out.squeeze(1) # (B, C) # Decoded output token and add to init for each quantity to regress. reshape = (lambda x: x) if idx_0.shape[0] == 0 else (lambda x: x[0, 0, ...][None, ...]) decoders = [self.decpose, self.decshape, self.deccam, self.decexpression] inits = [pred_body_pose, pred_betas, pred_cam, pred_expression] pred_body_pose, pred_betas, pred_cam, pred_expression = [d(token_out) + reshape(i) for d, i in zip(decoders, inits)] # Convert self.joint_rep_type -> rotmat joint_conversion_fn = rot6d_to_rotmat # conversion pred_body_pose = joint_conversion_fn(pred_body_pose).view(batch_size, self.num_body_joints+1, 3, 3) # Build the output dict pred_smpl_params = {'global_orient': pred_body_pose[:, [0]], 'body_pose': pred_body_pose[:, 1:], 'betas': pred_betas, #'betas_kid': pred_betas_kid, 'expression': pred_expression} return pred_smpl_params, pred_cam #, pred_smpl_params_list def regression_mlp(layers_sizes): """ Return a fully connected network. """ assert len(layers_sizes) >= 2 in_features = layers_sizes[0] layers = [] for i in range(1, len(layers_sizes)-1): out_features = layers_sizes[i] layers.append(torch.nn.Linear(in_features, out_features)) layers.append(torch.nn.ReLU()) in_features = out_features layers.append(torch.nn.Linear(in_features, layers_sizes[-1])) return torch.nn.Sequential(*layers) def apply_threshold(det_thresh, _scores): """ Apply thresholding to detection scores; if stack_K is used and det_thresh is a list, apply to each channel separately """ if isinstance(det_thresh, list): det_thresh = det_thresh[0] idx = torch.where(_scores >= det_thresh) return idx def _nms(heat, kernel=3): """ easy non maximal supression (as in CenterNet) """ if kernel not in [2, 4]: pad = (kernel - 1) // 2 else: if kernel == 2: pad = 1 else: pad = 2 hmax = nn.functional.max_pool2d( heat, (kernel, kernel), stride=1, padding=pad) if hmax.shape[2] > heat.shape[2]: hmax = hmax[:, :, :heat.shape[2], :heat.shape[3]] keep = (hmax == heat).float() return heat * keep def _sigmoid(x): y = torch.clamp(x.sigmoid_(), min=1e-4, max=1-1e-4) return y if __name__ == "__main__": Model()