Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,365 Bytes
1ea89dd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import torch
import torch.nn as nn
from unik3d.utils.constants import VERBOSE
from unik3d.utils.misc import profile_method
from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var,
masked_quantile)
class SILog(nn.Module):
def __init__(
self,
weight: float,
input_fn: str = "linear",
output_fn: str = "sqrt",
fn: str = "l1",
integrated: bool = False,
dims: bool = (-3, -2, -1),
quantile: float = 0.0,
alpha: float = 1.0,
gamma: float = 1.0,
eps: float = 1e-5,
):
super().__init__()
self.name: str = self.__class__.__name__
self.weight: float = weight
self.dims = dims
self.input_fn = FNS[input_fn]
self.output_fn = FNS[output_fn]
self.fn = REGRESSION_DICT[fn]
self.eps: float = eps
self.integrated = integrated
self.quantile = quantile
self.alpha = alpha
self.gamma = gamma
@profile_method(verbose=VERBOSE)
@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
def forward(
self,
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor | None = None,
si: torch.Tensor | None = None,
quality: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
mask = mask.bool()
if si.any():
rescale = torch.stack(
[x[m > 0].median() for x, m in zip(target, target)]
) / torch.stack([x[m > 0].detach().median() for x, m in zip(input, target)])
if rescale.isnan().any():
print(
"NaN in rescale", rescale.isnan().squeeze(), mask.sum(dim=[1, 2, 3])
)
rescale = torch.nan_to_num(rescale, nan=1.0)
input = (1 - si.int()).view(-1, 1, 1, 1) * input + (
rescale * si.int()
).view(-1, 1, 1, 1) * input
error = self.input_fn(input.float()) - self.input_fn(target.float())
if quality is not None:
for quality_level in [1, 2]:
current_quality = quality == quality_level
if current_quality.sum() > 0:
error_qtl = error[current_quality].detach().abs()
mask_qtl = error_qtl < masked_quantile(
error_qtl,
mask[current_quality],
dims=[1, 2, 3],
q=1 - self.quantile * quality_level,
).view(-1, 1, 1, 1)
mask[current_quality] = mask[current_quality] & mask_qtl
else:
error_qtl = error.detach().abs()
mask = mask & (
error_qtl
< masked_quantile(
error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile
).view(-1, 1, 1, 1)
)
mean_error, var_error = masked_mean_var(
data=error, mask=mask, dim=self.dims, keepdim=False
)
if var_error.ndim > 1:
var_error = var_error.mean(dim=-1)
if self.integrated > 0.0:
scale_error = masked_mean(
self.fn(error, alpha=self.alpha, gamma=self.gamma),
mask=mask,
dim=self.dims,
).reshape(-1)
var_error = var_error + self.integrated * scale_error
out_loss = self.output_fn(var_error)
if out_loss.isnan().any():
print(
"NaN in SILog variance, input, target, mask, target>0, error",
var_error.isnan().squeeze(),
input[mask].isnan().any(),
target[mask].isnan().any(),
mask.any(dim=[1, 2, 3]),
(target > 0.0).any(dim=[1, 2, 3]),
error[mask].isnan().any(),
)
return out_loss
@classmethod
def build(cls, config):
obj = cls(
weight=config["weight"],
dims=config["dims"],
output_fn=config["output_fn"],
input_fn=config["input_fn"],
fn=config["fn"],
alpha=config["alpha"],
gamma=config["gamma"],
integrated=config.get("integrated", False),
quantile=config["quantile"],
)
return obj
|