Spaces:
Sleeping
Sleeping
File size: 9,670 Bytes
99a05f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
# python scripts/datascripts/get_part_seg_mask.py --data_npz data/rich_val_smplx_small.npz --model_type 'smplx'
import os
import sys
sys.path.append('/is/cluster/work/achatterjee/dca_contact')
import cv2
import argparse
import numpy as np
import torch
from common import constants
from models.smpl import SMPL
from smplx import SMPLX
from utils.mesh_utils import save_results_mesh
import trimesh
from tqdm import tqdm
from utils.image_utils import get_body_part_texture, generate_part_labels
from utils.diff_renderer import Pytorch3D
class PART_LABELER:
def __init__(self, body_params, img_w, img_h, model_type, debug=False):
"""
Get part segmentation masks for images
Args:
body_params: SMPL parameters
img_w: image width
img_h: image height
model_type: 'smpl' or 'smplx'
"""
self.device = torch.device('cuda:{}'.format(args.gpu)) if torch.cuda.is_available() else torch.device('cpu')
self.model_type = model_type
# Setup the SMPL model
if self.model_type == 'smpl':
self.body_model = SMPL(constants.SMPL_MODEL_DIR).to(self.device)
if self.model_type == 'smplx':
self.body_model = SMPLX(constants.SMPL_MODEL_DIR,
num_betas=10,
use_pca=False).to(self.device)
self.body_part_vertex_colors, self.body_part_texture = get_body_part_texture(self.body_model.faces,
model_type=self.model_type,
non_parametric=False)
# bins are discrete part labels, add eps to avoid quantization error
eps = 1e-2
# self.part_label_bins = (torch.arange(int(constants.N_PARTS)) / float(constants.N_PARTS)) + eps
self.part_label_bins = torch.linspace(0, constants.N_PARTS-1, constants.N_PARTS) + eps
## Run SMPL forward
self.body_params = body_params
self.smpl_verts, self.smpl_joints = self.get_posed_mesh(debug)
# Assumbe same focal lenght for all frames in a seq
focal_length = self.body_params['cam_k'][0, 0, 0]
# focal_length = focal_length[0]
# Setup Pyrender renderer
# self.renderer = Renderer(focal_length=focal_length, img_w=img_w, img_h=img_h,
# faces=self.smpl_model.faces,
# same_mesh_color=False)
# Setup Pytorch3D Renderer
focal_length = torch.FloatTensor([focal_length])
smpl_faces = torch.from_numpy(self.body_model.faces.astype(np.int32)).to(self.device)
self.renderer = Pytorch3D(img_h=img_h,
img_w=img_w,
focal_length=focal_length,
smpl_faces=smpl_faces,
texture_mode='partseg',
vertex_colors=self.body_part_vertex_colors,
face_textures=self.body_part_texture,
model_type=self.model_type)
def get_posed_mesh(self, debug=False):
betas = torch.from_numpy(self.body_params['shape']).float().to(self.device)
pose = torch.from_numpy(self.body_params['pose']).float().to(self.device)
transl = torch.from_numpy(self.body_params['transl']).float().to(self.device)
# extra smplx params
extra_args = {'jaw_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'leye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'reye_pose': torch.zeros((betas.shape[0], 3)).float().to(self.device),
'expression': torch.zeros((betas.shape[0], 10)).float().to(self.device),
'left_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device),
'right_hand_pose': torch.zeros((betas.shape[0], 45)).float().to(self.device)}
smpl_output = self.body_model(betas=betas,
body_pose=pose[:, 3:],
global_orient=pose[:, :3],
pose2rot=True,
transl=transl,
**extra_args)
smpl_verts = smpl_output.vertices.detach().cpu().numpy()
smpl_joints = smpl_output.joints.detach().cpu().numpy()
if debug:
for mesh_i in range(smpl_verts.shape[0]):
out_dir = 'temp_meshes'
os.makedirs(out_dir, exist_ok=True)
out_file = os.path.join(out_dir, f'temp_mesh_{mesh_i:04d}.obj')
save_results_mesh(smpl_verts[mesh_i], self.body_model.faces, out_file)
return smpl_verts, smpl_joints
def bucketize_part_image(self, color_rgb, mask):
# make single channel
body_parts = color_rgb.clone()
body_parts *= 255. # multiply it with 255 to make labels distant
body_parts = body_parts.max(-1)[0] # reduce to single channel
body_parts = torch.bucketize(body_parts, self.part_label_bins, right=True) # np.digitize(body_parts, bins, right=True)
# add 1 to make background label 0
body_parts = body_parts.long() + 1
body_parts = body_parts * mask.detach()
return body_parts.long()
def create_part_masks(self, body_parts):
# extract every pixel as a separate mask
part_masks = []
for part_id in range(1, constants.N_PARTS+1): # first one is for background
part_mask = (body_parts == part_id)
part_masks.append(part_mask)
return part_masks
def render_part_mask_p3d(self, img_paths, out_dir):
with torch.no_grad():
# os.makedirs(out_dir, exist_ok=True)
for index, img_path in tqdm(enumerate(img_paths), dynamic_ncols=True):
# Load the image
if not os.path.exists(img_path):
if 'train' in img_path:
split = 'train'
elif 'val' in img_path:
split = 'val'
else:
split = 'test'
new_img_name = img_path[img_path.index(split)+4:].replace('/', '_')
new_path = os.path.join('/is/cluster/work/achatterjee/rich/images', split, new_img_name.replace('jpeg', 'bmp'))
if not os.path.exists(new_path):
new_path = new_path.replace('bmp', 'png')
img_path = new_path
if os.path.exists(out_dir[index]):
continue
# img_bgr = cv2.imread(img_path)
chosen_vert_arr = torch.FloatTensor(self.smpl_verts[[index]]).to(self.device)
front_view = self.renderer(chosen_vert_arr)
front_view_rgb = front_view[0, :3, :, :].permute(1,2,0).detach().cpu()
front_view_mask = front_view[0, 3, :, :].detach().cpu()
# front_view_depth = front_view[0, 4, :, :].detach().cpu()
body_parts = self.bucketize_part_image(front_view_rgb, front_view_mask)
body_parts = body_parts.numpy()
front_view_rgb = front_view_rgb.numpy()
# body_part_masks = self.create_part_masks(body_parts)
# display part masks
# for part_id, part_mask in enumerate(body_part_masks):
# part_mask = part_mask * 255
# part_dir = os.path.join(out_dir, f'frame_{index:04d}_parts')
# os.makedirs(part_dir, exist_ok=True)
# out_file = os.path.join(part_dir, f'part_{part_id:02d}_{index:04d}.png')
# cv2.imwrite(out_file, part_mask)
# out_file = os.path.join(out_dir, f'front_view_{index:04d}.png')
# cv2.imwrite(out_file, front_view_rgb[: ,:, [2, 1, 0]]*255)
# print(f'wrote front view to {out_file}')
body_parts = cv2.merge((body_parts, body_parts, body_parts))
# out_file = os.path.join(out_dir, f'body_parts_{index:04d}.png')
out_file = out_dir[index]
cv2.imwrite(out_file, body_parts)
# print(f'wrote body part masks to {out_file}')
def main(args):
out_dir = args.out_dir
data_md = np.load(args.data_npz)
# get all the jpg files in the folder
img_paths = data_md['imgname']
seg_paths = data_md['part_seg']
print(f'found {len(img_paths)} images')
# load first image
img = cv2.imread(img_paths[0])
img_h, img_w, _ = img.shape
labeler = PART_LABELER(body_params=data_md, img_w=img_w, img_h=img_h,
model_type=args.model_type, debug=args.debug)
labeler.render_part_mask_p3d(img_paths=img_paths, out_dir=seg_paths)
if __name__=='__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--out_dir', type=str, default='./temp_part_masks/', help='image folder')
parser.add_argument('--data_npz', type=str, default='.', help='folder with smpl/smpl-x npz')
parser.add_argument('--model_type', type=str, default='smplx', choices=['smpl', 'smplx'], help='model type')
parser.add_argument('--gpu', type=int, default=0, help='gpu id')
parser.add_argument('--debug', action='store_true', help='debug mode', default=False)
args = parser.parse_args()
main(args)
|