|
import tqdm |
|
import random |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import nerfacc |
|
|
|
import comfy.utils |
|
|
|
from pytorch_msssim import SSIM, MS_SSIM |
|
|
|
from kiui.op import safe_normalize |
|
from kiui.cam import orbit_camera |
|
from kiui.nn import MLP, trunc_exp |
|
|
|
from shared_utils.image_utils import prepare_torch_img |
|
|
|
class InstantNGP(nn.Module): |
|
def __init__(self, resolution=128, device="cuda"): |
|
super().__init__() |
|
from kiui.gridencoder import GridEncoder |
|
|
|
self.device = torch.device(device) |
|
self.ref_size_H = resolution |
|
self.ref_size_W = resolution |
|
|
|
self.render_step_size = 5e-3 |
|
self.aabb = torch.tensor([-1.0, -1.0, -1.0, 1.0, 1.0, 1.0], device=self.device) |
|
self.estimator = nerfacc.OccGridEstimator(roi_aabb=self.aabb, resolution=64, levels=1) |
|
|
|
self.encoder_density = GridEncoder(num_levels=12) |
|
self.encoder = GridEncoder(num_levels=12) |
|
self.mlp_density = MLP(self.encoder_density.output_dim, 1, 32, 2, bias=False) |
|
self.mlp = MLP(self.encoder.output_dim, 3, 32, 2, bias=False) |
|
|
|
def get_rays(self, pose, h, w, fovy, opengl=True): |
|
|
|
x, y = torch.meshgrid( |
|
torch.arange(w, device=pose.device), |
|
torch.arange(h, device=pose.device), |
|
indexing="xy", |
|
) |
|
x = x.flatten() |
|
y = y.flatten() |
|
|
|
cx = w * 0.5 |
|
cy = h * 0.5 |
|
|
|
focal = h * 0.5 / np.tan(0.5 * np.deg2rad(fovy)) |
|
|
|
camera_dirs = F.pad( |
|
torch.stack( |
|
[ |
|
(x - cx + 0.5) / focal, |
|
(y - cy + 0.5) / focal * (-1.0 if opengl else 1.0), |
|
], |
|
dim=-1, |
|
), |
|
(0, 1), |
|
value=(-1.0 if opengl else 1.0), |
|
) |
|
|
|
rays_d = camera_dirs @ pose[:3, :3].transpose(0, 1) |
|
rays_o = pose[:3, 3].unsqueeze(0).expand_as(rays_d) |
|
|
|
rays_o = rays_o.view(h, w, 3) |
|
rays_d = safe_normalize(rays_d).view(h, w, 3) |
|
|
|
return rays_o, rays_d |
|
|
|
def get_color(self, xs): |
|
return torch.sigmoid(self.mlp(self.encoder(xs.to(self.device)))) |
|
|
|
def get_density(self, xs): |
|
|
|
xs = xs.to(self.device) |
|
prefix = xs.shape[:-1] |
|
xs = xs.view(-1, 3) |
|
feats = self.encoder_density(xs) |
|
density = trunc_exp(self.mlp_density(feats)) |
|
density = density.view(*prefix, 1) |
|
return density |
|
|
|
def prepare_training(self, reference_images, reference_masks, reference_orbit_camera_poses, reference_orbit_camera_fovy): |
|
self.ref_imgs_num = len(reference_images) |
|
|
|
self.all_ref_cam_poses = reference_orbit_camera_poses |
|
self.ref_cam_fovy = reference_orbit_camera_fovy |
|
|
|
|
|
ref_imgs_torch_list = [] |
|
ref_masks_torch_list = [] |
|
for i in range(self.ref_imgs_num): |
|
ref_imgs_torch_list.append(prepare_torch_img(reference_images[i].unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device)) |
|
ref_masks_torch_list.append(prepare_torch_img(reference_masks[i].unsqueeze(2).unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device)) |
|
|
|
self.ref_imgs_torch = torch.cat(ref_imgs_torch_list, dim=0) |
|
self.ref_masks_torch = torch.cat(ref_masks_torch_list, dim=0).squeeze(1) |
|
|
|
def render_nerf(self, pose, bg_color=1): |
|
|
|
pose = torch.from_numpy(pose.astype(np.float32)).to(self.device) |
|
|
|
|
|
rays_o, rays_d = self.get_rays(pose, self.ref_size_H, self.ref_size_W, self.ref_cam_fovy) |
|
hw = rays_o.shape[0] * rays_o.shape[1] |
|
rays_o = rays_o.view(hw, 3) |
|
rays_d = rays_d.view(hw, 3) |
|
|
|
|
|
if self.training: |
|
def occ_eval_fn(xs): |
|
sigmas = self.get_density(xs) |
|
return self.render_step_size * sigmas |
|
|
|
self.estimator.update_every_n_steps(self.render_step, occ_eval_fn=occ_eval_fn, occ_thre=0.01, n=8) |
|
self.render_step += 1 |
|
|
|
|
|
def sigma_fn(t_starts, t_ends, ray_indices): |
|
t_origins = rays_o[ray_indices] |
|
t_dirs = rays_d[ray_indices] |
|
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 |
|
sigmas = self.get_density(xs) |
|
return sigmas.squeeze(-1) |
|
|
|
with torch.no_grad(): |
|
ray_indices, t_starts, t_ends = self.estimator.sampling( |
|
rays_o, |
|
rays_d, |
|
sigma_fn=sigma_fn, |
|
near_plane=0.01, |
|
far_plane=100, |
|
render_step_size=self.render_step_size, |
|
stratified=self.training, |
|
cone_angle=0, |
|
) |
|
|
|
t_origins = rays_o[ray_indices] |
|
t_dirs = rays_d[ray_indices] |
|
xs = t_origins + t_dirs * (t_starts + t_ends)[:, None] / 2.0 |
|
sigmas = self.get_density(xs).squeeze(-1) |
|
rgbs = torch.sigmoid(self.mlp(self.encoder(xs))) |
|
|
|
n_rays=rays_o.shape[0] |
|
weights, t, alphas = nerfacc.render_weight_from_density(t_starts, t_ends, sigmas, ray_indices=ray_indices, n_rays=n_rays) |
|
color = nerfacc.accumulate_along_rays(weights, values=rgbs, ray_indices=ray_indices, n_rays=n_rays) |
|
alpha = nerfacc.accumulate_along_rays(weights, values=None, ray_indices=ray_indices, n_rays=n_rays) |
|
|
|
color = color + (1.0 - alpha) * bg_color |
|
|
|
color = color.view(self.ref_size_H, self.ref_size_W, 3).clamp(0, 1).permute(2, 0, 1).contiguous() |
|
alpha = alpha.view(self.ref_size_H, self.ref_size_W).clamp(0, 1).contiguous() |
|
|
|
return color, alpha |
|
|
|
def fit_nerf(self, iters=512, bg_color=1): |
|
|
|
optimizer = torch.optim.Adam([ |
|
{'params': self.encoder_density.parameters(), 'lr': 1e-2}, |
|
{'params': self.encoder.parameters(), 'lr': 1e-2}, |
|
{'params': self.mlp_density.parameters(), 'lr': 1e-3}, |
|
{'params': self.mlp.parameters(), 'lr': 1e-3}, |
|
]) |
|
|
|
print(f"[INFO] fitting nerf...") |
|
self.render_step = 0 |
|
|
|
ref_imgs_num_minus_1 = self.ref_imgs_num-1 |
|
|
|
comfy_pbar = comfy.utils.ProgressBar(iters) |
|
pbar = tqdm.trange(iters) |
|
for step in pbar: |
|
|
|
i = random.randint(0, ref_imgs_num_minus_1) |
|
|
|
radius, elevation, azimuth, center_X, center_Y, center_Z = self.all_ref_cam_poses[i] |
|
|
|
orbit_target = np.array([center_X, center_Y, center_Z], dtype=np.float32) |
|
pose = orbit_camera(elevation, azimuth, radius, target=orbit_target) |
|
|
|
image_gt = self.ref_imgs_torch[i] |
|
alpha_gt = self.ref_masks_torch[i] |
|
image_pred, alpha_pred = self.render_nerf(pose, bg_color) |
|
|
|
|
|
|
|
|
|
loss_mse = F.mse_loss(image_pred, image_gt) + 0.1 * F.mse_loss(alpha_pred, alpha_gt) |
|
loss = loss_mse |
|
|
|
|
|
loss.backward() |
|
self.encoder_density.grad_total_variation(1e-8) |
|
|
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
pbar.set_description(f"NeRF Fitting Loss = {loss_mse.item():.6f}") |
|
comfy_pbar.update_absolute(step + 1) |
|
|
|
torch.cuda.synchronize() |
|
|
|
print(f"[INFO] finished fitting nerf!") |