Luigi Piccinelli
init demo
1ea89dd
from math import log, pi, prod
from typing import Any, Dict, List, Optional, Tuple
import torch
FNS = {
"sqrto": lambda x: torch.sqrt(x + 1),
"sqrt": lambda x: torch.sqrt(x + 1e-4),
"log": lambda x: torch.log(x + 1e-4),
"log1": lambda x: torch.log(x + 1),
# transition from log(1/x) to 1/x at x=100
# if x -> 0 : log(1/x), if x -> inf : log(1+1/x) -> 1/x + hot
"log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)),
"log10": lambda x: torch.log10(1e-4 + x),
"log2": lambda x: torch.log2(1e-4 + x),
"linear": lambda x: x,
"square": torch.square,
"disp": lambda x: 1 / (x + 1e-4),
"disp1": lambda x: 1 / (1 + x),
}
FNS_INV = {
"sqrt": torch.square,
"log": torch.exp,
"log1": lambda x: torch.exp(x) - 1,
"linear": lambda x: x,
"square": torch.sqrt,
"disp": lambda x: 1 / x,
}
def masked_mean_var(
data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True
):
if mask is None:
return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim)
# if data[mask].isnan().any():
# print("Warning: NaN in masked_mean_var, valid_pixels before and after", mask.sum(dim=dim).squeeze(), (mask & ~data.isnan()).sum(dim=dim).squeeze())
mask = (mask & ~data.isnan().any(dim=1, keepdim=True)).float()
data = torch.nan_to_num(data, nan=0.0)
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
mask_sum, min=1.0
)
mask_var = torch.sum(
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
) / torch.clamp(mask_sum, min=1.0)
if not keepdim:
mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim)
return mask_mean, mask_var
def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
if mask is None:
return data.mean(dim=dim, keepdim=True)
mask = mask.float()
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
mask_mean = torch.sum(
torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True
) / mask_sum.clamp(min=1.0)
return mask_mean
def masked_quantile(
data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float
):
"""
Compute the quantile of the data only where the mask is 1 along specified dimensions.
Args:
data (torch.Tensor): The input data tensor.
mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered.
dims (list of int): The dimensions to compute the quantile over.
q (float): The quantile to compute, must be between 0 and 1.
Returns:
torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values.
"""
masked_data = data * mask if mask is not None else data
# Get a list of all dimensions
all_dims = list(range(masked_data.dim()))
# Revert negative dimensions
dims = [d % masked_data.dim() for d in dims]
# Find the dimensions to keep (not included in the `dims` list)
keep_dims = [d for d in all_dims if d not in dims]
# Permute dimensions to bring `dims` to the front
permute_order = dims + keep_dims
permuted_data = masked_data.permute(permute_order)
# Reshape into 2D: (-1, remaining_dims)
collapsed_shape = (
-1,
prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]),
)
reshaped_data = permuted_data.reshape(collapsed_shape)
if mask is None:
return torch.quantile(reshaped_data, q, dim=0)
permuted_mask = mask.permute(permute_order)
reshaped_mask = permuted_mask.reshape(collapsed_shape)
# Calculate quantile along the first dimension where mask is true
quantiles = []
for i in range(reshaped_data.shape[1]):
valid_data = reshaped_data[:, i][reshaped_mask[:, i]]
if valid_data.numel() == 0:
# print("Warning: No valid data found for quantile calculation.")
quantiles.append(reshaped_data[:, i].min() * 0.99)
else:
quantiles.append(torch.quantile(valid_data, q, dim=0))
# Stack back into a tensor with reduced dimensions
quantiles = torch.stack(quantiles)
quantiles = quantiles.reshape(
[permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]
)
return quantiles
def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
ndim = data.ndim
data = data.flatten(ndim - len(dim))
mask = mask.flatten(ndim - len(dim))
mask_median = torch.median(data[..., mask], dim=-1).values
return mask_median
def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
ndim = data.ndim
data = data.flatten(ndim - len(dim))
mask = mask.flatten(ndim - len(dim))
mask_median = torch.median(data[mask], dim=-1, keepdim=True).values
mask_mad = masked_mean((data - mask_median).abs(), mask, dim=(-1,))
return mask_median, mask_mad
def masked_weighted_mean_var(
data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
):
if mask is None:
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
mask = mask.float()
mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
mask * weights, dim=dim, keepdim=True
).clamp(min=1.0)
# V1**2 - V2, V1: sum w_i, V2: sum w_i**2
denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
(mask * weights).square(), dim=dim, keepdim=True
)
# correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
min=1.0
)
mask_var = correction_factor * torch.sum(
weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
)
return mask_mean, mask_var
def stable_masked_mean_var(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int]
):
# recalculate mask with points in 95% confidence interval
input_detach = input.detach()
input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim)
target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim)
input_std = (input_var).clip(min=1e-6).sqrt()
target_std = (target_var).clip(min=1e-6).sqrt()
stable_points_input = torch.logical_and(
input_detach > input_mean - 1.96 * input_std,
input_detach < input_mean + 1.96 * input_std,
)
stable_points_target = torch.logical_and(
target > target_mean - 1.96 * target_std,
target < target_mean + 1.96 * target_std,
)
stable_mask = stable_points_target & stable_points_input & mask
input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim)
target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim)
return input_mean, input_var, target_mean, target_var, stable_mask
def ssi(
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
dim: list[int],
*args,
**kwargs,
) -> torch.Tensor:
# recalculate mask with points in 95% confidence interval
input_mean, input_var, target_mean, target_var, stable_mask = (
stable_masked_mean_var(input, target, mask, dim)
)
# if target_var.min() < 1e-6:
# print(
# "Warning: target low",
# list(zip(target_var.squeeze().cpu().numpy(),
# target_mean.squeeze().cpu().numpy(),
# mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy(),
# stable_mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy()))
# )
# if input_var.min() < 1e-6:
# print("Warning: input variance is too low", input_var.squeeze(), input_mean.squeeze())
if input_var.isnan().any():
print("Warning: input variance is nan")
if input_var.isinf().any():
print("Warning: input variance is isinf")
if input_mean.isnan().any():
print("Warning: input m is nan")
if input_mean.isinf().any():
print("Warning: input m is isinf")
target_normalized = (target - target_mean) / FNS["sqrt"](target_var)
input_normalized = (input - input_mean) / FNS["sqrt"](input_var)
return input_normalized, target_normalized, stable_mask
def ssi_nd(
input: torch.Tensor,
target: torch.Tensor,
mask: torch.Tensor,
dim: list[int],
input_info: torch.Tensor,
target_info: torch.Tensor,
) -> torch.Tensor:
input_mean, input_var, target_mean, target_var, stable_mask = (
stable_masked_mean_var(input_info, target_info, mask, dim)
)
if input_var.isnan().any():
print("Warning: input variance is nan")
if input_var.isinf().any():
print("Warning: input variance is isinf")
if input_mean.isnan().any():
print("Warning: input m is nan")
if input_mean.isinf().any():
print("Warning: input m is isinf")
target_normalized = (target - target_mean) / FNS["sqrt"](target_var)
input_normalized = (input - input_mean) / FNS["sqrt"](input_var)
return input_normalized, target_normalized, stable_mask
def stable_ssi(
input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int]
) -> torch.Tensor:
input_mean, input_var = masked_mean_var(input, mask=mask, dim=dim)
target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim)
target_normalized = (target - target_mean) / torch.sqrt(target_var.clamp(min=1e-6))
input_normalized = (input - input_mean) / torch.sqrt(input_var.clamp(min=1e-6))
return input_normalized, target_normalized, mask
def ind2sub(idx, cols):
r = idx // cols
c = idx % cols
return r, c
def sub2ind(r, c, cols):
idx = r * cols + c
return idx
def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
return (input_tensor / gamma) ** 2
def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
return torch.abs(input_tensor)
def charbonnier(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return gamma * torch.sqrt(torch.square(input_tensor / gamma) + 1) - 1
def cauchy(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return gamma * torch.log(torch.square(input_tensor / gamma) + 1) + log(gamma * pi)
def geman_mcclure(
input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma)
def robust_loss(
input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs
) -> torch.Tensor:
coeff = abs(alpha - 2) / alpha
power = torch.square(input_tensor / gamma) / abs(alpha - 2) + 1
return (
gamma * coeff * (torch.pow(power, alpha / 2) - 1)
) # mult gamma to keep grad magnitude invariant wrt gamma
REGRESSION_DICT = {
"l2": l2,
"l1": l1,
"cauchy": cauchy,
"charbonnier": charbonnier,
"geman_mcclure": geman_mcclure,
"robust_loss": robust_loss,
}