Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import tyro | |
import mediapy | |
import torch | |
import numpy as np | |
import pyvista as pv | |
import trimesh | |
from PIL import Image | |
from dreifus.matrix import Intrinsics, Pose, CameraCoordinateConvention, PoseType | |
from dreifus.pyvista import add_camera_frustum, render_from_camera | |
from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix | |
from pixel3dmm.env_paths import PREPROCESSED_DATA, TRACKING_OUTPUT | |
def main(vid_name : str, | |
HEAD_CENTRIC : bool = True, | |
DO_PROJECTION_TEST : bool = False, | |
): | |
tracking_dir = f'{TRACKING_OUTPUT}/{vid_name}_nV1_noPho_uv2000.0_n1000.0' | |
meshes = [f for f in os.listdir(f'{tracking_dir}/mesh/') if f.endswith('.ply') and not 'canonical' in f] | |
meshes.sort() | |
ckpts = [f for f in os.listdir(f'{tracking_dir}/checkpoint/') if f.endswith('.frame')] | |
ckpts.sort() | |
N_STEPS = len(meshes) | |
pl = pv.Plotter() | |
vid_frames = [] | |
for i in range(N_STEPS): | |
ckpt = torch.load(f'{tracking_dir}/checkpoint/{ckpts[i]}', weights_only=False) | |
mesh = trimesh.load(f'{tracking_dir}/mesh/{meshes[i]}', process=False) | |
head_rot = rotation_6d_to_matrix(torch.from_numpy(ckpt['flame']['R'])).numpy()[0] | |
if not HEAD_CENTRIC: | |
# move mesh from FLAME Space into World Space | |
mesh.vertices = mesh.vertices @ head_rot.T + (ckpt['flame']['t']) | |
else: | |
# undo neck rotation | |
verts_hom = np.concatenate([mesh.vertices, np.ones_like(mesh.vertices[..., :1])], axis=-1) | |
verts_hom = verts_hom @ np.linalg.inv(ckpt['joint_transforms'][0, 1, :, :]).T | |
mesh.vertices = verts_hom[..., :3] | |
extr_open_gl_world_to_cam = np.eye(4) | |
extr_open_gl_world_to_cam[:3, :3] = ckpt['camera']['R_base_0'][0] | |
extr_open_gl_world_to_cam[:3, 3] = ckpt['camera']['t_base_0'][0] | |
if HEAD_CENTRIC: | |
flame2world = np.eye(4) | |
flame2world[:3, :3] = head_rot | |
flame2world[:3, 3] = np.squeeze(ckpt['flame']['t']) | |
#TODO include neck transform as well | |
extr_open_gl_world_to_cam = extr_open_gl_world_to_cam @ flame2world @ ckpt['joint_transforms'][0, 1, :, :] | |
extr_open_gl_world_to_cam = Pose(extr_open_gl_world_to_cam, | |
camera_coordinate_convention=CameraCoordinateConvention.OPEN_GL, | |
pose_type=PoseType.WORLD_2_CAM) | |
intr = np.eye(3) | |
intr[0, 0] = ckpt['camera']['fl'][0, 0] * 256 | |
intr[1, 1] = ckpt['camera']['fl'][0, 0] * 256 | |
intr[:2, 2] = ckpt['camera']['pp'][0] * (256/2+0.5) + 256/2 + 0.5 | |
intr = Intrinsics(intr) | |
pl.add_mesh(mesh, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)]) | |
add_camera_frustum(pl, extr_open_gl_world_to_cam, intr, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)]) | |
if DO_PROJECTION_TEST: | |
pll = pv.Plotter(off_screen=True, window_size=(256, 256)) | |
pll.add_mesh(mesh) | |
img = render_from_camera(pll, extr_open_gl_world_to_cam, intr) | |
gt_img = np.array(Image.open(f'{PREPROCESSED_DATA}/{vid_name}/cropped/{i:05d}.jpg').resize((256, 256))) | |
alpha = img[..., 3] | |
overlay = (gt_img *0.5 + img[..., :3]*0.5).astype(np.uint8) | |
vid_frames.append(overlay) | |
pl.show() | |
if DO_PROJECTION_TEST: | |
mediapy.write_video(f'{tracking_dir}/projection_test.mp4', images=vid_frames) | |
if __name__ == '__main__': | |
tyro.cli(main) |