jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
from typing import Union, List, Optional, Tuple, Dict
import nvdiffrast.torch as dr
import utils3d
class RastContext:
def __init__(self, backend='cuda'):
self.backend = backend
def rasterize_triangle_faces(
ctx: RastContext,
vertices: torch.Tensor,
faces: torch.Tensor,
width: int,
height: int,
attr: torch.Tensor = None,
uv: torch.Tensor = None,
texture: torch.Tensor = None,
model: torch.Tensor = None,
view: torch.Tensor = None,
projection: torch.Tensor = None,
antialiasing: Union[bool, List[int]] = True,
diff_attrs: Union[None, List[int]] = None,
) -> Dict[str, torch.Tensor]:
"""
Rasterize a mesh with vertex attributes.
"""
assert vertices.ndim == 3
assert faces.ndim == 2
# Handle vertices dimensions
if vertices.shape[-1] == 2:
vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1)
elif vertices.shape[-1] == 3:
vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1)
elif vertices.shape[-1] == 4:
pass
else:
raise ValueError(f'Wrong shape of vertices: {vertices.shape}')
# Calculate MVP matrix
mvp = projection if projection is not None else torch.eye(4, device=vertices.device)
if view is not None:
mvp = mvp @ view
if model is not None:
mvp = mvp @ model
# Transform vertices to clip space
pos_clip = vertices @ mvp.transpose(-1, -2)
faces = faces.contiguous()
if attr is not None:
attr = attr.contiguous()
# Rasterize
rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True)
# Extract basic outputs
face_id = rast_out[..., 3].flip(1)
depth = rast_out[..., 2].flip(1)
mask = (face_id > 0).float()
depth = (depth * 0.5 + 0.5) * mask + (1.0 - mask)
ret = {
'depth': depth,
'mask': mask,
'face_id': face_id,
}
# Handle attribute interpolation
if attr is not None:
image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs)
if antialiasing == True:
image = dr.antialias(image, rast_out, pos_clip, faces)
elif isinstance(antialiasing, list):
aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces)
image[..., antialiasing] = aa_image
image = image.flip(1).permute(0, 3, 1, 2)
ret['image'] = image
# Handle UV mapping
if uv is not None:
uv_map, uv_map_dr = dr.interpolate(uv, rast_out, faces, rast_db, diff_attrs='all')
ret['uv'] = uv_map
ret['uv_dr'] = uv_map_dr
if texture is not None:
texture_map = dr.texture(texture, uv_map, uv_map_dr)
ret['texture'] = texture_map.flip(1).permute(0, 3, 1, 2)
# Handle derivatives
if diff_attrs is not None:
image_dr = image_dr.flip(1).permute(0, 3, 1, 2)
ret['image_dr'] = image_dr
return ret