ImgRoboAssetGen / asset3d_gen /data /mesh_operator.py
xinjie.wang
update
146eff7
import logging
from typing import Tuple, Union
import spaces
import igraph
import numpy as np
import pyvista as pv
import torch
import utils3d
from pymeshfix import _meshfix
from tqdm import tqdm
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(message)s", level=logging.INFO
)
logger = logging.getLogger(__name__)
__all__ = ["MeshFixer"]
def radical_inverse(base, n):
val = 0
inv_base = 1.0 / base
inv_base_n = inv_base
while n > 0:
digit = n % base
val += digit * inv_base_n
n //= base
inv_base_n *= inv_base
return val
def halton_sequence(dim, n):
PRIMES = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 53]
return [radical_inverse(PRIMES[dim], n) for dim in range(dim)]
def hammersley_sequence(dim, n, num_samples):
return [n / num_samples] + halton_sequence(dim - 1, n)
def sphere_hammersley_sequence(n, num_samples, offset=(0, 0), remap=False):
"""Generate a point on a unit sphere using the Hammersley sequence.
Args:
n (int): The index of the sample.
num_samples (int): The total number of samples.
offset (tuple, optional): Offset for the u and v coordinates.
remap (bool, optional): Whether to remap the u coordinate.
Returns:
list: A list containing the spherical coordinates [phi, theta].
"""
u, v = hammersley_sequence(2, n, num_samples)
u += offset[0] / num_samples
v += offset[1]
if remap:
u = 2 * u if u < 0.25 else 2 / 3 * u + 1 / 3
theta = np.arccos(1 - 2 * u) - np.pi / 2
phi = v * 2 * np.pi
return [phi, theta]
class MeshFixer(object):
"""Reduce and postprocess 3D meshes, simplifying and filling holes."""
def __init__(
self,
vertices: Union[torch.Tensor, np.ndarray],
faces: Union[torch.Tensor, np.ndarray],
device: str = "cuda",
) -> None:
self.device = device
self.vertices = (
torch.tensor(vertices, device=device)
if isinstance(vertices, np.ndarray)
else vertices.to(device)
)
self.faces = (
torch.tensor(faces.astype(np.int32), device=device)
if isinstance(faces, np.ndarray)
else faces.to(device)
)
@staticmethod
def log_mesh_changes(method):
def wrapper(self, *args, **kwargs):
logger.info(
f"Before {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
)
result = method(self, *args, **kwargs)
logger.info(
f"After {method.__name__}: {self.vertices.shape[0]} vertices, {self.faces.shape[0]} faces" # noqa
)
return result
return wrapper
@log_mesh_changes
def fill_holes(
self,
max_hole_size: float,
max_hole_nbe: int,
resolution: int,
num_views: int,
norm_mesh_ratio: float = 1.0,
) -> None:
self.vertices = self.vertices * norm_mesh_ratio
vertices, self.faces = self._fill_holes(
self.vertices,
self.faces,
max_hole_size,
max_hole_nbe,
resolution,
num_views,
)
self.vertices = vertices / norm_mesh_ratio
@staticmethod
@torch.no_grad()
def _fill_holes(
vertices: torch.Tensor,
faces: torch.Tensor,
max_hole_size: float,
max_hole_nbe: int,
resolution: int,
num_views: int,
) -> Union[torch.Tensor, torch.Tensor]:
yaws, pitchs = [], []
for i in range(num_views):
y, p = sphere_hammersley_sequence(i, num_views)
yaws.append(y)
pitchs.append(p)
yaws, pitchs = torch.tensor(yaws).to(vertices), torch.tensor(
pitchs
).to(vertices)
radius, fov = 2.0, torch.deg2rad(torch.tensor(40)).to(vertices)
projection = utils3d.torch.perspective_from_fov_xy(fov, fov, 1, 3)
views = []
for yaw, pitch in zip(yaws, pitchs):
orig = (
torch.tensor(
[
torch.sin(yaw) * torch.cos(pitch),
torch.cos(yaw) * torch.cos(pitch),
torch.sin(pitch),
]
).to(vertices)
* radius
)
view = utils3d.torch.view_look_at(
orig,
torch.tensor([0, 0, 0]).to(vertices),
torch.tensor([0, 0, 1]).to(vertices),
)
views.append(view)
views = torch.stack(views, dim=0)
# Rasterize the mesh
visibility = torch.zeros(
faces.shape[0], dtype=torch.int32, device=faces.device
)
rastctx = utils3d.torch.RastContext(backend="cuda")
for i in tqdm(
range(views.shape[0]), total=views.shape[0], desc="Rasterizing"
):
view = views[i]
buffers = utils3d.torch.rasterize_triangle_faces(
rastctx,
vertices[None],
faces,
resolution,
resolution,
view=view,
projection=projection,
)
face_id = buffers["face_id"][0][buffers["mask"][0] > 0.95] - 1
face_id = torch.unique(face_id).long()
visibility[face_id] += 1
# Normalize visibility by the number of views
visibility = visibility.float() / num_views
# Mincut: Identify outer and inner faces
edges, face2edge, edge_degrees = utils3d.torch.compute_edges(faces)
boundary_edge_indices = torch.nonzero(edge_degrees == 1).reshape(-1)
connected_components = utils3d.torch.compute_connected_components(
faces, edges, face2edge
)
outer_face_indices = torch.zeros(
faces.shape[0], dtype=torch.bool, device=faces.device
)
for i in range(len(connected_components)):
outer_face_indices[connected_components[i]] = visibility[
connected_components[i]
] > min(
max(
visibility[connected_components[i]].quantile(0.75).item(),
0.25,
),
0.5,
)
outer_face_indices = outer_face_indices.nonzero().reshape(-1)
inner_face_indices = torch.nonzero(visibility == 0).reshape(-1)
if inner_face_indices.shape[0] == 0:
return vertices, faces
# Construct dual graph (faces as nodes, edges as edges)
dual_edges, dual_edge2edge = utils3d.torch.compute_dual_graph(
face2edge
)
dual_edge2edge = edges[dual_edge2edge]
dual_edges_weights = torch.norm(
vertices[dual_edge2edge[:, 0]] - vertices[dual_edge2edge[:, 1]],
dim=1,
)
# Mincut: Construct main graph and solve the mincut problem
g = igraph.Graph()
g.add_vertices(faces.shape[0])
g.add_edges(dual_edges.cpu().numpy())
g.es["weight"] = dual_edges_weights.cpu().numpy()
g.add_vertex("s") # source
g.add_vertex("t") # target
g.add_edges(
[(f, "s") for f in inner_face_indices],
attributes={
"weight": torch.ones(
inner_face_indices.shape[0], dtype=torch.float32
)
.cpu()
.numpy()
},
)
g.add_edges(
[(f, "t") for f in outer_face_indices],
attributes={
"weight": torch.ones(
outer_face_indices.shape[0], dtype=torch.float32
)
.cpu()
.numpy()
},
)
cut = g.mincut("s", "t", (np.array(g.es["weight"]) * 1000).tolist())
remove_face_indices = torch.tensor(
[v for v in cut.partition[0] if v < faces.shape[0]],
dtype=torch.long,
device=faces.device,
)
# Check if the cut is valid with each connected component
to_remove_cc = utils3d.torch.compute_connected_components(
faces[remove_face_indices]
)
valid_remove_cc = []
cutting_edges = []
for cc in to_remove_cc:
# Check visibility median for connected component
visibility_median = visibility[remove_face_indices[cc]].median()
if visibility_median > 0.25:
continue
# Check if the cutting loop is small enough
cc_edge_indices, cc_edges_degree = torch.unique(
face2edge[remove_face_indices[cc]], return_counts=True
)
cc_boundary_edge_indices = cc_edge_indices[cc_edges_degree == 1]
cc_new_boundary_edge_indices = cc_boundary_edge_indices[
~torch.isin(cc_boundary_edge_indices, boundary_edge_indices)
]
if len(cc_new_boundary_edge_indices) > 0:
cc_new_boundary_edge_cc = (
utils3d.torch.compute_edge_connected_components(
edges[cc_new_boundary_edge_indices]
)
)
cc_new_boundary_edges_cc_center = [
vertices[edges[cc_new_boundary_edge_indices[edge_cc]]]
.mean(dim=1)
.mean(dim=0)
for edge_cc in cc_new_boundary_edge_cc
]
cc_new_boundary_edges_cc_area = []
for i, edge_cc in enumerate(cc_new_boundary_edge_cc):
_e1 = (
vertices[
edges[cc_new_boundary_edge_indices[edge_cc]][:, 0]
]
- cc_new_boundary_edges_cc_center[i]
)
_e2 = (
vertices[
edges[cc_new_boundary_edge_indices[edge_cc]][:, 1]
]
- cc_new_boundary_edges_cc_center[i]
)
cc_new_boundary_edges_cc_area.append(
torch.norm(torch.cross(_e1, _e2, dim=-1), dim=1).sum()
* 0.5
)
cutting_edges.append(cc_new_boundary_edge_indices)
if any(
[
_l > max_hole_size
for _l in cc_new_boundary_edges_cc_area
]
):
continue
valid_remove_cc.append(cc)
if len(valid_remove_cc) > 0:
remove_face_indices = remove_face_indices[
torch.cat(valid_remove_cc)
]
mask = torch.ones(
faces.shape[0], dtype=torch.bool, device=faces.device
)
mask[remove_face_indices] = 0
faces = faces[mask]
faces, vertices = utils3d.torch.remove_unreferenced_vertices(
faces, vertices
)
tqdm.write(f"Removed {(~mask).sum()} faces by mincut")
else:
tqdm.write(f"Removed 0 faces by mincut")
# Fill small boundaries (holes)
mesh = _meshfix.PyTMesh()
mesh.load_array(vertices.cpu().numpy(), faces.cpu().numpy())
mesh.fill_small_boundaries(nbe=max_hole_nbe, refine=True)
_vertices, _faces = mesh.return_arrays()
vertices = torch.tensor(_vertices).to(vertices)
faces = torch.tensor(_faces).to(faces)
return vertices, faces
@property
def vertices_np(self) -> np.ndarray:
return self.vertices.cpu().numpy()
@property
def faces_np(self) -> np.ndarray:
return self.faces.cpu().numpy()
@log_mesh_changes
def simplify(self, ratio: float) -> None:
"""Simplify the mesh using quadric edge collapse decimation.
Args:
ratio (float): Ratio of faces to filter out.
"""
if ratio <= 0 or ratio >= 1:
raise ValueError("Simplify ratio must be between 0 and 1.")
# Convert to PyVista format for simplification
mesh = pv.PolyData(
self.vertices_np,
np.hstack([np.full((self.faces.shape[0], 1), 3), self.faces_np]),
)
mesh = mesh.decimate(ratio, progress_bar=True)
# Update vertices and faces
self.vertices = torch.tensor(
mesh.points, device=self.device, dtype=torch.float32
)
self.faces = torch.tensor(
mesh.faces.reshape(-1, 4)[:, 1:],
device=self.device,
dtype=torch.int32,
)
@spaces.GPU
def __call__(
self,
filter_ratio: float,
max_hole_size: float,
resolution: int,
num_views: int,
norm_mesh_ratio: float = 1.0,
) -> Tuple[np.ndarray, np.ndarray]:
"""Post-process the mesh by simplifying and filling holes.
This method performs a two-step process:
1. Simplifies mesh by reducing faces using quadric edge decimation.
2. Fills holes by removing invisible faces, repairing small boundaries.
Args:
filter_ratio (float): Ratio of faces to simplify out.
Must be in the range (0, 1).
max_hole_size (float): Maximum area of a hole to fill. Connected
components of holes larger than this size will not be repaired.
resolution (int): Resolution of the rasterization buffer.
num_views (int): Number of viewpoints to sample for rasterization.
norm_mesh_ratio (float, optional): A scaling factor applied to the
vertices of the mesh during processing.
Returns:
Tuple[np.ndarray, np.ndarray]:
- vertices: Simplified and repaired vertex array of (V, 3).
- faces: Simplified and repaired face array of (F, 3).
"""
self.simplify(ratio=filter_ratio)
self.fill_holes(
max_hole_size=max_hole_size,
max_hole_nbe=int(250 * np.sqrt(1 - filter_ratio)),
resolution=resolution,
num_views=num_views,
norm_mesh_ratio=norm_mesh_ratio,
)
return self.vertices_np, self.faces_np