|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import nvdiffrast.torch as dr |
|
|
|
|
|
|
|
|
|
|
|
def dot(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: |
|
return torch.sum(x*y, -1, keepdim=True) |
|
|
|
def length(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor: |
|
return torch.sqrt(torch.clamp(dot(x,x), min=eps)) |
|
|
|
def safe_normalize(x: torch.Tensor, eps: float =1e-8) -> torch.Tensor: |
|
return x / length(x, eps) |
|
|
|
def perspective(fovy=0.7854, aspect=1.0, n=0.1, f=1000.0, device=None): |
|
y = np.tan(fovy / 2) |
|
return torch.tensor([[1/(y*aspect), 0, 0, 0], |
|
[ 0, 1/-y, 0, 0], |
|
[ 0, 0, -(f+n)/(f-n), -(2*f*n)/(f-n)], |
|
[ 0, 0, -1, 0]], dtype=torch.float32, device=device) |
|
|
|
def translate(x, y, z, device=None): |
|
return torch.tensor([[1, 0, 0, x], |
|
[0, 1, 0, y], |
|
[0, 0, 1, z], |
|
[0, 0, 0, 1]], dtype=torch.float32, device=device) |
|
|
|
@torch.no_grad() |
|
def random_rotation_translation(t, device=None): |
|
m = np.random.normal(size=[3, 3]) |
|
m[1] = np.cross(m[0], m[2]) |
|
m[2] = np.cross(m[0], m[1]) |
|
m = m / np.linalg.norm(m, axis=1, keepdims=True) |
|
m = np.pad(m, [[0, 1], [0, 1]], mode='constant') |
|
m[3, 3] = 1.0 |
|
m[:3, 3] = np.random.uniform(-t, t, size=[3]) |
|
return torch.tensor(m, dtype=torch.float32, device=device) |
|
|
|
def rotate_x(a, device=None): |
|
s, c = np.sin(a), np.cos(a) |
|
return torch.tensor([[1, 0, 0, 0], |
|
[0, c, s, 0], |
|
[0, -s, c, 0], |
|
[0, 0, 0, 1]], dtype=torch.float32, device=device) |
|
|
|
def rotate_y(a, device=None): |
|
s, c = np.sin(a), np.cos(a) |
|
return torch.tensor([[ c, 0, s, 0], |
|
[ 0, 1, 0, 0], |
|
[-s, 0, c, 0], |
|
[ 0, 0, 0, 1]], dtype=torch.float32, device=device) |
|
|
|
class SimpleMesh: |
|
def __init__(self, vertices, faces): |
|
self.vertices = vertices |
|
self.faces = faces |
|
|
|
def auto_normals(self): |
|
v0 = self.vertices[self.faces[:, 0], :] |
|
v1 = self.vertices[self.faces[:, 1], :] |
|
v2 = self.vertices[self.faces[:, 2], :] |
|
nrm = safe_normalize(torch.cross(v1 - v0, v2 - v0, dim=-1)) |
|
self.nrm = nrm |
|
|
|
def xfm_points(points, matrix): |
|
'''Transform points. |
|
Args: |
|
points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] |
|
matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] |
|
use_python: Use PyTorch's torch.matmul (for validation) |
|
Returns: |
|
Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. |
|
''' |
|
out = torch.matmul( |
|
torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) |
|
if torch.is_anomaly_enabled(): |
|
assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" |
|
return out |
|
|
|
def interpolate(attr, rast, attr_idx, rast_db=None): |
|
return dr.interpolate( |
|
attr, rast, attr_idx, rast_db=rast_db, |
|
diff_attrs=None if rast_db is None else 'all') |