Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
from unik3d.utils.constants import VERBOSE | |
from unik3d.utils.misc import profile_method | |
from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile | |
class Scale(nn.Module): | |
def __init__( | |
self, | |
weight: float, | |
output_fn: str = "sqrt", | |
input_fn: str = "disp", | |
fn: str = "l1", | |
quantile: float = 0.0, | |
gamma: float = 1.0, | |
alpha: float = 1.0, | |
eps: float = 1e-5, | |
): | |
super().__init__() | |
self.name: str = self.__class__.__name__ | |
self.weight: float = weight | |
self.dims = [-2, -1] | |
self.output_fn = FNS[output_fn] | |
self.input_fn = FNS[input_fn] | |
self.fn = REGRESSION_DICT[fn] | |
self.gamma = gamma | |
self.alpha = alpha | |
self.quantile = quantile | |
self.eps = eps | |
def forward( | |
self, | |
input: torch.Tensor, | |
target: torch.Tensor, | |
mask: torch.Tensor, | |
quality: torch.Tensor | None = None, | |
**kwargs, | |
) -> torch.Tensor: | |
mask = mask.bool() | |
input = self.input_fn(input.float()) | |
target = self.input_fn(target.float()) | |
error = self.fn(target - input, alpha=self.alpha, gamma=self.gamma) | |
if self.quantile > 0.0: | |
if quality is not None: | |
for quality_level in [1, 2]: | |
current_quality = quality == quality_level | |
if current_quality.sum() > 0: | |
error_qtl = error[current_quality].detach().abs() | |
mask_qtl = error_qtl < masked_quantile( | |
error_qtl, | |
mask[current_quality], | |
dims=[1, 2, 3], | |
q=1 - self.quantile * quality_level, | |
).view(-1, 1, 1, 1) | |
mask[current_quality] = mask[current_quality] & mask_qtl | |
else: | |
error_qtl = error.detach().abs() | |
mask = mask & ( | |
error_qtl | |
< masked_quantile( | |
error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile | |
).view(-1, 1, 1, 1) | |
) | |
error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2, 3) | |
error_image = self.output_fn(error_image) | |
return error_image | |
def build(cls, config): | |
obj = cls( | |
weight=config["weight"], | |
input_fn=config["input_fn"], | |
fn=config["fn"], | |
output_fn=config["output_fn"], | |
gamma=config["gamma"], | |
alpha=config["alpha"], | |
quantile=config.get("quantile", 0.1), | |
) | |
return obj | |