|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Collection of Losses. |
|
""" |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torchtyping import TensorType |
|
from torch.autograd import Variable |
|
import numpy as np |
|
from math import exp |
|
|
|
|
|
|
|
|
|
L1Loss = nn.L1Loss |
|
MSELoss = nn.MSELoss |
|
|
|
LOSSES = {"L1": L1Loss, "MSE": MSELoss} |
|
|
|
EPS = 1.0e-7 |
|
|
|
|
|
def outer( |
|
t0_starts: TensorType[..., "num_samples_0"], |
|
t0_ends: TensorType[..., "num_samples_0"], |
|
t1_starts: TensorType[..., "num_samples_1"], |
|
t1_ends: TensorType[..., "num_samples_1"], |
|
y1: TensorType[..., "num_samples_1"], |
|
) -> TensorType[..., "num_samples_0"]: |
|
"""Faster version of |
|
|
|
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L117 |
|
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L64 |
|
|
|
Args: |
|
t0_starts: start of the interval edges |
|
t0_ends: end of the interval edges |
|
t1_starts: start of the interval edges |
|
t1_ends: end of the interval edges |
|
y1: weights |
|
""" |
|
cy1 = torch.cat([torch.zeros_like(y1[..., :1]), torch.cumsum(y1, dim=-1)], dim=-1) |
|
|
|
idx_lo = torch.searchsorted(t1_starts.contiguous(), t0_starts.contiguous(), side="right") - 1 |
|
idx_lo = torch.clamp(idx_lo, min=0, max=y1.shape[-1] - 1) |
|
idx_hi = torch.searchsorted(t1_ends.contiguous(), t0_ends.contiguous(), side="right") |
|
idx_hi = torch.clamp(idx_hi, min=0, max=y1.shape[-1] - 1) |
|
cy1_lo = torch.take_along_dim(cy1[..., :-1], idx_lo, dim=-1) |
|
cy1_hi = torch.take_along_dim(cy1[..., 1:], idx_hi, dim=-1) |
|
y0_outer = cy1_hi - cy1_lo |
|
|
|
return y0_outer |
|
|
|
|
|
def lossfun_outer( |
|
t: TensorType[..., "num_samples+1"], |
|
w: TensorType[..., "num_samples"], |
|
t_env: TensorType[..., "num_samples+1"], |
|
w_env: TensorType[..., "num_samples"], |
|
): |
|
""" |
|
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L136 |
|
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L80 |
|
|
|
Args: |
|
t: interval edges |
|
w: weights |
|
t_env: interval edges of the upper bound enveloping historgram |
|
w_env: weights that should upper bound the inner (t,w) histogram |
|
""" |
|
w_outer = outer(t[..., :-1], t[..., 1:], t_env[..., :-1], t_env[..., 1:], w_env) |
|
return torch.clip(w - w_outer, min=0) ** 2 / (w + EPS) |
|
|
|
|
|
def ray_samples_to_sdist(ray_samples): |
|
"""Convert ray samples to s space""" |
|
starts = ray_samples.spacing_starts |
|
ends = ray_samples.spacing_ends |
|
sdist = torch.cat([starts[..., 0], ends[..., -1:, 0]], dim=-1) |
|
return sdist |
|
|
|
|
|
def interlevel_loss(weights_list, ray_samples_list): |
|
"""Calculates the proposal loss in the MipNeRF-360 paper. |
|
|
|
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/model.py#L515 |
|
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/train_utils.py#L133 |
|
""" |
|
c = ray_samples_to_sdist(ray_samples_list[-1]).detach() |
|
w = weights_list[-1][..., 0].detach() |
|
loss_interlevel = 0.0 |
|
for ray_samples, weights in zip(ray_samples_list[:-1], weights_list[:-1]): |
|
sdist = ray_samples_to_sdist(ray_samples) |
|
cp = sdist |
|
wp = weights[..., 0] |
|
loss_interlevel += torch.mean(lossfun_outer(c, w, cp, wp)) |
|
return loss_interlevel |
|
|
|
|
|
|
|
def blur_stepfun(x, y, r): |
|
x_c = torch.cat([x - r, x + r], dim=-1) |
|
x_r, x_idx = torch.sort(x_c, dim=-1) |
|
zeros = torch.zeros_like(y[:, :1]) |
|
y_1 = (torch.cat([y, zeros], dim=-1) - torch.cat([zeros, y], dim=-1)) / (2 * r) |
|
x_idx = x_idx[:, :-1] |
|
y_2 = torch.cat([y_1, -y_1], dim=-1)[ |
|
torch.arange(x_idx.shape[0]).reshape(-1, 1).expand(x_idx.shape).to(x_idx.device), x_idx |
|
] |
|
|
|
y_r = torch.cumsum((x_r[:, 1:] - x_r[:, :-1]) * torch.cumsum(y_2, dim=-1), dim=-1) |
|
y_r = torch.cat([zeros, y_r], dim=-1) |
|
return x_r, y_r |
|
|
|
|
|
def interlevel_loss_zip(weights_list, ray_samples_list): |
|
"""Calculates the proposal loss in the Zip-NeRF paper.""" |
|
c = ray_samples_to_sdist(ray_samples_list[-1]).detach() |
|
w = weights_list[-1][..., 0].detach() |
|
|
|
|
|
w_normalize = w / (c[:, 1:] - c[:, :-1]) |
|
|
|
loss_interlevel = 0.0 |
|
for ray_samples, weights, r in zip(ray_samples_list[:-1], weights_list[:-1], [0.03, 0.003]): |
|
|
|
x_r, y_r = blur_stepfun(c, w_normalize, r) |
|
y_r = torch.clip(y_r, min=0) |
|
assert (y_r >= 0.0).all() |
|
|
|
|
|
y_cum = torch.cumsum((y_r[:, 1:] + y_r[:, :-1]) * 0.5 * (x_r[:, 1:] - x_r[:, :-1]), dim=-1) |
|
y_cum = torch.cat([torch.zeros_like(y_cum[:, :1]), y_cum], dim=-1) |
|
|
|
|
|
sdist = ray_samples_to_sdist(ray_samples) |
|
cp = sdist |
|
wp = weights[..., 0] |
|
|
|
|
|
inds = torch.searchsorted(x_r, cp, side="right") |
|
below = torch.clamp(inds - 1, 0, x_r.shape[-1] - 1) |
|
above = torch.clamp(inds, 0, x_r.shape[-1] - 1) |
|
cdf_g0 = torch.gather(x_r, -1, below) |
|
bins_g0 = torch.gather(y_cum, -1, below) |
|
cdf_g1 = torch.gather(x_r, -1, above) |
|
bins_g1 = torch.gather(y_cum, -1, above) |
|
|
|
t = torch.clip(torch.nan_to_num((cp - cdf_g0) / (cdf_g1 - cdf_g0), 0), 0, 1) |
|
bins = bins_g0 + t * (bins_g1 - bins_g0) |
|
|
|
w_gt = bins[:, 1:] - bins[:, :-1] |
|
|
|
|
|
loss_interlevel += torch.mean(torch.clip(w_gt - wp, min=0) ** 2 / (wp + 1e-5)) |
|
|
|
return loss_interlevel |
|
|
|
|
|
|
|
def lossfun_distortion(t, w): |
|
""" |
|
https://github.com/kakaobrain/NeRF-Factory/blob/f61bb8744a5cb4820a4d968fb3bfbed777550f4a/src/model/mipnerf360/helper.py#L142 |
|
https://github.com/google-research/multinerf/blob/b02228160d3179300c7d499dca28cb9ca3677f32/internal/stepfun.py#L266 |
|
""" |
|
ut = (t[..., 1:] + t[..., :-1]) / 2 |
|
dut = torch.abs(ut[..., :, None] - ut[..., None, :]) |
|
loss_inter = torch.sum(w * torch.sum(w[..., None, :] * dut, dim=-1), dim=-1) |
|
|
|
loss_intra = torch.sum(w**2 * (t[..., 1:] - t[..., :-1]), dim=-1) / 3 |
|
|
|
return loss_inter + loss_intra |
|
|
|
|
|
def distortion_loss(weights_list, ray_samples_list): |
|
"""From mipnerf360""" |
|
c = ray_samples_to_sdist(ray_samples_list[-1]) |
|
w = weights_list[-1][..., 0] |
|
loss = torch.mean(lossfun_distortion(c, w)) |
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def orientation_loss( |
|
weights: TensorType["bs":..., "num_samples", 1], |
|
normals: TensorType["bs":..., "num_samples", 3], |
|
viewdirs: TensorType["bs":..., 3], |
|
): |
|
"""Orientation loss proposed in Ref-NeRF. |
|
Loss that encourages that all visible normals are facing towards the camera. |
|
""" |
|
w = weights |
|
n = normals |
|
v = viewdirs |
|
n_dot_v = (n * v[..., None, :]).sum(axis=-1) |
|
return (w[..., 0] * torch.fmin(torch.zeros_like(n_dot_v), n_dot_v) ** 2).sum(dim=-1) |
|
|
|
|
|
def pred_normal_loss( |
|
weights: TensorType["bs":..., "num_samples", 1], |
|
normals: TensorType["bs":..., "num_samples", 3], |
|
pred_normals: TensorType["bs":..., "num_samples", 3], |
|
): |
|
"""Loss between normals calculated from density and normals from prediction network.""" |
|
return (weights[..., 0] * (1.0 - torch.sum(normals * pred_normals, dim=-1))).sum(dim=-1) |
|
|
|
|
|
def monosdf_normal_loss(normal_pred: torch.Tensor, normal_gt: torch.Tensor): |
|
"""normal consistency loss as monosdf |
|
|
|
Args: |
|
normal_pred (torch.Tensor): volume rendered normal |
|
normal_gt (torch.Tensor): monocular normal |
|
""" |
|
normal_gt = torch.nn.functional.normalize(normal_gt, p=2, dim=-1) |
|
normal_pred = torch.nn.functional.normalize(normal_pred, p=2, dim=-1) |
|
l1 = torch.abs(normal_pred - normal_gt).sum(dim=-1).mean() |
|
cos = (1.0 - torch.sum(normal_pred * normal_gt, dim=-1)).mean() |
|
return l1 + cos |
|
|
|
|
|
|
|
def compute_scale_and_shift(prediction, target, mask): |
|
|
|
a_00 = torch.sum(mask * prediction * prediction, (1, 2)) |
|
a_01 = torch.sum(mask * prediction, (1, 2)) |
|
a_11 = torch.sum(mask, (1, 2)) |
|
|
|
|
|
b_0 = torch.sum(mask * prediction * target, (1, 2)) |
|
b_1 = torch.sum(mask * target, (1, 2)) |
|
|
|
|
|
x_0 = torch.zeros_like(b_0) |
|
x_1 = torch.zeros_like(b_1) |
|
|
|
det = a_00 * a_11 - a_01 * a_01 |
|
valid = det.nonzero() |
|
|
|
x_0[valid] = (a_11[valid] * b_0[valid] - a_01[valid] * b_1[valid]) / det[valid] |
|
x_1[valid] = (-a_01[valid] * b_0[valid] + a_00[valid] * b_1[valid]) / det[valid] |
|
|
|
return x_0, x_1 |
|
|
|
|
|
def reduction_batch_based(image_loss, M): |
|
|
|
|
|
|
|
divisor = torch.sum(M) |
|
|
|
if divisor == 0: |
|
return 0 |
|
else: |
|
return torch.sum(image_loss) / divisor |
|
|
|
|
|
def reduction_image_based(image_loss, M): |
|
|
|
|
|
|
|
valid = M.nonzero() |
|
|
|
image_loss[valid] = image_loss[valid] / M[valid] |
|
|
|
return torch.mean(image_loss) |
|
|
|
|
|
def mse_loss(prediction, target, mask, reduction=reduction_batch_based): |
|
M = torch.sum(mask, (1, 2)) |
|
res = prediction - target |
|
image_loss = torch.sum(mask * res * res, (1, 2)) |
|
|
|
return reduction(image_loss, 2 * M) |
|
|
|
|
|
def gradient_loss(prediction, target, mask, reduction=reduction_batch_based): |
|
M = torch.sum(mask, (1, 2)) |
|
|
|
diff = prediction - target |
|
diff = torch.mul(mask, diff) |
|
|
|
grad_x = torch.abs(diff[:, :, 1:] - diff[:, :, :-1]) |
|
mask_x = torch.mul(mask[:, :, 1:], mask[:, :, :-1]) |
|
grad_x = torch.mul(mask_x, grad_x) |
|
|
|
grad_y = torch.abs(diff[:, 1:, :] - diff[:, :-1, :]) |
|
mask_y = torch.mul(mask[:, 1:, :], mask[:, :-1, :]) |
|
grad_y = torch.mul(mask_y, grad_y) |
|
|
|
image_loss = torch.sum(grad_x, (1, 2)) + torch.sum(grad_y, (1, 2)) |
|
|
|
return reduction(image_loss, M) |
|
|
|
|
|
class MiDaSMSELoss(nn.Module): |
|
def __init__(self, reduction="batch-based"): |
|
super().__init__() |
|
|
|
if reduction == "batch-based": |
|
self.__reduction = reduction_batch_based |
|
else: |
|
self.__reduction = reduction_image_based |
|
|
|
def forward(self, prediction, target, mask): |
|
return mse_loss(prediction, target, mask, reduction=self.__reduction) |
|
|
|
|
|
class GradientLoss(nn.Module): |
|
def __init__(self, scales=4, reduction="batch-based"): |
|
super().__init__() |
|
|
|
if reduction == "batch-based": |
|
self.__reduction = reduction_batch_based |
|
else: |
|
self.__reduction = reduction_image_based |
|
|
|
self.__scales = scales |
|
|
|
def forward(self, prediction, target, mask): |
|
total = 0 |
|
|
|
for scale in range(self.__scales): |
|
step = pow(2, scale) |
|
|
|
total += gradient_loss( |
|
prediction[:, ::step, ::step], |
|
target[:, ::step, ::step], |
|
mask[:, ::step, ::step], |
|
reduction=self.__reduction, |
|
) |
|
|
|
return total |
|
|
|
|
|
class ScaleAndShiftInvariantLoss(nn.Module): |
|
def __init__(self, alpha=0.5, scales=4, reduction="batch-based"): |
|
super().__init__() |
|
|
|
self.__data_loss = MiDaSMSELoss(reduction=reduction) |
|
self.__regularization_loss = GradientLoss(scales=scales, reduction=reduction) |
|
self.__alpha = alpha |
|
|
|
self.__prediction_ssi = None |
|
|
|
def forward(self, prediction, target, mask): |
|
scale, shift = compute_scale_and_shift(prediction, target, mask) |
|
self.__prediction_ssi = scale.view(-1, 1, 1) * prediction + shift.view(-1, 1, 1) |
|
|
|
total = self.__data_loss(self.__prediction_ssi, target, mask) |
|
if self.__alpha > 0: |
|
total += self.__alpha * self.__regularization_loss(self.__prediction_ssi, target, mask) |
|
|
|
return total |
|
|
|
def __get_prediction_ssi(self): |
|
return self.__prediction_ssi |
|
|
|
prediction_ssi = property(__get_prediction_ssi) |
|
|
|
|
|
|
|
|
|
|
|
|
|
class SSIM(nn.Module): |
|
"""Layer to compute the SSIM loss between a pair of images""" |
|
|
|
def __init__(self, patch_size): |
|
super(SSIM, self).__init__() |
|
self.mu_x_pool = nn.AvgPool2d(patch_size, 1) |
|
self.mu_y_pool = nn.AvgPool2d(patch_size, 1) |
|
self.sig_x_pool = nn.AvgPool2d(patch_size, 1) |
|
self.sig_y_pool = nn.AvgPool2d(patch_size, 1) |
|
self.sig_xy_pool = nn.AvgPool2d(patch_size, 1) |
|
|
|
self.refl = nn.ReflectionPad2d(patch_size // 2) |
|
|
|
self.C1 = 0.01**2 |
|
self.C2 = 0.03**2 |
|
|
|
def forward(self, x, y): |
|
x = self.refl(x) |
|
y = self.refl(y) |
|
|
|
mu_x = self.mu_x_pool(x) |
|
mu_y = self.mu_y_pool(y) |
|
|
|
sigma_x = self.sig_x_pool(x**2) - mu_x**2 |
|
sigma_y = self.sig_y_pool(y**2) - mu_y**2 |
|
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y |
|
|
|
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) |
|
SSIM_d = (mu_x**2 + mu_y**2 + self.C1) * (sigma_x + sigma_y + self.C2) |
|
|
|
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) |
|
|
|
|
|
|
|
class NCC(nn.Module): |
|
"""Layer to compute the normalization cross correlation (NCC) of patches""" |
|
|
|
def __init__(self, patch_size: int = 11, min_patch_variance: float = 0.01): |
|
super(NCC, self).__init__() |
|
self.patch_size = patch_size |
|
self.min_patch_variance = min_patch_variance |
|
|
|
def forward(self, x, y): |
|
|
|
|
|
x = torch.mean(x, dim=1) |
|
y = torch.mean(y, dim=1) |
|
|
|
x_mean = torch.mean(x, dim=(1, 2), keepdim=True) |
|
y_mean = torch.mean(y, dim=(1, 2), keepdim=True) |
|
|
|
x_normalized = x - x_mean |
|
y_normalized = y - y_mean |
|
|
|
norm = torch.sum(x_normalized * y_normalized, dim=(1, 2)) |
|
var = torch.square(x_normalized).sum(dim=(1, 2)) * torch.square(y_normalized).sum(dim=(1, 2)) |
|
denom = torch.sqrt(var + 1e-6) |
|
|
|
ncc = norm / (denom + 1e-6) |
|
|
|
|
|
not_valid = (torch.square(x_normalized).sum(dim=(1, 2)) < self.min_patch_variance) | ( |
|
torch.square(y_normalized).sum(dim=(1, 2)) < self.min_patch_variance |
|
) |
|
ncc[not_valid] = 1.0 |
|
|
|
score = 1 - ncc.clip(-1.0, 1.0) |
|
return score[:, None, None, None] |
|
|
|
|
|
class MultiViewLoss(nn.Module): |
|
"""compute multi-view consistency loss""" |
|
|
|
def __init__(self, patch_size: int = 11, topk: int = 4, min_patch_variance: float = 0.01): |
|
super(MultiViewLoss, self).__init__() |
|
self.patch_size = patch_size |
|
self.topk = topk |
|
self.min_patch_variance = min_patch_variance |
|
|
|
|
|
|
|
self.ssim = NCC(patch_size=patch_size, min_patch_variance=min_patch_variance) |
|
|
|
self.iter = 0 |
|
|
|
def forward(self, patches: torch.Tensor, valid: torch.Tensor): |
|
"""take the mim |
|
|
|
Args: |
|
patches (torch.Tensor): _description_ |
|
valid (torch.Tensor): _description_ |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
num_imgs, num_rays, _, num_channels = patches.shape |
|
|
|
if num_rays <= 0: |
|
return torch.tensor(0.0).to(patches.device) |
|
|
|
ref_patches = ( |
|
patches[:1, ...] |
|
.reshape(1, num_rays, self.patch_size, self.patch_size, num_channels) |
|
.expand(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) |
|
.reshape(-1, self.patch_size, self.patch_size, num_channels) |
|
.permute(0, 3, 1, 2) |
|
) |
|
src_patches = ( |
|
patches[1:, ...] |
|
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, num_channels) |
|
.reshape(-1, self.patch_size, self.patch_size, num_channels) |
|
.permute(0, 3, 1, 2) |
|
) |
|
|
|
|
|
src_patches_valid = ( |
|
valid[1:, ...] |
|
.reshape(num_imgs - 1, num_rays, self.patch_size, self.patch_size, 1) |
|
.reshape(-1, self.patch_size, self.patch_size, 1) |
|
.permute(0, 3, 1, 2) |
|
) |
|
|
|
ssim = self.ssim(ref_patches.detach(), src_patches) |
|
ssim = torch.mean(ssim, dim=(1, 2, 3)) |
|
ssim = ssim.reshape(num_imgs - 1, num_rays) |
|
|
|
|
|
ssim_valid = ( |
|
src_patches_valid.reshape(-1, self.patch_size * self.patch_size).all(dim=-1).reshape(num_imgs - 1, num_rays) |
|
) |
|
|
|
|
|
|
|
min_ssim, idx = torch.topk(ssim, k=self.topk, largest=False, dim=0, sorted=True) |
|
|
|
min_ssim_valid = ssim_valid[idx, torch.arange(num_rays)[None].expand_as(idx)] |
|
|
|
min_ssim[torch.logical_not(min_ssim_valid)] = 0.0 |
|
|
|
if False: |
|
|
|
|
|
import cv2 |
|
import numpy as np |
|
|
|
vis_patch_num = num_rays |
|
K = min(100, vis_patch_num) |
|
|
|
image = ( |
|
patches[:, :vis_patch_num, :, :] |
|
.reshape(-1, vis_patch_num, self.patch_size, self.patch_size, 3) |
|
.permute(1, 2, 0, 3, 4) |
|
.reshape(vis_patch_num * self.patch_size, -1, 3) |
|
) |
|
|
|
src_patches_reshaped = src_patches.reshape( |
|
num_imgs - 1, num_rays, 3, self.patch_size, self.patch_size |
|
).permute(1, 0, 3, 4, 2) |
|
idx = idx.permute(1, 0) |
|
|
|
selected_patch = ( |
|
src_patches_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] |
|
.permute(0, 2, 1, 3, 4) |
|
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 3)[:vis_patch_num] |
|
.reshape(-1, self.topk * self.patch_size, 3) |
|
) |
|
|
|
|
|
src_patches_valid_reshaped = src_patches_valid.reshape( |
|
num_imgs - 1, num_rays, 1, self.patch_size, self.patch_size |
|
).permute(1, 0, 3, 4, 2) |
|
|
|
selected_patch_valid = ( |
|
src_patches_valid_reshaped[torch.arange(num_rays)[:, None].expand(idx.shape), idx] |
|
.permute(0, 2, 1, 3, 4) |
|
.reshape(num_rays, self.patch_size, self.topk * self.patch_size, 1)[:vis_patch_num] |
|
.reshape(-1, self.topk * self.patch_size, 1) |
|
) |
|
|
|
selected_patch_valid = selected_patch_valid.expand_as(selected_patch).float() |
|
|
|
|
|
image = torch.cat([selected_patch_valid, selected_patch, image], dim=1) |
|
|
|
|
|
image = image.reshape(num_rays, self.patch_size, -1, 3) |
|
|
|
_, idx2 = torch.topk( |
|
torch.sum(min_ssim, dim=0) / (min_ssim_valid.float().sum(dim=0) + 1e-6), |
|
k=K, |
|
largest=True, |
|
dim=0, |
|
sorted=True, |
|
) |
|
|
|
image = image[idx2].reshape(K * self.patch_size, -1, 3) |
|
|
|
cv2.imwrite(f"vis/{self.iter}.png", (image.detach().cpu().numpy() * 255).astype(np.uint8)[..., ::-1]) |
|
self.iter += 1 |
|
if self.iter == 9: |
|
breakpoint() |
|
|
|
return torch.sum(min_ssim) / (min_ssim_valid.float().sum() + 1e-6) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Implements Stochastic Structural SIMilarity(S3IM) algorithm. |
|
It is proposed in the ICCV2023 paper |
|
`S3IM: Stochastic Structural SIMilarity and Its Unreasonable Effectiveness for Neural Fields`. |
|
|
|
Arguments: |
|
s3im_kernel_size (int): kernel size in ssim's convolution(default: 4) |
|
s3im_stride (int): stride in ssim's convolution(default: 4) |
|
s3im_repeat_time (int): repeat time in re-shuffle virtual patch(default: 10) |
|
s3im_patch_height (height): height of virtual patch(default: 64) |
|
""" |
|
|
|
class S3IM(torch.nn.Module): |
|
def __init__(self, s3im_kernel_size = 4, s3im_stride=4, s3im_repeat_time=10, s3im_patch_height=64, size_average = True): |
|
super(S3IM, self).__init__() |
|
self.s3im_kernel_size = s3im_kernel_size |
|
self.s3im_stride = s3im_stride |
|
self.s3im_repeat_time = s3im_repeat_time |
|
self.s3im_patch_height = s3im_patch_height |
|
self.size_average = size_average |
|
self.channel = 1 |
|
self.s3im_kernel = self.create_kernel(s3im_kernel_size, self.channel) |
|
|
|
|
|
def gaussian(self, s3im_kernel_size, sigma): |
|
gauss = torch.Tensor([exp(-(x - s3im_kernel_size//2)**2/float(2*sigma**2)) for x in range(s3im_kernel_size)]) |
|
return gauss/gauss.sum() |
|
|
|
def create_kernel(self, s3im_kernel_size, channel): |
|
_1D_window = self.gaussian(s3im_kernel_size, 1.5).unsqueeze(1) |
|
_2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) |
|
s3im_kernel = Variable(_2D_window.expand(channel, 1, s3im_kernel_size, s3im_kernel_size).contiguous()) |
|
return s3im_kernel |
|
|
|
def _ssim(self, img1, img2, s3im_kernel, s3im_kernel_size, channel, size_average = True, s3im_stride=None): |
|
mu1 = F.conv2d(img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) |
|
mu2 = F.conv2d(img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) |
|
|
|
mu1_sq = mu1.pow(2) |
|
mu2_sq = mu2.pow(2) |
|
mu1_mu2 = mu1*mu2 |
|
|
|
sigma1_sq = F.conv2d(img1*img1, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_sq |
|
sigma2_sq = F.conv2d(img2*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu2_sq |
|
sigma12 = F.conv2d(img1*img2, s3im_kernel, padding = (s3im_kernel_size-1)//2, groups = channel, stride=s3im_stride) - mu1_mu2 |
|
|
|
C1 = 0.01**2 |
|
C2 = 0.03**2 |
|
|
|
ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2)) |
|
|
|
if size_average: |
|
return ssim_map.mean() |
|
else: |
|
return ssim_map.mean(1).mean(1).mean(1) |
|
|
|
def ssim_loss(self, img1, img2): |
|
""" |
|
img1, img2: torch.Tensor([b,c,h,w]) |
|
""" |
|
(_, channel, _, _) = img1.size() |
|
|
|
if channel == self.channel and self.s3im_kernel.data.type() == img1.data.type(): |
|
s3im_kernel = self.s3im_kernel |
|
else: |
|
s3im_kernel = self.create_kernel(self.s3im_kernel_size, channel) |
|
|
|
if img1.is_cuda: |
|
s3im_kernel = s3im_kernel.cuda(img1.get_device()) |
|
s3im_kernel = s3im_kernel.type_as(img1) |
|
|
|
self.s3im_kernel = s3im_kernel |
|
self.channel = channel |
|
|
|
|
|
return self._ssim(img1, img2, s3im_kernel, self.s3im_kernel_size, channel, self.size_average, s3im_stride=self.s3im_stride) |
|
|
|
def forward(self, src_vec, tar_vec): |
|
loss = 0.0 |
|
index_list = [] |
|
for i in range(self.s3im_repeat_time): |
|
if i == 0: |
|
tmp_index = torch.arange(len(tar_vec)) |
|
index_list.append(tmp_index) |
|
else: |
|
ran_idx = torch.randperm(len(tar_vec)) |
|
index_list.append(ran_idx) |
|
res_index = torch.cat(index_list) |
|
tar_all = tar_vec[res_index] |
|
src_all = src_vec[res_index] |
|
tar_patch = tar_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) |
|
src_patch = src_all.permute(1, 0).reshape(1, 3, self.s3im_patch_height, -1) |
|
loss = (1 - self.ssim_loss(src_patch, tar_patch)) |
|
return loss |
|
|
|
|