File size: 4,671 Bytes
ffbcf9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from pathlib import Path
from jaxtyping import Float
import numpy as np
from scipy.spatial.transform import Rotation as R
from plyfile import PlyData, PlyElement
import torch
from torch import Tensor
from einops import rearrange, einsum
def construct_list_of_attributes(num_rest: int) -> list[str]:
attributes = ["x", "y", "z", "nx", "ny", "nz"]
for i in range(3):
attributes.append(f"f_dc_{i}")
for i in range(num_rest):
attributes.append(f"f_rest_{i}")
attributes.append("opacity")
for i in range(3):
attributes.append(f"scale_{i}")
for i in range(4):
attributes.append(f"rot_{i}")
return attributes
def export_ply(
means: Float[Tensor, "gaussian 3"],
scales: Float[Tensor, "gaussian 3"],
rotations: Float[Tensor, "gaussian 4"],
harmonics: Float[Tensor, "gaussian 3 d_sh"],
opacities: Float[Tensor, "gaussian"],
path: Path,
):
path = Path(path)
# Shift the scene so that the median Gaussian is at the origin.
means = means - means.median(dim=0).values
# Rescale the scene so that most Gaussians are within range [-1, 1].
scale_factor = means.abs().quantile(0.95, dim=0).max()
means = means / scale_factor
scales = scales / scale_factor
scales = scales * 4.0
scales = torch.clamp(scales, 0, 0.0075)
# Define a rotation that makes +Z be the world up vector.
# rotation = [
# [0, 0, 1],
# [-1, 0, 0],
# [0, -1, 0],
# ]
rotation = [
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
]
rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
# The Polycam viewer seems to start at a 45 degree angle. Since we want to be
# looking directly at the object, we compose a 45 degree rotation onto the above
# rotation.
# adjustment = torch.tensor(
# R.from_rotvec([0, 0, -45], True).as_matrix(),
# dtype=torch.float32,
# device=means.device,
# )
# rotation = adjustment @ rotation
# We also want to see the scene in camera space (as the default view). We therefore
# compose the w2c rotation onto the above rotation.
# rotation = rotation @ extrinsics[:3, :3].inverse()
# Apply the rotation to the means (Gaussian positions).
means = einsum(rotation, means, "i j, ... j -> ... i")
# Apply the rotation to the Gaussian rotations.
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
rotations = rotation.detach().cpu().numpy() @ rotations
rotations = R.from_matrix(rotations).as_quat()
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
rotations = np.stack((w, x, y, z), axis=-1)
# Since our axes are swizzled for the spherical harmonics, we only export the DC
# band.
harmonics_view_invariant = harmonics
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
elements = np.empty(means.shape[0], dtype=dtype_full)
attributes = (
means.detach().cpu().numpy(),
torch.zeros_like(means).detach().cpu().numpy(),
harmonics_view_invariant.detach().cpu().contiguous().numpy(),
opacities.detach().cpu().numpy(),
scales.log().detach().cpu().numpy(),
rotations,
)
attributes = np.concatenate(attributes, axis=1)
elements[:] = list(map(tuple, attributes))
path.parent.mkdir(exist_ok=True, parents=True)
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
def save_ply(outputs, path, num_gauss=3):
pad = 32
def crop_r(t):
h, w = 256, 384
H = h + pad * 2
W = w + pad * 2
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
t = t[..., pad:H-pad, pad:W-pad]
t = rearrange(t, "b c h w -> b c (h w)")
return t
def crop(t):
h, w = 256, 384
H = h + pad * 2
W = w + pad * 2
t = t[..., pad:H-pad, pad:W-pad]
return t
# import pdb
# pdb.set_trace()
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
export_ply(
means,
scales,
rotations,
harmonics,
opacities,
path
) |