Spaces:
Sleeping
Sleeping
from einops.einops import rearrange | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from dkm.utils.utils import warp_kpts | |
class DepthRegressionLoss(nn.Module): | |
def __init__( | |
self, | |
robust=True, | |
center_coords=False, | |
scale_normalize=False, | |
ce_weight=0.01, | |
local_loss=True, | |
local_dist=4.0, | |
local_largest_scale=8, | |
): | |
super().__init__() | |
self.robust = robust # measured in pixels | |
self.center_coords = center_coords | |
self.scale_normalize = scale_normalize | |
self.ce_weight = ce_weight | |
self.local_loss = local_loss | |
self.local_dist = local_dist | |
self.local_largest_scale = local_largest_scale | |
def geometric_dist(self, depth1, depth2, T_1to2, K1, K2, dense_matches, scale): | |
"""[summary] | |
Args: | |
H ([type]): [description] | |
scale ([type]): [description] | |
Returns: | |
[type]: [description] | |
""" | |
b, h1, w1, d = dense_matches.shape | |
with torch.no_grad(): | |
x1_n = torch.meshgrid( | |
*[ | |
torch.linspace( | |
-1 + 1 / n, 1 - 1 / n, n, device=dense_matches.device | |
) | |
for n in (b, h1, w1) | |
] | |
) | |
x1_n = torch.stack((x1_n[2], x1_n[1]), dim=-1).reshape(b, h1 * w1, 2) | |
mask, x2 = warp_kpts( | |
x1_n.double(), | |
depth1.double(), | |
depth2.double(), | |
T_1to2.double(), | |
K1.double(), | |
K2.double(), | |
) | |
prob = mask.float().reshape(b, h1, w1) | |
gd = (dense_matches - x2.reshape(b, h1, w1, 2)).norm(dim=-1) # *scale? | |
return gd, prob | |
def dense_depth_loss(self, dense_certainty, prob, gd, scale, eps=1e-8): | |
"""[summary] | |
Args: | |
dense_certainty ([type]): [description] | |
prob ([type]): [description] | |
eps ([type], optional): [description]. Defaults to 1e-8. | |
Returns: | |
[type]: [description] | |
""" | |
smooth_prob = prob | |
ce_loss = F.binary_cross_entropy_with_logits(dense_certainty[:, 0], smooth_prob) | |
depth_loss = gd[prob > 0] | |
if not torch.any(prob > 0).item(): | |
depth_loss = (gd * 0.0).mean() # Prevent issues where prob is 0 everywhere | |
return { | |
f"ce_loss_{scale}": ce_loss.mean(), | |
f"depth_loss_{scale}": depth_loss.mean(), | |
} | |
def forward(self, dense_corresps, batch): | |
"""[summary] | |
Args: | |
out ([type]): [description] | |
batch ([type]): [description] | |
Returns: | |
[type]: [description] | |
""" | |
scales = list(dense_corresps.keys()) | |
tot_loss = 0.0 | |
prev_gd = 0.0 | |
for scale in scales: | |
dense_scale_corresps = dense_corresps[scale] | |
dense_scale_certainty, dense_scale_coords = ( | |
dense_scale_corresps["dense_certainty"], | |
dense_scale_corresps["dense_flow"], | |
) | |
dense_scale_coords = rearrange(dense_scale_coords, "b d h w -> b h w d") | |
b, h, w, d = dense_scale_coords.shape | |
gd, prob = self.geometric_dist( | |
batch["query_depth"], | |
batch["support_depth"], | |
batch["T_1to2"], | |
batch["K1"], | |
batch["K2"], | |
dense_scale_coords, | |
scale, | |
) | |
if ( | |
scale <= self.local_largest_scale and self.local_loss | |
): # Thought here is that fine matching loss should not be punished by coarse mistakes, but should identify wrong matching | |
prob = prob * ( | |
F.interpolate(prev_gd[:, None], size=(h, w), mode="nearest")[:, 0] | |
< (2 / 512) * (self.local_dist * scale) | |
) | |
depth_losses = self.dense_depth_loss(dense_scale_certainty, prob, gd, scale) | |
scale_loss = ( | |
self.ce_weight * depth_losses[f"ce_loss_{scale}"] | |
+ depth_losses[f"depth_loss_{scale}"] | |
) # scale ce loss for coarser scales | |
if self.scale_normalize: | |
scale_loss = scale_loss * 1 / scale | |
tot_loss = tot_loss + scale_loss | |
prev_gd = gd.detach() | |
return tot_loss | |