xinjie.wang
update
55ed985
import argparse
import logging
import math
import os
from typing import List, Literal, Tuple, Union
import cv2
import numpy as np
import nvdiffrast.torch as dr
import torch
import trimesh
import utils3d
import xatlas
from tqdm import tqdm
from asset3d_gen.data.mesh_operator import MeshFixer
from asset3d_gen.data.utils import (
CameraSetting,
get_images_from_grid,
init_kal_camera,
normalize_vertices_array,
post_process_texture,
save_mesh_with_mtl,
)
from asset3d_gen.models.delight_model import DelightingModel
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
class TextureBaker(object):
"""Baking textures onto a mesh from multiple observations.
This class take 3D mesh data, camera settings and texture baking parameters
to generate texture map by projecting images to the mesh from diff views.
It supports both a fast texture baking approach and a more optimized method
with total variation regularization.
Attributes:
vertices (torch.Tensor): The vertices of the mesh.
faces (torch.Tensor): The faces of the mesh, defined by vertex indices.
uvs (torch.Tensor): The UV coordinates of the mesh.
camera_params (CameraSetting): Camera setting (intrinsics, extrinsics).
device (str): The device to run computations on ("cpu" or "cuda").
w2cs (torch.Tensor): World-to-camera transformation matrices.
projections (torch.Tensor): Camera projection matrices.
Example:
>>> vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces) # noqa
>>> texture_backer = TextureBaker(vertices, faces, uvs, camera_params)
>>> images = get_images_from_grid(args.input_image, image_size)
>>> texture = texture_backer.bake_texture(
... images, texture_size=args.texture_size, mode=args.baker_mode
... )
>>> texture = post_process_texture(texture)
"""
def __init__(
self,
vertices: np.ndarray,
faces: np.ndarray,
uvs: np.ndarray,
camera_params: CameraSetting,
device: str = "cuda",
) -> None:
self.vertices = (
torch.tensor(vertices, device=device)
if isinstance(vertices, np.ndarray)
else vertices.to(device)
)
self.faces = (
torch.tensor(faces.astype(np.int32), device=device)
if isinstance(faces, np.ndarray)
else faces.to(device)
)
self.uvs = (
torch.tensor(uvs, device=device)
if isinstance(uvs, np.ndarray)
else uvs.to(device)
)
self.camera_params = camera_params
self.device = device
camera = init_kal_camera(camera_params)
matrix_mv = camera.view_matrix() # (n_cam 4 4) world2cam
matrix_mv = kaolin_to_opencv_view(matrix_mv)
matrix_p = (
camera.intrinsics.projection_matrix()
) # (n_cam 4 4) cam2pixel
self.w2cs = matrix_mv.to(self.device)
self.projections = matrix_p.to(self.device)
@staticmethod
def parametrize_mesh(
vertices: np.array, faces: np.array
) -> Union[np.array, np.array, np.array]:
vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
vertices = vertices[vmapping]
faces = indices
return vertices, faces, uvs
def _bake_fast(self, observations, w2cs, projections, texture_size, masks):
texture = torch.zeros(
(texture_size * texture_size, 3), dtype=torch.float32
).cuda()
texture_weights = torch.zeros(
(texture_size * texture_size), dtype=torch.float32
).cuda()
rastctx = utils3d.torch.RastContext(backend="cuda")
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(observations),
desc="Texture baking (fast)",
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
uv_map = rast["uv"][0].detach().flip(0)
mask = rast["mask"][0].detach().bool() & masks[0]
# nearest neighbor interpolation
uv_map = (uv_map * texture_size).floor().long()
obs = observation[mask]
uv_map = uv_map[mask]
idx = (
uv_map[:, 0] + (texture_size - uv_map[:, 1] - 1) * texture_size
)
texture = texture.scatter_add(
0, idx.view(-1, 1).expand(-1, 3), obs
)
texture_weights = texture_weights.scatter_add(
0,
idx,
torch.ones(
(obs.shape[0]), dtype=torch.float32, device=texture.device
),
)
mask = texture_weights > 0
texture[mask] /= texture_weights[mask][:, None]
texture = np.clip(
texture.reshape(texture_size, texture_size, 3).cpu().numpy() * 255,
0,
255,
).astype(np.uint8)
# inpaint
mask = (
(texture_weights == 0)
.cpu()
.numpy()
.astype(np.uint8)
.reshape(texture_size, texture_size)
)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def _bake_opt(
self,
observations,
w2cs,
projections,
texture_size,
lambda_tv,
masks,
total_steps,
):
rastctx = utils3d.torch.RastContext(backend="cuda")
observations = [observations.flip(0) for observations in observations]
masks = [m.flip(0) for m in masks]
_uv = []
_uv_dr = []
for observation, w2c, projection in tqdm(
zip(observations, w2cs, projections),
total=len(w2cs),
):
with torch.no_grad():
rast = utils3d.torch.rasterize_triangle_faces(
rastctx,
self.vertices[None],
self.faces,
observation.shape[1],
observation.shape[0],
uv=self.uvs[None],
view=w2c,
projection=projection,
)
_uv.append(rast["uv"].detach())
_uv_dr.append(rast["uv_dr"].detach())
texture = torch.nn.Parameter(
torch.zeros(
(1, texture_size, texture_size, 3), dtype=torch.float32
).cuda()
)
optimizer = torch.optim.Adam([texture], betas=(0.5, 0.9), lr=1e-2)
def cosine_anealing(step, total_steps, start_lr, end_lr):
return end_lr + 0.5 * (start_lr - end_lr) * (
1 + np.cos(np.pi * step / total_steps)
)
def tv_loss(texture):
return torch.nn.functional.l1_loss(
texture[:, :-1, :, :], texture[:, 1:, :, :]
) + torch.nn.functional.l1_loss(
texture[:, :, :-1, :], texture[:, :, 1:, :]
)
with tqdm(total=total_steps, desc="Texture baking") as pbar:
for step in range(total_steps):
optimizer.zero_grad()
selected = np.random.randint(0, len(w2cs))
uv, uv_dr, observation, mask = (
_uv[selected],
_uv_dr[selected],
observations[selected],
masks[selected],
)
render = dr.texture(texture, uv, uv_dr)[0]
loss = torch.nn.functional.l1_loss(
render[mask], observation[mask]
)
if lambda_tv > 0:
loss += lambda_tv * tv_loss(texture)
loss.backward()
optimizer.step()
optimizer.param_groups[0]["lr"] = cosine_anealing(
step, total_steps, 1e-2, 1e-5
)
pbar.set_postfix({"loss": loss.item()})
pbar.update()
texture = np.clip(
texture[0].flip(0).detach().cpu().numpy() * 255, 0, 255
).astype(np.uint8)
mask = 1 - utils3d.torch.rasterize_triangle_faces(
rastctx,
(self.uvs * 2 - 1)[None],
self.faces,
texture_size,
texture_size,
)["mask"][0].detach().cpu().numpy().astype(np.uint8)
texture = cv2.inpaint(texture, mask, 3, cv2.INPAINT_TELEA)
return texture
def bake_texture(
self,
images: List[np.array],
texture_size: int = 1024,
mode: Literal["fast", "opt"] = "opt",
lambda_tv: float = 1e-2,
opt_step: int = 2000,
):
masks = [np.any(img > 0, axis=-1) for img in images]
masks = [torch.tensor(m > 0).bool().to(self.device) for m in masks]
images = [
torch.tensor(obs / 255.0).float().to(self.device) for obs in images
]
if mode == "fast":
return self._bake_fast(
images, self.w2cs, self.projections, texture_size, masks
)
elif mode == "opt":
return self._bake_opt(
images,
self.w2cs,
self.projections,
texture_size,
lambda_tv,
masks,
opt_step,
)
else:
raise ValueError(f"Unknown mode: {mode}")
def kaolin_to_opencv_view(raw_matrix):
R_orig = raw_matrix[:, :3, :3]
t_orig = raw_matrix[:, :3, 3]
R_target = torch.zeros_like(R_orig)
R_target[:, :, 0] = R_orig[:, :, 2]
R_target[:, :, 1] = R_orig[:, :, 0]
R_target[:, :, 2] = R_orig[:, :, 1]
t_target = t_orig
target_matrix = (
torch.eye(4, device=raw_matrix.device)
.unsqueeze(0)
.repeat(raw_matrix.size(0), 1, 1)
)
target_matrix[:, :3, :3] = R_target
target_matrix[:, :3, 3] = t_target
return target_matrix
def parse_args():
parser = argparse.ArgumentParser(description="Render settings")
parser.add_argument(
"--mesh_path",
type=str,
nargs="+",
required=True,
help="Paths to the mesh files for rendering.",
)
parser.add_argument(
"--input_image",
type=str,
nargs="+",
required=True,
help="Paths to the mesh files for rendering.",
)
parser.add_argument(
"--output_root",
type=str,
default="./outputs",
help="Root directory for output",
)
parser.add_argument(
"--uuid",
type=str,
nargs="+",
default=None,
help="uuid for rendering saving.",
)
parser.add_argument(
"--num_images", type=int, default=6, help="Number of images to render."
)
parser.add_argument(
"--elevation",
type=float,
nargs="+",
default=[20.0, -10.0],
help="Elevation angles for the camera (default: [20.0, -10.0])",
)
parser.add_argument(
"--distance",
type=float,
default=5,
help="Camera distance (default: 5)",
)
parser.add_argument(
"--resolution_hw",
type=int,
nargs=2,
default=(512, 512),
help="Resolution of the output images (default: (512, 512))",
)
parser.add_argument(
"--fov",
type=float,
default=30,
help="Field of view in degrees (default: 30)",
)
parser.add_argument(
"--device",
type=str,
choices=["cpu", "cuda"],
default="cuda",
help="Device to run on (default: `cuda`)",
)
parser.add_argument(
"--texture_size",
type=int,
default=1024,
help="Texture size for texture baking (default: 1024)",
)
parser.add_argument(
"--baker_mode",
type=str,
default="opt",
help="Texture baking mode, `fast` or `opt` (default: opt)",
)
parser.add_argument(
"--opt_step",
type=int,
default=2500,
help="Optimization steps for texture baking (default: 2500)",
)
parser.add_argument(
"--mesh_sipmlify_ratio",
type=float,
default=0.9,
help="Mesh simplification ratio (default: 0.9)",
)
parser.add_argument(
"--no_coor_trans",
action="store_true",
help="Do not transform the asset coordinate system.",
)
parser.add_argument(
"--delight", action="store_true", help="Use delighting model."
)
parser.add_argument(
"--skip_fix_mesh", action="store_true", help="Fix mesh geometry."
)
args = parser.parse_args()
if args.uuid is None:
args.uuid = []
for path in args.mesh_path:
uuid = os.path.basename(path).split(".")[0]
args.uuid.append(uuid)
return args
def entrypoint() -> None:
args = parse_args()
camera_params = CameraSetting(
num_images=args.num_images,
elevation=args.elevation,
distance=args.distance,
resolution_hw=args.resolution_hw,
fov=math.radians(args.fov),
device=args.device,
)
for mesh_path, uuid, img_path in zip(
args.mesh_path, args.uuid, args.input_image
):
mesh = trimesh.load(mesh_path)
if isinstance(mesh, trimesh.Scene):
mesh = mesh.dump(concatenate=True)
vertices, scale, center = normalize_vertices_array(mesh.vertices)
if not args.no_coor_trans:
x_rot = torch.Tensor([[1, 0, 0], [0, 0, 1], [0, -1, 0]])
z_rot = torch.Tensor([[0, 1, 0], [-1, 0, 0], [0, 0, 1]])
vertices = vertices @ x_rot
vertices = vertices @ z_rot
faces = mesh.faces.cpu().numpy().astype(np.int32)
vertices = vertices.cpu().numpy().astype(np.float32)
if not args.skip_fix_mesh:
mesh_fixer = MeshFixer(vertices, faces, args.device)
vertices, faces = mesh_fixer(
filter_ratio=args.mesh_sipmlify_ratio,
max_hole_size=0.04,
resolution=1024,
num_views=1000,
norm_mesh_ratio=0.5,
)
vertices, faces, uvs = TextureBaker.parametrize_mesh(vertices, faces)
texture_backer = TextureBaker(
vertices,
faces,
uvs,
camera_params,
)
images = get_images_from_grid(
img_path, img_size=camera_params.resolution_hw[0]
)
if args.delight:
delight_model = DelightingModel(
model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa
)
delight_images = [delight_model(img) for img in images]
images = [np.array(img) for img in delight_images]
texture = texture_backer.bake_texture(
images=[img[..., :3] for img in images],
texture_size=args.texture_size,
mode=args.baker_mode,
opt_step=args.opt_step,
)
texture = post_process_texture(texture)
if not args.no_coor_trans:
vertices = vertices @ np.linalg.inv(z_rot)
vertices = vertices @ np.linalg.inv(x_rot)
vertices = vertices / scale
vertices = vertices + center
output_path = os.path.join(args.output_root, f"{uuid}.obj")
mesh = save_mesh_with_mtl(vertices, faces, uvs, texture, output_path)
return
if __name__ == "__main__":
entrypoint()