# MIT License # Copyright (c) 2022 Intelligent Systems Lab Org # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in all # copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # File author: Shariq Farooq Bhat, Zhenyu Li import torch import torch.nn as nn import torch.nn.functional as F from tqdm.auto import tqdm import torch.optim as optim import torch.optim.lr_scheduler from zoedepth.utils.align.loss import SILogLoss, gradl1_loss, edge_aware_smoothness_per_pixel # from utils.misc import * from .depth_alignment import apply_depth_smoothing, scale_shift_linear import cv2 import numpy as np def as_bchw_tensor(input_tensor, num, device=None): if len(input_tensor.shape) == 2: input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0).unsqueeze(dim=0) elif len(input_tensor.shape) == 3: input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0) else: input_tensor = input_tensor if device is not None: input_tensor = input_tensor.to(device) return input_tensor def get_mlp(in_channels, out_channels): conv_config = dict(kernel_size=1, padding=0, stride=1) net = nn.Sequential( # nn.Conv2d(in_channels, 64, kernel_size=7, padding=3, stride=1), # nn.GELU(), nn.Conv2d(in_channels, 64, **conv_config), nn.GELU(), nn.Conv2d(64, 128, **conv_config), nn.GELU(), nn.Conv2d(128, out_channels, **conv_config), ) # initialize last layer to predict zeroes # net[-1].weight.data.zero_() # net[-1].bias.data.zero_() return net def smoothness_loss(depth): depth_dx = depth[:, :, :-1, :-1] - depth[:, :, :-1, 1:] depth_dy = depth[:, :, :-1, :-1] - depth[:, :, 1:, :-1] depth_dx = depth_dx.abs().mean() depth_dy = depth_dy.abs().mean() return depth_dx + depth_dy def curvature_loss(depth): depth_dx = depth[:, :, :-1, :-1] - depth[:, :, :-1, 1:] depth_dy = depth[:, :, :-1, :-1] - depth[:, :, 1:, :-1] depth_dxx = depth_dx[:, :, :, :-1] - depth_dx[:, :, :, 1:] depth_dyy = depth_dy[:, :, :-1, :] - depth_dy[:, :, 1:, :] depth_dxy = depth_dx[:, :, :-1, :-1] - depth_dx[:, :, 1:, 1:] depth_dxx = depth_dxx.abs().mean() depth_dyy = depth_dyy.abs().mean() depth_dxy = depth_dxy.abs().mean() return depth_dxx + depth_dyy + depth_dxy def multi_scale_curvature_loss(depth, scales=[1, 2, 4]): loss = 0 for s in scales: loss += curvature_loss(F.interpolate(depth, scale_factor=1/s, mode='bilinear', align_corners=False)) return loss def tv_loss(x): """Total variation loss.""" b, c, h, w = x.shape dh = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) dw = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) return torch.sum(dh) + torch.sum(dw) def scale_invariant_gradient_loss(pred, gt): alpha = 1e-10 kernel_grad_x = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view((1, 1, 3, 3)).to(pred.device) kernel_grad_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view((1, 1, 3, 3)).to(pred.device) g = torch.log(pred + alpha) - torch.log(gt + alpha) g_x = F.conv2d(g, kernel_grad_x, padding=1) g_y = F.conv2d(g, kernel_grad_y, padding=1) # n, c, h, w = g.shape # norm = 1/(h*w) # Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2 Dgx = torch.var(g_x) + 0.5 * torch.pow(torch.mean(g_x), 2) Dgy = torch.var(g_y) + 0.5 * torch.pow(torch.mean(g_y), 2) loss = 10 * torch.sqrt(Dgx) + 10 * torch.sqrt(Dgy) return loss def positionalencoding2d(d_model, height, width): """ :param d_model: dimension of the model :param height: height of the positions :param width: width of the positions :return: d_model*height*width position matrix """ if d_model % 4 != 0: raise ValueError("Cannot use sin/cos positional encoding with " "odd dimension (got dim={:d})".format(d_model)) pe = torch.zeros(d_model, height, width) # Each dimension use half of d_model d_model = int(d_model / 2) div_term = torch.exp(torch.arange(0., d_model, 2) * -(np.log(10000.0) / d_model)) pos_w = torch.arange(0., width).unsqueeze(1) pos_h = torch.arange(0., height).unsqueeze(1) pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) return pe def gaussian_rff(d_model, height, width, sigma=10): assert d_model % 2 == 0 B = torch.randn((d_model//2, 2)) * sigma x = torch.linspace(-1, 1, width) y = torch.linspace(-1, 1, height) x, y = torch.meshgrid(x, y) xy = torch.stack([x, y], dim=-1).view(-1, 2) xy = torch.matmul(B, xy.T).T xy = 2 * np.pi * xy.view(height, width, d_model//2) enc = torch.cat([torch.sin(xy), torch.cos(xy)], dim=-1) return enc.permute(2, 0, 1) def get_depth(pred, D): a, b, c = torch.split(pred, 1, dim=1) return 1e-7 + torch.relu(torch.exp(a) * D + (torch.sigmoid(c)-0.5)*torch.exp(b)) # return 1e-4 + torch.exp(a) * D + b # return nn.Softplus()(a) def train_mlp(image, mask, dr, dp, lr=3e-2, num_iters=3000, device='cuda:0', pos_dim=32, loss_config=dict(beta=0.99), w_smooth=1, w_curvature=0.0, w_gl1=0.1, w_tv=0.1, w_shift_reg=0.1, **kwargs): mlp = get_mlp(pos_dim+4, 3) # mlp = get_mlp(4, 3) mlp = mlp.to(device) optimizer = optim.AdamW(mlp.parameters(), lr=lr) scheduler = optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=num_iters, steps_per_epoch=1) image = as_bchw_tensor(image, 3, device=device).detach() h, w = image.shape[-2:] pe = positionalencoding2d(pos_dim, h, w) pe = as_bchw_tensor(pe, pos_dim, device=device) D = as_bchw_tensor(dp, 1, device=device).detach() # pe = as_bchw_tensor(gaussian_rff(pos_dim, h, w, sigma=5), pos_dim, device=device) X = torch.cat([image, D, pe], dim=1) # bchw # X = torch.cat([image, D], dim=1) # bchw Y = as_bchw_tensor(dr, 1, device=device).detach() mask = as_bchw_tensor(mask, 1, device=device).detach() pbar = tqdm(range(num_iters), desc=f"Training") # beta_min, beta_max = 0. si_log = SILogLoss(**loss_config) for i in pbar: optimizer.zero_grad() # pred = dr.max().item() * torch.sigmoid(mlp(X)) pred = mlp(X) a, b, c = torch.split(pred, 1, dim=1) pred = get_depth(pred, D.detach()) loss_si = si_log(pred[mask], Y[mask]) loss = loss_si + w_curvature * multi_scale_curvature_loss(pred) + w_gl1 * gradl1_loss(pred, D.detach()) + w_smooth * edge_aware_smoothness_per_pixel(image, pred) # loss_tv = w_tv * (tv_loss(a) + tv_loss(b) + tv_loss(c)) # loss_gl1 = w_gl1 * gradl1_loss(pred, D.detach()) # loss_gl1 = w_gl1 * scale_invariant_gradient_loss(pred, D.detach()) # loss_shift_reg = w_shift_reg * torch.mean(b**2) # loss = loss_si + loss_gl1 # loss = F.mse_loss(pred[mask], Y[mask]) loss.backward() optimizer.step() scheduler.step() pbar.set_postfix(loss=loss.item(), si=loss_si.item()) return mlp def predict_aligned(mlp, image, dp, pos_dim=32, **kwargs): device = next(mlp.parameters()).device image = as_bchw_tensor(image, 3, device=device) h, w = image.shape[-2:] pe = positionalencoding2d(pos_dim, h, w) pe = as_bchw_tensor(pe, pos_dim, device=device) D = as_bchw_tensor(dp, 1, device=device) # pe = as_bchw_tensor(gaussian_rff(pos_dim, h, w, sigma=5), pos_dim, device=device) X = torch.cat([image, D, pe], dim=1) # bchw # X = torch.cat([image, D], dim=1) # bchw pred = mlp(X) pred = get_depth(pred, D) return pred.detach() def align_by_mlp(image, mask, dr, dp, **kwargs): mlp = train_mlp(image, mask, dr, dp, **kwargs) pred = predict_aligned(mlp, image, dp, **kwargs) return pred from abc import ABC, abstractmethod # Abstract class for depth alignment. All depth alignment methods should inherit from this class. # The abstract class defines the interface for depth alignment. class DepthAligner(ABC): def __init__(self): super().__init__() @abstractmethod def align(self, depth_src, depth_target, valid_mask, *args, **kwargs): """ Aligns the depth_src to the depth_target such that the aligned depth_src is as close as possible to the depth_target. """ raise NotImplementedError class MLPAligner(DepthAligner): def __init__(self): super().__init__() def align(self, depth_src, depth_target, valid_mask, image, **kwargs): depth_src = as_bchw_tensor(depth_src, 1) depth_target = as_bchw_tensor(depth_target, 1) valid_mask = as_bchw_tensor(valid_mask, 1) depth_target = scale_shift_linear(depth_target, depth_src, valid_mask) aligned = align_by_mlp(image, valid_mask, depth_target, depth_src, **kwargs) depth_numpy = aligned.squeeze().float().cpu().numpy() blur_bilateral = cv2.bilateralFilter(depth_numpy, 5, 140, 140) blur_gaussian = cv2.GaussianBlur(blur_bilateral, (5, 5), 0) blur_gaussian = torch.from_numpy(blur_gaussian).to(aligned) return blur_gaussian.unsqueeze(0)