nasttam's picture
Duplicate from cownclown/Image-and-3D-Model-Creator
d444fe9
raw
history blame
3.36 kB
import io
import os
import torch
from skimage.io import imread
import numpy as np
import cv2
from tqdm import tqdm_notebook as tqdm
import base64
from IPython.display import HTML
# Util function for loading meshes
from pytorch3d.io import load_objs_as_meshes
from IPython.display import HTML
from base64 import b64encode
# Data structures and functions for rendering
from pytorch3d.structures import Meshes
from pytorch3d.renderer import (
look_at_view_transform,
OpenGLOrthographicCameras,
PointLights,
DirectionalLights,
Materials,
RasterizationSettings,
MeshRenderer,
MeshRasterizer,
SoftPhongShader,
HardPhongShader,
TexturesVertex
)
def set_renderer():
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Initialize an OpenGL perspective camera.
R, T = look_at_view_transform(2.0, 0, 180)
cameras = OpenGLOrthographicCameras(device=device, R=R, T=T)
raster_settings = RasterizationSettings(
image_size=512,
blur_radius=0.0,
faces_per_pixel=1,
bin_size = None,
max_faces_per_bin = None
)
lights = PointLights(device=device, location=((2.0, 2.0, 2.0),))
renderer = MeshRenderer(
rasterizer=MeshRasterizer(
cameras=cameras,
raster_settings=raster_settings
),
shader=HardPhongShader(
device=device,
cameras=cameras,
lights=lights
)
)
return renderer
def get_verts_rgb_colors(obj_path):
rgb_colors = []
f = open(obj_path)
lines = f.readlines()
for line in lines:
ls = line.split(' ')
if len(ls) == 7:
rgb_colors.append(ls[-3:])
return np.array(rgb_colors, dtype='float32')[None, :, :]
def generate_video_from_obj(obj_path, video_path, renderer):
# Setup
device = torch.device("cuda:0")
torch.cuda.set_device(device)
# Load obj file
verts_rgb_colors = get_verts_rgb_colors(obj_path)
verts_rgb_colors = torch.from_numpy(verts_rgb_colors).to(device)
textures = TexturesVertex(verts_features=verts_rgb_colors)
wo_textures = TexturesVertex(verts_features=torch.ones_like(verts_rgb_colors)*0.75)
# Load obj
mesh = load_objs_as_meshes([obj_path], device=device)
# Set mesh
vers = mesh._verts_list
faces = mesh._faces_list
mesh_w_tex = Meshes(vers, faces, textures)
mesh_wo_tex = Meshes(vers, faces, wo_textures)
# create VideoWriter
fourcc = cv2. VideoWriter_fourcc(*'MP4V')
out = cv2.VideoWriter(video_path, fourcc, 20.0, (1024,512))
for i in tqdm(range(90)):
R, T = look_at_view_transform(1.8, 0, i*4, device=device)
images_w_tex = renderer(mesh_w_tex, R=R, T=T)
images_w_tex = np.clip(images_w_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
images_wo_tex = renderer(mesh_wo_tex, R=R, T=T)
images_wo_tex = np.clip(images_wo_tex[0, ..., :3].cpu().numpy(), 0.0, 1.0)[:, :, ::-1] * 255
image = np.concatenate([images_w_tex, images_wo_tex], axis=1)
out.write(image.astype('uint8'))
out.release()
def video(path):
mp4 = open(path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
return HTML('<video width=500 controls loop> <source src="%s" type="video/mp4"></video>' % data_url)