xinjie.wang
update
c85a832
import logging
import os
import struct
from dataclasses import dataclass, field
from typing import Optional, Union
import cv2
import numpy as np
import torch
from gsplat.cuda._wrapper import spherical_harmonics
from gsplat.rendering import rasterization
from plyfile import PlyData
from scipy.spatial.transform import Rotation
from torch.nn import functional as F
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__all__ = [
"RenderResult",
"GaussianOperator",
]
def quat_mult(q1, q2):
# NOTE:
# Q1 is the quaternion that rotates the vector from the original position to the final position # noqa
# Q2 is the quaternion that been rotated
w1, x1, y1, z1 = q1.T
w2, x2, y2, z2 = q2.T
w = w1 * w2 - x1 * x2 - y1 * y2 - z1 * z2
x = w1 * x2 + x1 * w2 + y1 * z2 - z1 * y2
y = w1 * y2 - x1 * z2 + y1 * w2 + z1 * x2
z = w1 * z2 + x1 * y2 - y1 * x2 + z1 * w2
return torch.stack([w, x, y, z]).T
def quat_to_rotmat(quats: torch.Tensor, mode="wxyz") -> torch.Tensor:
"""Convert quaternion to rotation matrix."""
quats = F.normalize(quats, p=2, dim=-1)
if mode == "xyzw":
x, y, z, w = torch.unbind(quats, dim=-1)
elif mode == "wxyz":
w, x, y, z = torch.unbind(quats, dim=-1)
else:
raise ValueError(f"Invalid mode: {mode}.")
R = torch.stack(
[
1 - 2 * (y**2 + z**2),
2 * (x * y - w * z),
2 * (x * z + w * y),
2 * (x * y + w * z),
1 - 2 * (x**2 + z**2),
2 * (y * z - w * x),
2 * (x * z - w * y),
2 * (y * z + w * x),
1 - 2 * (x**2 + y**2),
],
dim=-1,
)
return R.reshape(quats.shape[:-1] + (3, 3))
def gamma_shs(shs: torch.Tensor, gamma: float) -> torch.Tensor:
C0 = 0.28209479177387814 # Constant for normalization in spherical harmonics # noqa
# Clip to the range [0.0, 1.0], apply gamma correction, and then un-clip back # noqa
new_shs = torch.clip(shs * C0 + 0.5, 0.0, 1.0)
new_shs = (torch.pow(new_shs, gamma) - 0.5) / C0
return new_shs
@dataclass
class RenderResult:
rgb: np.ndarray
depth: np.ndarray
opacity: np.ndarray
mask_threshold: float = 10
mask: Optional[np.ndarray] = None
rgba: Optional[np.ndarray] = None
def __post_init__(self):
if isinstance(self.rgb, torch.Tensor):
rgb = self.rgb.detach().cpu().numpy()
rgb = (rgb * 255).astype(np.uint8)
self.rgb = cv2.cvtColor(rgb, cv2.COLOR_BGR2RGB)
if isinstance(self.depth, torch.Tensor):
self.depth = self.depth.detach().cpu().numpy()
if isinstance(self.opacity, torch.Tensor):
opacity = self.opacity.detach().cpu().numpy()
opacity = (opacity * 255).astype(np.uint8)
self.opacity = cv2.cvtColor(opacity, cv2.COLOR_GRAY2RGB)
mask = np.where(self.opacity > self.mask_threshold, 255, 0)
self.mask = mask[..., 0:1].astype(np.uint8)
self.rgba = np.concatenate([self.rgb, self.mask], axis=-1)
@dataclass
class GaussianBase:
_opacities: torch.Tensor
_means: torch.Tensor
_scales: torch.Tensor
_quats: torch.Tensor
_rgbs: Optional[torch.Tensor] = None
_features_dc: Optional[torch.Tensor] = None
_features_rest: Optional[torch.Tensor] = None
sh_degree: Optional[int] = 0
device: str = "cuda"
def __post_init__(self):
self.active_sh_degree: int = self.sh_degree
self.to(self.device)
def to(self, device: str) -> None:
for k, v in self.__dict__.items():
if not isinstance(v, torch.Tensor):
continue
self.__dict__[k] = v.to(device)
def get_numpy_data(self):
data = {}
for k, v in self.__dict__.items():
if not isinstance(v, torch.Tensor):
continue
data[k] = v.detach().cpu().numpy()
return data
def quat_norm(self, x: torch.Tensor) -> torch.Tensor:
return x / x.norm(dim=-1, keepdim=True)
@classmethod
def load_from_ply(
cls,
path: str,
gamma: float = 1.0,
device: str = "cuda",
) -> "GaussianBase":
plydata = PlyData.read(path)
xyz = torch.stack(
(
torch.tensor(plydata.elements[0]["x"], dtype=torch.float32),
torch.tensor(plydata.elements[0]["y"], dtype=torch.float32),
torch.tensor(plydata.elements[0]["z"], dtype=torch.float32),
),
dim=1,
)
opacities = torch.tensor(
plydata.elements[0]["opacity"], dtype=torch.float32
).unsqueeze(-1)
features_dc = torch.zeros((xyz.shape[0], 3), dtype=torch.float32)
features_dc[:, 0] = torch.tensor(
plydata.elements[0]["f_dc_0"], dtype=torch.float32
)
features_dc[:, 1] = torch.tensor(
plydata.elements[0]["f_dc_1"], dtype=torch.float32
)
features_dc[:, 2] = torch.tensor(
plydata.elements[0]["f_dc_2"], dtype=torch.float32
)
scale_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("scale_")
]
scale_names = sorted(scale_names, key=lambda x: int(x.split("_")[-1]))
scales = torch.zeros(
(xyz.shape[0], len(scale_names)), dtype=torch.float32
)
for idx, attr_name in enumerate(scale_names):
scales[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
rot_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("rot_")
]
rot_names = sorted(rot_names, key=lambda x: int(x.split("_")[-1]))
rots = torch.zeros((xyz.shape[0], len(rot_names)), dtype=torch.float32)
for idx, attr_name in enumerate(rot_names):
rots[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
rots = rots / torch.norm(rots, dim=-1, keepdim=True)
# extra features
extra_f_names = [
p.name
for p in plydata.elements[0].properties
if p.name.startswith("f_rest_")
]
extra_f_names = sorted(
extra_f_names, key=lambda x: int(x.split("_")[-1])
)
max_sh_degree = int(np.sqrt((len(extra_f_names) + 3) / 3) - 1)
if max_sh_degree != 0:
features_extra = torch.zeros(
(xyz.shape[0], len(extra_f_names)), dtype=torch.float32
)
for idx, attr_name in enumerate(extra_f_names):
features_extra[:, idx] = torch.tensor(
plydata.elements[0][attr_name], dtype=torch.float32
)
features_extra = features_extra.view(
(features_extra.shape[0], 3, (max_sh_degree + 1) ** 2 - 1)
)
features_extra = features_extra.permute(0, 2, 1)
if abs(gamma - 1.0) > 1e-3:
features_dc = gamma_shs(features_dc, gamma)
features_extra[..., :] = 0.0
opacities *= 0.8
shs = torch.cat(
[
features_dc.reshape(-1, 3),
features_extra.reshape(len(features_dc), -1),
],
dim=-1,
)
else:
# sh_dim is 0, only dc features
shs = features_dc
features_extra = None
return cls(
sh_degree=max_sh_degree,
_means=xyz,
_opacities=opacities,
_rgbs=shs,
_scales=scales,
_quats=rots,
_features_dc=features_dc,
_features_rest=features_extra,
device=device,
)
def save_to_ply(
self, path: str, colors: torch.Tensor = None, enable_mask: bool = False
):
os.makedirs(os.path.dirname(path), exist_ok=True)
numpy_data = self.get_numpy_data()
means = numpy_data["_means"]
scales = numpy_data["_scales"]
quats = numpy_data["_quats"]
opacities = numpy_data["_opacities"]
sh0 = numpy_data["_features_dc"]
shN = numpy_data.get("_features_rest", np.zeros((means.shape[0], 0)))
shN = shN.reshape(means.shape[0], -1)
# Create a mask to identify rows with NaN or Inf in any of the numpy_data arrays # noqa
if enable_mask:
invalid_mask = (
np.isnan(means).any(axis=1)
| np.isinf(means).any(axis=1)
| np.isnan(scales).any(axis=1)
| np.isinf(scales).any(axis=1)
| np.isnan(quats).any(axis=1)
| np.isinf(quats).any(axis=1)
| np.isnan(opacities).any(axis=0)
| np.isinf(opacities).any(axis=0)
| np.isnan(sh0).any(axis=1)
| np.isinf(sh0).any(axis=1)
| np.isnan(shN).any(axis=1)
| np.isinf(shN).any(axis=1)
)
# Filter out rows with NaNs or Infs from all data arrays
means = means[~invalid_mask]
scales = scales[~invalid_mask]
quats = quats[~invalid_mask]
opacities = opacities[~invalid_mask]
sh0 = sh0[~invalid_mask]
shN = shN[~invalid_mask]
num_points = means.shape[0]
with open(path, "wb") as f:
# Write PLY header
f.write(b"ply\n")
f.write(b"format binary_little_endian 1.0\n")
f.write(f"element vertex {num_points}\n".encode())
f.write(b"property float x\n")
f.write(b"property float y\n")
f.write(b"property float z\n")
f.write(b"property float nx\n")
f.write(b"property float ny\n")
f.write(b"property float nz\n")
if colors is not None:
for j in range(colors.shape[1]):
f.write(f"property float f_dc_{j}\n".encode())
else:
for i, data in enumerate([sh0, shN]):
prefix = "f_dc" if i == 0 else "f_rest"
for j in range(data.shape[1]):
f.write(f"property float {prefix}_{j}\n".encode())
f.write(b"property float opacity\n")
for i in range(scales.shape[1]):
f.write(f"property float scale_{i}\n".encode())
for i in range(quats.shape[1]):
f.write(f"property float rot_{i}\n".encode())
f.write(b"end_header\n")
# Write vertex data
for i in range(num_points):
f.write(struct.pack("<fff", *means[i])) # x, y, z
f.write(struct.pack("<fff", 0, 0, 0)) # nx, ny, nz (zeros)
if colors is not None:
color = colors.detach().cpu().numpy()
for j in range(color.shape[1]):
f_dc = (color[i, j] - 0.5) / 0.2820947917738781
f.write(struct.pack("<f", f_dc))
else:
for data in [sh0, shN]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))
f.write(struct.pack("<f", opacities[i])) # opacity
for data in [scales, quats]:
for j in range(data.shape[1]):
f.write(struct.pack("<f", data[i, j]))
@dataclass
class GaussianOperator(GaussianBase):
def _compute_transform(
self,
means: torch.Tensor,
quats: torch.Tensor,
instance_pose: torch.Tensor,
):
"""Compute the transform of the GS models.
Args:
means: tensor of gs means.
quats: tensor of gs quaternions.
instance_pose: instances poses in [x y z qx qy qz qw] format.
"""
# (x y z qx qy qz qw) -> (x y z qw qx qy qz)
instance_pose = instance_pose[[0, 1, 2, 6, 3, 4, 5]]
cur_instances_quats = self.quat_norm(instance_pose[3:])
rot_cur = quat_to_rotmat(cur_instances_quats, mode="wxyz")
# update the means
num_gs = means.shape[0]
trans_per_pts = torch.stack([instance_pose[:3]] * num_gs, dim=0)
quat_per_pts = torch.stack([instance_pose[3:]] * num_gs, dim=0)
rot_per_pts = torch.stack([rot_cur] * num_gs, dim=0) # (num_gs, 3, 3)
# update the means
cur_means = (
torch.bmm(rot_per_pts, means.unsqueeze(-1)).squeeze(-1)
+ trans_per_pts
)
# update the quats
_quats = self.quat_norm(quats)
cur_quats = quat_mult(quat_per_pts, _quats)
return cur_means, cur_quats
def get_gaussians(
self,
c2w: torch.Tensor = None,
instance_pose: torch.Tensor = None,
apply_activate: bool = False,
) -> "GaussianBase":
"""Get Gaussian data under the given instance_pose."""
if c2w is None:
c2w = torch.eye(4).to(self.device)
if instance_pose is not None:
# compute the transformed gs means and quats
world_means, world_quats = self._compute_transform(
self._means, self._quats, instance_pose.float().to(self.device)
)
else:
world_means, world_quats = self._means, self._quats
# get colors of gaussians
if self._features_rest is not None:
colors = torch.cat(
(self._features_dc[:, None, :], self._features_rest), dim=1
)
else:
colors = self._features_dc[:, None, :]
if self.sh_degree > 0:
viewdirs = world_means.detach() - c2w[..., :3, 3] # (N, 3)
viewdirs = viewdirs / viewdirs.norm(dim=-1, keepdim=True)
rgbs = spherical_harmonics(self.sh_degree, viewdirs, colors)
rgbs = torch.clamp(rgbs + 0.5, 0.0, 1.0)
else:
rgbs = torch.sigmoid(colors[:, 0, :])
gs_dict = dict(
_means=world_means,
_opacities=(
torch.sigmoid(self._opacities)
if apply_activate
else self._opacities
),
_rgbs=rgbs,
_scales=(
torch.exp(self._scales) if apply_activate else self._scales
),
_quats=self.quat_norm(world_quats),
_features_dc=self._features_dc,
_features_rest=self._features_rest,
sh_degree=self.sh_degree,
device=self.device,
)
print("self.device", self.device)
return GaussianOperator(**gs_dict)
def rescale(self, scale: float):
if scale != 1.0:
self._means *= scale
self._scales += torch.log(self._scales.new_tensor(scale))
def set_scale_by_height(self, real_height: float) -> None:
def _ptp(tensor, dim):
val = tensor.max(dim=dim).values - tensor.min(dim=dim).values
return val.tolist()
xyz_scale = max(_ptp(self._means, dim=0))
self.rescale(1 / (xyz_scale + 1e-6)) # Normalize to [-0.5, 0.5]
raw_height = _ptp(self._means, dim=0)[1]
scale = real_height / raw_height
self.rescale(scale)
return
@staticmethod
def resave_ply(
in_ply: str,
out_ply: str,
real_height: float = None,
instance_pose: np.ndarray = None,
sh_degree: int = 0,
device: str = "cuda",
) -> None:
gs_model = GaussianOperator.load_from_ply(in_ply, sh_degree, device=device)
if instance_pose is not None:
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
if real_height is not None:
gs_model.set_scale_by_height(real_height)
gs_model.save_to_ply(out_ply)
return
@staticmethod
def trans_to_quatpose(
rot_matrix: list[list[float]],
trans_matrix: list[float] = [0, 0, 0],
) -> torch.Tensor:
if isinstance(rot_matrix, list):
rot_matrix = np.array(rot_matrix)
rot = Rotation.from_matrix(rot_matrix)
qx, qy, qz, qw = rot.as_quat()
instance_pose = torch.tensor([*trans_matrix, qx, qy, qz, qw])
return instance_pose
def render(
self,
c2w: torch.Tensor,
Ks: torch.Tensor,
image_width: int,
image_height: int,
) -> RenderResult:
gs = self.get_gaussians(c2w, apply_activate=True)
renders, alphas, _ = rasterization(
means=gs._means,
quats=gs._quats,
scales=gs._scales,
opacities=gs._opacities.squeeze(),
colors=gs._rgbs,
viewmats=torch.linalg.inv(c2w)[None, ...],
Ks=Ks[None, ...],
width=image_width,
height=image_height,
packed=False,
absgrad=True,
sparse_grad=False,
# rasterize_mode="classic",
rasterize_mode="antialiased",
**{
"near_plane": 0.01,
"far_plane": 1000000000,
"radius_clip": 0.0,
"render_mode": "RGB+ED",
},
)
renders = renders[0]
alphas = alphas[0].squeeze(-1)
assert renders.shape[-1] == 4, f"Must render rgb, depth and alpha"
rendered_rgb, rendered_depth = torch.split(renders, [3, 1], dim=-1)
return RenderResult(
torch.clamp(rendered_rgb, min=0, max=1),
rendered_depth,
alphas[..., None],
)
if __name__ == "__main__":
input_gs = "outputs/test/debug.ply"
output_gs = "./debug_v3.ply"
gs_model: GaussianOperator = GaussianOperator.load_from_ply(input_gs)
# 绕 x 轴旋转 180°
R_x = [[1, 0, 0], [0, -1, 0], [0, 0, -1]]
instance_pose = gs_model.trans_to_quatpose(R_x)
gs_model = gs_model.get_gaussians(instance_pose=instance_pose)
gs_model.rescale(2)
gs_model.set_scale_by_height(1.3)
gs_model.save_to_ply(output_gs)