Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import os | |
import gc | |
import logging | |
from lib.common.config import cfg | |
from lib.dataset.mesh_util import ( | |
load_checkpoint, | |
update_mesh_shape_prior_losses, | |
blend_rgb_norm, | |
unwrap, | |
remesh, | |
tensor2variable, | |
) | |
from lib.dataset.TestDataset import TestDataset | |
from lib.common.render import query_color | |
from lib.net.local_affine import LocalAffine | |
from pytorch3d.structures import Meshes | |
from apps.ICON import ICON | |
from termcolor import colored | |
import numpy as np | |
from PIL import Image | |
import trimesh | |
import numpy as np | |
from tqdm import tqdm | |
import torch | |
torch.backends.cudnn.benchmark = True | |
logging.getLogger("trimesh").setLevel(logging.ERROR) | |
def generate_model(in_path, model_type): | |
torch.cuda.empty_cache() | |
if model_type == 'ICON': | |
model_type = 'icon-filter' | |
else: | |
model_type = model_type.lower() | |
config_dict = {'loop_smpl': 100, | |
'loop_cloth': 200, | |
'patience': 5, | |
'out_dir': './results', | |
'hps_type': 'pymaf', | |
'config': f"./configs/{model_type}.yaml"} | |
# cfg read and merge | |
cfg.merge_from_file(config_dict['config']) | |
cfg.merge_from_file("./lib/pymaf/configs/pymaf_config.yaml") | |
os.makedirs(config_dict['out_dir'], exist_ok=True) | |
cfg_show_list = [ | |
"test_gpus", | |
[0], | |
"mcube_res", | |
256, | |
"clean_mesh", | |
True, | |
] | |
cfg.merge_from_list(cfg_show_list) | |
cfg.freeze() | |
os.environ["CUDA_VISIBLE_DEVICES"] = "0" | |
device = torch.device(f"cuda:0") | |
# load model and dataloader | |
model = ICON(cfg) | |
model = load_checkpoint(model, cfg) | |
dataset_param = { | |
'image_path': in_path, | |
'seg_dir': None, | |
'has_det': True, # w/ or w/o detection | |
'hps_type': 'pymaf' # pymaf/pare/pixie | |
} | |
if config_dict['hps_type'] == "pixie" and "pamir" in config_dict['config']: | |
print(colored("PIXIE isn't compatible with PaMIR, thus switch to PyMAF", "red")) | |
dataset_param["hps_type"] = "pymaf" | |
dataset = TestDataset(dataset_param, device) | |
print(colored(f"Dataset Size: {len(dataset)}", "green")) | |
pbar = tqdm(dataset) | |
for data in pbar: | |
pbar.set_description(f"{data['name']}") | |
in_tensor = {"smpl_faces": data["smpl_faces"], "image": data["image"]} | |
# The optimizer and variables | |
optimed_pose = torch.tensor( | |
data["body_pose"], device=device, requires_grad=True | |
) # [1,23,3,3] | |
optimed_trans = torch.tensor( | |
data["trans"], device=device, requires_grad=True | |
) # [3] | |
optimed_betas = torch.tensor( | |
data["betas"], device=device, requires_grad=True | |
) # [1,10] | |
optimed_orient = torch.tensor( | |
data["global_orient"], device=device, requires_grad=True | |
) # [1,1,3,3] | |
optimizer_smpl = torch.optim.SGD( | |
[optimed_pose, optimed_trans, optimed_betas, optimed_orient], | |
lr=1e-3, | |
momentum=0.9, | |
) | |
scheduler_smpl = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer_smpl, | |
mode="min", | |
factor=0.5, | |
verbose=0, | |
min_lr=1e-5, | |
patience=config_dict['patience'], | |
) | |
losses = { | |
# Cloth: Normal_recon - Normal_pred | |
"cloth": {"weight": 1e1, "value": 0.0}, | |
# Cloth: [RT]_v1 - [RT]_v2 (v1-edge-v2) | |
"stiffness": {"weight": 1e5, "value": 0.0}, | |
# Cloth: det(R) = 1 | |
"rigid": {"weight": 1e5, "value": 0.0}, | |
# Cloth: edge length | |
"edge": {"weight": 0, "value": 0.0}, | |
# Cloth: normal consistency | |
"nc": {"weight": 0, "value": 0.0}, | |
# Cloth: laplacian smoonth | |
"laplacian": {"weight": 1e2, "value": 0.0}, | |
# Body: Normal_pred - Normal_smpl | |
"normal": {"weight": 1e0, "value": 0.0}, | |
# Body: Silhouette_pred - Silhouette_smpl | |
"silhouette": {"weight": 1e0, "value": 0.0}, | |
} | |
# smpl optimization | |
loop_smpl = tqdm(range(config_dict['loop_smpl'])) | |
for _ in loop_smpl: | |
optimizer_smpl.zero_grad() | |
if dataset_param["hps_type"] != "pixie": | |
smpl_out = dataset.smpl_model( | |
betas=optimed_betas, | |
body_pose=optimed_pose, | |
global_orient=optimed_orient, | |
pose2rot=False, | |
) | |
smpl_verts = ((smpl_out.vertices) + | |
optimed_trans) * data["scale"] | |
else: | |
smpl_verts, _, _ = dataset.smpl_model( | |
shape_params=optimed_betas, | |
expression_params=tensor2variable(data["exp"], device), | |
body_pose=optimed_pose, | |
global_pose=optimed_orient, | |
jaw_pose=tensor2variable(data["jaw_pose"], device), | |
left_hand_pose=tensor2variable( | |
data["left_hand_pose"], device), | |
right_hand_pose=tensor2variable( | |
data["right_hand_pose"], device), | |
) | |
smpl_verts = (smpl_verts + optimed_trans) * data["scale"] | |
# render optimized mesh (normal, T_normal, image [-1,1]) | |
in_tensor["T_normal_F"], in_tensor["T_normal_B"] = dataset.render_normal( | |
smpl_verts * | |
torch.tensor([1.0, -1.0, -1.0] | |
).to(device), in_tensor["smpl_faces"] | |
) | |
T_mask_F, T_mask_B = dataset.render.get_silhouette_image() | |
with torch.no_grad(): | |
in_tensor["normal_F"], in_tensor["normal_B"] = model.netG.normal_filter( | |
in_tensor | |
) | |
diff_F_smpl = torch.abs( | |
in_tensor["T_normal_F"] - in_tensor["normal_F"]) | |
diff_B_smpl = torch.abs( | |
in_tensor["T_normal_B"] - in_tensor["normal_B"]) | |
losses["normal"]["value"] = (diff_F_smpl + diff_B_smpl).mean() | |
# silhouette loss | |
smpl_arr = torch.cat([T_mask_F, T_mask_B], dim=-1)[0] | |
gt_arr = torch.cat( | |
[in_tensor["normal_F"][0], in_tensor["normal_B"][0]], dim=2 | |
).permute(1, 2, 0) | |
gt_arr = ((gt_arr + 1.0) * 0.5).to(device) | |
bg_color = ( | |
torch.Tensor([0.5, 0.5, 0.5]).unsqueeze( | |
0).unsqueeze(0).to(device) | |
) | |
gt_arr = ((gt_arr - bg_color).sum(dim=-1) != 0.0).float() | |
diff_S = torch.abs(smpl_arr - gt_arr) | |
losses["silhouette"]["value"] = diff_S.mean() | |
# Weighted sum of the losses | |
smpl_loss = 0.0 | |
pbar_desc = "Body Fitting --- " | |
for k in ["normal", "silhouette"]: | |
pbar_desc += f"{k}: {losses[k]['value'] * losses[k]['weight']:.3f} | " | |
smpl_loss += losses[k]["value"] * losses[k]["weight"] | |
pbar_desc += f"Total: {smpl_loss:.3f}" | |
loop_smpl.set_description(pbar_desc) | |
smpl_loss.backward() | |
optimizer_smpl.step() | |
scheduler_smpl.step(smpl_loss) | |
in_tensor["smpl_verts"] = smpl_verts * \ | |
torch.tensor([1.0, 1.0, -1.0]).to(device) | |
# visualize the optimization process | |
# 1. SMPL Fitting | |
# 2. Clothes Refinement | |
os.makedirs(os.path.join(config_dict['out_dir'], cfg.name, | |
"refinement"), exist_ok=True) | |
# visualize the final results in self-rotation mode | |
os.makedirs(os.path.join(config_dict['out_dir'], | |
cfg.name, "vid"), exist_ok=True) | |
# final results rendered as image | |
# 1. Render the final fitted SMPL (xxx_smpl.png) | |
# 2. Render the final reconstructed clothed human (xxx_cloth.png) | |
# 3. Blend the original image with predicted cloth normal (xxx_overlap.png) | |
os.makedirs(os.path.join(config_dict['out_dir'], | |
cfg.name, "png"), exist_ok=True) | |
# final reconstruction meshes | |
# 1. SMPL mesh (xxx_smpl.obj) | |
# 2. SMPL params (xxx_smpl.npy) | |
# 3. clohted mesh (xxx_recon.obj) | |
# 4. remeshed clothed mesh (xxx_remesh.obj) | |
# 5. refined clothed mesh (xxx_refine.obj) | |
os.makedirs(os.path.join(config_dict['out_dir'], | |
cfg.name, "obj"), exist_ok=True) | |
norm_pred_F = ( | |
((in_tensor["normal_F"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) | |
.detach() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
norm_pred_B = ( | |
((in_tensor["normal_B"][0].permute(1, 2, 0) + 1.0) * 255.0 / 2.0) | |
.detach() | |
.cpu() | |
.numpy() | |
.astype(np.uint8) | |
) | |
norm_orig_F = unwrap(norm_pred_F, data) | |
norm_orig_B = unwrap(norm_pred_B, data) | |
mask_orig = unwrap( | |
np.repeat( | |
data["mask"].permute(1, 2, 0).detach().cpu().numpy(), 3, axis=2 | |
).astype(np.uint8), | |
data, | |
) | |
rgb_norm_F = blend_rgb_norm(data["ori_image"], norm_orig_F, mask_orig) | |
rgb_norm_B = blend_rgb_norm(data["ori_image"], norm_orig_B, mask_orig) | |
Image.fromarray( | |
np.concatenate( | |
[data["ori_image"].astype(np.uint8), rgb_norm_F, rgb_norm_B], axis=1) | |
).save(os.path.join(config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png")) | |
smpl_obj = trimesh.Trimesh( | |
in_tensor["smpl_verts"].detach().cpu()[0] * | |
torch.tensor([1.0, -1.0, 1.0]), | |
in_tensor['smpl_faces'].detach().cpu()[0], | |
process=False, | |
maintains_order=True | |
) | |
smpl_obj.visual.vertex_colors = (smpl_obj.vertex_normals+1.0)*255.0*0.5 | |
smpl_obj.export( | |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj") | |
smpl_obj.export( | |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb") | |
smpl_info = {'betas': optimed_betas, | |
'pose': optimed_pose, | |
'orient': optimed_orient, | |
'trans': optimed_trans} | |
np.save( | |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy", smpl_info, allow_pickle=True) | |
# ------------------------------------------------------------------------------------------------------------------ | |
# cloth optimization | |
# cloth recon | |
in_tensor.update( | |
dataset.compute_vis_cmap( | |
in_tensor["smpl_verts"][0], in_tensor["smpl_faces"][0] | |
) | |
) | |
if cfg.net.prior_type == "pamir": | |
in_tensor.update( | |
dataset.compute_voxel_verts( | |
optimed_pose, | |
optimed_orient, | |
optimed_betas, | |
optimed_trans, | |
data["scale"], | |
) | |
) | |
with torch.no_grad(): | |
verts_pr, faces_pr, _ = model.test_single(in_tensor) | |
recon_obj = trimesh.Trimesh( | |
verts_pr, faces_pr, process=False, maintains_order=True | |
) | |
recon_obj.visual.vertex_colors = ( | |
recon_obj.vertex_normals+1.0)*255.0*0.5 | |
recon_obj.export( | |
os.path.join(config_dict['out_dir'], cfg.name, | |
f"obj/{data['name']}_recon.obj") | |
) | |
# Isotropic Explicit Remeshing for better geometry topology | |
verts_refine, faces_refine = remesh(os.path.join(config_dict['out_dir'], cfg.name, | |
f"obj/{data['name']}_recon.obj"), 0.5, device) | |
# define local_affine deform verts | |
mesh_pr = Meshes(verts_refine, faces_refine).to(device) | |
local_affine_model = LocalAffine( | |
mesh_pr.verts_padded().shape[1], mesh_pr.verts_padded().shape[0], mesh_pr.edges_packed()).to(device) | |
optimizer_cloth = torch.optim.Adam( | |
[{'params': local_affine_model.parameters()}], lr=1e-4, amsgrad=True) | |
scheduler_cloth = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
optimizer_cloth, | |
mode="min", | |
factor=0.1, | |
verbose=0, | |
min_lr=1e-5, | |
patience=config_dict['patience'], | |
) | |
final = None | |
if config_dict['loop_cloth'] > 0: | |
loop_cloth = tqdm(range(config_dict['loop_cloth'])) | |
for _ in loop_cloth: | |
optimizer_cloth.zero_grad() | |
deformed_verts, stiffness, rigid = local_affine_model( | |
verts_refine.to(device), return_stiff=True) | |
mesh_pr = mesh_pr.update_padded(deformed_verts) | |
# losses for laplacian, edge, normal consistency | |
update_mesh_shape_prior_losses(mesh_pr, losses) | |
in_tensor["P_normal_F"], in_tensor["P_normal_B"] = dataset.render_normal( | |
mesh_pr.verts_padded(), mesh_pr.faces_padded()) | |
diff_F_cloth = torch.abs( | |
in_tensor["P_normal_F"] - in_tensor["normal_F"]) | |
diff_B_cloth = torch.abs( | |
in_tensor["P_normal_B"] - in_tensor["normal_B"]) | |
losses["cloth"]["value"] = (diff_F_cloth + diff_B_cloth).mean() | |
losses["stiffness"]["value"] = torch.mean(stiffness) | |
losses["rigid"]["value"] = torch.mean(rigid) | |
# Weighted sum of the losses | |
cloth_loss = torch.tensor(0.0, requires_grad=True).to(device) | |
pbar_desc = "Cloth Refinement --- " | |
for k in losses.keys(): | |
if k not in ["normal", "silhouette"] and losses[k]["weight"] > 0.0: | |
cloth_loss = cloth_loss + \ | |
losses[k]["value"] * losses[k]["weight"] | |
pbar_desc += f"{k}:{losses[k]['value']* losses[k]['weight']:.5f} | " | |
pbar_desc += f"Total: {cloth_loss:.5f}" | |
loop_cloth.set_description(pbar_desc) | |
# update params | |
cloth_loss.backward() | |
optimizer_cloth.step() | |
scheduler_cloth.step(cloth_loss) | |
final = trimesh.Trimesh( | |
mesh_pr.verts_packed().detach().squeeze(0).cpu(), | |
mesh_pr.faces_packed().detach().squeeze(0).cpu(), | |
process=False, maintains_order=True | |
) | |
# only with front texture | |
tex_colors = query_color( | |
mesh_pr.verts_packed().detach().squeeze(0).cpu(), | |
mesh_pr.faces_packed().detach().squeeze(0).cpu(), | |
in_tensor["image"], | |
device=device, | |
) | |
# full normal textures | |
norm_colors = (mesh_pr.verts_normals_padded().squeeze( | |
0).detach().cpu() + 1.0) * 0.5 * 255.0 | |
final.visual.vertex_colors = tex_colors | |
final.export( | |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj") | |
final.visual.vertex_colors = norm_colors | |
final.export( | |
f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb") | |
# always export visualized video regardless of the cloth refinment | |
verts_lst = [smpl_obj.vertices, final.vertices] | |
faces_lst = [smpl_obj.faces, final.faces] | |
# self-rotated video | |
dataset.render.load_meshes( | |
verts_lst, faces_lst) | |
dataset.render.get_rendered_video( | |
[data["ori_image"], rgb_norm_F, rgb_norm_B], | |
os.path.join(config_dict['out_dir'], cfg.name, | |
f"vid/{data['name']}_cloth.mp4"), | |
) | |
smpl_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.obj" | |
smpl_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.glb" | |
smpl_npy_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_smpl.npy" | |
refine_obj_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.obj" | |
refine_glb_path = f"{config_dict['out_dir']}/{cfg.name}/obj/{data['name']}_refine.glb" | |
video_path = os.path.join( | |
config_dict['out_dir'], cfg.name, f"vid/{data['name']}_cloth.mp4") | |
overlap_path = os.path.join( | |
config_dict['out_dir'], cfg.name, f"png/{data['name']}_overlap.png") | |
# clean all the variables | |
for element in dir(): | |
if 'path' not in element: | |
del locals()[element] | |
gc.collect() | |
torch.cuda.empty_cache() | |
return [smpl_glb_path, smpl_obj_path,smpl_npy_path, | |
refine_glb_path, refine_obj_path, | |
video_path, video_path, overlap_path] | |