File size: 3,510 Bytes
cf92dec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import os
import tyro
import mediapy
import torch
import numpy as np
import pyvista as pv
import trimesh
from PIL import Image

from dreifus.matrix import Intrinsics, Pose, CameraCoordinateConvention, PoseType
from dreifus.pyvista import add_camera_frustum, render_from_camera

from pixel3dmm.utils.utils_3d import rotation_6d_to_matrix
from pixel3dmm.env_paths import PREPROCESSED_DATA, TRACKING_OUTPUT


def main(vid_name : str,
         HEAD_CENTRIC : bool = True,
         DO_PROJECTION_TEST : bool = False,
         ):
    tracking_dir = f'{TRACKING_OUTPUT}/{vid_name}_nV1_noPho_uv2000.0_n1000.0'

    meshes = [f for f in os.listdir(f'{tracking_dir}/mesh/') if f.endswith('.ply') and not 'canonical' in f]
    meshes.sort()

    ckpts = [f for f in os.listdir(f'{tracking_dir}/checkpoint/') if f.endswith('.frame')]
    ckpts.sort()

    N_STEPS = len(meshes)

    pl = pv.Plotter()
    vid_frames = []
    for i in range(N_STEPS):
        ckpt = torch.load(f'{tracking_dir}/checkpoint/{ckpts[i]}', weights_only=False)

        mesh = trimesh.load(f'{tracking_dir}/mesh/{meshes[i]}', process=False)

        head_rot = rotation_6d_to_matrix(torch.from_numpy(ckpt['flame']['R'])).numpy()[0]

        if not HEAD_CENTRIC:
            # move mesh from FLAME Space into World Space
            mesh.vertices = mesh.vertices @ head_rot.T + (ckpt['flame']['t'])
        else:
            # undo neck rotation
            verts_hom = np.concatenate([mesh.vertices, np.ones_like(mesh.vertices[..., :1])], axis=-1)
            verts_hom = verts_hom @ np.linalg.inv(ckpt['joint_transforms'][0, 1, :, :]).T
            mesh.vertices = verts_hom[..., :3]



        extr_open_gl_world_to_cam = np.eye(4)
        extr_open_gl_world_to_cam[:3, :3] = ckpt['camera']['R_base_0'][0]
        extr_open_gl_world_to_cam[:3, 3] = ckpt['camera']['t_base_0'][0]
        if HEAD_CENTRIC:
            flame2world = np.eye(4)
            flame2world[:3, :3] = head_rot
            flame2world[:3, 3] = np.squeeze(ckpt['flame']['t'])
            #TODO include neck transform as well
            extr_open_gl_world_to_cam = extr_open_gl_world_to_cam @ flame2world @ ckpt['joint_transforms'][0, 1, :, :]




        extr_open_gl_world_to_cam = Pose(extr_open_gl_world_to_cam,
                                         camera_coordinate_convention=CameraCoordinateConvention.OPEN_GL,
                                         pose_type=PoseType.WORLD_2_CAM)

        intr = np.eye(3)
        intr[0, 0] = ckpt['camera']['fl'][0, 0] * 256
        intr[1, 1] = ckpt['camera']['fl'][0, 0] * 256
        intr[:2, 2] = ckpt['camera']['pp'][0] * (256/2+0.5) + 256/2 + 0.5

        intr = Intrinsics(intr)



        pl.add_mesh(mesh, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)])
        add_camera_frustum(pl, extr_open_gl_world_to_cam, intr, color=[(i/N_STEPS), 0, ((N_STEPS-i)/N_STEPS)])

        if DO_PROJECTION_TEST:
            pll = pv.Plotter(off_screen=True, window_size=(256, 256))
            pll.add_mesh(mesh)
            img = render_from_camera(pll, extr_open_gl_world_to_cam, intr)

            gt_img = np.array(Image.open(f'{PREPROCESSED_DATA}/{vid_name}/cropped/{i:05d}.jpg').resize((256, 256)))

            alpha = img[..., 3]

            overlay = (gt_img *0.5 + img[..., :3]*0.5).astype(np.uint8)
            vid_frames.append(overlay)




    pl.show()

    if DO_PROJECTION_TEST:
        mediapy.write_video(f'{tracking_dir}/projection_test.mp4', images=vid_frames)



if __name__ == '__main__':
    tyro.cli(main)