import argparse import logging import math import os import cv2 import numpy as np import nvdiffrast.torch as dr import spaces import torch import torch.nn.functional as F import trimesh import xatlas from PIL import Image from asset3d_gen.data.mesh_operator import MeshFixer from asset3d_gen.data.utils import ( CameraSetting, DiffrastRender, get_images_from_grid, init_kal_camera, normalize_vertices_array, post_process_texture, save_mesh_with_mtl, ) from asset3d_gen.models.delight_model import DelightingModel from asset3d_gen.models.sr_model import ImageRealESRGAN logging.basicConfig( format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO ) logger = logging.getLogger(__name__) __all__ = [ "TextureBacker", ] def transform_vertices( mtx: torch.Tensor, pos: torch.Tensor, keepdim: bool = False ) -> torch.Tensor: """Transform 3D vertices using a projection matrix.""" t_mtx = torch.as_tensor(mtx, device=pos.device, dtype=pos.dtype) if pos.size(-1) == 3: pos = torch.cat([pos, torch.ones_like(pos[..., :1])], dim=-1) result = pos @ t_mtx.T return result if keepdim else result.unsqueeze(0) def _bilinear_interpolation_scattering( image_h: int, image_w: int, coords: torch.Tensor, values: torch.Tensor ) -> torch.Tensor: """Bilinear interpolation scattering for grid-based value accumulation.""" device = values.device dtype = values.dtype C = values.shape[-1] indices = coords * torch.tensor( [image_h - 1, image_w - 1], dtype=dtype, device=device ) i, j = indices.unbind(-1) i0, j0 = ( indices.floor() .long() .clamp(0, image_h - 2) .clamp(0, image_w - 2) .unbind(-1) ) i1, j1 = i0 + 1, j0 + 1 w_i = i - i0.float() w_j = j - j0.float() weights = torch.stack( [(1 - w_i) * (1 - w_j), (1 - w_i) * w_j, w_i * (1 - w_j), w_i * w_j], dim=1, ) indices_comb = torch.stack( [ torch.stack([i0, j0], dim=1), torch.stack([i0, j1], dim=1), torch.stack([i1, j0], dim=1), torch.stack([i1, j1], dim=1), ], dim=1, ) grid = torch.zeros(image_h, image_w, C, device=device, dtype=dtype) cnt = torch.zeros(image_h, image_w, 1, device=device, dtype=dtype) for k in range(4): idx = indices_comb[:, k] w = weights[:, k].unsqueeze(-1) stride = torch.tensor([image_w, 1], device=device, dtype=torch.long) flat_idx = (idx * stride).sum(-1) grid.view(-1, C).scatter_add_( 0, flat_idx.unsqueeze(-1).expand(-1, C), values * w ) cnt.view(-1, 1).scatter_add_(0, flat_idx.unsqueeze(-1), w) mask = cnt.squeeze(-1) > 0 grid[mask] = grid[mask] / cnt[mask].repeat(1, C) return grid def _texture_inpaint_smooth( texture: np.ndarray, mask: np.ndarray, vertices: np.ndarray, faces: np.ndarray, uv_map: np.ndarray, ) -> tuple[np.ndarray, np.ndarray]: """Perform texture inpainting using vertex-based color propagation.""" image_h, image_w, C = texture.shape N = vertices.shape[0] # Initialize vertex data structures vtx_mask = np.zeros(N, dtype=np.float32) vtx_colors = np.zeros((N, C), dtype=np.float32) unprocessed = [] adjacency = [[] for _ in range(N)] # Build adjacency graph and initial color assignment for face_idx in range(faces.shape[0]): for k in range(3): uv_idx_k = faces[face_idx, k] v_idx = faces[face_idx, k] # Convert UV to pixel coordinates with boundary clamping u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) if mask[v, u]: vtx_mask[v_idx] = 1.0 vtx_colors[v_idx] = texture[v, u] elif v_idx not in unprocessed: unprocessed.append(v_idx) # Build undirected adjacency graph neighbor = faces[face_idx, (k + 1) % 3] if neighbor not in adjacency[v_idx]: adjacency[v_idx].append(neighbor) if v_idx not in adjacency[neighbor]: adjacency[neighbor].append(v_idx) # Color propagation with dynamic stopping remaining_iters, prev_count = 2, 0 while remaining_iters > 0: current_unprocessed = [] for v_idx in unprocessed: valid_neighbors = [n for n in adjacency[v_idx] if vtx_mask[n] > 0] if not valid_neighbors: current_unprocessed.append(v_idx) continue # Calculate inverse square distance weights neighbors_pos = vertices[valid_neighbors] dist_sq = np.sum((vertices[v_idx] - neighbors_pos) ** 2, axis=1) weights = 1 / np.maximum(dist_sq, 1e-8) vtx_colors[v_idx] = np.average( vtx_colors[valid_neighbors], weights=weights, axis=0 ) vtx_mask[v_idx] = 1.0 # Update iteration control if len(current_unprocessed) == prev_count: remaining_iters -= 1 else: remaining_iters = min(remaining_iters + 1, 2) prev_count = len(current_unprocessed) unprocessed = current_unprocessed # Generate output texture inpainted_texture, updated_mask = texture.copy(), mask.copy() for face_idx in range(faces.shape[0]): for k in range(3): v_idx = faces[face_idx, k] if not vtx_mask[v_idx]: continue # UV coordinate conversion uv_idx_k = faces[face_idx, k] u = np.clip( int(round(uv_map[uv_idx_k, 0] * (image_w - 1))), 0, image_w - 1 ) v = np.clip( int(round((1.0 - uv_map[uv_idx_k, 1]) * (image_h - 1))), 0, image_h - 1, ) inpainted_texture[v, u] = vtx_colors[v_idx] updated_mask[v, u] = 255 return inpainted_texture, updated_mask class TextureBacker: """Texture baking pipeline for multi-view projection and fusion.""" def __init__( self, camera_params: CameraSetting, view_weights: list[float], render_wh: tuple[int, int] = (2048, 2048), texture_wh: tuple[int, int] = (2048, 2048), bake_angle_thresh: int = 75, mask_thresh: float = 0.5, ): self.camera_params = camera_params self.renderer = None self.view_weights = view_weights self.device = camera_params.device self.render_wh = render_wh self.texture_wh = texture_wh self.mask_thresh = mask_thresh self.bake_angle_thresh = bake_angle_thresh self.bake_unreliable_kernel_size = int( (2 / 512) * max(self.render_wh[0], self.render_wh[1]) ) def _lazy_init_render(self, camera_params, mask_thresh): if self.renderer is None: camera = init_kal_camera(camera_params) mv = camera.view_matrix() # (n 4 4) world2cam p = camera.intrinsics.projection_matrix() # NOTE: add a negative sign at P[0, 2] as the y axis is flipped in `nvdiffrast` output. # noqa p[:, 1, 1] = -p[:, 1, 1] self.renderer = DiffrastRender( p_matrix=p, mv_matrix=mv, resolution_hw=camera_params.resolution_hw, context=dr.RasterizeCudaContext(), mask_thresh=mask_thresh, grad_db=False, device=self.device, antialias_mask=True, ) def load_mesh(self, mesh: trimesh.Trimesh) -> trimesh.Trimesh: mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) self.scale, self.center = scale, center vmapping, indices, uvs = xatlas.parametrize(mesh.vertices, mesh.faces) uvs[:, 1] = 1 - uvs[:, 1] mesh.vertices = mesh.vertices[vmapping] mesh.faces = indices mesh.visual.uv = uvs return mesh def get_mesh_np_attrs( self, mesh: trimesh.Trimesh, scale: float = None, center: np.ndarray = None, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: vertices = mesh.vertices.copy() faces = mesh.faces.copy() uv_map = mesh.visual.uv.copy() uv_map[:, 1] = 1.0 - uv_map[:, 1] if scale is not None: vertices = vertices / scale if center is not None: vertices = vertices + center return vertices, faces, uv_map def _render_depth_edges(self, depth_image: torch.Tensor) -> torch.Tensor: depth_image_np = depth_image.cpu().numpy() depth_image_np = (depth_image_np * 255).astype(np.uint8) depth_edges = cv2.Canny(depth_image_np, 30, 80) sketch_image = ( torch.from_numpy(depth_edges).to(depth_image.device).float() / 255 ) sketch_image = sketch_image.unsqueeze(-1) return sketch_image def compute_enhanced_viewnormal( self, mv_mtx: torch.Tensor, vertices: torch.Tensor, faces: torch.Tensor ) -> torch.Tensor: rast, _ = self.renderer.compute_dr_raster(vertices, faces) rendered_view_normals = [] for idx in range(len(mv_mtx)): pos_cam = transform_vertices(mv_mtx[idx], vertices, keepdim=True) pos_cam = pos_cam[:, :3] / pos_cam[:, 3:] v0, v1, v2 = (pos_cam[faces[:, i]] for i in range(3)) face_norm = F.normalize( torch.cross(v1 - v0, v2 - v0, dim=-1), dim=-1 ) vertex_norm = ( torch.from_numpy( trimesh.geometry.mean_vertex_normals( len(pos_cam), faces.cpu(), face_norm.cpu() ) ) .to(vertices.device) .contiguous() ) im_base_normals, _ = dr.interpolate( vertex_norm[None, ...].float(), rast[idx : idx + 1], faces.to(torch.int32), ) rendered_view_normals.append(im_base_normals) rendered_view_normals = torch.cat(rendered_view_normals, dim=0) return rendered_view_normals def back_project( self, image, vis_mask, depth, normal, uv ) -> tuple[torch.Tensor, torch.Tensor]: image = np.array(image) image = torch.as_tensor(image, device=self.device, dtype=torch.float32) if image.ndim == 2: image = image.unsqueeze(-1) image = image / 255 depth_inv = (1.0 - depth) * vis_mask sketch_image = self._render_depth_edges(depth_inv) cos = F.cosine_similarity( torch.tensor([[0, 0, 1]], device=self.device), normal.view(-1, 3), ).view_as(normal[..., :1]) cos[cos < np.cos(np.radians(self.bake_angle_thresh))] = 0 k = self.bake_unreliable_kernel_size * 2 + 1 kernel = torch.ones((1, 1, k, k), device=self.device) vis_mask = vis_mask.permute(2, 0, 1).unsqueeze(0).float() vis_mask = F.conv2d( 1.0 - vis_mask, kernel, padding=k // 2, ) vis_mask = 1.0 - (vis_mask > 0).float() vis_mask = vis_mask.squeeze(0).permute(1, 2, 0) sketch_image = sketch_image.permute(2, 0, 1).unsqueeze(0) sketch_image = F.conv2d(sketch_image, kernel, padding=k // 2) sketch_image = (sketch_image > 0).float() sketch_image = sketch_image.squeeze(0).permute(1, 2, 0) vis_mask = vis_mask * (sketch_image < 0.5) cos[vis_mask == 0] = 0 valid_pixels = (vis_mask != 0).view(-1) return ( self._scatter_texture(uv, image, valid_pixels), self._scatter_texture(uv, cos, valid_pixels), ) def _scatter_texture(self, uv, data, mask): def __filter_data(data, mask): return data.view(-1, data.shape[-1])[mask] return _bilinear_interpolation_scattering( self.texture_wh[1], self.texture_wh[0], __filter_data(uv, mask)[..., [1, 0]], __filter_data(data, mask), ) @torch.no_grad() def fast_bake_texture( self, textures: list[torch.Tensor], confidence_maps: list[torch.Tensor] ) -> tuple[torch.Tensor, torch.Tensor]: channel = textures[0].shape[-1] texture_merge = torch.zeros(self.texture_wh + [channel]).to( self.device ) trust_map_merge = torch.zeros(self.texture_wh + [1]).to(self.device) for texture, cos_map in zip(textures, confidence_maps): view_sum = (cos_map > 0).sum() painted_sum = ((cos_map > 0) * (trust_map_merge > 0)).sum() if painted_sum / view_sum > 0.99: continue texture_merge += texture * cos_map trust_map_merge += cos_map texture_merge = texture_merge / torch.clamp(trust_map_merge, min=1e-8) return texture_merge, trust_map_merge > 1e-8 def uv_inpaint( self, mesh: trimesh.Trimesh, texture: np.ndarray, mask: np.ndarray ) -> np.ndarray: vertices, faces, uv_map = self.get_mesh_np_attrs(mesh) texture, mask = _texture_inpaint_smooth( texture, mask, vertices, faces, uv_map ) texture = texture.clip(0, 1) texture = cv2.inpaint( (texture * 255).astype(np.uint8), 255 - mask, 3, cv2.INPAINT_NS, ) return texture @spaces.GPU def compute_texture( self, colors: list[Image.Image], mesh: trimesh.Trimesh, ) -> trimesh.Trimesh: self._lazy_init_render(self.camera_params, self.mask_thresh) vertices = torch.from_numpy(mesh.vertices).to(self.device).float() faces = torch.from_numpy(mesh.faces).to(self.device).to(torch.int) uv_map = torch.from_numpy(mesh.visual.uv).to(self.device).float() rendered_depth, masks = self.renderer.render_depth(vertices, faces) norm_deps = self.renderer.normalize_map_by_mask(rendered_depth, masks) render_uvs, _ = self.renderer.render_uv(vertices, faces, uv_map) view_normals = self.compute_enhanced_viewnormal( self.renderer.mv_mtx, vertices, faces ) textures, weighted_cos_maps = [], [] for color, mask, dep, normal, uv, weight in zip( colors, masks, norm_deps, view_normals, render_uvs, self.view_weights, ): texture, cos_map = self.back_project(color, mask, dep, normal, uv) textures.append(texture) weighted_cos_maps.append(weight * (cos_map**4)) texture, mask = self.fast_bake_texture(textures, weighted_cos_maps) texture_np = texture.cpu().numpy() mask_np = (mask.squeeze(-1).cpu().numpy() * 255).astype(np.uint8) return texture_np, mask_np def __call__( self, colors: list[Image.Image], mesh: trimesh.Trimesh, output_path: str, ) -> trimesh.Trimesh: mesh = self.load_mesh(mesh) texture_np, mask_np = self.compute_texture(colors, mesh) texture_np = self.uv_inpaint(mesh, texture_np, mask_np) texture_np = post_process_texture(texture_np) vertices, faces, uv_map = self.get_mesh_np_attrs( mesh, self.scale, self.center ) textured_mesh = save_mesh_with_mtl( vertices, faces, uv_map, texture_np, output_path ) return textured_mesh def parse_args(): parser = argparse.ArgumentParser(description="Backproject texture") parser.add_argument( "--color_path", type=str, help="Multiview color image in 6x512x512 file path", ) parser.add_argument( "--mesh_path", type=str, help="Mesh path, .obj, .glb or .ply", ) parser.add_argument( "--output_path", type=str, help="Output mesh path with suffix", ) parser.add_argument( "--num_images", type=int, default=6, help="Number of images to render." ) parser.add_argument( "--elevation", nargs=2, type=float, default=[20.0, -10.0], help="Elevation angles for the camera (default: [20.0, -10.0])", ) parser.add_argument( "--distance", type=float, default=5, help="Camera distance (default: 5)", ) parser.add_argument( "--resolution_hw", type=int, nargs=2, default=(2048, 2048), help="Resolution of the output images (default: (2048, 2048))", ) parser.add_argument( "--fov", type=float, default=30, help="Field of view in degrees (default: 30)", ) parser.add_argument( "--device", type=str, choices=["cpu", "cuda"], default="cuda", help="Device to run on (default: `cuda`)", ) parser.add_argument( "--skip_fix_mesh", action="store_true", help="Fix mesh geometry." ) parser.add_argument( "--texture_wh", nargs=2, type=int, default=[2048, 2048], help="Texture resolution width and height", ) parser.add_argument( "--mesh_sipmlify_ratio", type=float, default=0.9, help="Mesh simplification ratio (default: 0.9)", ) parser.add_argument( "--delight", action="store_true", help="Use delighting model." ) args = parser.parse_args() return args def entrypoint( delight_model: DelightingModel = None, imagesr_model: ImageRealESRGAN = None, **kwargs, ) -> trimesh.Trimesh: args = parse_args() for k, v in kwargs.items(): if hasattr(args, k) and v is not None: setattr(args, k, v) # Setup camera parameters. camera_params = CameraSetting( num_images=args.num_images, elevation=args.elevation, distance=args.distance, resolution_hw=args.resolution_hw, fov=math.radians(args.fov), device=args.device, ) view_weights = [1, 0.1, 0.02, 0.1, 1, 0.02] color_grid = Image.open(args.color_path) if args.delight: if delight_model is None: delight_model = DelightingModel( model_path="/horizon-bucket/robot_lab/users/xinjie.wang/weights/hunyuan3d-delight-v2-0" # noqa ) save_dir = os.path.dirname(args.output_path) os.makedirs(save_dir, exist_ok=True) color_grid = delight_model(color_grid) color_grid.save(f"{save_dir}/color_grid_delight.png") multiviews = get_images_from_grid(color_grid, img_size=512) # Use RealESRGAN_x4plus for x4 (512->2048) image super resolution. if imagesr_model is None: imagesr_model = ImageRealESRGAN(outscale=4) multiviews = [imagesr_model(img) for img in multiviews] multiviews = [img.convert("RGB") for img in multiviews] mesh = trimesh.load(args.mesh_path) if isinstance(mesh, trimesh.Scene): mesh = mesh.dump(concatenate=True) if not args.skip_fix_mesh: mesh.vertices, scale, center = normalize_vertices_array(mesh.vertices) mesh_fixer = MeshFixer(mesh.vertices, mesh.faces, args.device) mesh.vertices, mesh.faces = mesh_fixer( filter_ratio=args.mesh_sipmlify_ratio, max_hole_size=0.04, resolution=1024, num_views=1000, norm_mesh_ratio=0.5, ) # Restore scale. mesh.vertices = mesh.vertices / scale mesh.vertices = mesh.vertices + center # Baking texture to mesh. texture_backer = TextureBacker( camera_params=camera_params, view_weights=view_weights, render_wh=camera_params.resolution_hw, texture_wh=args.texture_wh, ) textured_mesh = texture_backer(multiviews, mesh, args.output_path) return textured_mesh if __name__ == "__main__": entrypoint()