import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from collections import OrderedDict import torch_scatter from torch_scatter import scatter_sum from . import fastba from . import altcorr from . import lietorch from .lietorch import SE3 from .extractor import BasicEncoder, BasicEncoder4 from .blocks import GradientClip, GatedResidual, SoftAgg from .utils import * from .ba import BA from . import projective_ops as pops autocast = torch.cuda.amp.autocast import matplotlib.pyplot as plt DIM = 384 class Update(nn.Module): def __init__(self, p): super(Update, self).__init__() self.c1 = nn.Sequential( nn.Linear(DIM, DIM), nn.ReLU(inplace=True), nn.Linear(DIM, DIM)) self.c2 = nn.Sequential( nn.Linear(DIM, DIM), nn.ReLU(inplace=True), nn.Linear(DIM, DIM)) self.norm = nn.LayerNorm(DIM, eps=1e-3) self.agg_kk = SoftAgg(DIM) self.agg_ij = SoftAgg(DIM) self.gru = nn.Sequential( nn.LayerNorm(DIM, eps=1e-3), GatedResidual(DIM), nn.LayerNorm(DIM, eps=1e-3), GatedResidual(DIM), ) self.corr = nn.Sequential( nn.Linear(2*49*p*p, DIM), nn.ReLU(inplace=True), nn.Linear(DIM, DIM), nn.LayerNorm(DIM, eps=1e-3), nn.ReLU(inplace=True), nn.Linear(DIM, DIM), ) self.d = nn.Sequential( nn.ReLU(inplace=False), nn.Linear(DIM, 2), GradientClip()) self.w = nn.Sequential( nn.ReLU(inplace=False), nn.Linear(DIM, 2), GradientClip(), nn.Sigmoid()) def forward(self, net, inp, corr, flow, ii, jj, kk): """ update operator """ net = net + inp + self.corr(corr) net = self.norm(net) ix, jx = fastba.neighbors(kk, jj) mask_ix = (ix >= 0).float().reshape(1, -1, 1) mask_jx = (jx >= 0).float().reshape(1, -1, 1) net = net + self.c1(mask_ix * net[:,ix]) net = net + self.c2(mask_jx * net[:,jx]) net = net + self.agg_kk(net, kk) net = net + self.agg_ij(net, ii*12345 + jj) net = self.gru(net) return net, (self.d(net), self.w(net), None) class Patchifier(nn.Module): def __init__(self, patch_size=3): super(Patchifier, self).__init__() self.patch_size = patch_size self.fnet = BasicEncoder4(output_dim=128, norm_fn='instance') self.inet = BasicEncoder4(output_dim=DIM, norm_fn='none') def __image_gradient(self, images): gray = ((images + 0.5) * (255.0 / 2)).sum(dim=2) dx = gray[...,:-1,1:] - gray[...,:-1,:-1] dy = gray[...,1:,:-1] - gray[...,:-1,:-1] g = torch.sqrt(dx**2 + dy**2) g = F.avg_pool2d(g, 4, 4) return g def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False, return_color=False): """ extract patches from input images """ fmap = self.fnet(images) / 4.0 imap = self.inet(images) / 4.0 b, n, c, h, w = fmap.shape P = self.patch_size # bias patch selection towards regions with high gradient if gradient_bias: g = self.__image_gradient(images) x = torch.randint(1, w-1, size=[n, 3*patches_per_image], device="cuda") y = torch.randint(1, h-1, size=[n, 3*patches_per_image], device="cuda") coords = torch.stack([x, y], dim=-1).float() g = altcorr.patchify(g[0,:,None], coords, 0).view(n, 3 * patches_per_image) ix = torch.argsort(g, dim=1) x = torch.gather(x, 1, ix[:, -patches_per_image:]) y = torch.gather(y, 1, ix[:, -patches_per_image:]) else: x = torch.randint(1, w-1, size=[n, patches_per_image], device="cuda") y = torch.randint(1, h-1, size=[n, patches_per_image], device="cuda") coords = torch.stack([x, y], dim=-1).float() imap = altcorr.patchify(imap[0], coords, 0).view(b, -1, DIM, 1, 1) gmap = altcorr.patchify(fmap[0], coords, P//2).view(b, -1, 128, P, P) if return_color: clr = altcorr.patchify(images[0], 4*(coords + 0.5), 0).view(b, -1, 3) if disps is None: disps = torch.ones(b, n, h, w, device="cuda") grid, _ = coords_grid_with_index(disps, device=fmap.device) patches = altcorr.patchify(grid[0], coords, P//2).view(b, -1, 3, P, P) index = torch.arange(n, device="cuda").view(n, 1) index = index.repeat(1, patches_per_image).reshape(-1) if return_color: return fmap, gmap, imap, patches, index, clr return fmap, gmap, imap, patches, index class CorrBlock: def __init__(self, fmap, gmap, radius=3, dropout=0.2, levels=[1,4]): self.dropout = dropout self.radius = radius self.levels = levels self.gmap = gmap self.pyramid = pyramidify(fmap, lvls=levels) def __call__(self, ii, jj, coords): corrs = [] for i in range(len(self.levels)): corrs += [ altcorr.corr(self.gmap, self.pyramid[i], coords / self.levels[i], ii, jj, self.radius, self.dropout) ] return torch.stack(corrs, -1).view(1, len(ii), -1) class VONet(nn.Module): def __init__(self, use_viewer=False): super(VONet, self).__init__() self.P = 3 self.patchify = Patchifier(self.P) self.update = Update(self.P) self.DIM = DIM self.RES = 4 @autocast(enabled=False) def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, structure_only=False, rescale=False): """ Estimates SE3 or Sim3 between pair of frames """ images = 2 * (images / 255.0) - 0.5 intrinsics = intrinsics / 4.0 disps = disps[:, :, 1::4, 1::4].float() fmap, gmap, imap, patches, ix = self.patchify(images, disps=disps) corr_fn = CorrBlock(fmap, gmap) b, N, c, h, w = fmap.shape p = self.P patches_gt = patches.clone() Ps = poses d = patches[..., 2, p//2, p//2] patches = set_depth(patches, torch.rand_like(d)) kk, jj = flatmeshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda")) ii = ix[kk] imap = imap.view(b, -1, DIM) net = torch.zeros(b, len(kk), DIM, device="cuda", dtype=torch.float) Gs = SE3.IdentityLike(poses) if structure_only: Gs.data[:] = poses.data[:] traj = [] bounds = [-64, -64, w + 64, h + 64] while len(traj) < STEPS: Gs = Gs.detach() patches = patches.detach() n = ii.max() + 1 if len(traj) >= 8 and n < images.shape[1]: if not structure_only: Gs.data[:,n] = Gs.data[:,n-1] kk1, jj1 = flatmeshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda")) kk2, jj2 = flatmeshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda")) ii = torch.cat([ix[kk1], ix[kk2], ii]) jj = torch.cat([jj1, jj2, jj]) kk = torch.cat([kk1, kk2, kk]) net1 = torch.zeros(b, len(kk1) + len(kk2), DIM, device="cuda") net = torch.cat([net1, net], dim=1) if np.random.rand() < 0.1: k = (ii != (n - 4)) & (jj != (n - 4)) ii = ii[k] jj = jj[k] kk = kk[k] net = net[:,k] patches[:,ix==n,2] = torch.median(patches[:,(ix == n-1) | (ix == n-2),2]) n = ii.max() + 1 coords = pops.transform(Gs, patches, intrinsics, ii, jj, kk) coords1 = coords.permute(0, 1, 4, 2, 3).contiguous() corr = corr_fn(kk, jj, coords1) net, (delta, weight, _) = self.update(net, imap[:,kk], corr, None, ii, jj, kk) lmbda = 1e-4 target = coords[...,p//2,p//2,:] + delta ep = 10 for itr in range(2): Gs, patches = BA(Gs, patches, intrinsics, target, weight, lmbda, ii, jj, kk, bounds, ep=ep, fixedp=1, structure_only=structure_only) kl = torch.as_tensor(0) dij = (ii - jj).abs() k = (dij > 0) & (dij <= 2) coords = pops.transform(Gs, patches, intrinsics, ii[k], jj[k], kk[k]) coords_gt, valid, _ = pops.transform(Ps, patches_gt, intrinsics, ii[k], jj[k], kk[k], jacobian=True) traj.append((valid, coords, coords_gt, Gs[:,:n], Ps[:,:n], kl)) return traj