Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import json | |
import logging | |
import math | |
import os | |
from collections import defaultdict | |
from typing import List, Union | |
import cv2 | |
import imageio | |
import numpy as np | |
import nvdiffrast.torch as dr | |
import torch | |
from PIL import Image | |
from tqdm import tqdm | |
from asset3d_gen.data.utils import ( | |
CameraSetting, | |
DiffrastRender, | |
RenderItems, | |
as_list, | |
calc_vertex_normals, | |
import_kaolin_mesh, | |
init_kal_camera, | |
normalize_vertices_array, | |
render_pbr, | |
save_images, | |
) | |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" | |
os.environ["TORCH_EXTENSIONS_DIR"] = os.path.expanduser( | |
"~/.cache/torch_extensions" | |
) | |
logging.basicConfig( | |
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO | |
) | |
logger = logging.getLogger(__name__) | |
def create_gif_from_images(images, output_path, fps=10): | |
pil_images = [] | |
for image in images: | |
image = image.clip(min=0, max=1) | |
image = (255.0 * image).astype(np.uint8) | |
image = Image.fromarray(image, mode="RGBA") | |
pil_images.append(image.convert("RGB")) | |
duration = 1000 // fps | |
pil_images[0].save( | |
output_path, | |
save_all=True, | |
append_images=pil_images[1:], | |
duration=duration, | |
loop=0, | |
) | |
logger.info(f"GIF saved to {output_path}") | |
def create_mp4_from_images(images, output_path, fps=10, prompt=None): | |
font = cv2.FONT_HERSHEY_SIMPLEX # 字体样式 | |
font_scale = 0.5 # 字体大小 | |
font_thickness = 1 # 字体粗细 | |
color = (255, 255, 255) # 文字颜色(白色) | |
position = (20, 25) # 左上角坐标 (x, y) | |
with imageio.get_writer(output_path, fps=fps) as writer: | |
for image in images: | |
image = image.clip(min=0, max=1) | |
image = (255.0 * image).astype(np.uint8) | |
image = image[..., :3] | |
if prompt is not None: | |
cv2.putText( | |
image, | |
prompt, | |
position, | |
font, | |
font_scale, | |
color, | |
font_thickness, | |
) | |
writer.append_data(image) | |
logger.info(f"MP4 video saved to {output_path}") | |
class ImageRender(object): | |
def __init__( | |
self, | |
render_items: list[RenderItems], | |
camera_params: CameraSetting, | |
recompute_vtx_normal: bool = True, | |
device: str = "cuda", | |
with_mtl: bool = False, | |
gen_color_gif: bool = False, | |
gen_color_mp4: bool = False, | |
gen_viewnormal_mp4: bool = False, | |
gen_glonormal_mp4: bool = False, | |
no_index_file: bool = False, | |
light_factor: float = 1.0, | |
) -> None: | |
camera_params.device = device | |
camera = init_kal_camera(camera_params) | |
self.camera = camera | |
# Setup MVP matrix and renderer. | |
mv = camera.view_matrix() # (n 4 4) world2cam | |
p = camera.intrinsics.projection_matrix() | |
# NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa | |
p[:, 1, 1] = -p[:, 1, 1] | |
# mvp = torch.bmm(p, mv) # camera.view_projection_matrix() | |
self.mv = mv | |
self.p = p | |
renderer = DiffrastRender( | |
p_matrix=p, | |
mv_matrix=mv, | |
resolution_hw=camera_params.resolution_hw, | |
context=dr.RasterizeCudaContext(), | |
mask_thresh=0.5, | |
grad_db=False, | |
device=camera_params.device, | |
antialias_mask=True, | |
) | |
self.renderer = renderer | |
self.recompute_vtx_normal = recompute_vtx_normal | |
self.render_items = render_items | |
self.device = device | |
self.with_mtl = with_mtl | |
self.gen_color_gif = gen_color_gif | |
self.gen_color_mp4 = gen_color_mp4 | |
self.gen_viewnormal_mp4 = gen_viewnormal_mp4 | |
self.gen_glonormal_mp4 = gen_glonormal_mp4 | |
self.light_factor = light_factor | |
self.no_index_file = no_index_file | |
def render_mesh( | |
self, | |
mesh_path: Union[str, List[str]], | |
output_root: str, | |
uuid: Union[str, List[str]] = None, | |
prompts: List[str] = None, | |
) -> None: | |
mesh_path = as_list(mesh_path) | |
if uuid is None: | |
uuid = [os.path.basename(p).split(".")[0] for p in mesh_path] | |
uuid = as_list(uuid) | |
assert len(mesh_path) == len(uuid) | |
os.makedirs(output_root, exist_ok=True) | |
meta_info = dict() | |
for idx, (path, uid) in tqdm( | |
enumerate(zip(mesh_path, uuid)), total=len(mesh_path) | |
): | |
output_dir = os.path.join(output_root, uid) | |
os.makedirs(output_dir, exist_ok=True) | |
prompt = prompts[idx] if prompts else None | |
data_dict = self(path, output_dir, prompt) | |
meta_info[uid] = data_dict | |
if self.no_index_file: | |
return | |
index_file = os.path.join(output_root, "index.json") | |
with open(index_file, "w") as fout: | |
json.dump(meta_info, fout) | |
logger.info(f"Rendering meta info logged in {index_file}") | |
def __call__( | |
self, mesh_path: str, output_dir: str, prompt: str = None | |
) -> dict[str, str]: | |
try: | |
mesh = import_kaolin_mesh(mesh_path, self.with_mtl) | |
except Exception as e: | |
logger.error(f"[ERROR MESH LOAD]: {e}, skip {mesh_path}") | |
return | |
mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) | |
if self.recompute_vtx_normal: | |
mesh.vertex_normals = calc_vertex_normals( | |
mesh.vertices, mesh.faces | |
) | |
mesh = mesh.to(self.device) | |
vertices, faces, vertex_normals = ( | |
mesh.vertices, | |
mesh.faces, | |
mesh.vertex_normals, | |
) | |
# Perform rendering. | |
data_dict = defaultdict(list) | |
if RenderItems.ALPHA.value in self.render_items: | |
masks, _ = self.renderer.render_rast_alpha(vertices, faces) | |
render_paths = save_images( | |
masks, f"{output_dir}/{RenderItems.ALPHA}" | |
) | |
data_dict[RenderItems.ALPHA.value] = render_paths | |
if RenderItems.GLOBAL_NORMAL.value in self.render_items: | |
rendered_normals, masks = self.renderer.render_global_normal( | |
vertices, faces, vertex_normals | |
) | |
if self.gen_glonormal_mp4: | |
if isinstance(rendered_normals, torch.Tensor): | |
rendered_normals = rendered_normals.detach().cpu().numpy() | |
create_mp4_from_images( | |
rendered_normals, | |
output_path=f"{output_dir}/normal.mp4", | |
fps=15, | |
prompt=prompt, | |
) | |
else: | |
render_paths = save_images( | |
rendered_normals, | |
f"{output_dir}/{RenderItems.GLOBAL_NORMAL}", | |
cvt_color=cv2.COLOR_BGR2RGB, | |
) | |
data_dict[RenderItems.GLOBAL_NORMAL.value] = render_paths | |
if RenderItems.VIEW_NORMAL.value in self.render_items: | |
assert ( | |
RenderItems.GLOBAL_NORMAL in self.render_items | |
), f"Must render global normal firstly, got render_items: {self.render_items}." # noqa | |
rendered_view_normals = self.renderer.transform_normal( | |
rendered_normals, self.mv, masks, to_view=True | |
) | |
# rendered_inv_view_normals = renderer.transform_normal(rendered_view_normals, torch.linalg.inv(mv), masks, to_view=False) # noqa | |
if self.gen_viewnormal_mp4: | |
create_mp4_from_images( | |
rendered_view_normals, | |
output_path=f"{output_dir}/view_normal.mp4", | |
fps=15, | |
prompt=prompt, | |
) | |
else: | |
render_paths = save_images( | |
rendered_view_normals, | |
f"{output_dir}/{RenderItems.VIEW_NORMAL}", | |
cvt_color=cv2.COLOR_BGR2RGB, | |
) | |
data_dict[RenderItems.VIEW_NORMAL.value] = render_paths | |
if RenderItems.POSITION_MAP.value in self.render_items: | |
rendered_position, masks = self.renderer.render_position( | |
vertices, faces | |
) | |
norm_position = self.renderer.normalize_map_by_mask( | |
rendered_position, masks | |
) | |
render_paths = save_images( | |
norm_position, | |
f"{output_dir}/{RenderItems.POSITION_MAP}", | |
cvt_color=cv2.COLOR_BGR2RGB, | |
) | |
data_dict[RenderItems.POSITION_MAP.value] = render_paths | |
if RenderItems.DEPTH.value in self.render_items: | |
rendered_depth, masks = self.renderer.render_depth(vertices, faces) | |
norm_depth = self.renderer.normalize_map_by_mask( | |
rendered_depth, masks | |
) | |
render_paths = save_images( | |
norm_depth, | |
f"{output_dir}/{RenderItems.DEPTH}", | |
) | |
data_dict[RenderItems.DEPTH.value] = render_paths | |
render_paths = save_images( | |
rendered_depth, | |
f"{output_dir}/{RenderItems.DEPTH}_exr", | |
to_uint8=False, | |
format=".exr", | |
) | |
data_dict[f"{RenderItems.DEPTH.value}_exr"] = render_paths | |
if RenderItems.IMAGE.value in self.render_items: | |
images = [] | |
albedos = [] | |
diffuses = [] | |
masks, _ = self.renderer.render_rast_alpha(vertices, faces) | |
try: | |
for idx, cam in enumerate(self.camera): | |
image, albedo, diffuse, _ = render_pbr( | |
mesh, cam, light_factor=self.light_factor | |
) | |
image = torch.cat([image[0], masks[idx]], axis=-1) | |
images.append(image.detach().cpu().numpy()) | |
if RenderItems.ALBEDO.value in self.render_items: | |
albedo = torch.cat([albedo[0], masks[idx]], axis=-1) | |
albedos.append(albedo.detach().cpu().numpy()) | |
if RenderItems.DIFFUSE.value in self.render_items: | |
diffuse = torch.cat([diffuse[0], masks[idx]], axis=-1) | |
diffuses.append(diffuse.detach().cpu().numpy()) | |
except Exception as e: | |
logger.error(f"[ERROR pbr render]: {e}, skip {mesh_path}") | |
return | |
if self.gen_color_gif: | |
create_gif_from_images( | |
images, | |
output_path=f"{output_dir}/color.gif", | |
fps=15, | |
) | |
if self.gen_color_mp4: | |
create_mp4_from_images( | |
images, | |
output_path=f"{output_dir}/color.mp4", | |
fps=15, | |
prompt=prompt, | |
) | |
if self.gen_color_mp4 or self.gen_color_gif: | |
return data_dict | |
render_paths = save_images( | |
images, | |
f"{output_dir}/{RenderItems.IMAGE}", | |
cvt_color=cv2.COLOR_BGRA2RGBA, | |
) | |
data_dict[RenderItems.IMAGE.value] = render_paths | |
render_paths = save_images( | |
albedos, | |
f"{output_dir}/{RenderItems.ALBEDO}", | |
cvt_color=cv2.COLOR_BGRA2RGBA, | |
) | |
data_dict[RenderItems.ALBEDO.value] = render_paths | |
render_paths = save_images( | |
diffuses, | |
f"{output_dir}/{RenderItems.DIFFUSE}", | |
cvt_color=cv2.COLOR_BGRA2RGBA, | |
) | |
data_dict[RenderItems.DIFFUSE.value] = render_paths | |
data_dict["status"] = "success" | |
logger.info(f"Finish rendering in {output_dir}") | |
return data_dict | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Render settings") | |
parser.add_argument( | |
"--mesh_path", | |
type=str, | |
nargs="+", | |
help="Paths to the mesh files for rendering.", | |
) | |
parser.add_argument( | |
"--output_root", | |
type=str, | |
help="Root directory for output", | |
) | |
parser.add_argument( | |
"--uuid", | |
type=str, | |
nargs="+", | |
default=None, | |
help="uuid for rendering saving.", | |
) | |
parser.add_argument( | |
"--num_images", type=int, default=6, help="Number of images to render." | |
) | |
parser.add_argument( | |
"--elevation", | |
type=float, | |
nargs="+", | |
default=[20.0, -10.0], | |
help="Elevation angles for the camera (default: [20.0, -10.0])", | |
) | |
parser.add_argument( | |
"--distance", | |
type=float, | |
default=5, | |
help="Camera distance (default: 5)", | |
) | |
parser.add_argument( | |
"--resolution_hw", | |
type=int, | |
nargs=2, | |
default=(512, 512), | |
help="Resolution of the output images (default: (512, 512))", | |
) | |
parser.add_argument( | |
"--fov", | |
type=float, | |
default=30, | |
help="Field of view in degrees (default: 30)", | |
) | |
parser.add_argument( | |
"--pbr_light_factor", | |
type=float, | |
default=1.0, | |
help="Light factor for mesh PBR rendering (default: 2.)", | |
) | |
parser.add_argument( | |
"--device", | |
type=str, | |
choices=["cpu", "cuda"], | |
default="cuda", | |
help="Device to run on (default: 'cuda')", | |
) | |
parser.add_argument( | |
"--with_mtl", | |
action="store_true", | |
help="Whether to render with mesh material.", | |
) | |
parser.add_argument( | |
"--gen_color_gif", | |
action="store_true", | |
help="Whether to generate color .gif rendering file.", | |
) | |
parser.add_argument( | |
"--gen_color_mp4", | |
action="store_true", | |
help="Whether to generate color .mp4 rendering file.", | |
) | |
parser.add_argument( | |
"--gen_viewnormal_mp4", | |
action="store_true", | |
help="Whether to generate view normal .mp4 rendering file.", | |
) | |
parser.add_argument( | |
"--gen_glonormal_mp4", | |
action="store_true", | |
help="Whether to generate global normal .mp4 rendering file.", | |
) | |
parser.add_argument( | |
"--prompts", | |
type=str, | |
nargs="+", | |
default=None, | |
help="Text prompts for the rendering.", | |
) | |
args = parser.parse_args() | |
if args.uuid is None and args.mesh_path is not None: | |
args.uuid = [] | |
for path in args.mesh_path: | |
uuid = os.path.basename(path).split(".")[0] | |
args.uuid.append(uuid) | |
return args | |
def entrypoint(**kwargs) -> None: | |
args = parse_args() | |
for k, v in kwargs.items(): | |
if hasattr(args, k) and v is not None: | |
setattr(args, k, v) | |
camera_settings = CameraSetting( | |
num_images=args.num_images, | |
elevation=args.elevation, | |
distance=args.distance, | |
resolution_hw=args.resolution_hw, | |
fov=math.radians(args.fov), | |
device=args.device, | |
) | |
render_items = [ | |
RenderItems.ALPHA.value, | |
RenderItems.GLOBAL_NORMAL.value, | |
RenderItems.VIEW_NORMAL.value, | |
RenderItems.POSITION_MAP.value, | |
RenderItems.IMAGE.value, | |
RenderItems.DEPTH.value, | |
# RenderItems.ALBEDO.value, | |
# RenderItems.DIFFUSE.value, | |
] | |
gen_video = ( | |
args.gen_color_gif | |
or args.gen_color_mp4 | |
or args.gen_viewnormal_mp4 | |
or args.gen_glonormal_mp4 | |
) | |
if gen_video: | |
render_items = [] | |
if args.gen_color_gif or args.gen_color_mp4: | |
render_items.append(RenderItems.IMAGE.value) | |
if args.gen_glonormal_mp4: | |
render_items.append(RenderItems.GLOBAL_NORMAL.value) | |
if args.gen_viewnormal_mp4: | |
render_items.append(RenderItems.VIEW_NORMAL.value) | |
if RenderItems.GLOBAL_NORMAL.value not in render_items: | |
render_items.append(RenderItems.GLOBAL_NORMAL.value) | |
image_render = ImageRender( | |
render_items=render_items, | |
camera_params=camera_settings, | |
with_mtl=args.with_mtl, | |
gen_color_gif=args.gen_color_gif, | |
gen_color_mp4=args.gen_color_mp4, | |
gen_viewnormal_mp4=args.gen_viewnormal_mp4, | |
gen_glonormal_mp4=args.gen_glonormal_mp4, | |
light_factor=args.pbr_light_factor, | |
no_index_file=gen_video, | |
) | |
image_render.render_mesh( | |
mesh_path=args.mesh_path, | |
output_root=args.output_root, | |
uuid=args.uuid, | |
prompts=args.prompts, | |
) | |
return | |
if __name__ == "__main__": | |
entrypoint() | |