PatchFusion / zoedepth /utils /align /mlp_alignment.py
Zhenyu Li
update
78ab311
raw
history blame
10.5 kB
# MIT License
# Copyright (c) 2022 Intelligent Systems Lab Org
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# File author: Shariq Farooq Bhat, Zhenyu Li
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm.auto import tqdm
import torch.optim as optim
import torch.optim.lr_scheduler
from zoedepth.utils.align.loss import SILogLoss, gradl1_loss, edge_aware_smoothness_per_pixel
# from utils.misc import *
from .depth_alignment import apply_depth_smoothing, scale_shift_linear
import cv2
import numpy as np
def as_bchw_tensor(input_tensor, num, device=None):
if len(input_tensor.shape) == 2:
input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0).unsqueeze(dim=0)
elif len(input_tensor.shape) == 3:
input_tensor = torch.tensor(input_tensor).unsqueeze(dim=0)
else:
input_tensor = input_tensor
if device is not None:
input_tensor = input_tensor.to(device)
return input_tensor
def get_mlp(in_channels, out_channels):
conv_config = dict(kernel_size=1, padding=0, stride=1)
net = nn.Sequential(
# nn.Conv2d(in_channels, 64, kernel_size=7, padding=3, stride=1),
# nn.GELU(),
nn.Conv2d(in_channels, 64, **conv_config),
nn.GELU(),
nn.Conv2d(64, 128, **conv_config),
nn.GELU(),
nn.Conv2d(128, out_channels, **conv_config),
)
# initialize last layer to predict zeroes
# net[-1].weight.data.zero_()
# net[-1].bias.data.zero_()
return net
def smoothness_loss(depth):
depth_dx = depth[:, :, :-1, :-1] - depth[:, :, :-1, 1:]
depth_dy = depth[:, :, :-1, :-1] - depth[:, :, 1:, :-1]
depth_dx = depth_dx.abs().mean()
depth_dy = depth_dy.abs().mean()
return depth_dx + depth_dy
def curvature_loss(depth):
depth_dx = depth[:, :, :-1, :-1] - depth[:, :, :-1, 1:]
depth_dy = depth[:, :, :-1, :-1] - depth[:, :, 1:, :-1]
depth_dxx = depth_dx[:, :, :, :-1] - depth_dx[:, :, :, 1:]
depth_dyy = depth_dy[:, :, :-1, :] - depth_dy[:, :, 1:, :]
depth_dxy = depth_dx[:, :, :-1, :-1] - depth_dx[:, :, 1:, 1:]
depth_dxx = depth_dxx.abs().mean()
depth_dyy = depth_dyy.abs().mean()
depth_dxy = depth_dxy.abs().mean()
return depth_dxx + depth_dyy + depth_dxy
def multi_scale_curvature_loss(depth, scales=[1, 2, 4]):
loss = 0
for s in scales:
loss += curvature_loss(F.interpolate(depth, scale_factor=1/s, mode='bilinear', align_corners=False))
return loss
def tv_loss(x):
"""Total variation loss."""
b, c, h, w = x.shape
dh = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
dw = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
return torch.sum(dh) + torch.sum(dw)
def scale_invariant_gradient_loss(pred, gt):
alpha = 1e-10
kernel_grad_x = torch.Tensor([[1, 0, -1], [2, 0, -2], [1, 0, -1]]).view((1, 1, 3, 3)).to(pred.device)
kernel_grad_y = torch.Tensor([[1, 2, 1], [0, 0, 0], [-1, -2, -1]]).view((1, 1, 3, 3)).to(pred.device)
g = torch.log(pred + alpha) - torch.log(gt + alpha)
g_x = F.conv2d(g, kernel_grad_x, padding=1)
g_y = F.conv2d(g, kernel_grad_y, padding=1)
# n, c, h, w = g.shape
# norm = 1/(h*w)
# Dg = norm * torch.sum(g**2) - (0.85/(norm**2)) * (torch.sum(g))**2
Dgx = torch.var(g_x) + 0.5 * torch.pow(torch.mean(g_x), 2)
Dgy = torch.var(g_y) + 0.5 * torch.pow(torch.mean(g_y), 2)
loss = 10 * torch.sqrt(Dgx) + 10 * torch.sqrt(Dgy)
return loss
def positionalencoding2d(d_model, height, width):
"""
:param d_model: dimension of the model
:param height: height of the positions
:param width: width of the positions
:return: d_model*height*width position matrix
"""
if d_model % 4 != 0:
raise ValueError("Cannot use sin/cos positional encoding with "
"odd dimension (got dim={:d})".format(d_model))
pe = torch.zeros(d_model, height, width)
# Each dimension use half of d_model
d_model = int(d_model / 2)
div_term = torch.exp(torch.arange(0., d_model, 2) *
-(np.log(10000.0) / d_model))
pos_w = torch.arange(0., width).unsqueeze(1)
pos_h = torch.arange(0., height).unsqueeze(1)
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
pe[d_model + 1::2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
return pe
def gaussian_rff(d_model, height, width, sigma=10):
assert d_model % 2 == 0
B = torch.randn((d_model//2, 2)) * sigma
x = torch.linspace(-1, 1, width)
y = torch.linspace(-1, 1, height)
x, y = torch.meshgrid(x, y)
xy = torch.stack([x, y], dim=-1).view(-1, 2)
xy = torch.matmul(B, xy.T).T
xy = 2 * np.pi * xy.view(height, width, d_model//2)
enc = torch.cat([torch.sin(xy), torch.cos(xy)], dim=-1)
return enc.permute(2, 0, 1)
def get_depth(pred, D):
a, b, c = torch.split(pred, 1, dim=1)
return 1e-7 + torch.relu(torch.exp(a) * D + (torch.sigmoid(c)-0.5)*torch.exp(b))
# return 1e-4 + torch.exp(a) * D + b
# return nn.Softplus()(a)
def train_mlp(image, mask, dr, dp, lr=3e-2, num_iters=3000, device='cuda:0', pos_dim=32, loss_config=dict(beta=0.99),
w_smooth=1, w_curvature=0.0, w_gl1=0.1, w_tv=0.1, w_shift_reg=0.1, **kwargs):
mlp = get_mlp(pos_dim+4, 3)
# mlp = get_mlp(4, 3)
mlp = mlp.to(device)
optimizer = optim.AdamW(mlp.parameters(), lr=lr)
scheduler = optim.lr_scheduler.OneCycleLR(optimizer, lr, epochs=num_iters, steps_per_epoch=1)
image = as_bchw_tensor(image, 3, device=device).detach()
h, w = image.shape[-2:]
pe = positionalencoding2d(pos_dim, h, w)
pe = as_bchw_tensor(pe, pos_dim, device=device)
D = as_bchw_tensor(dp, 1, device=device).detach()
# pe = as_bchw_tensor(gaussian_rff(pos_dim, h, w, sigma=5), pos_dim, device=device)
X = torch.cat([image, D, pe], dim=1) # bchw
# X = torch.cat([image, D], dim=1) # bchw
Y = as_bchw_tensor(dr, 1, device=device).detach()
mask = as_bchw_tensor(mask, 1, device=device).detach()
pbar = tqdm(range(num_iters), desc=f"Training")
# beta_min, beta_max = 0.
si_log = SILogLoss(**loss_config)
for i in pbar:
optimizer.zero_grad()
# pred = dr.max().item() * torch.sigmoid(mlp(X))
pred = mlp(X)
a, b, c = torch.split(pred, 1, dim=1)
pred = get_depth(pred, D.detach())
loss_si = si_log(pred[mask], Y[mask])
loss = loss_si + w_curvature * multi_scale_curvature_loss(pred) + w_gl1 * gradl1_loss(pred, D.detach()) + w_smooth * edge_aware_smoothness_per_pixel(image, pred)
# loss_tv = w_tv * (tv_loss(a) + tv_loss(b) + tv_loss(c))
# loss_gl1 = w_gl1 * gradl1_loss(pred, D.detach())
# loss_gl1 = w_gl1 * scale_invariant_gradient_loss(pred, D.detach())
# loss_shift_reg = w_shift_reg * torch.mean(b**2)
# loss = loss_si + loss_gl1
# loss = F.mse_loss(pred[mask], Y[mask])
loss.backward()
optimizer.step()
scheduler.step()
pbar.set_postfix(loss=loss.item(), si=loss_si.item())
return mlp
def predict_aligned(mlp, image, dp, pos_dim=32, **kwargs):
device = next(mlp.parameters()).device
image = as_bchw_tensor(image, 3, device=device)
h, w = image.shape[-2:]
pe = positionalencoding2d(pos_dim, h, w)
pe = as_bchw_tensor(pe, pos_dim, device=device)
D = as_bchw_tensor(dp, 1, device=device)
# pe = as_bchw_tensor(gaussian_rff(pos_dim, h, w, sigma=5), pos_dim, device=device)
X = torch.cat([image, D, pe], dim=1) # bchw
# X = torch.cat([image, D], dim=1) # bchw
pred = mlp(X)
pred = get_depth(pred, D)
return pred.detach()
def align_by_mlp(image, mask, dr, dp, **kwargs):
mlp = train_mlp(image, mask, dr, dp, **kwargs)
pred = predict_aligned(mlp, image, dp, **kwargs)
return pred
from abc import ABC, abstractmethod
# Abstract class for depth alignment. All depth alignment methods should inherit from this class.
# The abstract class defines the interface for depth alignment.
class DepthAligner(ABC):
def __init__(self):
super().__init__()
@abstractmethod
def align(self, depth_src, depth_target, valid_mask, *args, **kwargs):
"""
Aligns the depth_src to the depth_target such that the aligned depth_src is as close as possible to the depth_target.
"""
raise NotImplementedError
class MLPAligner(DepthAligner):
def __init__(self):
super().__init__()
def align(self, depth_src, depth_target, valid_mask, image, **kwargs):
depth_src = as_bchw_tensor(depth_src, 1)
depth_target = as_bchw_tensor(depth_target, 1)
valid_mask = as_bchw_tensor(valid_mask, 1)
depth_target = scale_shift_linear(depth_target, depth_src, valid_mask)
aligned = align_by_mlp(image, valid_mask, depth_target, depth_src, **kwargs)
depth_numpy = aligned.squeeze().float().cpu().numpy()
blur_bilateral = cv2.bilateralFilter(depth_numpy, 5, 140, 140)
blur_gaussian = cv2.GaussianBlur(blur_bilateral, (5, 5), 0)
blur_gaussian = torch.from_numpy(blur_gaussian).to(aligned)
return blur_gaussian.unsqueeze(0)