Spaces:
Build error
Build error
import os | |
import cv2 | |
import torch | |
import matplotlib | |
import numpy as np | |
import open3d as o3d | |
from PIL import Image | |
from copy import deepcopy | |
from omegaconf import OmegaConf | |
from scipy.spatial import cKDTree | |
def gen_config(cfg_path): | |
return OmegaConf.load(cfg_path) | |
def get_focal_from_fov(new_fov, H, W): | |
# NOTE: top-left pixel should be (0,0) | |
if W >= H: | |
f = (W / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) | |
else: | |
f = (H / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) | |
return f | |
def get_intrins_from_fov(new_fov, H, W): | |
# NOTE: top-left pixel should be (0,0) | |
f = get_focal_from_fov(new_fov,H,W) | |
new_cu = (W / 2.0) - 0.5 | |
new_cv = (H / 2.0) - 0.5 | |
new_intrins = np.array([ | |
[f, 0, new_cu ], | |
[0, f, new_cv ], | |
[0, 0, 1 ] | |
]) | |
return new_intrins | |
def dpt2xyz(dpt,intrinsic): | |
# get grid | |
height, width = dpt.shape[0:2] | |
grid_u = np.arange(width)[None,:].repeat(height,axis=0) | |
grid_v = np.arange(height)[:,None].repeat(width,axis=1) | |
grid = np.concatenate([grid_u[:,:,None],grid_v[:,:,None],np.ones_like(grid_v)[:,:,None]],axis=-1) | |
uvz = grid * dpt[:,:,None] | |
# inv intrinsic | |
inv_intrinsic = np.linalg.inv(intrinsic) | |
xyz = np.einsum(f'ab,hwb->hwa',inv_intrinsic,uvz) | |
return xyz | |
def dpt2xyz_torch(dpt,intrinsic): | |
# get grid | |
height, width = dpt.shape[0:2] | |
grid_u = torch.arange(width)[None,:].repeat(height,1) | |
grid_v = torch.arange(height)[:,None].repeat(1,width) | |
grid = torch.concatenate([grid_u[:,:,None],grid_v[:,:,None],torch.ones_like(grid_v)[:,:,None]],axis=-1).to(dpt) | |
uvz = grid * dpt[:,:,None] | |
# inv intrinsic | |
inv_intrinsic = torch.linalg.inv(intrinsic) | |
xyz = torch.einsum(f'ab,hwb->hwa',inv_intrinsic,uvz) | |
return xyz | |
def visual_pcd(xyz, color=None, normal = True): | |
if hasattr(xyz,'ndim'): | |
xyz_norm = np.mean(np.sqrt(np.sum(np.square(xyz),axis=1))) | |
xyz = xyz / xyz_norm | |
xyz = xyz.reshape(-1,3) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(xyz) | |
else: pcd = xyz | |
if color is not None: | |
color = color.reshape(-1,3) | |
pcd.colors = o3d.utility.Vector3dVector(color) | |
if normal: | |
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(0.2, 20)) | |
o3d.visualization.draw_geometries([pcd]) | |
def visual_pcds(xyzs, normal = True): | |
pcds = [] | |
for xyz in xyzs: | |
if hasattr(xyz,'ndim'): | |
# xyz_norm = np.mean(np.sqrt(np.sum(np.square(xyz),axis=1))) | |
# xyz = xyz / xyz_norm | |
xyz = xyz.reshape(-1,3) | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(xyz) | |
pcd.paint_uniform_color(np.random.rand(3)) | |
else: pcd = xyz | |
if normal: | |
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(0.2, 20)) | |
pcds.append(pcd) | |
o3d.visualization.draw_geometries(pcds) | |
def save_pic(input_pic:np.array,save_fn,normalize=True): | |
# avoid replace | |
pic = deepcopy(input_pic).astype(np.float32) | |
pic = np.nan_to_num(pic) | |
if normalize: | |
vmin = np.percentile(pic, 2) | |
vmax = np.percentile(pic, 98) | |
pic = (pic - vmin) / (vmax - vmin) | |
pic = (pic * 255.0).clip(0, 255) | |
if save_fn is not None: | |
pic_save = Image.fromarray(pic.astype(np.uint8)) | |
pic_save.save(save_fn) | |
return pic | |
def depth_colorize(dpt,sky_mask=None): | |
cm = matplotlib.colormaps["Spectral"] | |
depth = dpt_normalize(dpt,sky_mask) | |
img_colored_np = cm(depth, bytes=False)[:, :, 0:3] # value from 0 to 1 | |
return img_colored_np | |
def dpt_normalize(dpt, sky_mask = None): | |
if sky_mask is not None: | |
pic = dpt[~sky_mask] | |
else: | |
pic = dpt | |
vmin = np.percentile(pic, 2) | |
vmax = np.percentile(pic, 98) | |
dpt = (deepcopy(dpt) - vmin) / (vmax - vmin) | |
if sky_mask is not None: | |
dpt[sky_mask] = 1. | |
return dpt | |
def transform_points(pts,transform): | |
h,w=transform.shape | |
if h==3 and w==3: | |
return pts @ transform.T | |
if h==3 and w==4: | |
return pts @ transform[:,:3].T + transform[:,3:].T | |
elif h==4 and w==4: | |
return pts @ transform[0:3,:3].T + transform[0:3,3:].T | |
else: raise NotImplementedError | |
def get_nml_from_quant(quant): | |
''' | |
input N*4 | |
outut N*3 | |
follow https://arxiv.org/pdf/2404.17774 | |
''' | |
w=quant[:,0] | |
x=quant[:,1] | |
y=quant[:,2] | |
z=quant[:,3] | |
n0 = 2*x*z+2*y*w | |
n1 = 2*y*z-2*x*w | |
n2 = 1-2*x*x-2*y*y | |
nml = torch.cat((n0[:,None],n1[:,None],n2[:,None]),dim=1) | |
return nml | |
def quaternion_from_matrix(M): | |
m00 = M[..., 0, 0] | |
m01 = M[..., 0, 1] | |
m02 = M[..., 0, 2] | |
m10 = M[..., 1, 0] | |
m11 = M[..., 1, 1] | |
m12 = M[..., 1, 2] | |
m20 = M[..., 2, 0] | |
m21 = M[..., 2, 1] | |
m22 = M[..., 2, 2] | |
K = torch.zeros((len(M),4,4)).to(M) | |
K[:,0,0] = m00 - m11 - m22 | |
K[:,1,0] = m01 + m10 | |
K[:,1,1] = m11 - m00 - m22 | |
K[:,2,0] = m02 + m20 | |
K[:,2,1] = m12 + m21 | |
K[:,2,2] = m22 - m00 - m11 | |
K[:,3,0] = m21 - m12 | |
K[:,3,1] = m02 - m20 | |
K[:,3,2] = m10 - m01 | |
K[:,3,3] = m00 + m11 + m22 | |
K = K/3 | |
# quaternion is eigenvector of K that corresponds to largest eigenvalue | |
w, V = torch.linalg.eigh(K) | |
q = V[torch.arange(len(V)),:,torch.argmax(w,dim=1)] | |
q = q[:,[3, 0, 1, 2]] | |
for i in range(len(q)): | |
if q[i,0]<0.: | |
q[i] = -q[i] | |
return q | |
def numpy_quaternion_from_matrix(M): | |
H,W = M.shape[0:2] | |
M = M.reshape(-1,3,3) | |
m00 = M[..., 0, 0] | |
m01 = M[..., 0, 1] | |
m02 = M[..., 0, 2] | |
m10 = M[..., 1, 0] | |
m11 = M[..., 1, 1] | |
m12 = M[..., 1, 2] | |
m20 = M[..., 2, 0] | |
m21 = M[..., 2, 1] | |
m22 = M[..., 2, 2] | |
K = np.zeros((len(M),4,4)) | |
K[...,0,0] = m00 - m11 - m22 | |
K[...,1,0] = m01 + m10 | |
K[...,1,1] = m11 - m00 - m22 | |
K[...,2,0] = m02 + m20 | |
K[...,2,1] = m12 + m21 | |
K[...,2,2] = m22 - m00 - m11 | |
K[...,3,0] = m21 - m12 | |
K[...,3,1] = m02 - m20 | |
K[...,3,2] = m10 - m01 | |
K[...,3,3] = m00 + m11 + m22 | |
K = K/3 | |
# quaternion is eigenvector of K that corresponds to largest eigenvalue | |
w, V = np.linalg.eigh(K) | |
q = V[np.arange(len(V)),:,np.argmax(w,axis=1)] | |
q = q[...,[3, 0, 1, 2]] | |
for i in range(len(q)): | |
if q[i,0]<0.: | |
q[i] = -q[i] | |
q = q.reshape(H,W,4) | |
return q | |
def numpy_normalize(input): | |
input = input / (np.sqrt(np.sum(np.square(input),axis=-1,keepdims=True))+1e-5) | |
return input | |
class suppress_stdout_stderr(object): | |
''' | |
Avoid terminal output of diffusion processings! | |
A context manager for doing a "deep suppression" of stdout and stderr in | |
Python, i.e. will suppress all print, even if the print originates in a | |
compiled C/Fortran sub-function. | |
This will not suppress raised exceptions, since exceptions are printed | |
to stderr just before a script exits, and after the context manager has | |
exited (at least, I think that is why it lets exceptions through). | |
''' | |
def __init__(self): | |
# Open a pair of null files | |
self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] | |
# Save the actual stdout (1) and stderr (2) file descriptors. | |
self.save_fds = (os.dup(1), os.dup(2)) | |
def __enter__(self): | |
# Assign the null pointers to stdout and stderr. | |
os.dup2(self.null_fds[0], 1) | |
os.dup2(self.null_fds[1], 2) | |
def __exit__(self, *_): | |
# Re-assign the real stdout/stderr back to (1) and (2) | |
os.dup2(self.save_fds[0], 1) | |
os.dup2(self.save_fds[1], 2) | |
# Close the null files | |
os.close(self.null_fds[0]) | |
os.close(self.null_fds[1]) | |
import torch.nn.functional as F | |
def nei_delta(input,pad=2): | |
if not type(input) is torch.Tensor: | |
input = torch.from_numpy(input.astype(np.float32)) | |
if len(input.shape) < 3: | |
input = input[:,:,None] | |
h,w,c = input.shape | |
# reshape | |
input = input.permute(2,0,1)[None] | |
input = F.pad(input, pad=(pad,pad,pad,pad), mode='replicate') | |
kernel = 2*pad + 1 | |
input = F.unfold(input,[kernel,kernel],padding=0) | |
input = input.reshape(c,-1,h,w).permute(2,3,0,1).squeeze() # hw(3)*25 | |
return torch.amax(input,dim=-1),torch.amin(input,dim=-1),input | |
def inpaint_mask(render_dpt,render_rgb): | |
# edge filter delta thres | |
valid_dpt = render_dpt[render_dpt>1e-3] | |
valid_dpt = torch.sort(valid_dpt).values | |
max = valid_dpt[int(.85*len(valid_dpt))] | |
min = valid_dpt[int(.15*len(valid_dpt))] | |
ths = (max-min) * 0.2 | |
# nei check | |
nei_max, nei_min, _ = nei_delta(render_dpt,pad=1) | |
edge_mask = (nei_max - nei_min) > ths | |
# render hole | |
hole_mask = render_dpt < 1e-3 | |
# whole mask -- original noise and sparse | |
mask = edge_mask | hole_mask | |
mask = mask.cpu().float().numpy() | |
# modify rgb sightly for small holes : blur and sharpen | |
render_rgb = render_rgb.detach().cpu().numpy() | |
render_rgb = (render_rgb*255).astype(np.uint8) | |
render_rgb_blur = cv2.medianBlur(render_rgb,5) | |
render_rgb[mask>.5] = render_rgb_blur[mask>.5] # blur and replace small holes | |
render_rgb = torch.from_numpy((render_rgb/255).astype(np.float32)).to(render_dpt) | |
# slightly clean mask | |
kernel = np.ones((5,5),np.uint8) | |
mask = cv2.erode(mask,kernel,iterations=2) | |
mask = cv2.dilate(mask,kernel,iterations=7) | |
mask = mask > 0.5 | |
return mask,render_rgb | |
def alpha_inpaint_mask(render_alpha): | |
render_alpha = render_alpha.detach().squeeze().cpu().numpy() | |
paint_mask = 1.-np.around(render_alpha) | |
# slightly clean mask | |
kernel = np.ones((5,5),np.uint8) | |
paint_mask = cv2.erode(paint_mask,kernel,iterations=1) | |
paint_mask = cv2.dilate(paint_mask,kernel,iterations=3) | |
paint_mask = paint_mask > 0.5 | |
return paint_mask | |
def edge_filter(metric_dpt,sky=None,times=0.1): | |
sky = np.zeros_like(metric_dpt,bool) if sky is None else sky | |
_max = np.percentile(metric_dpt[~sky],95) | |
_min = np.percentile(metric_dpt[~sky], 5) | |
_range = _max - _min | |
nei_max,nei_min,_ = nei_delta(metric_dpt) | |
delta = (nei_max-nei_min).numpy() | |
edge = delta > times*_range | |
return edge | |
def fill_mask_with_nearest(imgs, mask): | |
# mask and un-mask pixel coors | |
mask_coords = np.column_stack(np.where(mask > .5)) | |
non_mask_coords = np.column_stack(np.where(mask < .5)) | |
# kd-tree on un-masked pixels | |
tree = cKDTree(non_mask_coords) | |
# nn search of masked pixels | |
_, idxs = tree.query(mask_coords) | |
# replace and fill | |
for i, coord in enumerate(mask_coords): | |
nearest_coord = non_mask_coords[idxs[i]] | |
for img in imgs: | |
img[coord[0], coord[1]] = img[nearest_coord[0], nearest_coord[1]] | |
return imgs | |
def edge_rectify(metric_dpt,rgb,sky=None): | |
edge = edge_filter(metric_dpt,sky) | |
process_rgb = deepcopy(rgb) | |
metric_dpt,process_rgb = fill_mask_with_nearest([metric_dpt,process_rgb],edge) | |
return metric_dpt,process_rgb | |
from plyfile import PlyData, PlyElement | |
def color2feat(color): | |
max_sh_degree = 3 | |
fused_color = (color-0.5)/0.28209479177387814 | |
features = np.zeros((fused_color.shape[0], 3, (max_sh_degree + 1) ** 2)) | |
features = torch.from_numpy(features.astype(np.float32)) | |
features[:, :3, 0 ] = fused_color | |
features[:, 3:, 1:] = 0.0 | |
features_dc = features[:,:,0:1] | |
features_rest = features[:,:,1: ] | |
return features_dc,features_rest | |
def construct_list_of_attributes(features_dc,features_rest,scale,rotation): | |
l = ['x', 'y', 'z', 'nx', 'ny', 'nz'] | |
# All channels except the 3 DC | |
for i in range(features_dc.shape[1]*features_dc.shape[2]): | |
l.append('f_dc_{}'.format(i)) | |
for i in range(features_rest.shape[1]*features_rest.shape[2]): | |
l.append('f_rest_{}'.format(i)) | |
l.append('opacity') | |
for i in range(scale.shape[1]): | |
l.append('scale_{}'.format(i)) | |
for i in range(rotation.shape[1]): | |
l.append('rot_{}'.format(i)) | |
return l | |
def save_ply(scene,path): | |
xyz = torch.cat([gf.xyz.reshape(-1,3) for gf in scene.gaussian_frames],dim=0).detach().cpu().numpy() | |
scale = torch.cat([gf.scale.reshape(-1,3) for gf in scene.gaussian_frames],dim=0).detach().cpu().numpy() | |
opacities = torch.cat([gf.opacity.reshape(-1) for gf in scene.gaussian_frames],dim=0)[:,None].detach().cpu().numpy() | |
rotation = torch.cat([gf.rotation.reshape(-1,4) for gf in scene.gaussian_frames],dim=0).detach().cpu().numpy() | |
rgb = torch.sigmoid(torch.cat([gf.rgb.reshape(-1,3) for gf in scene.gaussian_frames],dim=0)) | |
# rgb | |
features_dc, features_rest = color2feat(rgb) | |
f_dc = features_dc.flatten(start_dim=1).detach().cpu().numpy() | |
f_rest = features_rest.flatten(start_dim=1).detach().cpu().numpy() | |
normals = np.zeros_like(xyz) | |
# save | |
dtype_full = [(attribute, 'f4') for attribute in construct_list_of_attributes(features_dc,features_rest,scale,rotation)] | |
elements = np.empty(xyz.shape[0], dtype=dtype_full) | |
attributes = np.concatenate((xyz, normals, f_dc, f_rest, opacities, scale, rotation), axis=1) | |
elements[:] = list(map(tuple, attributes)) | |
el = PlyElement.describe(elements, 'vertex') | |
PlyData([el]).write(path) |