StyleNeRF / training /stylenerf.py
Jiatao Gu
add code from the original repo
94ada0b
raw
history blame
115 kB
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from bdb import set_trace
import copy
from email import generator
import imp
import math
from platform import architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import grad
from training.networks import *
from dnnlib.camera import *
from dnnlib.geometry import (
positional_encoding, upsample, downsample
)
from dnnlib.util import dividable, hash_func, EasyDict
from torch_utils.ops.hash_sample import hash_sample
from torch_utils.ops.grid_sample_gradfix import grid_sample
from torch_utils.ops.nerf_utils import topp_masking
from einops import repeat, rearrange
# --------------------------------- basic modules ------------------------------------------- #
@persistence.persistent_class
class Style2Layer(nn.Module):
def __init__(self,
in_channels,
out_channels,
w_dim,
activation='lrelu',
resample_filter=[1,3,3,1],
magnitude_ema_beta = -1, # -1 means not using magnitude ema
**unused_kwargs):
# simplified version of SynthesisLayer
# no noise, kernel size forced to be 1x1, used in NeRF block
super().__init__()
self.activation = activation
self.conv_clamp = None
self.register_buffer('resample_filter', upfirdn2d.setup_filter(resample_filter))
self.padding = 0
self.act_gain = bias_act.activation_funcs[activation].def_gain
self.w_dim = w_dim
self.in_features = in_channels
self.out_features = out_channels
memory_format = torch.contiguous_format
if w_dim > 0:
self.affine = FullyConnectedLayer(w_dim, in_channels, bias_init=1)
self.weight = torch.nn.Parameter(
torch.randn([out_channels, in_channels, 1, 1]).to(memory_format=memory_format))
self.bias = torch.nn.Parameter(torch.zeros([out_channels]))
else:
self.weight = torch.nn.Parameter(torch.Tensor(out_channels, in_channels))
self.bias = torch.nn.Parameter(torch.Tensor(out_channels))
self.weight_gain = 1.
# initialization
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in)
torch.nn.init.uniform_(self.bias, -bound, bound)
self.magnitude_ema_beta = magnitude_ema_beta
if magnitude_ema_beta > 0:
self.register_buffer('w_avg', torch.ones([]))
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, style={}'.format(
self.in_features, self.out_features, self.w_dim
)
def forward(self, x, w=None, fused_modconv=None, gain=1, up=1, **unused_kwargs):
flip_weight = True # (up == 1) # slightly faster HACK
act = self.activation
if (self.magnitude_ema_beta > 0):
if self.training: # updating EMA.
with torch.autograd.profiler.record_function('update_magnitude_ema'):
magnitude_cur = x.detach().to(torch.float32).square().mean()
self.w_avg.copy_(magnitude_cur.lerp(self.w_avg, self.magnitude_ema_beta))
input_gain = self.w_avg.rsqrt()
x = x * input_gain
if fused_modconv is None:
with misc.suppress_tracer_warnings(): # this value will be treated as a constant
fused_modconv = not self.training
if self.w_dim > 0: # modulated convolution
assert x.ndim == 4, "currently not support modulated MLP"
styles = self.affine(w) # Batch x style_dim
if x.size(0) > styles.size(0):
styles = repeat(styles, 'b c -> (b s) c', s=x.size(0) // styles.size(0))
x = modulated_conv2d(x=x, weight=self.weight, styles=styles, noise=None, up=up,
padding=self.padding, resample_filter=self.resample_filter,
flip_weight=flip_weight, fused_modconv=fused_modconv)
act_gain = self.act_gain * gain
act_clamp = self.conv_clamp * gain if self.conv_clamp is not None else None
x = bias_act.bias_act(x, self.bias.to(x.dtype), act=act, gain=act_gain, clamp=act_clamp)
else:
if x.ndim == 2: # MLP mode
x = F.relu(F.linear(x, self.weight, self.bias.to(x.dtype)))
else:
x = F.relu(F.conv2d(x, self.weight[:,:,None, None], self.bias))
# x = bias_act.bias_act(x, self.bias.to(x.dtype), act='relu')
return x
@persistence.persistent_class
class SDFDensityLaplace(nn.Module): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
def __init__(self, params_init={}, noise_std=0.0, beta_min=0.001, exp_beta=False):
super().__init__()
self.noise_std = noise_std
for p in params_init:
param = nn.Parameter(torch.tensor(params_init[p]))
setattr(self, p, param)
self.beta_min = beta_min
self.exp_beta = exp_beta
if (exp_beta == 'upper') or exp_beta:
self.register_buffer("steps", torch.scalar_tensor(0).float())
def density_func(self, sdf, beta=None):
if beta is None:
beta = self.get_beta()
alpha = 1 / beta
return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta)) # TODO: need abs maybe, not sure
def get_beta(self):
if self.exp_beta == 'upper':
beta_upper = 0.12 * torch.exp(-0.003 * (self.steps / 1e3))
beta = min(self.beta.abs(), beta_upper) + self.beta_min
elif self.exp_beta:
if self.steps < 500000:
beta = self.beta.abs() + self.beta_min
else:
beta = self.beta.abs().detach() + self.beta_min
else:
beta = self.beta.abs() + self.beta_min
return beta
def set_steps(self, steps):
if hasattr(self, "steps"):
self.steps = self.steps * 0 + steps
# ------------------------------------------------------------------------------------------- #
@persistence.persistent_class
class NeRFBlock(nn.Module):
'''
Predicts volume density and color from 3D location, viewing
direction, and latent code z.
'''
# dimensions
input_dim = 3
w_dim = 512 # style latent
z_dim = 0 # input latent
rgb_out_dim = 128
hidden_size = 128
n_blocks = 8
img_channels = 3
magnitude_ema_beta = -1
disable_latents = False
max_batch_size = 2 ** 18
shuffle_factor = 1
implementation = 'batch_reshape' # option: [flatten_2d, batch_reshape]
# architecture settings
activation = 'lrelu'
use_skip = False
use_viewdirs = False
add_rgb = False
predict_rgb = False
inverse_sphere = False
merge_sigma_feat = False # use one MLP for sigma and features
no_sigma = False # do not predict sigma, only output features
tcnn_backend = False
use_style = None
use_normal = False
use_sdf = None
volsdf_exp_beta = False
normalized_feat = False
final_sigmoid_act = False
# positional encoding inpuut
use_pos = False
n_freq_posenc = 10
n_freq_posenc_views = 4
downscale_p_by = 1
gauss_dim_pos = 20
gauss_dim_view = 4
gauss_std = 10.
positional_encoding = "normal"
def __init__(self, nerf_kwargs):
super().__init__()
for key in nerf_kwargs:
if hasattr(self, key):
setattr(self, key, nerf_kwargs[key])
self.sdf_mode = self.use_sdf
self.use_sdf = self.use_sdf is not None
if self.use_sdf == 'volsdf':
self.density_transform = SDFDensityLaplace(
params_init={'beta': 0.1},
beta_min=0.0001,
exp_beta=self.volsdf_exp_beta)
# ----------- input module -------------------------
D = self.input_dim if not self.inverse_sphere else self.input_dim + 1
if self.positional_encoding == 'gauss':
rng = np.random.RandomState(2021)
B_pos = self.gauss_std * torch.from_numpy(rng.randn(D, self.gauss_dim_pos * D)).float()
B_view = self.gauss_std * torch.from_numpy(rng.randn(3, self.gauss_dim_view * 3)).float()
self.register_buffer("B_pos", B_pos)
self.register_buffer("B_view", B_view)
dim_embed = D * self.gauss_dim_pos * 2
dim_embed_view = 3 * self.gauss_dim_view * 2
elif self.positional_encoding == 'normal':
dim_embed = D * self.n_freq_posenc * 2
dim_embed_view = 3 * self.n_freq_posenc_views * 2
else: # not using positional encoding
dim_embed, dim_embed_view = D, 3
if self.use_pos:
dim_embed, dim_embed_view = dim_embed + D, dim_embed_view + 3
self.dim_embed = dim_embed
self.dim_embed_view = dim_embed_view
# ------------ Layers --------------------------
assert not (self.add_rgb and self.predict_rgb), "only one could be achieved"
assert not ((self.use_viewdirs or self.use_normal) and (self.merge_sigma_feat or self.no_sigma)), \
"merged MLP does not support."
if self.disable_latents:
w_dim = 0
elif self.z_dim > 0: # if input global latents, disable using style vectors
w_dim, dim_embed, dim_embed_view = 0, dim_embed + self.z_dim, dim_embed_view + self.z_dim
else:
w_dim = self.w_dim
final_in_dim = self.hidden_size
if self.use_normal:
final_in_dim += D
final_out_dim = self.rgb_out_dim * self.shuffle_factor
if self.merge_sigma_feat:
final_out_dim += self.shuffle_factor # predicting sigma
if self.add_rgb:
final_out_dim += self.img_channels
# start building the model
if self.tcnn_backend:
try:
import tinycudann as tcnn
except ImportError:
raise ImportError("This sample requires the tiny-cuda-nn extension for PyTorch.")
assert self.merge_sigma_feat and (not self.predict_rgb) and (not self.add_rgb)
assert w_dim == 0, "do not use any modulating inputs"
tcnn_config = {"otype": "FullyFusedMLP", "activation": "ReLU", "output_activation": "None", "n_neurons": 64, "n_hidden_layers": 1}
self.network = tcnn.Network(dim_embed, final_out_dim, tcnn_config)
self.num_ws = 0
else:
self.fc_in = Style2Layer(dim_embed, self.hidden_size, w_dim, activation=self.activation)
self.num_ws = 1
self.skip_layer = self.n_blocks // 2 - 1 if self.use_skip else None
if self.n_blocks > 1:
self.blocks = nn.ModuleList([
Style2Layer(
self.hidden_size if i != self.skip_layer else self.hidden_size + dim_embed,
self.hidden_size,
w_dim, activation=self.activation,
magnitude_ema_beta=self.magnitude_ema_beta)
for i in range(self.n_blocks - 1)])
self.num_ws += (self.n_blocks - 1)
if not (self.merge_sigma_feat or self.no_sigma):
self.sigma_out = ToRGBLayer(self.hidden_size, self.shuffle_factor, w_dim, kernel_size=1)
self.num_ws += 1
self.feat_out = ToRGBLayer(final_in_dim, final_out_dim, w_dim, kernel_size=1)
if (self.z_dim == 0 and (not self.disable_latents)):
self.num_ws += 1
else:
self.num_ws = 0
if self.use_viewdirs:
assert self.predict_rgb, "only works when predicting RGB"
self.from_ray = Conv2dLayer(dim_embed_view, final_out_dim, kernel_size=1, activation='linear')
if self.predict_rgb: # predict RGB over features
self.to_rgb = Conv2dLayer(final_out_dim, self.img_channels * self.shuffle_factor, kernel_size=1, activation='linear')
def set_steps(self, steps):
if hasattr(self, "steps"):
self.steps.fill_(steps)
def transform_points(self, p, views=False):
p = p / self.downscale_p_by
if self.positional_encoding == 'gauss':
B = self.B_view if views else self.B_pos
p_transformed = positional_encoding(p, B, 'gauss', self.use_pos)
elif self.positional_encoding == 'normal':
L = self.n_freq_posenc_views if views else self.n_freq_posenc
p_transformed = positional_encoding(p, L, 'normal', self.use_pos)
else:
p_transformed = p
return p_transformed
def forward(self, p_in, ray_d, z_shape=None, z_app=None, ws=None, shape=None, requires_grad=False, impl=None):
with torch.set_grad_enabled(self.training or self.use_sdf or requires_grad):
impl = 'mlp' if self.tcnn_backend else impl
option, p_in = self.forward_inputs(p_in, shape=shape, impl=impl)
if self.tcnn_backend:
with torch.cuda.amp.autocast():
p = p_in.squeeze(-1).squeeze(-1)
o = self.network(p)
sigma_raw, feat = o[:, :self.shuffle_factor], o[:, self.shuffle_factor:]
sigma_raw = rearrange(sigma_raw, '(b s) d -> b s d', s=option[2]).to(p_in.dtype)
feat = rearrange(feat, '(b s) d -> b s d', s=option[2]).to(p_in.dtype)
else:
feat, sigma_raw = self.forward_nerf(option, p_in, ray_d, ws=ws, z_shape=z_shape, z_app=z_app)
return feat, sigma_raw
def forward_inputs(self, p_in, shape=None, impl=None):
# prepare the inputs
impl = impl if impl is not None else self.implementation
if (shape is not None) and (impl == 'batch_reshape'):
height, width, n_steps = shape[1:]
elif impl == 'flatten_2d':
(height, width), n_steps = dividable(p_in.shape[1]), 1
elif impl == 'mlp':
height, width, n_steps = 1, 1, p_in.shape[1]
else:
raise NotImplementedError("looking for more efficient implementation.")
p_in = rearrange(p_in, 'b (h w s) d -> (b s) d h w', h=height, w=width, s=n_steps)
use_normal = self.use_normal or self.use_sdf
if use_normal:
p_in.requires_grad_(True)
return (height, width, n_steps, use_normal), p_in
def forward_nerf(self, option, p_in, ray_d=None, ws=None, z_shape=None, z_app=None):
height, width, n_steps, use_normal = option
# forward nerf feature networks
p = self.transform_points(p_in.permute(0,2,3,1))
if (self.z_dim > 0) and (not self.disable_latents):
assert (z_shape is not None) and (ws is None)
z_shape = repeat(z_shape, 'b c -> (b s) h w c', h=height, w=width, s=n_steps)
p = torch.cat([p, z_shape], -1)
p = p.permute(0,3,1,2) # BS x C x H x W
if height == width == 1: # MLP
p = p.squeeze(-1).squeeze(-1)
net = self.fc_in(p, ws[:, 0] if ws is not None else None)
if self.n_blocks > 1:
for idx, layer in enumerate(self.blocks):
ws_i = ws[:, idx + 1] if ws is not None else None
if (self.skip_layer is not None) and (idx == self.skip_layer):
net = torch.cat([net, p], 1)
net = layer(net, ws_i, up=1)
# forward to get the final results
w_idx = self.n_blocks # fc_in, self.blocks
feat_inputs = [net]
if not (self.merge_sigma_feat or self.no_sigma):
ws_i = ws[:, w_idx] if ws is not None else None
sigma_out = self.sigma_out(net, ws_i)
if use_normal:
gradients, = grad(
outputs=sigma_out, inputs=p_in,
grad_outputs=torch.ones_like(sigma_out, requires_grad=False),
retain_graph=True, create_graph=True, only_inputs=True)
feat_inputs.append(gradients)
ws_i = ws[:, -1] if ws is not None else None
net = torch.cat(feat_inputs, 1) if len(feat_inputs) > 1 else net
feat_out = self.feat_out(net, ws_i) # this is used for lowres output
if self.merge_sigma_feat: # split sigma from the feature
sigma_out, feat_out = feat_out[:, :self.shuffle_factor], feat_out[:, self.shuffle_factor:]
elif self.no_sigma:
sigma_out = None
if self.predict_rgb:
if self.use_viewdirs and ray_d is not None:
ray_d = ray_d / torch.norm(ray_d, dim=-1, keepdim=True)
ray_d = self.transform_points(ray_d, views=True)
if self.z_dim > 0:
ray_d = torch.cat([ray_d, repeat(z_app, 'b c -> b (h w s) c', h=height, w=width, s=n_steps)], -1)
ray_d = rearrange(ray_d, 'b (h w s) d -> (b s) d h w', h=height, w=width, s=n_steps)
feat_ray = self.from_ray(ray_d)
rgb = self.to_rgb(F.leaky_relu(feat_out + feat_ray))
else:
rgb = self.to_rgb(feat_out)
if self.final_sigmoid_act:
rgb = torch.sigmoid(rgb)
if self.normalized_feat:
feat_out = feat_out / (1e-7 + feat_out.norm(dim=-1, keepdim=True))
feat_out = torch.cat([rgb, feat_out], 1)
# transform back
if feat_out.ndim == 2: # mlp mode
sigma_out = rearrange(sigma_out, '(b s) d -> b s d', s=n_steps) if sigma_out is not None else None
feat_out = rearrange(feat_out, '(b s) d -> b s d', s=n_steps)
else:
sigma_out = rearrange(sigma_out, '(b s) d h w -> b (h w s) d', s=n_steps) if sigma_out is not None else None
feat_out = rearrange(feat_out, '(b s) d h w -> b (h w s) d', s=n_steps)
return feat_out, sigma_out
@persistence.persistent_class
class CameraGenerator(torch.nn.Module):
def __init__(self, in_dim=2, hi_dim=128, out_dim=2):
super().__init__()
self.affine1 = FullyConnectedLayer(in_dim, hi_dim, activation='lrelu')
self.affine2 = FullyConnectedLayer(hi_dim, hi_dim, activation='lrelu')
self.proj = FullyConnectedLayer(hi_dim, out_dim)
def forward(self, x):
cam = self.proj(self.affine2(self.affine1(x)))
return cam
@persistence.persistent_class
class CameraRay(object):
range_u = (0, 0)
range_v = (0.25, 0.25)
range_radius = (2.732, 2.732)
depth_range = [0.5, 6.]
gaussian_camera = False
angular_camera = False
intersect_ball = False
fov = 49.13
bg_start = 1.0
depth_transform = None # "LogWarp" or "InverseWarp"
dists_normalized = False # use normalized interval instead of real dists
random_rotate = False
ray_align_corner = True
nonparam_cameras = None
def __init__(self, camera_kwargs, **other_kwargs):
if len(camera_kwargs) == 0: # for compitatbility of old checkpoints
camera_kwargs.update(other_kwargs)
for key in camera_kwargs:
if hasattr(self, key):
setattr(self, key, camera_kwargs[key])
self.camera_matrix = get_camera_mat(fov=self.fov)
def prepare_pixels(self, img_res, tgt_res, vol_res, camera_matrices, theta, margin=0, **unused):
if self.ray_align_corner:
all_pixels = self.get_pixel_coords(img_res, camera_matrices, theta=theta)
all_pixels = rearrange(all_pixels, 'b (h w) c -> b c h w', h=img_res, w=img_res)
tgt_pixels = F.interpolate(all_pixels, size=(tgt_res, tgt_res), mode='nearest') if tgt_res < img_res else all_pixels.clone()
vol_pixels = F.interpolate(tgt_pixels, size=(vol_res, vol_res), mode='nearest') if tgt_res > vol_res else tgt_pixels.clone()
vol_pixels = rearrange(vol_pixels, 'b c h w -> b (h w) c')
else: # coordinates not aligned!
tgt_pixels = self.get_pixel_coords(tgt_res, camera_matrices, corner_aligned=False, theta=theta)
vol_pixels = self.get_pixel_coords(vol_res, camera_matrices, corner_aligned=False, theta=theta, margin=margin) \
if (tgt_res > vol_res) or (margin > 0) else tgt_pixels.clone()
tgt_pixels = rearrange(tgt_pixels, 'b (h w) c -> b c h w', h=tgt_res, w=tgt_res)
return vol_pixels, tgt_pixels
def prepare_pixels_regularization(self, tgt_pixels, n_reg_samples):
# only apply when size is bigger than voxel resolution
pace = tgt_pixels.size(-1) // n_reg_samples
idxs = torch.arange(0, tgt_pixels.size(-1), pace, device=tgt_pixels.device) # n_reg_samples
u_xy = torch.rand(tgt_pixels.size(0), 2, device=tgt_pixels.device)
u_xy = (u_xy * pace).floor().long() # batch_size x 2
x_idxs, y_idxs = idxs[None,:] + u_xy[:,:1], idxs[None,:] + u_xy[:,1:]
rand_indexs = (x_idxs[:,None,:] + y_idxs[:,:,None] * tgt_pixels.size(-1)).reshape(tgt_pixels.size(0), -1)
tgt_pixels = rearrange(tgt_pixels, 'b c h w -> b (h w) c')
rand_pixels = tgt_pixels.gather(1, rand_indexs.unsqueeze(-1).repeat(1,1,2))
return rand_pixels, rand_indexs
def get_roll(self, ws, training=True, theta=None, **unused):
if (self.random_rotate is not None) and training:
theta = torch.randn(ws.size(0)).to(ws.device) * self.random_rotate / 2
theta = theta / 180 * math.pi
else:
if theta is not None:
theta = torch.ones(ws.size(0)).to(ws.device) * theta
return theta
def get_camera(self, batch_size, device, mode='random', fov=None, force_uniform=False):
if fov is not None:
camera_matrix = get_camera_mat(fov)
else:
camera_matrix = self.camera_matrix
camera_mat = camera_matrix.repeat(batch_size, 1, 1).to(device)
reg_loss = None # TODO: useless
if isinstance(mode, list):
# default camera generator, we assume input mode is linear
if len(mode) == 3:
val_u, val_v, val_r = mode
r0 = self.range_radius[0]
r1 = self.range_radius[1]
else:
val_u, val_v, val_r, r_s = mode
r0 = self.range_radius[0] * r_s
r1 = self.range_radius[1] * r_s
world_mat = get_camera_pose(
self.range_u, self.range_v, [r0, r1],
val_u, val_v, val_r,
batch_size=batch_size,
gaussian=False, # input mode is by default uniform
angular=self.angular_camera).to(device)
elif isinstance(mode, torch.Tensor):
world_mat, mode = get_camera_pose_v2(
self.range_u, self.range_v, self.range_radius, mode,
gaussian=self.gaussian_camera and (not force_uniform),
angular=self.angular_camera)
world_mat = world_mat.to(device)
mode = torch.stack(mode, 1).to(device)
else:
world_mat, mode = get_random_pose(
self.range_u, self.range_v,
self.range_radius, batch_size,
gaussian=self.gaussian_camera,
angular=self.angular_camera)
world_mat = world_mat.to(device)
mode = torch.stack(mode, 1).to(device)
return camera_mat.float(), world_mat.float(), mode, reg_loss
def get_transformed_depth(self, di, reversed=False):
depth_range = self.depth_range
if (self.depth_transform is None) or (self.depth_transform == 'None'):
g_fwd, g_inv = lambda x: x, lambda x: x
elif self.depth_transform == 'LogWarp':
g_fwd, g_inv = math.log, torch.exp
elif self.depth_transform == 'InverseWarp':
g_fwd, g_inv = lambda x: 1/x, lambda x: 1/x
else:
raise NotImplementedError
if not reversed:
return g_inv(g_fwd(depth_range[1]) * di + g_fwd(depth_range[0]) * (1 - di))
else:
d0 = (g_fwd(di) - g_fwd(depth_range[0])) / (g_fwd(depth_range[1]) - g_fwd(depth_range[0]))
return d0.clip(min=0, max=1)
def get_evaluation_points(self, pixels_world=None, camera_world=None, di=None, p_i=None, no_reshape=False, transform=None):
if p_i is None:
batch_size = pixels_world.shape[0]
n_steps = di.shape[-1]
ray_i = pixels_world - camera_world
p_i = camera_world.unsqueeze(-2).contiguous() + \
di.unsqueeze(-1).contiguous() * ray_i.unsqueeze(-2).contiguous()
ray_i = ray_i.unsqueeze(-2).repeat(1, 1, n_steps, 1)
else:
assert no_reshape, "only used to transform points to a warped space"
if transform is None:
transform = self.depth_transform
if transform == 'LogWarp':
c = torch.tensor([1., 0., 0.]).to(p_i.device)
p_i = normalization_inverse_sqrt_dist_centered(
p_i, c[None, None, None, :], self.depth_range[1])
elif transform == 'InverseWarp':
# https://arxiv.org/pdf/2111.12077.pdf
p_n = p_i.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-7)
con = p_n.ge(1).type_as(p_n)
p_i = p_i * (1 -con) + (2 - 1 / p_n) * (p_i / p_n) * con
if no_reshape:
return p_i
assert(p_i.shape == ray_i.shape)
p_i = p_i.reshape(batch_size, -1, 3)
ray_i = ray_i.reshape(batch_size, -1, 3)
return p_i, ray_i
def get_evaluation_points_bg(self, pixels_world, camera_world, di):
batch_size = pixels_world.shape[0]
n_steps = di.shape[-1]
n_pixels = pixels_world.shape[1]
ray_world = pixels_world - camera_world
ray_world = ray_world / ray_world.norm(dim=-1, keepdim=True) # normalize
camera_world = camera_world.unsqueeze(-2).expand(batch_size, n_pixels, n_steps, 3)
ray_world = ray_world.unsqueeze(-2).expand(batch_size, n_pixels, n_steps, 3)
bg_pts, _ = depth2pts_outside(camera_world, ray_world, di) # di: 1 ---> 0
bg_pts = bg_pts.reshape(batch_size, -1, 4)
ray_world = ray_world.reshape(batch_size, -1, 3)
return bg_pts, ray_world
def add_noise_to_interval(self, di):
di_mid = .5 * (di[..., 1:] + di[..., :-1])
di_high = torch.cat([di_mid, di[..., -1:]], dim=-1)
di_low = torch.cat([di[..., :1], di_mid], dim=-1)
noise = torch.rand_like(di_low)
ti = di_low + (di_high - di_low) * noise
return ti
def calc_volume_weights(self, sigma, z_vals=None, ray_vector=None, dists=None, last_dist=1e10):
if dists is None:
dists = z_vals[..., 1:] - z_vals[..., :-1]
if ray_vector is not None:
dists = dists * torch.norm(ray_vector, dim=-1, keepdim=True)
dists = torch.cat([dists, torch.ones_like(dists[..., :1]) * last_dist], dim=-1)
alpha = 1.-torch.exp(-F.relu(sigma)*dists)
if last_dist > 0:
alpha[..., -1] = 1
# alpha = 1.-torch.exp(-sigma * dists)
T = torch.cumprod(torch.cat([
torch.ones_like(alpha[:, :, :1]),
(1. - alpha + 1e-10), ], dim=-1), dim=-1)[..., :-1]
weights = alpha * T
return weights, T[..., -1], dists
def get_pixel_coords(self, tgt_res, camera_matrices, corner_aligned=True, margin=0, theta=None, invert_y=True):
device = camera_matrices[0].device
batch_size = camera_matrices[0].shape[0]
# margin = self.margin if margin is None else margin
full_pixels = arange_pixels((tgt_res, tgt_res),
batch_size, invert_y_axis=invert_y, margin=margin,
corner_aligned=corner_aligned).to(device)
if (theta is not None):
theta = theta.unsqueeze(-1)
x = full_pixels[..., 0] * torch.cos(theta) - full_pixels[..., 1] * torch.sin(theta)
y = full_pixels[..., 0] * torch.sin(theta) + full_pixels[..., 1] * torch.cos(theta)
full_pixels = torch.stack([x, y], -1)
return full_pixels
def get_origin_direction(self, pixels, camera_matrices):
camera_mat, world_mat = camera_matrices[:2]
if camera_mat.size(0) < pixels.size(0):
camera_mat = repeat(camera_mat, 'b c d -> (b s) c d', s=pixels.size(0)//camera_mat.size(0))
if world_mat.size(0) < pixels.size(0):
world_mat = repeat(world_mat, 'b c d -> (b s) c d', s=pixels.size(0)//world_mat.size(0))
pixels_world = image_points_to_world(pixels, camera_mat=camera_mat, world_mat=world_mat)
camera_world = origin_to_world(pixels.size(1), camera_mat=camera_mat, world_mat=world_mat)
ray_vector = pixels_world - camera_world
return pixels_world, camera_world, ray_vector
def set_camera_prior(self, dataset_cams):
self.nonparam_cameras = dataset_cams
@persistence.persistent_class
class VolumeRenderer(object):
n_ray_samples = 14
n_bg_samples = 4
n_final_samples = None # final nerf steps after upsampling (optional)
sigma_type = 'relu' # other allowed options including, "abs", "shiftedsoftplus", "exp"
hierarchical = True
fine_only = False
no_background = False
white_background = False
mask_background = False
pre_volume_size = None
bound = None
density_p_target = 1.0
tv_loss_weight = 0.0 # for now only works for density-based voxels
def __init__(self, renderer_kwargs, camera_ray, input_encoding=None, **other_kwargs):
if len(renderer_kwargs) == 0: # for compitatbility of old checkpoints
renderer_kwargs.update(other_kwargs)
for key in renderer_kwargs:
if hasattr(self, key):
setattr(self, key, renderer_kwargs[key])
self.C = camera_ray
self.I = input_encoding
def split_feat(self, x, img_channels, white_color=None, split_rgb=True):
img = x[:, :img_channels]
if split_rgb:
x = x[:, img_channels:]
if (white_color is not None) and self.white_background:
img = img + white_color
return x, img
def get_bound(self):
if self.bound is not None:
return self.bound
# when applying normalization, the points are restricted inside R=2 ball
if self.C.depth_transform == 'InverseWarp':
bound = 2
else: # TODO: this is a bit hacky as we assume object at origin
bound = (self.C.depth_range[1] - self.C.depth_range[0])
return bound
def get_density(self, sigma_raw, fg_nerf, no_noise=False, training=False):
if fg_nerf.use_sdf:
sigma = fg_nerf.density_transform.density_func(sigma_raw)
elif self.sigma_type == 'relu':
if training and (not no_noise): # adding noise to pass gradient?
sigma_raw = sigma_raw + torch.randn_like(sigma_raw)
sigma = F.relu(sigma_raw)
elif self.sigma_type == 'shiftedsoftplus': # https://arxiv.org/pdf/2111.11215.pdf
sigma = F.softplus(sigma_raw - 1) # 1 is the shifted bias.
elif self.sigma_type == 'exp_truncated': # density in the log-space
sigma = torch.exp(5 - F.relu(5 - (sigma_raw - 1))) # up-bound = 5, also shifted by 1
else:
sigma = sigma_raw
return sigma
def forward_hierarchical_sampling(self, di, weights, n_steps, det=False):
di_mid = 0.5 * (di[..., :-1] + di[..., 1:])
n_bins = di_mid.size(-1)
batch_size = di.size(0)
di_fine = sample_pdf(
di_mid.reshape(-1, n_bins),
weights.reshape(-1, n_bins+1)[:, 1:-1],
n_steps, det=det).reshape(batch_size, -1, n_steps)
return di_fine
def forward_rendering_with_pre_density(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles):
pixels_world, camera_world, ray_vector = nerf_input_cams
z_shape_obj, z_app_obj = latent_codes[:2]
height, width = dividable(H.n_points)
fg_shape = [H.batch_size, height, width, H.n_steps]
bound = self.get_bound()
# sample points
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device)
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points)
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False):
di = self.C.add_noise_to_interval(di)
di_trs = self.C.get_transformed_depth(di)
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_trs)
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound)
pre_sigma_raw, p_i = p_i[...,:self.I.sigma_dim].sum(dim=-1, keepdim=True), p_i[..., self.I.sigma_dim:]
pre_sigma = self.get_density(rearrange(pre_sigma_raw, 'b (n s) () -> b n s', s=H.n_steps),
fg_nerf, training=H.training)
pre_weights = self.C.calc_volume_weights(
pre_sigma, di if self.C.dists_normalized else di_trs, ray_vector, last_dist=1e10)[0]
feat, _ = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape)
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_steps)
feat = torch.sum(pre_weights.unsqueeze(-1) * feat, dim=-2)
output.feat += [feat]
output.fg_weights = pre_weights
output.fg_depths = (di, di_trs)
return output
def forward_sampling(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles):
# TODO: experimental research code. Not functional yet.
pixels_world, camera_world, ray_vector = nerf_input_cams
z_shape_obj, z_app_obj = latent_codes[:2]
height, width = dividable(H.n_points)
bound = self.get_bound()
# just to simulate
H.n_steps = 64
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device)
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points)
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False):
di = self.C.add_noise_to_interval(di)
di_trs = self.C.get_transformed_depth(di)
fg_shape = [H.batch_size, height, width, 1]
# iteration in the loop (?)
feats, sigmas = [], []
with torch.enable_grad():
di_trs.requires_grad_(True)
for s in range(di_trs.shape[-1]):
di_s = di_trs[..., s:s+1]
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_s)
if nerf_input_feats is not None:
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound)
feat, sigma_raw = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape, requires_grad=True)
sigma = self.get_density(sigma_raw, fg_nerf, training=H.training)
feats += [feat]
sigmas += [sigma]
feat, sigma = torch.stack(feats, 2), torch.cat(sigmas, 2)
fg_weights, bg_lambda = self.C.calc_volume_weights(
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2]
fg_feat = torch.sum(fg_weights.unsqueeze(-1) * feat, dim=-2)
output.feat += [fg_feat]
output.full_out += [feat]
output.fg_weights = fg_weights
output.bg_lambda = bg_lambda
output.fg_depths = (di, di_trs)
return output
def forward_rendering(self, H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles):
pixels_world, camera_world, ray_vector = nerf_input_cams
z_shape_obj, z_app_obj = latent_codes[:2]
height, width = dividable(H.n_points)
fg_shape = [H.batch_size, height, width, H.n_steps]
bound = self.get_bound()
# sample points
di = torch.linspace(0., 1., steps=H.n_steps).to(H.device)
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points)
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False):
di = self.C.add_noise_to_interval(di)
di_trs = self.C.get_transformed_depth(di)
p_i, r_i = self.C.get_evaluation_points(pixels_world, camera_world, di_trs)
if nerf_input_feats is not None:
p_i = self.I.query_input_features(p_i, nerf_input_feats, fg_shape, bound)
feat, sigma_raw = fg_nerf(p_i, r_i, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape)
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_steps)
sigma_raw = rearrange(sigma_raw.squeeze(-1), 'b (n s) -> b n s', s=H.n_steps)
sigma = self.get_density(sigma_raw, fg_nerf, training=H.training)
fg_weights, bg_lambda = self.C.calc_volume_weights(
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2]
if self.hierarchical and (not H.get('disable_hierarchical', False)):
with torch.no_grad():
di_fine = self.forward_hierarchical_sampling(di, fg_weights, H.n_steps, det=(not H.training))
di_trs_fine = self.C.get_transformed_depth(di_fine)
p_f, r_f = self.C.get_evaluation_points(pixels_world, camera_world, di_trs_fine)
if nerf_input_feats is not None:
p_f = self.I.query_input_features(p_f, nerf_input_feats, fg_shape, bound)
feat_f, sigma_raw_f = fg_nerf(p_f, r_f, z_shape_obj, z_app_obj, ws=styles, shape=fg_shape)
feat_f = rearrange(feat_f, 'b (n s) d -> b n s d', s=H.n_steps)
sigma_raw_f = rearrange(sigma_raw_f.squeeze(-1), 'b (n s) -> b n s', s=H.n_steps)
sigma_f = self.get_density(sigma_raw_f, fg_nerf, training=H.training)
feat = torch.cat([feat_f, feat], 2)
sigma = torch.cat([sigma_f, sigma], 2)
sigma_raw = torch.cat([sigma_raw_f, sigma_raw], 2)
di = torch.cat([di_fine, di], 2)
di_trs = torch.cat([di_trs_fine, di_trs], 2)
di, indices = torch.sort(di, dim=2)
di_trs = torch.gather(di_trs, 2, indices)
sigma = torch.gather(sigma, 2, indices)
sigma_raw = torch.gather(sigma_raw, 2, indices)
feat = torch.gather(feat, 2, repeat(indices, 'b n s -> b n s d', d=feat.size(-1)))
fg_weights, bg_lambda = self.C.calc_volume_weights(
sigma, di if self.C.dists_normalized else di_trs, # use real dists for computing weights,
ray_vector, last_dist=0 if not H.fg_inf_depth else 1e10)[:2]
fg_feat = torch.sum(fg_weights.unsqueeze(-1) * feat, dim=-2)
output.feat += [fg_feat]
output.full_out += [feat]
output.fg_weights = fg_weights
output.bg_lambda = bg_lambda
output.fg_depths = (di, di_trs)
return output
def forward_rendering_background(self, H, output, bg_nerf, nerf_input_cams, latent_codes, styles_bg):
pixels_world, camera_world, _ = nerf_input_cams
z_shape_bg, z_app_bg = latent_codes[2:]
height, width = dividable(H.n_points)
bg_shape = [H.batch_size, height, width, H.n_bg_steps]
if H.fixed_input_cams is not None:
pixels_world, camera_world, _ = H.fixed_input_cams
# render background, use NeRF++ inverse sphere parameterization
di = torch.linspace(-1., 0., steps=H.n_bg_steps).to(H.device)
di = repeat(di, 's -> b n s', b=H.batch_size, n=H.n_points) * self.C.bg_start
if (H.training and (not H.get('disable_noise', False))) or H.get('force_noise', False):
di = self.C.add_noise_to_interval(di)
p_bg, r_bg = self.C.get_evaluation_points_bg(pixels_world, camera_world, -di)
feat, sigma_raw = bg_nerf(p_bg, r_bg, z_shape_bg, z_app_bg, ws=styles_bg, shape=bg_shape)
feat = rearrange(feat, 'b (n s) d -> b n s d', s=H.n_bg_steps)
sigma_raw = rearrange(sigma_raw.squeeze(-1), 'b (n s) -> b n s', s=H.n_bg_steps)
sigma = self.get_density(sigma_raw, bg_nerf, training=H.training)
bg_weights = self.C.calc_volume_weights(sigma, di, None)[0]
bg_feat = torch.sum(bg_weights.unsqueeze(-1) * feat, dim=-2)
if output.get('bg_lambda', None) is not None:
bg_feat = output.bg_lambda.unsqueeze(-1) * bg_feat
output.feat += [bg_feat]
output.full_out += [feat]
output.bg_weights = bg_weights
output.bg_depths = di
return output
def forward_volume_rendering(
self,
nerf_modules, # (fg_nerf, bg_nerf)
camera_matrices, # camera (K, RT)
vol_pixels,
nerf_input_feats = None,
latent_codes = None,
styles = None,
styles_bg = None,
not_render_background = False,
only_render_background = False,
render_option = None,
return_full = False,
alpha = 0,
**unused):
assert (latent_codes is not None) or (styles is not None)
assert self.no_background or (nerf_input_feats is None), "input features do not support background field"
# hyper-parameters for rendering
H = EasyDict(**unused)
output = EasyDict()
output.reg_loss = EasyDict()
output.feat = []
output.full_out = []
if render_option is None:
render_option = ""
H.render_option = render_option
H.alpha = alpha
# prepare for rendering (parameters)
fg_nerf, bg_nerf = nerf_modules
H.training = fg_nerf.training
H.device = camera_matrices[0].device
H.batch_size = camera_matrices[0].shape[0]
H.img_channels = fg_nerf.img_channels
H.n_steps = self.n_ray_samples
H.n_bg_steps = self.n_bg_samples
if alpha == -1:
H.n_steps = 20 # just for memory safe.
if "steps" in render_option:
H.n_steps = [int(r.split(':')[1]) for r in H.render_option.split(',') if r[:5] == 'steps'][0]
# prepare for pixels for generating images
if isinstance(vol_pixels, tuple):
vol_pixels, rand_pixels = vol_pixels
pixels = torch.cat([vol_pixels, rand_pixels], 1)
H.rnd_res = int(math.sqrt(rand_pixels.size(1)))
else:
pixels, rand_pixels, H.rnd_res = vol_pixels, None, None
H.tgt_res, H.n_points = int(math.sqrt(vol_pixels.size(1))), pixels.size(1)
nerf_input_cams = self.C.get_origin_direction(pixels, camera_matrices)
# set up an frozen camera for background if necessary
if ('freeze_bg' in H.render_option) and (bg_nerf is not None):
pitch, yaw = 0.2 + np.pi/2, 0
range_u, range_v = self.C.range_u, self.C.range_v
u = (yaw - range_u[0]) / (range_u[1] - range_u[0])
v = (pitch - range_v[0]) / (range_v[1] - range_v[0])
fixed_camera = self.C.get_camera(
batch_size=H.batch_size, mode=[u, v, 0.5], device=H.device)
H.fixed_input_cams = self.C.get_origin_direction(pixels, fixed_camera)
else:
H.fixed_input_cams = None
H.fg_inf_depth = (self.no_background or not_render_background) and (not self.white_background)
assert(not (not_render_background and only_render_background))
# volume rendering options: bg_weights, bg_lambda = None, None
if (nerf_input_feats is not None) and \
len(nerf_input_feats) == 4 and \
nerf_input_feats[2] == 'tri_vector' and \
self.I.sigma_dim > 0 and H.fg_inf_depth:
# volume rendering with pre-computed density similar to tensor-decomposition
output = self.forward_rendering_with_pre_density(
H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles)
else:
# standard volume rendering
if not only_render_background:
output = self.forward_rendering(
H, output, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles)
# background rendering (NeRF++)
if (not not_render_background) and (not self.no_background):
output = self.forward_rendering_background(
H, output, bg_nerf, nerf_input_cams, latent_codes, styles_bg)
if ('early' in render_option) and ('value' not in render_option):
return self.gen_optional_output(
H, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles, output)
# ------------------------------------------- PREPARE FULL OUTPUT (NO 2D aggregation) -------------------------------------------- #
vol_len = vol_pixels.size(1)
feat_map = sum(output.feat)
full_x = rearrange(feat_map[:, :vol_len], 'b (h w) d -> b d h w', h=H.tgt_res)
split_rgb = fg_nerf.add_rgb or fg_nerf.predict_rgb
full_out = self.split_feat(full_x, H.img_channels, None, split_rgb=split_rgb)
if rand_pixels is not None: # used in full supervision (debug later)
if return_full:
assert (fg_nerf.predict_rgb or fg_nerf.add_rgb)
rand_outputs = [f[:,vol_pixels.size(1):] for f in output.full_out]
full_weights = torch.cat([output.fg_weights, output.bg_weights * output.bg_lambda.unsqueeze(-1)], -1) \
if output.get('bg_weights', None) is not None else output.fg_weights
full_weights = full_weights[:,vol_pixels.size(1):]
full_weights = rearrange(full_weights, 'b (h w) s -> b s h w', h=H.rnd_res, w=H.rnd_res)
lh, lw = dividable(full_weights.size(1))
full_x = rearrange(torch.cat(rand_outputs, 2), 'b (h w) (l m) d -> b d (l h) (m w)',
h=H.rnd_res, w=H.rnd_res, l=lh, m=lw)
full_x, full_img = self.split_feat(full_x, H.img_channels, split_rgb=split_rgb)
output.rand_out = (full_x, full_img, full_weights)
else:
rand_x = rearrange(feat_map[:, vol_len:], 'b (h w) d -> b d h w', h=H.rnd_res)
output.rand_out = self.split_feat(rand_x, H.img_channels, split_rgb=split_rgb)
output.full_out = full_out
return output
def post_process_outputs(self, outputs, freeze_nerf=False):
if freeze_nerf:
outputs = [x.detach() if isinstance(x, torch.Tensor) else x for x in outputs]
x, img = outputs[0], outputs[1]
probs = outputs[2] if len(outputs) == 3 else None
return x, img, probs
def gen_optional_output(self, H, fg_nerf, nerf_input_cams, nerf_input_feats, latent_codes, styles, output):
_, camera_world, ray_vector = nerf_input_cams
z_shape_obj, z_app_obj = latent_codes[:2]
fg_depth_map = torch.sum(output.fg_weights * output.fg_depths[1], dim=-1, keepdim=True)
img = camera_world[:, :1] + fg_depth_map * ray_vector
img = img.permute(0,2,1).reshape(-1, 3, H.tgt_res, H.tgt_res)
if 'input_feats' in H.render_option:
a, b = [r.split(':')[1:] for r in H.render_option.split(',') if r.startswith('input_feats')][0]
a, b = int(a), int(b)
if nerf_input_feats[0] == 'volume':
img = nerf_input_feats[1][:,a:a+3,b,:,:]
elif nerf_input_feats[0] == 'tri_plane':
img = nerf_input_feats[1][:,b,a:a+3,:,:]
elif nerf_input_feats[0] == 'hash_table':
assert self.I.hash_mode == 'grid_hash'
img = nerf_input_feats[1][:,self.I.offsets[b]:self.I.offsets[b+1], :]
siz = int(np.ceil(img.size(1)**(1/3)))
img = rearrange(img, 'b (d h w) c -> b (d c) h w', h=siz, w=siz, d=siz)
img = img[:, a:a+3]
else:
raise NotImplementedError
if 'normal' in H.render_option.split(','):
shift_l, shift_r = img[:,:,2:,:], img[:,:,:-2,:]
shift_u, shift_d = img[:,:,:,2:], img[:,:,:,:-2]
diff_hor = normalize(shift_r - shift_l, axis=1)[0][:, :, :, 1:-1]
diff_ver = normalize(shift_u - shift_d, axis=1)[0][:, :, 1:-1, :]
normal = torch.cross(diff_hor, diff_ver, dim=1)
img = normalize(normal, axis=1)[0]
if 'gradient' in H.render_option.split(','):
points, _ = self.C.get_evaluation_points(camera_world + ray_vector, camera_world, output.fg_depths[1])
fg_shape = [H.batch_size, H.tgt_res, H.tgt_res, output.fg_depths[1].size(-1)]
with torch.enable_grad():
points.requires_grad_(True)
inputs = self.I.query_input_features(points, nerf_input_feats, fg_shape, self.get_bound(), True) \
if nerf_input_feats is not None else points
if (nerf_input_feats is not None) and len(nerf_input_feats) == 4 and nerf_input_feats[2] == 'tri_vector' and (self.I.sigma_dim > 0):
sigma_out = inputs[..., :8].sum(dim=-1, keepdim=True)
else:
_, sigma_out = fg_nerf(inputs, None, ws=styles, shape=fg_shape, z_shape=z_shape_obj, z_app=z_app_obj, requires_grad=True)
gradients, = grad(
outputs=sigma_out, inputs=points,
grad_outputs=torch.ones_like(sigma_out, requires_grad=False),
retain_graph=True, create_graph=True, only_inputs=True)
gradients = rearrange(gradients, 'b (n s) d -> b n s d', s=output.fg_depths[1].size(-1))
avg_grads = (gradients * output.fg_weights.unsqueeze(-1)).sum(-2)
avg_grads = F.normalize(avg_grads, p=2, dim=-1)
normal = rearrange(avg_grads, 'b (h w) s -> b s h w', h=H.tgt_res, w=H.tgt_res)
img = -normal
return {'full_out': (None, img)}
@persistence.persistent_class
class Upsampler(object):
no_2d_renderer = False
no_residual_img = False
block_reses = None
shared_rgb_style = False
upsample_type = 'default'
img_channels = 3
in_res = 32
out_res = 512
channel_base = 1
channel_base_sz = None
channel_max = 512
channel_dict = None
out_channel_dict = None
def __init__(self, upsampler_kwargs, **other_kwargs):
# for compitatbility of old checkpoints
for key in other_kwargs:
if hasattr(self, key) and (key not in upsampler_kwargs):
upsampler_kwargs[key] = other_kwargs[key]
for key in upsampler_kwargs:
if hasattr(self, key):
setattr(self, key, upsampler_kwargs[key])
self.out_res_log2 = int(np.log2(self.out_res))
# set up upsamplers
if self.block_reses is None:
self.block_resolutions = [2 ** i for i in range(2, self.out_res_log2 + 1)]
self.block_resolutions = [b for b in self.block_resolutions if b > self.in_res]
else:
self.block_resolutions = self.block_reses
if self.no_2d_renderer:
self.block_resolutions = []
def build_network(self, w_dim, input_dim, **block_kwargs):
upsamplers = []
if len(self.block_resolutions) > 0: # nerf resolution smaller than image
channel_base = int(self.channel_base * 32768) if self.channel_base_sz is None else self.channel_base_sz
fp16_resolution = self.block_resolutions[0] * 2 # do not use fp16 for the first block
if self.channel_dict is None:
channels_dict = {res: min(channel_base // res, self.channel_max) for res in self.block_resolutions}
else:
channels_dict = self.channel_dict
if self.out_channel_dict is not None:
img_channels = self.out_channel_dict
else:
img_channels = {res: self.img_channels for res in self.block_resolutions}
for ir, res in enumerate(self.block_resolutions):
res_before = self.block_resolutions[ir-1] if ir > 0 else self.in_res
in_channels = channels_dict[res_before] if ir > 0 else input_dim
out_channels = channels_dict[res]
use_fp16 = (res >= fp16_resolution) # TRY False
is_last = (ir == (len(self.block_resolutions) - 1))
no_upsample = (res == res_before)
block = util.construct_class_by_name(
class_name=block_kwargs.get('block_name', "training.networks.SynthesisBlock"),
in_channels=in_channels,
out_channels=out_channels,
w_dim=w_dim,
resolution=res,
img_channels=img_channels[res],
is_last=is_last,
use_fp16=use_fp16,
disable_upsample=no_upsample,
block_id=ir,
**block_kwargs)
upsamplers += [{
'block': block,
'num_ws': block.num_conv if not is_last else block.num_conv + block.num_torgb,
'name': f'b{res}' if res_before != res else f'b{res}_l{ir}'
}]
self.num_ws = sum([u['num_ws'] for u in upsamplers])
return upsamplers
def forward_ws_split(self, ws, blocks):
block_ws, w_idx = [], 0
for ir, res in enumerate(self.block_resolutions):
block = blocks[ir]
if self.shared_rgb_style:
w = ws.narrow(1, w_idx, block.num_conv)
w_img = ws.narrow(1, -block.num_torgb, block.num_torgb) # TODO: tRGB to use the same style (?)
block_ws.append(torch.cat([w, w_img], 1))
else:
block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
w_idx += block.num_conv
return block_ws
def forward_network(self, blocks, block_ws, x, img, target_res, alpha, skip_up=False, **block_kwargs):
imgs = []
for index_l, (res, cur_ws) in enumerate(zip(self.block_resolutions, block_ws)):
if res > target_res:
break
block = blocks[index_l]
block_noise = block_kwargs['voxel_noise'][index_l] if "voxel_noise" in block_kwargs else None
x, img = block(
x,
img if not self.no_residual_img else None,
cur_ws,
block_noise=block_noise,
skip_up=skip_up,
**block_kwargs)
imgs += [img]
return imgs
@persistence.persistent_class
class NeRFInput(Upsampler):
""" Instead of positional encoding, it learns additional features for each points.
However, it is important to normalize the input points
"""
output_mode = 'none'
input_mode = 'random' # coordinates
architecture = 'skip'
# only useful for triplane/volume inputs
in_res = 4
out_res = 256
out_dim = 32
sigma_dim = 8
split_size = 64
# only useful for hashtable inputs
hash_n_min = 16
hash_n_max = 512
hash_size = 16
hash_level = 16
hash_dim_in = 32
hash_dim_mid = None
hash_dim_out = 2
hash_n_layer = 4
hash_mode = 'fast_hash' # grid_hash (like volumes)
keep_posenc = -1
keep_nerf_latents = False
def build_network(self, w_dim, **block_kwargs):
# change global settings for input field.
kwargs_copy = copy.deepcopy(block_kwargs)
kwargs_copy['kernel_size'] = 3
kwargs_copy['upsample_mode'] = 'default'
kwargs_copy['use_noise'] = True
kwargs_copy['architecture'] = self.architecture
self._flag = 0
assert self.input_mode == 'random', \
"currently only support normal StyleGAN2. in the future we may work on other inputs."
# plane-based inputs with modulated 2D convolutions
if self.output_mode == 'tri_plane_reshape':
self.img_channels, in_channels, const = 3 * self.out_dim, 0, None
elif self.output_mode == 'tri_plane_product': #TODO: sigma_dim is for density
self.img_channels, in_channels = 3 * (self.out_dim + self.sigma_dim), 0
const = torch.nn.Parameter(0.1 * torch.randn([self.img_channels, self.out_res]))
elif self.output_mode == 'multi_planes':
self.img_channels, in_channels, const = self.out_dim * self.split_size, 0, None
kwargs_copy['architecture'] = 'orig'
# volume-based inputs with modulated 3D convolutions
elif self.output_mode == '3d_volume': # use 3D convolution to generate
kwargs_copy['architecture'] = 'orig'
kwargs_copy['mode'] = '3d'
self.img_channels, in_channels, const = self.out_dim, 0, None
elif self.output_mode == 'ms_volume': # multi-resolution voulume, between hashtable and volumes
kwargs_copy['architecture'] = 'orig'
kwargs_copy['mode'] = '3d'
self.img_channels, in_channels, const = self.out_dim, 0, None
# embedding-based inputs with modulated MLPs
elif self.output_mode == 'hash_table':
if self.hash_mode == 'grid_hash':
assert self.hash_size % 3 == 0, "needs to be 3D"
kwargs_copy['hash_size'], self._flag = 2 ** self.hash_size, 1
assert self.hash_dim_out * self.hash_level == self.out_dim, "size must matched"
return self.build_modulated_embedding(w_dim, **kwargs_copy)
elif self.output_mode == 'ms_nerf_hash':
self.hash_mode, self._flag = 'grid_hash', 2
ms_nerf = NeRFBlock({
'rgb_out_dim': self.hash_dim_out * self.hash_level, # HACK
'magnitude_ema_beta': block_kwargs['magnitude_ema_beta'],
'no_sigma': True, 'predict_rgb': False, 'add_rgb': False,
'n_freq_posenc': 5,
})
self.num_ws = ms_nerf.num_ws
return [{'block': ms_nerf, 'num_ws': ms_nerf.num_ws, 'name': 'ms_nerf'}]
else:
raise NotImplementedError
networks = super().build_network(w_dim, in_channels, **kwargs_copy)
if const is not None:
networks.append({'block': const, 'num_ws': 0, 'name': 'const'})
return networks
def forward_ws_split(self, ws, blocks):
if self._flag == 1:
return ws.split(1, dim=1)[:len(blocks)-1]
elif self._flag == 0:
return super().forward_ws_split(ws, blocks)
else:
return ws # do not split
def forward_network(self, blocks, block_ws, batch_size, **block_kwargs):
x, img, out = None, None, None
def _forward_conv_networks(x, img, blocks, block_ws):
for index_l, (res, cur_ws) in enumerate(zip(self.block_resolutions, block_ws)):
x, img = blocks[index_l](x, img, cur_ws, **block_kwargs)
return img
def _forward_ffn_networks(x, blocks, block_ws):
#TODO: FFN is implemented as 1x1 conv for now #
h, w = dividable(x.size(0))
x = repeat(x, 'n d -> b n d', b=batch_size)
x = rearrange(x, 'b (h w) d -> b d h w', h=h, w=w)
for index_l, cur_ws in enumerate(block_ws):
block, cur_ws = blocks[index_l], cur_ws[:, 0]
x = block(x, cur_ws)
return x
# tri-plane outputs
if 'tri_plane' in self.output_mode:
img = _forward_conv_networks(x, img, blocks, block_ws)
if self.output_mode == 'tri_plane_reshape':
out = ('tri_plane', rearrange(img, 'b (s c) h w -> b s c h w', s=3))
elif self.output_mode == 'tri_plane_product':
out = ('tri_plane', rearrange(img, 'b (s c) h w -> b s c h w', s=3),
'tri_vector', repeat(rearrange(blocks[-1], '(s c) d -> s c d', s=3), 's c d -> b s c d', b=img.size(0)))
else:
raise NotImplementedError("remove support for other types of tri-plane implementation.")
# volume/3d voxel outputs
elif self.output_mode == 'multi_planes':
img = _forward_conv_networks(x, img, blocks, block_ws)
out = ('volume', rearrange(img, 'b (s c) h w -> b s c h w', s=self.out_dim))
elif self.output_mode == '3d_volume':
img = _forward_conv_networks(x, img, blocks, block_ws)
out = ('volume', img)
# multi-resolution 3d volume outputs (similar to hash-table)
elif self.output_mode == 'ms_volume':
img = _forward_conv_networks(x, img, blocks, block_ws)
out = ('ms_volume', rearrange(img, 'b (l m) d h w -> b l m d h w', l=self.hash_level))
# hash-table outputs (need hash sample implemented #TODO#
elif self.output_mode == 'hash_table':
x, blocks = blocks[-1], blocks[:-1]
if len(blocks) > 0:
x = _forward_ffn_networks(x, blocks, block_ws)
out = ('hash_table', rearrange(x, 'b d h w -> b (h w) d'))
else:
out = ('hash_table', repeat(x, 'n d -> b n d', b=batch_size))
elif self.output_mode == 'ms_nerf_hash':
# prepare inputs for nerf
x = torch.linspace(-1, 1, steps=self.out_res, device=block_ws.device)
x = torch.stack(torch.meshgrid(x,x,x), -1).reshape(-1, 3)
x = repeat(x, 'n s -> b n s', b=block_ws.size(0))
x = blocks[0](x, None, ws=block_ws, shape=[block_ws.size(0), 32, 32, 32])[0]
x = rearrange(x, 'b (d h w) (l m) -> b l m d h w', l=self.hash_level, d=32, h=32, w=32)
out = ('ms_volume', x)
else:
raise NotImplementedError
return out
def query_input_features(self, p_i, input_feats, p_shape, bound, grad_inputs=False):
batch_size, height, width, n_steps = p_shape
p_i = p_i / bound
if input_feats[0] == 'tri_plane':
# TODO!! Our world space, x->depth, y->width, z->height
lh, lw = dividable(n_steps)
p_ds = rearrange(p_i, 'b (h w l m) d -> b (l h) (m w) d',
b=batch_size, h=height, w=width, l=lh, m=lw).split(1, dim=-1)
px, py, pz = p_ds[0], p_ds[1], p_ds[2]
# project points onto three planes
p_xy = torch.cat([px, py], -1)
p_xz = torch.cat([px, pz], -1)
p_yz = torch.cat([py, pz], -1)
p_gs = torch.cat([p_xy, p_xz, p_yz], 0)
f_in = torch.cat([input_feats[1][:, i] for i in range(3)], 0)
p_f = grid_sample(f_in, p_gs) # gradient-fix bilinear interpolation
p_f = [p_f[i * batch_size: (i+1) * batch_size] for i in range(3)]
# project points to three vectors (optional)
if len(input_feats) == 4 and input_feats[2] == 'tri_vector':
# TODO: PyTorch did not support grid_sample for 1D data. Maybe need custom code.
p_gs_vec = torch.cat([pz, py, px], 0)
f_in_vec = torch.cat([input_feats[3][:, i] for i in range(3)], 0)
p_f_vec = grid_sample(f_in_vec.unsqueeze(-1), torch.cat([torch.zeros_like(p_gs_vec), p_gs_vec], -1))
p_f_vec = [p_f_vec[i * batch_size: (i+1) * batch_size] for i in range(3)]
# multiply on the triplane features
p_f = [m * v for m, v in zip(p_f, p_f_vec)]
p_f = sum(p_f)
p_f = rearrange(p_f, 'b d (l h) (m w) -> b (h w l m) d', l=lh, m=lw)
elif input_feats[0] == 'volume':
# TODO!! Our world space, x->depth, y->width, z->height
# (width-c, height-c, depth-c), volume (B x N x D x H x W)
p_ds = rearrange(p_i, 'b (h w s) d -> b s h w d',
b=batch_size, h=height, w=width, s=n_steps).split(1, dim=-1)
px, py, pz = p_ds[0], p_ds[1], p_ds[2]
p_yzx = torch.cat([py, -pz, px], -1)
p_f = F.grid_sample(input_feats[1], p_yzx, mode='bilinear', align_corners=False)
p_f = rearrange(p_f, 'b c s h w -> b (h w s) c')
elif input_feats[0] == 'ms_volume':
# TODO!! Multi-resolution volumes (experimental)
# for smoothness, maybe we should expand the volume? (TODO)
# print(p_i.shape)
ms_v = input_feats[1].new_zeros(
batch_size, self.hash_level, self.hash_dim_out, self.out_res+1, self.out_res+1, self.out_res+1)
ms_v[..., 1:, 1:, 1:] = input_feats[1].flip([3,4,5])
ms_v[..., :self.out_res, :self.out_res, :self.out_res] = input_feats[1]
v_size = ms_v.size(-1)
# multi-resolutions
b = math.exp((math.log(self.hash_n_max) - math.log(self.hash_n_min))/(self.hash_level-1))
hash_res_ls = [round(self.hash_n_min * b ** l) for l in range(self.hash_level)]
# prepare interpolate grids
p_ds = rearrange(p_i, 'b (h w s) d -> b s h w d',
b=batch_size, h=height, w=width, s=n_steps).split(1, dim=-1)
px, py, pz = p_ds[0], p_ds[1], p_ds[2]
p_yzx = torch.cat([py, -pz, px], -1)
p_yzx = ((p_yzx + 1) / 2).clamp(min=0, max=1) # normalize to 0~1 (just for safe)
p_yzx = torch.stack([p_yzx if n < v_size else torch.fmod(p_yzx * n, v_size) / v_size for n in hash_res_ls], 1)
p_yzx = (p_yzx * 2 - 1).view(-1, n_steps, height, width, 3)
ms_v = ms_v.view(-1, self.hash_dim_out, v_size, v_size, v_size) # back to -1~1
p_f = F.grid_sample(ms_v, p_yzx, mode='bilinear', align_corners=False)
p_f = rearrange(p_f, '(b l) c s h w -> b (h w s) (l c)', l=self.hash_level)
elif input_feats[0] == 'hash_table':
# TODO:!! Experimental code trying to learn hashtable used in (maybe buggy)
# https://nvlabs.github.io/instant-ngp/assets/mueller2022instant.pdf
p_xyz = ((p_i + 1) / 2).clamp(min=0, max=1) # normalize to 0~1
p_f = hash_sample(
p_xyz, input_feats[1], self.offsets.to(p_xyz.device),
self.beta, self.hash_n_min, grad_inputs, mode=self.hash_mode)
else:
raise NotImplementedError
if self.keep_posenc > -1:
if self.keep_posenc > 0:
p_f = torch.cat([p_f, positional_encoding(p_i, self.keep_posenc, use_pos=True)], -1)
else:
p_f = torch.cat([p_f, p_i], -1)
return p_f
def build_hashtable_info(self, hash_size):
self.beta = math.exp((math.log(self.hash_n_max) - math.log(self.hash_n_min)) / (self.hash_level-1))
self.hash_res_ls = [round(self.hash_n_min * self.beta ** l) for l in range(self.hash_level)]
offsets, offset = [], 0
for i in range(self.hash_level):
resolution = self.hash_res_ls[i]
params_in_level = min(hash_size, (resolution + 1) ** 3)
offsets.append(offset)
offset += params_in_level
offsets.append(offset)
self.offsets = torch.from_numpy(np.array(offsets, dtype=np.int32))
return offset
def build_modulated_embedding(self, w_dim, hash_size, **block_kwargs):
# allocate parameters
offset = self.build_hashtable_info(hash_size)
hash_const = torch.nn.Parameter(torch.zeros(
[offset, self.hash_dim_in if self.hash_n_layer > -1 else self.hash_dim_out]))
hash_const.data.uniform_(-1e-4, 1e-4)
hash_networks = []
if self.hash_n_layer > -1:
input_dim = self.hash_dim_in
for l in range(self.hash_n_layer):
output_dim = self.hash_dim_mid if self.hash_dim_mid is not None else self.hash_dim_in
hash_networks.append({
'block': Style2Layer(input_dim, output_dim, w_dim),
'num_ws': 1, 'name': f'hmlp{l}'
})
input_dim = output_dim
hash_networks.append({
'block': ToRGBLayer(input_dim, self.hash_dim_out, w_dim, kernel_size=1),
'num_ws': 1, 'name': 'hmlpout'})
hash_networks.append({'block': hash_const, 'num_ws': 0, 'name': 'hash_const'})
self.num_ws = sum([h['num_ws'] for h in hash_networks])
return hash_networks
@persistence.persistent_class
class NeRFSynthesisNetwork(torch.nn.Module):
def __init__(self,
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output image resolution.
img_channels, # Number of color channels.
channel_base = 1,
channel_max = 1024,
# module settings
camera_kwargs = {},
renderer_kwargs = {},
upsampler_kwargs = {},
input_kwargs = {},
foreground_kwargs = {},
background_kwargs = {},
# nerf space settings
z_dim = 256,
z_dim_bg = 128,
rgb_out_dim = 256,
rgb_out_dim_bg = None,
resolution_vol = 32,
resolution_start = None,
progressive = True,
prog_nerf_only = False,
interp_steps = None, # (optional) "start_step:final_step"
# others (regularization)
regularization = [], # nv_beta, nv_vol
predict_camera = False,
camera_condition = None,
n_reg_samples = 0,
reg_full = False,
cam_based_sampler = False,
rectangular = None,
freeze_nerf = False,
**block_kwargs, # Other arguments for SynthesisBlock.
):
assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
super().__init__()
# dimensions
self.w_dim = w_dim
self.z_dim = z_dim
self.z_dim_bg = z_dim_bg
self.num_ws = 0
self.rgb_out_dim = rgb_out_dim
self.rgb_out_dim_bg = rgb_out_dim_bg if rgb_out_dim_bg is not None else rgb_out_dim
self.img_resolution = img_resolution
self.resolution_vol = resolution_vol if resolution_vol < img_resolution else img_resolution
self.resolution_start = resolution_start if resolution_start is not None else resolution_vol
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
# number of samples
self.n_reg_samples = n_reg_samples
self.reg_full = reg_full
self.use_noise = block_kwargs.get('use_noise', False)
# ---------------------------------- Initialize Modules ---------------------------------------- -#
# camera module
self.C = CameraRay(camera_kwargs, **block_kwargs)
# input encoding module
if (len(input_kwargs) > 0) and (input_kwargs['output_mode'] != 'none'): # using synthezied inputs
input_kwargs['channel_base'] = input_kwargs.get('channel_base', channel_base)
input_kwargs['channel_max'] = input_kwargs.get('channel_max', channel_max)
self.I = NeRFInput(input_kwargs, **block_kwargs)
else:
self.I = None
# volume renderer module
self.V = VolumeRenderer(renderer_kwargs, camera_ray=self.C, input_encoding=self.I, **block_kwargs)
# upsampler module
upsampler_kwargs.update(dict(
img_channels=img_channels,
in_res=resolution_vol,
out_res=img_resolution,
channel_max=channel_max,
channel_base=channel_base))
self.U = Upsampler(upsampler_kwargs, **block_kwargs)
# full model resolutions
self.block_resolutions = copy.deepcopy(self.U.block_resolutions)
if self.resolution_start < self.resolution_vol:
r = self.resolution_vol
while r > self.resolution_start:
self.block_resolutions.insert(0, r)
r = r // 2
self.predict_camera = predict_camera
if predict_camera: # encoder side camera predictor (not very useful)
self.camera_generator = CameraGenerator()
self.camera_condition = camera_condition
if self.camera_condition is not None: # style vector modulated by the camera poses (uv)
self.camera_map = MappingNetwork(z_dim=0, c_dim=16, w_dim=self.w_dim, num_ws=None, w_avg_beta=None, num_layers=2)
# ray level choices
self.regularization = regularization
self.margin = block_kwargs.get('margin', 0)
self.activation = block_kwargs.get('activation', 'lrelu')
self.rectangular_crop = rectangular # [384, 512] ??
# nerf (foregournd/background)
foreground_kwargs.update(dict(
z_dim=self.z_dim,
w_dim=w_dim,
rgb_out_dim=self.rgb_out_dim,
activation=self.activation))
# disable positional encoding if input encoding is given
if self.I is not None:
foreground_kwargs.update(dict(
disable_latents=(not self.I.keep_nerf_latents),
input_dim=self.I.out_dim + 3 * (2 * self.I.keep_posenc + 1)
if self.I.keep_posenc > -1 else self.I.out_dim,
positional_encoding='none'))
self.fg_nerf = NeRFBlock(foreground_kwargs)
self.num_ws += self.fg_nerf.num_ws
if not self.V.no_background:
background_kwargs.update(dict(
z_dim=self.z_dim_bg, w_dim=w_dim,
rgb_out_dim=self.rgb_out_dim_bg,
activation=self.activation))
self.bg_nerf = NeRFBlock(background_kwargs)
self.num_ws += self.bg_nerf.num_ws
else:
self.bg_nerf = None
# ---------------------------------- Build Networks ---------------------------------------- -#
# input encoding (optional)
if self.I is not None:
assert self.V.no_background, "does not support background field"
nerf_inputs = self.I.build_network(w_dim, **block_kwargs)
self.input_block_names = ['in_' + i['name'] for i in nerf_inputs]
self.num_ws += sum([i['num_ws'] for i in nerf_inputs])
for i in nerf_inputs:
setattr(self, 'in_' + i['name'], i['block'])
# upsampler
upsamplers = self.U.build_network(w_dim, self.fg_nerf.rgb_out_dim, **block_kwargs)
if len(upsamplers) > 0:
self.block_names = [u['name'] for u in upsamplers]
self.num_ws += sum([u['num_ws'] for u in upsamplers])
for u in upsamplers:
setattr(self, u['name'], u['block'])
# data-sampler
if cam_based_sampler:
self.sampler = (CameraQueriedSampler, {'camera_module': self.C})
# other hyperameters
self.progressive_growing = progressive
self.progressive_nerf_only = prog_nerf_only
assert not (self.progressive_growing and self.progressive_nerf_only)
if prog_nerf_only:
assert (self.n_reg_samples == 0) and (not reg_full), "does not support regularization"
self.register_buffer("alpha", torch.scalar_tensor(-1))
if predict_camera:
self.num_ws += 1 # additional w for camera
self.freeze_nerf = freeze_nerf
self.steps = None
self.interp_steps = [int(a) for a in interp_steps.split(':')] \
if interp_steps is not None else None #TODO two-stage training trick (from EG3d paper, not working so far)
def set_alpha(self, alpha):
if alpha is not None:
self.alpha.fill_(alpha)
def set_steps(self, steps):
if hasattr(self, "steps"):
if self.steps is not None:
self.steps = self.steps * 0 + steps / 1000.0
else:
self.steps = steps / 1000.0
def forward(self, ws, **block_kwargs):
block_ws, imgs, rand_imgs = [], [], []
batch_size = block_kwargs['batch_size'] = ws.size(0)
n_levels, end_l, _, target_res = self.get_current_resolution()
# save ws for potential usage.
block_kwargs['ws_detach'] = ws.detach()
# cameras, background codes
if self.camera_condition is not None:
cam_cond = self.get_camera_samples(batch_size, ws, block_kwargs, gen_cond=True)
if "camera_matrices" not in block_kwargs:
block_kwargs['camera_matrices'] = self.get_camera_samples(batch_size, ws, block_kwargs)
if (self.camera_condition is not None) and (cam_cond is None):
cam_cond = block_kwargs['camera_matrices']
block_kwargs['theta'] = self.C.get_roll(ws, self.training, **block_kwargs)
# get latent codes instead of style vectors (used in GRAF & GIRAFFE)
if "latent_codes" not in block_kwargs:
block_kwargs["latent_codes"] = self.get_latent_codes(batch_size, device=ws.device)
if (self.camera_condition is not None) and (self.camera_condition == 'full'):
cam_cond = normalize_2nd_moment(self.camera_map(None, cam_cond[1].reshape(-1, 16)))
ws = ws * cam_cond[:, None, :]
# generate features for input points (Optional, default not use)
with torch.autograd.profiler.record_function('nerf_input_feats'):
if self.I is not None:
ws = ws.to(torch.float32)
blocks = [getattr(self, name) for name in self.input_block_names]
block_ws = self.I.forward_ws_split(ws, blocks)
nerf_input_feats = self.I.forward_network(blocks, block_ws, **block_kwargs)
ws = ws[:, self.I.num_ws:]
else:
nerf_input_feats = None
# prepare for NeRF part
with torch.autograd.profiler.record_function('prepare_nerf_path'):
if self.progressive_nerf_only and (self.alpha > -1):
cur_resolution = int(self.resolution_start * (1 - self.alpha) + self.resolution_vol * self.alpha)
elif (end_l == 0) or len(self.block_resolutions) == 0:
cur_resolution = self.resolution_start
else:
cur_resolution = self.block_resolutions[end_l-1]
vol_resolution = self.resolution_vol if self.resolution_vol < cur_resolution else cur_resolution
nerf_resolution = vol_resolution
if (self.interp_steps is not None) and (self.steps is not None) and (self.alpha > 0): # interpolation trick (maybe work??)
if self.steps < self.interp_steps[0]:
nerf_resolution = vol_resolution // 2
elif self.steps < self.interp_steps[1]:
nerf_resolution = (self.steps - self.interp_steps[0]) / (self.interp_steps[1] - self.interp_steps[0])
nerf_resolution = int(nerf_resolution * (vol_resolution / 2) + vol_resolution / 2)
vol_pixels, tgt_pixels = self.C.prepare_pixels(self.img_resolution, cur_resolution, nerf_resolution, **block_kwargs)
if (end_l > 0) and (self.n_reg_samples > 0) and self.training:
rand_pixels, rand_indexs = self.C.prepare_pixels_regularization(tgt_pixels, self.n_reg_samples)
else:
rand_pixels, rand_indexs = None, None
if self.fg_nerf.num_ws > 0: # use style vector instead of latent codes?
block_kwargs["styles"] = ws[:, :self.fg_nerf.num_ws]
ws = ws[:, self.fg_nerf.num_ws:]
if (self.bg_nerf is not None) and self.bg_nerf.num_ws > 0:
block_kwargs["styles_bg"] = ws[:, :self.bg_nerf.num_ws]
ws = ws[:, self.bg_nerf.num_ws:]
# volume rendering
with torch.autograd.profiler.record_function('nerf'):
if (rand_pixels is not None) and self.training:
vol_pixels = (vol_pixels, rand_pixels)
outputs = self.V.forward_volume_rendering(
nerf_modules=(self.fg_nerf, self.bg_nerf),
vol_pixels=vol_pixels,
nerf_input_feats=nerf_input_feats,
return_full=self.reg_full,
alpha=self.alpha,
**block_kwargs)
reg_loss = outputs.get('reg_loss', {})
x, img, _ = self.V.post_process_outputs(outputs['full_out'], self.freeze_nerf)
if nerf_resolution < vol_resolution:
x = F.interpolate(x, vol_resolution, mode='bilinear', align_corners=False)
img = F.interpolate(img, vol_resolution, mode='bilinear', align_corners=False)
# early output from the network (used for visualization)
if 'meshes' in block_kwargs:
from dnnlib.geometry import render_mesh
block_kwargs['voxel_noise'] = render_mesh(block_kwargs['meshes'], block_kwargs["camera_matrices"])
if (len(self.U.block_resolutions) == 0) or \
(x is None) or \
(block_kwargs.get("render_option", None) is not None and
'early' in block_kwargs['render_option']):
if 'value' in block_kwargs['render_option']:
img = x[:,:3]
img = img / img.norm(dim=1, keepdim=True)
assert img is not None, "need to add RGB"
return img
if 'rand_out' in outputs:
x_rand, img_rand, rand_probs = self.V.post_process_outputs(outputs['rand_out'], self.freeze_nerf)
lh, lw = dividable(rand_probs.size(1))
rand_imgs += [img_rand]
# append low-resolution image
if img is not None:
if self.progressive_nerf_only and (img.size(-1) < self.resolution_vol):
x = upsample(x, self.resolution_vol)
img = upsample(img, self.resolution_vol)
block_kwargs['img_nerf'] = img
# Use 2D upsampler
if (cur_resolution > self.resolution_vol) or self.progressive_nerf_only:
imgs += [img]
if (self.camera_condition is not None) and (self.camera_condition != 'full'):
cam_cond = normalize_2nd_moment(self.camera_map(None, cam_cond[1].reshape(-1, 16)))
ws = ws * cam_cond[:, None, :]
# 2D feature map upsampling
with torch.autograd.profiler.record_function('upsampling'):
ws = ws.to(torch.float32)
blocks = [getattr(self, name) for name in self.block_names]
block_ws = self.U.forward_ws_split(ws, blocks)
imgs += self.U.forward_network(blocks, block_ws, x, img, target_res, self.alpha, **block_kwargs)
img = imgs[-1]
if len(rand_imgs) > 0: # nerf path regularization
rand_imgs += self.U.forward_network(
blocks, block_ws, x_rand, img_rand, target_res, self.alpha, skip_up=True, **block_kwargs)
img_rand = rand_imgs[-1]
with torch.autograd.profiler.record_function('rgb_interp'):
if (self.alpha > -1) and (not self.progressive_nerf_only) and self.progressive_growing:
if (self.alpha < 1) and (self.alpha > 0):
alpha, _ = math.modf(self.alpha * n_levels)
img_nerf = imgs[-2]
if img_nerf.size(-1) < img.size(-1): # need upsample image
img_nerf = upsample(img_nerf, 2 * img_nerf.size(-1))
img = img_nerf * (1 - alpha) + img * alpha
if len(rand_imgs) > 0:
img_rand = rand_imgs[-2] * (1 - alpha) + img_rand * alpha
with torch.autograd.profiler.record_function('nerf_path_reg_loss'):
if len(rand_imgs) > 0: # and self.training: # random pixel regularization??
assert self.progressive_growing
if self.reg_full: # aggregate RGB in the end.
lh, lw = img_rand.size(2) // self.n_reg_samples, img_rand.size(3) // self.n_reg_samples
img_rand = rearrange(img_rand, 'b d (l h) (m w) -> b d (l m) h w', l=lh, m=lw)
img_rand = (img_rand * rand_probs[:, None]).sum(2)
if self.V.white_background:
img_rand = img_rand + (1 - rand_probs.sum(1, keepdim=True))
rand_indexs = repeat(rand_indexs, 'b n -> b d n', d=img_rand.size(1))
img_ff = rearrange(rearrange(img, 'b d l h -> b d (l h)').gather(2, rand_indexs), 'b d (l h) -> b d l h', l=self.n_reg_samples)
def l2(img_ff, img_nf):
batch_size = img_nf.size(0)
return ((img_ff - img_nf) ** 2).sum(1).reshape(batch_size, -1).mean(-1, keepdim=True)
reg_loss['reg_loss'] = l2(img_ff, img_rand) * 2.0
if len(reg_loss) > 0:
for key in reg_loss:
block_kwargs[key] = reg_loss[key]
if self.rectangular_crop is not None: # in case rectangular
h, w = self.rectangular_crop
c = int(img.size(-1) * (1 - h / w) / 2)
mask = torch.ones_like(img)
mask[:, :, c:-c, :] = 0
img = img.masked_fill(mask > 0, -1)
block_kwargs['img'] = img
return block_kwargs
def get_current_resolution(self):
n_levels = len(self.block_resolutions)
if not self.progressive_growing:
end_l = n_levels
elif (self.alpha > -1) and (not self.progressive_nerf_only):
if self.alpha == 0:
end_l = 0
elif self.alpha == 1:
end_l = n_levels
elif self.alpha < 1:
end_l = int(math.modf(self.alpha * n_levels)[1] + 1)
else:
end_l = n_levels
target_res = self.resolution_start if end_l <= 0 else self.block_resolutions[end_l-1]
before_res = self.resolution_start if end_l <= 1 else self.block_resolutions[end_l-2]
return n_levels, end_l, before_res, target_res
def get_latent_codes(self, batch_size=32, device="cpu", tmp=1.):
z_dim, z_dim_bg = self.z_dim, self.z_dim_bg
def sample_z(*size):
torch.randn(*size).to(device)
return torch.randn(*size).to(device) * tmp
z_shape_obj = sample_z(batch_size, z_dim)
z_app_obj = sample_z(batch_size, z_dim)
z_shape_bg = sample_z(batch_size, z_dim_bg) if not self.V.no_background else None
z_app_bg = sample_z(batch_size, z_dim_bg) if not self.V.no_background else None
return z_shape_obj, z_app_obj, z_shape_bg, z_app_bg
def get_camera(self, *args, **kwargs): # for compitability
return self.C.get_camera(*args, **kwargs)
def get_camera_samples(self, batch_size, ws, block_kwargs, gen_cond=False):
if gen_cond: # camera condition for generator (? a special variant)
if ('camera_matrices' in block_kwargs) and (not self.training): # this is for rendering
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=[0.5, 0.5, 0.5])
elif self.training and (np.random.rand() > 0.5):
camera_matrices = self.get_camera(batch_size, device=ws.device)
else:
camera_matrices = None
elif 'camera_mode' in block_kwargs:
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=block_kwargs["camera_mode"])
else:
if self.predict_camera:
rand_mode = ws.new_zeros(ws.size(0), 2)
if self.C.gaussian_camera:
rand_mode = rand_mode.normal_()
pred_mode = self.camera_generator(rand_mode)
else:
rand_mode = rand_mode.uniform_()
pred_mode = self.camera_generator(rand_mode - 0.5)
mode = rand_mode if self.alpha <= 0 else rand_mode + pred_mode * 0.1
camera_matrices = self.get_camera(batch_size, device=ws.device, mode=mode)
else:
camera_matrices = self.get_camera(batch_size, device=ws.device)
if ('camera_RT' in block_kwargs) or ('camera_UV' in block_kwargs):
camera_matrices = list(camera_matrices)
camera_mask = torch.rand(batch_size).type_as(camera_matrices[1]).lt(self.alpha)
if 'camera_RT' in block_kwargs:
image_RT = block_kwargs['camera_RT'].reshape(-1, 4, 4)
camera_matrices[1][camera_mask] = image_RT[camera_mask] # replacing with inferred cameras
else: # sample uv instead of sampling the extrinsic matrix
image_UV = block_kwargs['camera_UV']
image_RT = self.get_camera(batch_size, device=ws.device, mode=image_UV, force_uniform=True)[1]
camera_matrices[1][camera_mask] = image_RT[camera_mask] # replacing with inferred cameras
camera_matrices[2][camera_mask] = image_UV[camera_mask] # replacing with inferred uvs
camera_matrices = tuple(camera_matrices)
return camera_matrices
@persistence.persistent_class
class Discriminator(torch.nn.Module):
def __init__(self,
c_dim, # Conditioning label (C) dimensionality.
img_resolution, # Input resolution.
img_channels, # Number of input color channels.
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
channel_base = 1, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping.
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
lowres_head = None, # add a low-resolution discriminator head
dual_discriminator = False, # add low-resolution (NeRF) image
block_kwargs = {}, # Arguments for DiscriminatorBlock.
mapping_kwargs = {}, # Arguments for MappingNetwork.
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
camera_kwargs = {}, # Arguments for Camera predictor and condition (optional, refactoring)
upsample_type = 'default',
progressive = False,
resize_real_early = False, # Peform resizing before the training loop
enable_ema = False, # Additionally save an EMA checkpoint
**unused
):
super().__init__()
# setup parameters
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
self.architecture = architecture
self.lowres_head = lowres_head
self.dual_discriminator = dual_discriminator
self.upsample_type = upsample_type
self.progressive = progressive
self.resize_real_early = resize_real_early
self.enable_ema = enable_ema
if self.progressive:
assert self.architecture == 'skip', "not supporting other types for now."
channel_base = int(channel_base * 32768)
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
# camera prediction module
self.camera_kwargs = EasyDict(
predict_camera=False,
predict_styles=False,
camera_type='3d',
camera_encoder=True,
camera_encoder_progressive=False,
camera_disc=True)
## ------ for compitibility ------- #
self.camera_kwargs.predict_camera = unused.get('predict_camera', False)
self.camera_kwargs.camera_type = '9d' if unused.get('predict_9d_camera', False) else '3d'
self.camera_kwargs.camera_disc = not unused.get('no_camera_condition', False)
self.camera_kwargs.camera_encoder = unused.get('saperate_camera', False)
self.camera_kwargs.update(camera_kwargs)
## ------ for compitibility ------- #
self.c_dim = c_dim
if self.camera_kwargs.predict_camera:
if self.camera_kwargs.camera_type == '3d':
self.c_dim = out_dim = 3 # (u, v) on the sphere
elif self.camera_kwargs.camera_type == '9d':
self.c_dim, out_dim = 16, 9
elif self.camera_kwargs.camera_type == '16d':
self.c_dim = out_dim = 16
else:
raise NotImplementedError('Wrong camera type')
if not self.camera_kwargs.camera_disc:
self.c_dim = c_dim
self.projector = EqualConv2d(channels_dict[4], out_dim, 4, padding=0, bias=False)
if cmap_dim is None:
cmap_dim = channels_dict[4]
if self.c_dim == 0:
cmap_dim = 0
if self.c_dim > 0:
self.mapping = MappingNetwork(z_dim=0, c_dim=self.c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
if self.camera_kwargs.predict_styles:
self.w_dim, self.num_ws = self.camera_kwargs.w_dim, self.camera_kwargs.num_ws
self.projector_styles = EqualConv2d(channels_dict[4], self.w_dim * self.num_ws, 4, padding=0, bias=False)
self.mapping_styles = MappingNetwork(z_dim=0, c_dim=self.w_dim * self.num_ws, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
# main discriminator blocks
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
def build_blocks(layer_name='b', low_resolution=False):
cur_layer_idx = 0
block_resolutions = self.block_resolutions
if low_resolution:
block_resolutions = [r for r in self.block_resolutions if r <= self.lowres_head]
for res in block_resolutions:
in_channels = channels_dict[res] if res < img_resolution else 0
tmp_channels = channels_dict[res]
out_channels = channels_dict[res // 2]
use_fp16 = (res >= fp16_resolution)
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
setattr(self, f'{layer_name}{res}', block)
cur_layer_idx += block.num_layers
build_blocks(layer_name='b') # main blocks
if self.dual_discriminator:
build_blocks(layer_name='dual', low_resolution=True)
if self.camera_kwargs.camera_encoder:
build_blocks(layer_name='c', low_resolution=(not self.camera_kwargs.camera_encoder_progressive))
# final output module
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
self.register_buffer("alpha", torch.scalar_tensor(-1))
def set_alpha(self, alpha):
if alpha is not None:
self.alpha = self.alpha * 0 + alpha
def set_resolution(self, res):
self.curr_status = res
def forward_blocks_progressive(self, img, mode="disc", **block_kwargs):
# mode from ['disc', 'dual_disc', 'cam_enc']
if isinstance(img, dict):
img = img['img']
block_resolutions, alpha, lowres_head = self.get_block_resolutions(img)
layer_name, progressive = 'b', self.progressive
if mode == "cam_enc":
assert self.camera_kwargs.predict_camera and self.camera_kwargs.camera_encoder
layer_name = 'c'
if not self.camera_kwargs.camera_encoder_progressive:
block_resolutions, progressive = [r for r in self.block_resolutions if r <= self.lowres_head], False
img = downsample(img, self.lowres_head)
elif mode == 'dual_disc':
layer_name = 'dual'
block_resolutions, progressive = [r for r in self.block_resolutions if r <= self.lowres_head], False
img0 = downsample(img, img.size(-1) // 2) if \
progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0) \
else None
x = None if (not progressive) or (block_resolutions[0] == self.img_resolution) \
else getattr(self, f'{layer_name}{block_resolutions[0]}').fromrgb(img)
for res in block_resolutions:
block = getattr(self, f'{layer_name}{res}')
if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0):
if progressive:
if self.architecture == 'skip':
img = img * alpha + img0 * (1 - alpha)
x = x * alpha + block.fromrgb(img0) * (1 - alpha)
x, img = block(x, img, **block_kwargs)
output = {}
if (mode == 'cam_enc') or \
(mode == 'disc' and self.camera_kwargs.predict_camera and (not self.camera_kwargs.camera_encoder)):
c = self.projector(x)[:,:,0,0]
if self.camera_kwargs.camera_type == '9d':
c = camera_9d_to_16d(c)
output['cam'] = c
if self.camera_kwargs.predict_styles:
w = self.projector_styles(x)[:,:,0,0]
output['styles'] = w
return output, x, img
def get_camera_loss(self, RT=None, UV=None, c=None):
if (RT is None) or (UV is None):
return None
if self.camera_kwargs.camera_type == '3d': # UV has higher priority?
return F.mse_loss(UV, c)
else:
return F.smooth_l1_loss(RT.reshape(RT.size(0), -1), c) * 10
def get_styles_loss(self, WS=None, w=None):
if WS is None:
return None
return F.mse_loss(WS, w) * 0.1
def get_block_resolutions(self, input_img):
block_resolutions = self.block_resolutions
lowres_head = self.lowres_head
alpha = self.alpha
img_res = input_img.size(-1)
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1):
if (self.alpha < 1) and (self.alpha > 0):
try:
n_levels, _, before_res, target_res = self.curr_status
alpha, index = math.modf(self.alpha * n_levels)
index = int(index)
except Exception as e: # TODO: this is a hack, better to save status as buffers.
before_res = target_res = img_res
if before_res == target_res: # no upsampling was used in generator, do not increase the discriminator
alpha = 0
block_resolutions = [res for res in self.block_resolutions if res <= target_res]
lowres_head = before_res
elif self.alpha == 0:
block_resolutions = [res for res in self.block_resolutions if res <= lowres_head]
return block_resolutions, alpha, lowres_head
def forward(self, inputs, c=None, aug_pipe=None, return_camera=False, **block_kwargs):
if not isinstance(inputs, dict):
inputs = {'img': inputs}
img = inputs['img']
# this is to handle real images
block_resolutions, alpha, _ = self.get_block_resolutions(img)
if img.size(-1) > block_resolutions[0]:
img = downsample(img, block_resolutions[0])
if self.dual_discriminator and ('img_nerf' not in inputs):
inputs['img_nerf'] = downsample(img, self.lowres_head)
RT = inputs['camera_matrices'][1].detach() if 'camera_matrices' in inputs else None
UV = inputs['camera_matrices'][2].detach() if 'camera_matrices' in inputs else None
WS = inputs['ws_detach'].reshape(inputs['batch_size'], -1) if 'ws_detach' in inputs else None
no_condition = (c.size(-1) == 0)
# forward separate camera encoder, which can also be progressive...
if self.camera_kwargs.camera_encoder:
out_camenc, _, _ = self.forward_blocks_progressive(img, mode='cam_enc', **block_kwargs)
if no_condition and ('cam' in out_camenc):
c, camera_loss = out_camenc['cam'], self.get_camera_loss(RT, UV, out_camenc['cam'])
if 'styles' in out_camenc:
w, styles_loss = out_camenc['styles'], self.get_styles_loss(WS, out_camenc['styles'])
no_condition = False
# forward another dual discriminator only for low resolution images
if self.dual_discriminator:
_, x_nerf, img_nerf = self.forward_blocks_progressive(inputs['img_nerf'], mode='dual_disc', **block_kwargs)
# if applied data augmentation for discriminator
if aug_pipe is not None:
img = aug_pipe(img)
# perform main discriminator block
out_disc, x, img = self.forward_blocks_progressive(img, mode='disc', **block_kwargs)
if no_condition and ('cam' in out_disc):
c, camera_loss = out_disc['cam'], self.get_camera_loss(RT, UV, out_disc['cam'])
if 'styles' in out_disc:
w, styles_loss = out_disc['styles'], self.get_styles_loss(WS, out_disc['styles'])
no_condition = False
# camera conditional discriminator
cmap = None
if self.c_dim > 0:
cc = c.clone().detach()
cmap = self.mapping(None, cc)
if self.camera_kwargs.predict_styles:
ww = w.clone().detach()
cmap = [cmap] + [self.mapping_styles(None, ww)]
logits = self.b4(x, img, cmap)
if self.dual_discriminator:
logits = torch.cat([logits, self.b4(x_nerf, img_nerf, cmap)], 0)
outputs = {'logits': logits}
if self.camera_kwargs.predict_camera and (camera_loss is not None):
outputs['camera_loss'] = camera_loss
if self.camera_kwargs.predict_styles and (styles_loss is not None):
outputs['styles_loss'] = styles_loss
if return_camera:
outputs['camera'] = c
return outputs
@persistence.persistent_class
class Encoder(torch.nn.Module):
def __init__(self,
img_resolution, # Input resolution.
img_channels, # Number of input color channels.
bottleneck_factor = 2, # By default, the same as discriminator we use 4x4 features
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
channel_base = 1, # Overall multiplier for the number of channels.
channel_max = 512, # Maximum number of channels in any layer.
num_fp16_res = 0, # Use FP16 for the N highest resolutions.
conv_clamp = None, # Clamp the output of convolution layers to +-X, None = disable clamping
lowres_head = None, # add a low-resolution discriminator head
block_kwargs = {}, # Arguments for DiscriminatorBlock.
model_kwargs = {},
upsample_type = 'default',
progressive = False,
**unused
):
super().__init__()
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution))
self.img_channels = img_channels
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, bottleneck_factor, -1)]
self.architecture = architecture
self.lowres_head = lowres_head
self.upsample_type = upsample_type
self.progressive = progressive
self.model_kwargs = model_kwargs
self.output_mode = model_kwargs.get('output_mode', 'styles')
if self.progressive:
assert self.architecture == 'skip', "not supporting other types for now."
self.predict_camera = model_kwargs.get('predict_camera', False)
channel_base = int(channel_base * 32768)
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
cur_layer_idx = 0
for res in self.block_resolutions:
in_channels = channels_dict[res] if res < img_resolution else 0
tmp_channels = channels_dict[res]
out_channels = channels_dict[res // 2]
use_fp16 = (res >= fp16_resolution)
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
setattr(self, f'b{res}', block)
cur_layer_idx += block.num_layers
# this is an encoder
if self.output_mode in ['W', 'W+', 'None']:
self.num_ws = self.model_kwargs.get('num_ws', 0)
self.n_latents = self.num_ws if self.output_mode == 'W+' else (0 if self.output_mode == 'None' else 1)
self.w_dim = self.model_kwargs.get('w_dim', 512)
self.add_dim = self.model_kwargs.get('add_dim', 0) if not self.predict_camera else 9
self.out_dim = self.w_dim * self.n_latents + self.add_dim
assert self.out_dim > 0, 'output dimenstion has to be larger than 0'
assert self.block_resolutions[-1] // 2 == 4, "make sure the last resolution is 4x4"
self.projector = EqualConv2d(channels_dict[4], self.out_dim, 4, padding=0, bias=False)
else:
raise NotImplementedError
self.register_buffer("alpha", torch.scalar_tensor(-1))
def set_alpha(self, alpha):
if alpha is not None:
self.alpha.fill_(alpha)
def set_resolution(self, res):
self.curr_status = res
def get_block_resolutions(self, input_img):
block_resolutions = self.block_resolutions
lowres_head = self.lowres_head
alpha = self.alpha
img_res = input_img.size(-1)
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1):
if (self.alpha < 1) and (self.alpha > 0):
try:
n_levels, _, before_res, target_res = self.curr_status
alpha, index = math.modf(self.alpha * n_levels)
index = int(index)
except Exception as e: # TODO: this is a hack, better to save status as buffers.
before_res = target_res = img_res
if before_res == target_res:
# no upsampling was used in generator, do not increase the discriminator
alpha = 0
block_resolutions = [res for res in self.block_resolutions if res <= target_res]
lowres_head = before_res
elif self.alpha == 0:
block_resolutions = [res for res in self.block_resolutions if res <= lowres_head]
return block_resolutions, alpha, lowres_head
def forward(self, inputs, **block_kwargs):
if isinstance(inputs, dict):
img = inputs['img']
else:
img = inputs
block_resolutions, alpha, lowres_head = self.get_block_resolutions(img)
if img.size(-1) > block_resolutions[0]:
img = downsample(img, block_resolutions[0])
if self.progressive and (self.lowres_head is not None) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0):
img0 = downsample(img, img.size(-1) // 2)
x = None if (not self.progressive) or (block_resolutions[0] == self.img_resolution) \
else getattr(self, f'b{block_resolutions[0]}').fromrgb(img)
for res in block_resolutions:
block = getattr(self, f'b{res}')
if (lowres_head == res) and (self.alpha > -1) and (self.alpha < 1) and (alpha > 0):
if self.architecture == 'skip':
img = img * alpha + img0 * (1 - alpha)
if self.progressive:
x = x * alpha + block.fromrgb(img0) * (1 - alpha) # combine from img0
x, img = block(x, img, **block_kwargs)
outputs = {}
if self.output_mode in ['W', 'W+', 'None']:
out = self.projector(x)[:,:,0,0]
if self.predict_camera:
out, out_cam_9d = out[:, 9:], out[:, :9]
outputs['camera'] = camera_9d_to_16d(out_cam_9d)
if self.output_mode == 'W+':
out = rearrange(out, 'b (n s) -> b n s', n=self.num_ws, s=self.w_dim)
elif self.output_mode == 'W':
out = repeat(out, 'b s -> b n s', n=self.num_ws)
else:
out = None
outputs['ws'] = out
return outputs
# ------------------------------------------------------------------------------------------- #
class CameraQueriedSampler(torch.utils.data.Sampler):
def __init__(self, dataset, camera_module, nearest_neighbors=400, rank=0, num_replicas=1, device='cpu', seed=0):
assert len(dataset) > 0
super().__init__(dataset)
self.dataset = dataset
self.dataset_cameras = None
self.seed = seed
self.rank = rank
self.device = device
self.num_replicas = num_replicas
self.C = camera_module
self.K = nearest_neighbors
self.B = 1000
def update_dataset_cameras(self, estimator):
import tqdm
from torch_utils.distributed_utils import gather_list_and_concat
output = torch.ones(len(self.dataset), 16).to(self.device)
with torch.no_grad():
predicted_cameras, image_indices, bsz = [], [], 64
item_subset = [(i * self.num_replicas + self.rank) % len(self.dataset) for i in range((len(self.dataset) - 1) // self.num_replicas + 1)]
for _, (images, _, indices) in tqdm.tqdm(enumerate(torch.utils.data.DataLoader(
dataset=copy.deepcopy(self.dataset), sampler=item_subset, batch_size=bsz)),
total=len(item_subset)//bsz+1, colour='red', desc=f'Estimating camera poses for the training set at'):
predicted_cameras += [estimator(images.to(self.device).to(torch.float32) / 127.5 - 1)]
image_indices += [indices.to(self.device).long()]
predicted_cameras = torch.cat(predicted_cameras, 0)
image_indices = torch.cat(image_indices, 0)
if self.num_replicas > 1:
predicted_cameras = gather_list_and_concat(predicted_cameras)
image_indices = gather_list_and_concat(image_indices)
output[image_indices] = predicted_cameras
self.dataset_cameras = output
def get_knn_cameras(self):
return torch.norm(
self.dataset_cameras.unsqueeze(1) -
self.C.get_camera(self.B, self.device)[0].reshape(1,self.B,16), dim=2, p=None
).topk(self.K, largest=False, dim=0)[1] # K x B
def __iter__(self):
order = np.arange(len(self.dataset))
rnd = np.random.RandomState(self.seed+self.rank)
while True:
if self.dataset_cameras is None:
rand_idx = rnd.randint(order.size)
yield rand_idx
else:
knn_idxs = self.get_knn_cameras()
for i in range(self.B):
rand_idx = rnd.randint(self.K)
yield knn_idxs[rand_idx, i].item()