# python scripts/datascripts/get_part_seg_mask.py --data_npz data/rich_val_smplx_small.npz --model_type 'smplx' import os import sys sys.path.append('/is/cluster/work/achatterjee/dca_contact') import cv2 import argparse import numpy as np import torch from common import constants from models.smpl import SMPL from smplx import SMPLX from utils.mesh_utils import save_results_mesh import trimesh from tqdm import tqdm from utils.image_utils import get_body_part_texture, generate_part_labels from utils.diff_renderer import Pytorch3D class PART_LABELER: def __init__(self, body_params, img_w, img_h, model_type, debug=False): """ Get part segmentation masks for images Args: body_params: SMPL parameters img_w: image width img_h: image height model_type: 'smpl' or 'smplx' """ self.device = torch.device('cuda:{}'.format(args.gpu)) if torch.cuda.is_available() else torch.device('cpu') self.model_type = model_type # Setup the SMPL model if self.model_type == 'smpl': self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device) if self.model_type == 'smplx': self.body_model = SMPLX(constants.SMPL_MODEL_DIR, num_betas=10, use_pca=False).to(self.device) self.body_part_vertex_colors, self.body_part_texture = get_body_part_texture(self.body_model.faces, model_type=self.model_type, non_parametric=False) # bins are discrete part labels, add eps to avoid quantization error eps = 1e-2 # self.part_label_bins = (torch.arange(int(constants.N_PARTS)) / float(constants.N_PARTS)) + eps self.part_label_bins = torch.linspace(0, constants.N_PARTS-1, constants.N_PARTS) + eps ## Run SMPL forward self.body_params = body_params self.smpl_verts, self.smpl_joints = self.get_posed_mesh(debug) # Assumbe same focal lenght for all frames in a seq focal_length = self.body_params['cam_k'][0, 0, 0] # focal_length = focal_length[0] # Setup Pyrender renderer # self.renderer = Renderer(focal_length=focal_length, img_w=img_w, img_h=img_h, # faces=self.smpl_model.faces, # same_mesh_color=False) # Setup Pytorch3D Renderer focal_length = torch.FloatTensor([focal_length]) smpl_faces = torch.from_numpy(self.body_model.faces.astype(np.int32)).to(self.device) self.renderer = Pytorch3D(img_h=img_h, img_w=img_w, focal_length=focal_length, smpl_faces=smpl_faces, texture_mode='partseg', vertex_colors=self.body_part_vertex_colors, face_textures=self.body_part_texture, model_type=self.model_type) def get_posed_mesh(self, debug=False): betas = torch.from_numpy(self.body_params['shape']).float().to(self.device) pose = torch.from_numpy(self.body_params['pose']).float().to(self.device) transl = torch.from_numpy(self.body_params['transl']).float().to(self.device) # extra smplx params extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device), 'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device), 'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device), 'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)} smpl_output = self.body_model(betas=betas, body_pose=pose[:, 3:], global_orient=pose[:, :3], pose2rot=True, transl=transl, **extra_args) smpl_verts = smpl_output.vertices.detach().cpu().numpy() smpl_joints = smpl_output.joints.detach().cpu().numpy() if debug: for mesh_i in range(smpl_verts.shape[0]): out_dir = 'temp_meshes' os.makedirs(out_dir, exist_ok=True) out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj') save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file) return smpl_verts, smpl_joints def bucketize_part_image(self, color_rgb, mask): # make single channel body_parts = color_rgb.clone() body_parts *= 255. # multiply it with 255 to make labels distant body_parts = body_parts.max(-1)[0] # reduce to single channel body_parts = torch.bucketize(body_parts, self.part_label_bins, right=True) # np.digitize(body_parts, bins, right=True) # add 1 to make background label 0 body_parts = body_parts.long() + 1 body_parts = body_parts * mask.detach() return body_parts.long() def create_part_masks(self, body_parts): # extract every pixel as a separate mask part_masks = [] for part_id in range(1, constants.N_PARTS+1): # first one is for background part_mask = (body_parts == part_id) part_masks.append(part_mask) return part_masks def render_part_mask_p3d(self, img_paths, out_dir): with torch.no_grad(): # os.makedirs(out_dir, exist_ok=True) for index, img_path in tqdm(enumerate(img_paths), dynamic_ncols=True): # Load the image if not os.path.exists(img_path): if 'train' in img_path: split = 'train' elif 'val' in img_path: split = 'val' else: split = 'test' new_img_name = img_path[img_path.index(split)+4:].replace('/', '_') new_path = os.path.join('/is/cluster/work/achatterjee/rich/images', split, new_img_name.replace('jpeg', 'bmp')) if not os.path.exists(new_path): new_path = new_path.replace('bmp', 'png') img_path = new_path if os.path.exists(out_dir[index]): continue # img_bgr = cv2.imread(img_path) chosen_vert_arr = torch.FloatTensor(self.smpl_verts[[index]]).to(self.device) front_view = self.renderer(chosen_vert_arr) front_view_rgb = front_view[0, :3, :, :].permute(1,2,0).detach().cpu() front_view_mask = front_view[0, 3, :, :].detach().cpu() # front_view_depth = front_view[0, 4, :, :].detach().cpu() body_parts = self.bucketize_part_image(front_view_rgb, front_view_mask) body_parts = body_parts.numpy() front_view_rgb = front_view_rgb.numpy() # body_part_masks = self.create_part_masks(body_parts) # display part masks # for part_id, part_mask in enumerate(body_part_masks): # part_mask = part_mask * 255 # part_dir = os.path.join(out_dir, f'frame_{index:04d}_parts') # os.makedirs(part_dir, exist_ok=True) # out_file = os.path.join(part_dir, f'part_{part_id:02d}_{index:04d}.png') # cv2.imwrite(out_file, part_mask) # out_file = os.path.join(out_dir, f'front_view_{index:04d}.png') # cv2.imwrite(out_file, front_view_rgb[: ,:, [2, 1, 0]]*255) # print(f'wrote front view to {out_file}') body_parts = cv2.merge((body_parts, body_parts, body_parts)) # out_file = os.path.join(out_dir, f'body_parts_{index:04d}.png') out_file = out_dir[index] cv2.imwrite(out_file, body_parts) # print(f'wrote body part masks to {out_file}') def main(args): out_dir = args.out_dir data_md = np.load(args.data_npz) # get all the jpg files in the folder img_paths = data_md['imgname'] seg_paths = data_md['part_seg'] print(f'found {len(img_paths)} images') # load first image img = cv2.imread(img_paths[0]) img_h, img_w, _ = img.shape labeler = PART_LABELER(body_params=data_md, img_w=img_w, img_h=img_h, model_type=args.model_type, debug=args.debug) labeler.render_part_mask_p3d(img_paths=img_paths, out_dir=seg_paths) if __name__=='__main__': parser = argparse.ArgumentParser() parser.add_argument('--out_dir', type=str, default='./temp_part_masks/', help='image folder') parser.add_argument('--data_npz', type=str, default='.', help='folder with smpl/smpl-x npz') parser.add_argument('--model_type', type=str, default='smplx', choices=['smpl', 'smplx'], help='model type') parser.add_argument('--gpu', type=int, default=0, help='gpu id') parser.add_argument('--debug', action='store_true', help='debug mode', default=False) args = parser.parse_args() main(args)