File size: 4,365 Bytes
1ea89dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import torch
import torch.nn as nn

from unik3d.utils.constants import VERBOSE
from unik3d.utils.misc import profile_method

from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var,
                    masked_quantile)


class SILog(nn.Module):
    def __init__(
        self,
        weight: float,
        input_fn: str = "linear",
        output_fn: str = "sqrt",
        fn: str = "l1",
        integrated: bool = False,
        dims: bool = (-3, -2, -1),
        quantile: float = 0.0,
        alpha: float = 1.0,
        gamma: float = 1.0,
        eps: float = 1e-5,
    ):
        super().__init__()
        self.name: str = self.__class__.__name__
        self.weight: float = weight

        self.dims = dims
        self.input_fn = FNS[input_fn]
        self.output_fn = FNS[output_fn]
        self.fn = REGRESSION_DICT[fn]
        self.eps: float = eps
        self.integrated = integrated
        self.quantile = quantile
        self.alpha = alpha
        self.gamma = gamma

    @profile_method(verbose=VERBOSE)
    @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
    def forward(
        self,
        input: torch.Tensor,
        target: torch.Tensor,
        mask: torch.Tensor | None = None,
        si: torch.Tensor | None = None,
        quality: torch.Tensor | None = None,
        **kwargs,
    ) -> torch.Tensor:
        mask = mask.bool()

        if si.any():
            rescale = torch.stack(
                [x[m > 0].median() for x, m in zip(target, target)]
            ) / torch.stack([x[m > 0].detach().median() for x, m in zip(input, target)])
            if rescale.isnan().any():
                print(
                    "NaN in rescale", rescale.isnan().squeeze(), mask.sum(dim=[1, 2, 3])
                )
                rescale = torch.nan_to_num(rescale, nan=1.0)
            input = (1 - si.int()).view(-1, 1, 1, 1) * input + (
                rescale * si.int()
            ).view(-1, 1, 1, 1) * input

        error = self.input_fn(input.float()) - self.input_fn(target.float())
        if quality is not None:
            for quality_level in [1, 2]:
                current_quality = quality == quality_level
                if current_quality.sum() > 0:
                    error_qtl = error[current_quality].detach().abs()
                    mask_qtl = error_qtl < masked_quantile(
                        error_qtl,
                        mask[current_quality],
                        dims=[1, 2, 3],
                        q=1 - self.quantile * quality_level,
                    ).view(-1, 1, 1, 1)
                    mask[current_quality] = mask[current_quality] & mask_qtl
        else:
            error_qtl = error.detach().abs()
            mask = mask & (
                error_qtl
                < masked_quantile(
                    error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile
                ).view(-1, 1, 1, 1)
            )

        mean_error, var_error = masked_mean_var(
            data=error, mask=mask, dim=self.dims, keepdim=False
        )
        if var_error.ndim > 1:
            var_error = var_error.mean(dim=-1)

        if self.integrated > 0.0:
            scale_error = masked_mean(
                self.fn(error, alpha=self.alpha, gamma=self.gamma),
                mask=mask,
                dim=self.dims,
            ).reshape(-1)
            var_error = var_error + self.integrated * scale_error

        out_loss = self.output_fn(var_error)
        if out_loss.isnan().any():
            print(
                "NaN in SILog variance, input, target, mask, target>0, error",
                var_error.isnan().squeeze(),
                input[mask].isnan().any(),
                target[mask].isnan().any(),
                mask.any(dim=[1, 2, 3]),
                (target > 0.0).any(dim=[1, 2, 3]),
                error[mask].isnan().any(),
            )
        return out_loss

    @classmethod
    def build(cls, config):
        obj = cls(
            weight=config["weight"],
            dims=config["dims"],
            output_fn=config["output_fn"],
            input_fn=config["input_fn"],
            fn=config["fn"],
            alpha=config["alpha"],
            gamma=config["gamma"],
            integrated=config.get("integrated", False),
            quantile=config["quantile"],
        )
        return obj