Spaces:
Sleeping
Sleeping
import cv2 | |
import os | |
import trimesh | |
import PIL.Image as pil_img | |
import numpy as np | |
import pyrender | |
from common import constants | |
os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
def render_image(scene, img_res, img=None, viewer=False): | |
''' | |
Render the given pyrender scene and return the image. Can also overlay the mesh on an image. | |
''' | |
if viewer: | |
pyrender.Viewer(scene, use_raymond_lighting=True) | |
return 0 | |
else: | |
r = pyrender.OffscreenRenderer(viewport_width=img_res, | |
viewport_height=img_res, | |
point_size=1.0) | |
color, _ = r.render(scene, flags=pyrender.RenderFlags.RGBA) | |
color = color.astype(np.float32) / 255.0 | |
if img is not None: | |
valid_mask = (color[:, :, -1] > 0)[:, :, np.newaxis] | |
input_img = img.detach().cpu().numpy() | |
output_img = (color[:, :, :-1] * valid_mask + | |
(1 - valid_mask) * input_img) | |
else: | |
output_img = color | |
return output_img | |
def create_scene(mesh, img, focal_length=500, camera_center=250, img_res=500): | |
# Setup the scene | |
scene = pyrender.Scene(bg_color=[1.0, 1.0, 1.0, 1.0], | |
ambient_light=(0.3, 0.3, 0.3)) | |
# add mesh for camera | |
camera_pose = np.eye(4) | |
camera_rotation = np.eye(3, 3) | |
camera_translation = np.array([0., 0, 2.5]) | |
camera_pose[:3, :3] = camera_rotation | |
camera_pose[:3, 3] = camera_rotation @ camera_translation | |
pyrencamera = pyrender.camera.IntrinsicsCamera( | |
fx=focal_length, fy=focal_length, | |
cx=camera_center, cy=camera_center) | |
scene.add(pyrencamera, pose=camera_pose) | |
# create and add light | |
light = pyrender.PointLight(color=[1.0, 1.0, 1.0], intensity=1) | |
light_pose = np.eye(4) | |
for lp in [[1, 1, 1], [-1, 1, 1], [1, -1, 1], [-1, -1, 1]]: | |
light_pose[:3, 3] = mesh.vertices.mean(0) + np.array(lp) | |
# out_mesh.vertices.mean(0) + np.array(lp) | |
scene.add(light, pose=light_pose) | |
# add body mesh | |
material = pyrender.MetallicRoughnessMaterial( | |
metallicFactor=0.0, | |
alphaMode='OPAQUE', | |
baseColorFactor=(1.0, 1.0, 0.9, 1.0)) | |
mesh_images = [] | |
# resize input image to fit the mesh image height | |
# print(img.shape) | |
img_height = img_res | |
img_width = int(img_height * img.shape[1] / img.shape[0]) | |
img = cv2.resize(img, (img_width, img_height)) | |
mesh_images.append(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
for sideview_angle in [0, 90, 180, 270]: | |
out_mesh = mesh.copy() | |
rot = trimesh.transformations.rotation_matrix( | |
np.radians(sideview_angle), [0, 1, 0]) | |
out_mesh.apply_transform(rot) | |
out_mesh = pyrender.Mesh.from_trimesh( | |
out_mesh, | |
material=material) | |
mesh_pose = np.eye(4) | |
scene.add(out_mesh, pose=mesh_pose, name='mesh') | |
output_img = render_image(scene, img_res) | |
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8)) | |
output_img = np.asarray(output_img)[:, :, :3] | |
mesh_images.append(output_img) | |
# delete the previous mesh | |
prev_mesh = scene.get_nodes(name='mesh').pop() | |
scene.remove_node(prev_mesh) | |
# show upside down view | |
for topview_angle in [90, 270]: | |
out_mesh = mesh.copy() | |
rot = trimesh.transformations.rotation_matrix( | |
np.radians(topview_angle), [1, 0, 0]) | |
out_mesh.apply_transform(rot) | |
out_mesh = pyrender.Mesh.from_trimesh( | |
out_mesh, | |
material=material) | |
mesh_pose = np.eye(4) | |
scene.add(out_mesh, pose=mesh_pose, name='mesh') | |
output_img = render_image(scene, img_res) | |
output_img = pil_img.fromarray((output_img * 255).astype(np.uint8)) | |
output_img = np.asarray(output_img)[:, :, :3] | |
mesh_images.append(output_img) | |
# delete the previous mesh | |
prev_mesh = scene.get_nodes(name='mesh').pop() | |
scene.remove_node(prev_mesh) | |
# stack images | |
IMG = np.hstack(mesh_images) | |
IMG = pil_img.fromarray(IMG) | |
IMG.thumbnail((3000, 3000)) | |
return IMG | |
# img = cv2.imread('../samples/prox_N3OpenArea_03301_01_s001_frame_00694.jpg') | |
# mesh = trimesh.load('../samples/mesh.ply', process=False) | |
# comb_img = create_scene(mesh, img) | |
# comb_img.save('../samples/combined_image.png') | |
def unsplit(img, palette): | |
rgb_img = np.zeros((img.shape[0], img.shape[1], 3)) | |
for i in range(img.shape[0]): | |
for j in range(img.shape[1]): | |
id = np.argmax(img[i, j, :]) | |
rgb_img[i, j, :] = palette[id] | |
return rgb_img | |
def gen_render(output, normalize=True): | |
img = output['img'].cpu().numpy() | |
contact_labels_3d = output['contact_labels_3d_gt'].cpu().numpy() | |
contact_labels_3d_pred = output['contact_labels_3d_pred'].cpu().numpy() | |
sem_mask_gt = output['sem_mask_gt'].cpu().numpy() | |
sem_mask_pred = output['sem_mask_pred'].cpu().numpy() | |
part_mask_gt = output['part_mask_gt'].cpu().numpy() | |
part_mask_pred = output['part_mask_pred'].cpu().numpy() | |
contact_2d_gt_rgb = output['contact_2d_gt'].cpu().numpy() | |
contact_2d_pred_rgb = output['contact_2d_pred_rgb'].cpu().numpy() | |
mesh_path = './data/smpl/smpl_neutral_tpose.ply' | |
gt_mesh = trimesh.load(mesh_path, process=False) | |
pred_mesh = trimesh.load(mesh_path, process=False) | |
img = np.transpose(img[0], (1, 2, 0)) | |
if normalize: | |
# unnormalize the image before displaying | |
mean = np.array(constants.IMG_NORM_MEAN, dtype=np.float32) | |
std = np.array(constants.IMG_NORM_STD, dtype=np.float32) | |
img = img * std + mean | |
img = img * 255 | |
img = img.astype(np.uint8) | |
color = np.array([0, 0, 0, 255]) | |
th = 0.5 | |
contact_labels_3d = contact_labels_3d[0, :] | |
for vid, val in enumerate(contact_labels_3d): | |
if val >= th: | |
gt_mesh.visual.vertex_colors[vid] = color | |
contact_labels_3d_pred = contact_labels_3d_pred[0, :] | |
for vid, val in enumerate(contact_labels_3d_pred): | |
if val >= th: | |
pred_mesh.visual.vertex_colors[vid] = color | |
gt_rend = create_scene(gt_mesh, img) | |
pred_rend = create_scene(pred_mesh, img) | |
sem_palette = [[220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240], [0, 125, 92], [209, 0, 151], [188, 208, 182], [0, 220, 176], [255, 99, 164], [92, 0, 73], [133, 129, 255], [78, 180, 255], [0, 228, 0], [174, 255, 243], [45, 89, 255], [134, 134, 103], [145, 148, 174], [255, 208, 186], [197, 226, 255], [171, 134, 1], [109, 63, 54], [207, 138, 255], [151, 0, 95], [9, 80, 61], [84, 105, 51], [74, 65, 105], [166, 196, 102], [208, 195, 210], [255, 109, 65], [0, 143, 149], [179, 0, 194], [209, 99, 106], [5, 121, 0], [227, 255, 205], [147, 186, 208], [153, 69, 1], [3, 95, 161], [163, 255, 0], [119, 0, 170], [0, 182, 199], [0, 165, 120], [183, 130, 88], [95, 32, 0], [130, 114, 135], [110, 129, 133], [166, 74, 118], [219, 142, 185], [79, 210, 114], [178, 90, 62], [65, 70, 15], [127, 167, 115], [59, 105, 106], [142, 108, 45], [196, 172, 0], [95, 54, 80], [128, 76, 255], [201, 57, 1], [246, 0, 122], [191, 162, 208], [255, 255, 128], [147, 211, 203], [150, 100, 100], [168, 171, 172], [146, 112, 198], [210, 170, 100], [92, 136, 89], [218, 88, 184], [241, 129, 0], [217, 17, 255], [124, 74, 181], [70, 70, 70], [255, 228, 255], [154, 208, 0], [193, 0, 92], [76, 91, 113], [255, 180, 195], [106, 154, 176], [230, 150, 140], [60, 143, 255], [128, 64, 128], [92, 82, 55], [254, 212, 124], [73, 77, 174], [255, 160, 98], [255, 255, 255], [104, 84, 109], [169, 164, 131], [225, 199, 255], [137, 54, 74], [135, 158, 223], [7, 246, 231], [107, 255, 200], [58, 41, 149], [183, 121, 142], [255, 73, 97], [107, 142, 35], [190, 153, 153], [146, 139, 141], [70, 130, 180], [134, 199, 156], [209, 226, 140], [96, 36, 108], [96, 96, 96], [64, 170, 64], [152, 251, 152], [208, 229, 228], [206, 186, 171], [152, 161, 64], [116, 112, 0], [0, 114, 143], [102, 102, 156], [250, 141, 255]] | |
# part_palette = [(0,0,0), (128,0,0), (255,0,0), (0,85,0), (170,0,51), (255,85,0), (0,0,85), (0,119,221), (85,85,0), (0,85,85), (85,51,0), (52,86,128), (0,128,0), (0,0,255), (51,170,221), (0,255,255), (85,255,170), (170,255,85), (255,255,0), (255,170,0)] | |
part_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252], [182, 182, 255], [0, 82, 0], [120, 166, 157], [110, 76, 0], [174, 57, 255], [199, 100, 0], [72, 0, 118], [255, 179, 240]] | |
hot_palette = [[0, 0, 0], [220, 20, 60], [119, 11, 32], [0, 0, 142], [0, 0, 230], [106, 0, 228], [0, 60, 100], [0, 80, 100], [0, 0, 70], [0, 0, 192], [250, 170, 30], [100, 170, 30], [220, 220, 0], [175, 116, 175], [250, 0, 30], [165, 42, 42], [255, 77, 255], [0, 226, 252]] | |
sem_mask_gt = np.transpose(sem_mask_gt[0], (1, 2, 0))*255 | |
sem_mask_gt = sem_mask_gt.astype(np.uint8) | |
sem_mask_pred = np.transpose(sem_mask_pred[0], (1, 2, 0))*255 | |
sem_mask_pred = sem_mask_pred.astype(np.uint8) | |
part_mask_gt = np.transpose(part_mask_gt[0], (1, 2, 0))*255 | |
part_mask_gt = part_mask_gt.astype(np.uint8) | |
part_mask_pred = np.transpose(part_mask_pred[0], (1, 2, 0))*255 | |
part_mask_pred = part_mask_pred.astype(np.uint8) | |
contact_2d_gt_rgb = contact_2d_gt_rgb[0]*255 | |
contact_2d_gt_rgb = contact_2d_gt_rgb.astype(np.uint8) | |
contact_2d_pred_rgb = contact_2d_pred_rgb[0]*255 | |
contact_2d_pred_rgb = contact_2d_pred_rgb.astype(np.uint8) | |
sem_mask_rgb = unsplit(sem_mask_gt, sem_palette) | |
sem_pred_rgb = unsplit(sem_mask_pred, sem_palette) | |
part_mask_rgb = unsplit(part_mask_gt, part_palette) | |
part_pred_rgb = unsplit(part_mask_pred, part_palette) | |
sem_mask_rgb = sem_mask_rgb.astype(np.uint8) | |
sem_pred_rgb = sem_pred_rgb.astype(np.uint8) | |
part_mask_rgb = part_mask_rgb.astype(np.uint8) | |
part_pred_rgb = part_pred_rgb.astype(np.uint8) | |
sem_mask_rgb = pil_img.fromarray(sem_mask_rgb) | |
sem_pred_rgb = pil_img.fromarray(sem_pred_rgb) | |
part_mask_rgb = pil_img.fromarray(part_mask_rgb) | |
part_pred_rgb = pil_img.fromarray(part_pred_rgb) | |
contact_2d_gt_rgb = pil_img.fromarray(contact_2d_gt_rgb) | |
contact_2d_pred_rgb = pil_img.fromarray(contact_2d_pred_rgb) | |
tot_rend = pil_img.new('RGB', (3000, 2000)) | |
tot_rend.paste(gt_rend, (0, 0)) | |
tot_rend.paste(pred_rend, (0, 450)) | |
tot_rend.paste(sem_mask_rgb, (0, 900)) | |
tot_rend.paste(sem_pred_rgb, (400, 900)) | |
tot_rend.paste(part_mask_rgb, (0, 1300)) | |
tot_rend.paste(part_pred_rgb, (400, 1300)) | |
tot_rend.paste(contact_2d_gt_rgb, (0, 1700)) | |
tot_rend.paste(contact_2d_pred_rgb, (400, 1700)) | |
return tot_rend |