DECO / vis /visualize.py
ac5113's picture
added files
99a05f0
raw
history blame
11.2 kB
import cv2
import os
import trimesh
import PIL.Image as pil_img
import numpy as np
import pyrender
from common import constants
os.environ['PYOPENGL_PLATFORM'] = 'egl'
def render_image(scene, img_res, img=None, viewer=False):
'''
Render the given pyrender scene and return the image. Can also overlay the mesh on an image.
'''
if viewer:
pyrender.Viewer(scene, use_raymond_lighting=True)
return 0
else:
r = pyrender.OffscreenRenderer(viewport_width=img_res,
viewport_height=img_res,
point_size=1.0)
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA)
color = color.astype(np.float32) / 255.0
if img is not None:
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis]
input_img = img.detach().cpu().numpy()
output_img = (color[:, :, :-1] * valid_mask +
(1 - valid_mask) * input_img)
else:
output_img = color
return output_img
def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500):
# Setup the scene
scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0],
ambient_light=(0.3, 0.3, 0.3))
# add mesh for camera
camera_pose = np.eye(4)
camera_rotation = np.eye(3, 3)
camera_translation = np.array([0., 0, 2.5])
camera_pose[:3, :3] = camera_rotation
camera_pose[:3, 3] = camera_rotation @ camera_translation
pyrencamera = pyrender.camera.IntrinsicsCamera(
fx=focal_length, fy=focal_length,
cx=camera_center, cy=camera_center)
scene.add(pyrencamera, pose=camera_pose)
# create and add light
light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1)
light_pose = np.eye(4)
for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]:
light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp)
# out_mesh.vertices.mean(0) + np.array(lp)
scene.add(light, pose=light_pose)
# add body mesh
material = pyrender.MetallicRoughnessMaterial(
metallicFactor=0.0,
alphaMode='OPAQUE',
baseColorFactor=(1.0, 1.0, 0.9, 1.0))
mesh_images = []
# resize input image to fit the mesh image height
# print(img.shape)
img_height = img_res
img_width = int(img_height * img.shape[1] / img.shape[0])
img = cv2.resize(img, (img_width, img_height))
mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
for sideview_angle in [0, 90, 180, 270]:
out_mesh = mesh.copy()
rot = trimesh.transformations.rotation_matrix(
np.radians(sideview_angle), [0, 1, 0])
out_mesh.apply_transform(rot)
out_mesh = pyrender.Mesh.from_trimesh(
out_mesh,
material=material)
mesh_pose = np.eye(4)
scene.add(out_mesh, pose=mesh_pose, name='mesh')
output_img = render_image(scene, img_res)
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
output_img = np.asarray(output_img)[:, :, :3]
mesh_images.append(output_img)
# delete the previous mesh
prev_mesh = scene.get_nodes(name='mesh').pop()
scene.remove_node(prev_mesh)
# show upside down view
for topview_angle in [90, 270]:
out_mesh = mesh.copy()
rot = trimesh.transformations.rotation_matrix(
np.radians(topview_angle), [1, 0, 0])
out_mesh.apply_transform(rot)
out_mesh = pyrender.Mesh.from_trimesh(
out_mesh,
material=material)
mesh_pose = np.eye(4)
scene.add(out_mesh, pose=mesh_pose, name='mesh')
output_img = render_image(scene, img_res)
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8))
output_img = np.asarray(output_img)[:, :, :3]
mesh_images.append(output_img)
# delete the previous mesh
prev_mesh = scene.get_nodes(name='mesh').pop()
scene.remove_node(prev_mesh)
# stack images
IMG = np.hstack(mesh_images)
IMG = pil_img.fromarray(IMG)
IMG.thumbnail((3000, 3000))
return IMG
# img = cv2.imread('../samples/prox_N3OpenArea_03301_01_s001_frame_00694.jpg')
# mesh = trimesh.load('../samples/mesh.ply', process=False)
# comb_img = create_scene(mesh, img)
# comb_img.save('../samples/combined_image.png')
def unsplit(img, palette):
rgb_img = np.zeros((img.shape[0], img.shape[1], 3))
for i in range(img.shape[0]):
for j in range(img.shape[1]):
id = np.argmax(img[i, j, :])
rgb_img[i, j, :] = palette[id]
return rgb_img
def gen_render(output, normalize=True):
img = output['img'].cpu().numpy()
contact_labels_3d = output['contact_labels_3d_gt'].cpu().numpy()
contact_labels_3d_pred = output['contact_labels_3d_pred'].cpu().numpy()
sem_mask_gt = output['sem_mask_gt'].cpu().numpy()
sem_mask_pred = output['sem_mask_pred'].cpu().numpy()
part_mask_gt = output['part_mask_gt'].cpu().numpy()
part_mask_pred = output['part_mask_pred'].cpu().numpy()
contact_2d_gt_rgb = output['contact_2d_gt'].cpu().numpy()
contact_2d_pred_rgb = output['contact_2d_pred_rgb'].cpu().numpy()
mesh_path = './data/smpl/smpl_neutral_tpose.ply'
gt_mesh = trimesh.load(mesh_path, process=False)
pred_mesh = trimesh.load(mesh_path, process=False)
img = np.transpose(img[0], (1, 2, 0))
if normalize:
# unnormalize the image before displaying
mean = np.array(constants.IMG_NORM_MEAN, dtype=np.float32)
std = np.array(constants.IMG_NORM_STD, dtype=np.float32)
img = img * std + mean
img = img * 255
img = img.astype(np.uint8)
color = np.array([0, 0, 0, 255])
th = 0.5
contact_labels_3d = contact_labels_3d[0, :]
for vid, val in enumerate(contact_labels_3d):
if val >= th:
gt_mesh.visual.vertex_colors[vid] = color
contact_labels_3d_pred = contact_labels_3d_pred[0, :]
for vid, val in enumerate(contact_labels_3d_pred):
if val >= th:
pred_mesh.visual.vertex_colors[vid] = color
gt_rend = create_scene(gt_mesh, img)
pred_rend = create_scene(pred_mesh, img)
sem_palette = [[220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240], [0, 125, 92], [209, 0, 151], [188, 208, 182], [0, 220, 176], [255, 99, 164], [92, 0, 73], [133, 129, 255], [78, 180, 255], [0, 228, 0], [174, 255, 243], [45, 89, 255], [134, 134, 103], [145, 148, 174], [255, 208, 186], [197, 226, 255], [171, 134, 1], [109, 63, 54], [207, 138, 255], [151, 0, 95], [9, 80, 61], [84, 105, 51], [74, 65, 105], [166, 196, 102], [208, 195, 210], [255, 109, 65], [0, 143, 149], [179, 0, 194], [209, 99, 106], [5, 121, 0], [227, 255, 205], [147, 186, 208], [153, 69, 1], [3, 95, 161], [163, 255, 0], [119, 0, 170], [0, 182, 199], [0, 165, 120], [183, 130, 88], [95, 32, 0], [130, 114, 135], [110, 129, 133], [166, 74, 118], [219, 142, 185], [79, 210, 114], [178, 90, 62], [65, 70, 15], [127, 167, 115], [59, 105, 106], [142, 108, 45], [196, 172, 0], [95, 54, 80], [128, 76, 255], [201, 57, 1], [246, 0, 122], [191, 162, 208], [255, 255, 128], [147, 211, 203], [150, 100, 100], [168, 171, 172], [146, 112, 198], [210, 170, 100], [92, 136, 89], [218, 88, 184], [241, 129, 0], [217, 17, 255], [124, 74, 181], [70, 70, 70], [255, 228, 255], [154, 208, 0], [193, 0, 92], [76, 91, 113], [255, 180, 195], [106, 154, 176], [230, 150, 140], [60, 143, 255], [128, 64, 128], [92, 82, 55], [254, 212, 124], [73, 77, 174], [255, 160, 98], [255, 255, 255], [104, 84, 109], [169, 164, 131], [225, 199, 255], [137, 54, 74], [135, 158, 223], [7, 246, 231], [107, 255, 200], [58, 41, 149], [183, 121, 142], [255, 73, 97], [107, 142, 35], [190, 153, 153], [146, 139, 141], [70, 130, 180], [134, 199, 156], [209, 226, 140], [96, 36, 108], [96, 96, 96], [64, 170, 64], [152, 251, 152], [208, 229, 228], [206, 186, 171], [152, 161, 64], [116, 112, 0], [0, 114, 143], [102, 102, 156], [250, 141, 255]]
# part_palette = [(0,0,0), (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0), (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)]
part_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240]]
hot_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252]]
sem_mask_gt = np.transpose(sem_mask_gt[0], (1, 2, 0))*255
sem_mask_gt = sem_mask_gt.astype(np.uint8)
sem_mask_pred = np.transpose(sem_mask_pred[0], (1, 2, 0))*255
sem_mask_pred = sem_mask_pred.astype(np.uint8)
part_mask_gt = np.transpose(part_mask_gt[0], (1, 2, 0))*255
part_mask_gt = part_mask_gt.astype(np.uint8)
part_mask_pred = np.transpose(part_mask_pred[0], (1, 2, 0))*255
part_mask_pred = part_mask_pred.astype(np.uint8)
contact_2d_gt_rgb = contact_2d_gt_rgb[0]*255
contact_2d_gt_rgb = contact_2d_gt_rgb.astype(np.uint8)
contact_2d_pred_rgb = contact_2d_pred_rgb[0]*255
contact_2d_pred_rgb = contact_2d_pred_rgb.astype(np.uint8)
sem_mask_rgb = unsplit(sem_mask_gt, sem_palette)
sem_pred_rgb = unsplit(sem_mask_pred, sem_palette)
part_mask_rgb = unsplit(part_mask_gt, part_palette)
part_pred_rgb = unsplit(part_mask_pred, part_palette)
sem_mask_rgb = sem_mask_rgb.astype(np.uint8)
sem_pred_rgb = sem_pred_rgb.astype(np.uint8)
part_mask_rgb = part_mask_rgb.astype(np.uint8)
part_pred_rgb = part_pred_rgb.astype(np.uint8)
sem_mask_rgb = pil_img.fromarray(sem_mask_rgb)
sem_pred_rgb = pil_img.fromarray(sem_pred_rgb)
part_mask_rgb = pil_img.fromarray(part_mask_rgb)
part_pred_rgb = pil_img.fromarray(part_pred_rgb)
contact_2d_gt_rgb = pil_img.fromarray(contact_2d_gt_rgb)
contact_2d_pred_rgb = pil_img.fromarray(contact_2d_pred_rgb)
tot_rend = pil_img.new('RGB', (3000, 2000))
tot_rend.paste(gt_rend, (0, 0))
tot_rend.paste(pred_rend, (0, 450))
tot_rend.paste(sem_mask_rgb, (0, 900))
tot_rend.paste(sem_pred_rgb, (400, 900))
tot_rend.paste(part_mask_rgb, (0, 1300))
tot_rend.paste(part_pred_rgb, (400, 1300))
tot_rend.paste(contact_2d_gt_rgb, (0, 1700))
tot_rend.paste(contact_2d_pred_rgb, (400, 1700))
return tot_rend