Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from collections import OrderedDict | |
import torch_scatter | |
from torch_scatter import scatter_sum | |
from . import fastba | |
from . import altcorr | |
from . import lietorch | |
from .lietorch import SE3 | |
from .extractor import BasicEncoder, BasicEncoder4 | |
from .blocks import GradientClip, GatedResidual, SoftAgg | |
from .utils import * | |
from .ba import BA | |
from . import projective_ops as pops | |
autocast = torch.cuda.amp.autocast | |
import matplotlib.pyplot as plt | |
DIM = 384 | |
class Update(nn.Module): | |
def __init__(self, p): | |
super(Update, self).__init__() | |
self.c1 = nn.Sequential( | |
nn.Linear(DIM, DIM), | |
nn.ReLU(inplace=True), | |
nn.Linear(DIM, DIM)) | |
self.c2 = nn.Sequential( | |
nn.Linear(DIM, DIM), | |
nn.ReLU(inplace=True), | |
nn.Linear(DIM, DIM)) | |
self.norm = nn.LayerNorm(DIM, eps=1e-3) | |
self.agg_kk = SoftAgg(DIM) | |
self.agg_ij = SoftAgg(DIM) | |
self.gru = nn.Sequential( | |
nn.LayerNorm(DIM, eps=1e-3), | |
GatedResidual(DIM), | |
nn.LayerNorm(DIM, eps=1e-3), | |
GatedResidual(DIM), | |
) | |
self.corr = nn.Sequential( | |
nn.Linear(2*49*p*p, DIM), | |
nn.ReLU(inplace=True), | |
nn.Linear(DIM, DIM), | |
nn.LayerNorm(DIM, eps=1e-3), | |
nn.ReLU(inplace=True), | |
nn.Linear(DIM, DIM), | |
) | |
self.d = nn.Sequential( | |
nn.ReLU(inplace=False), | |
nn.Linear(DIM, 2), | |
GradientClip()) | |
self.w = nn.Sequential( | |
nn.ReLU(inplace=False), | |
nn.Linear(DIM, 2), | |
GradientClip(), | |
nn.Sigmoid()) | |
def forward(self, net, inp, corr, flow, ii, jj, kk): | |
""" update operator """ | |
net = net + inp + self.corr(corr) | |
net = self.norm(net) | |
ix, jx = fastba.neighbors(kk, jj) | |
mask_ix = (ix >= 0).float().reshape(1, -1, 1) | |
mask_jx = (jx >= 0).float().reshape(1, -1, 1) | |
net = net + self.c1(mask_ix * net[:,ix]) | |
net = net + self.c2(mask_jx * net[:,jx]) | |
net = net + self.agg_kk(net, kk) | |
net = net + self.agg_ij(net, ii*12345 + jj) | |
net = self.gru(net) | |
return net, (self.d(net), self.w(net), None) | |
class Patchifier(nn.Module): | |
def __init__(self, patch_size=3): | |
super(Patchifier, self).__init__() | |
self.patch_size = patch_size | |
self.fnet = BasicEncoder4(output_dim=128, norm_fn='instance') | |
self.inet = BasicEncoder4(output_dim=DIM, norm_fn='none') | |
def __image_gradient(self, images): | |
gray = ((images + 0.5) * (255.0 / 2)).sum(dim=2) | |
dx = gray[...,:-1,1:] - gray[...,:-1,:-1] | |
dy = gray[...,1:,:-1] - gray[...,:-1,:-1] | |
g = torch.sqrt(dx**2 + dy**2) | |
g = F.avg_pool2d(g, 4, 4) | |
return g | |
def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False, return_color=False): | |
""" extract patches from input images """ | |
fmap = self.fnet(images) / 4.0 | |
imap = self.inet(images) / 4.0 | |
b, n, c, h, w = fmap.shape | |
P = self.patch_size | |
# bias patch selection towards regions with high gradient | |
if gradient_bias: | |
g = self.__image_gradient(images) | |
x = torch.randint(1, w-1, size=[n, 3*patches_per_image], device="cuda") | |
y = torch.randint(1, h-1, size=[n, 3*patches_per_image], device="cuda") | |
coords = torch.stack([x, y], dim=-1).float() | |
g = altcorr.patchify(g[0,:,None], coords, 0).view(n, 3 * patches_per_image) | |
ix = torch.argsort(g, dim=1) | |
x = torch.gather(x, 1, ix[:, -patches_per_image:]) | |
y = torch.gather(y, 1, ix[:, -patches_per_image:]) | |
else: | |
x = torch.randint(1, w-1, size=[n, patches_per_image], device="cuda") | |
y = torch.randint(1, h-1, size=[n, patches_per_image], device="cuda") | |
coords = torch.stack([x, y], dim=-1).float() | |
imap = altcorr.patchify(imap[0], coords, 0).view(b, -1, DIM, 1, 1) | |
gmap = altcorr.patchify(fmap[0], coords, P//2).view(b, -1, 128, P, P) | |
if return_color: | |
clr = altcorr.patchify(images[0], 4*(coords + 0.5), 0).view(b, -1, 3) | |
if disps is None: | |
disps = torch.ones(b, n, h, w, device="cuda") | |
grid, _ = coords_grid_with_index(disps, device=fmap.device) | |
patches = altcorr.patchify(grid[0], coords, P//2).view(b, -1, 3, P, P) | |
index = torch.arange(n, device="cuda").view(n, 1) | |
index = index.repeat(1, patches_per_image).reshape(-1) | |
if return_color: | |
return fmap, gmap, imap, patches, index, clr | |
return fmap, gmap, imap, patches, index | |
class CorrBlock: | |
def __init__(self, fmap, gmap, radius=3, dropout=0.2, levels=[1,4]): | |
self.dropout = dropout | |
self.radius = radius | |
self.levels = levels | |
self.gmap = gmap | |
self.pyramid = pyramidify(fmap, lvls=levels) | |
def __call__(self, ii, jj, coords): | |
corrs = [] | |
for i in range(len(self.levels)): | |
corrs += [ altcorr.corr(self.gmap, self.pyramid[i], coords / self.levels[i], ii, jj, self.radius, self.dropout) ] | |
return torch.stack(corrs, -1).view(1, len(ii), -1) | |
class VONet(nn.Module): | |
def __init__(self, use_viewer=False): | |
super(VONet, self).__init__() | |
self.P = 3 | |
self.patchify = Patchifier(self.P) | |
self.update = Update(self.P) | |
self.DIM = DIM | |
self.RES = 4 | |
def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, structure_only=False, rescale=False): | |
""" Estimates SE3 or Sim3 between pair of frames """ | |
images = 2 * (images / 255.0) - 0.5 | |
intrinsics = intrinsics / 4.0 | |
disps = disps[:, :, 1::4, 1::4].float() | |
fmap, gmap, imap, patches, ix = self.patchify(images, disps=disps) | |
corr_fn = CorrBlock(fmap, gmap) | |
b, N, c, h, w = fmap.shape | |
p = self.P | |
patches_gt = patches.clone() | |
Ps = poses | |
d = patches[..., 2, p//2, p//2] | |
patches = set_depth(patches, torch.rand_like(d)) | |
kk, jj = flatmeshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda")) | |
ii = ix[kk] | |
imap = imap.view(b, -1, DIM) | |
net = torch.zeros(b, len(kk), DIM, device="cuda", dtype=torch.float) | |
Gs = SE3.IdentityLike(poses) | |
if structure_only: | |
Gs.data[:] = poses.data[:] | |
traj = [] | |
bounds = [-64, -64, w + 64, h + 64] | |
while len(traj) < STEPS: | |
Gs = Gs.detach() | |
patches = patches.detach() | |
n = ii.max() + 1 | |
if len(traj) >= 8 and n < images.shape[1]: | |
if not structure_only: Gs.data[:,n] = Gs.data[:,n-1] | |
kk1, jj1 = flatmeshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda")) | |
kk2, jj2 = flatmeshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda")) | |
ii = torch.cat([ix[kk1], ix[kk2], ii]) | |
jj = torch.cat([jj1, jj2, jj]) | |
kk = torch.cat([kk1, kk2, kk]) | |
net1 = torch.zeros(b, len(kk1) + len(kk2), DIM, device="cuda") | |
net = torch.cat([net1, net], dim=1) | |
if np.random.rand() < 0.1: | |
k = (ii != (n - 4)) & (jj != (n - 4)) | |
ii = ii[k] | |
jj = jj[k] | |
kk = kk[k] | |
net = net[:,k] | |
patches[:,ix==n,2] = torch.median(patches[:,(ix == n-1) | (ix == n-2),2]) | |
n = ii.max() + 1 | |
coords = pops.transform(Gs, patches, intrinsics, ii, jj, kk) | |
coords1 = coords.permute(0, 1, 4, 2, 3).contiguous() | |
corr = corr_fn(kk, jj, coords1) | |
net, (delta, weight, _) = self.update(net, imap[:,kk], corr, None, ii, jj, kk) | |
lmbda = 1e-4 | |
target = coords[...,p//2,p//2,:] + delta | |
ep = 10 | |
for itr in range(2): | |
Gs, patches = BA(Gs, patches, intrinsics, target, weight, lmbda, ii, jj, kk, | |
bounds, ep=ep, fixedp=1, structure_only=structure_only) | |
kl = torch.as_tensor(0) | |
dij = (ii - jj).abs() | |
k = (dij > 0) & (dij <= 2) | |
coords = pops.transform(Gs, patches, intrinsics, ii[k], jj[k], kk[k]) | |
coords_gt, valid, _ = pops.transform(Ps, patches_gt, intrinsics, ii[k], jj[k], kk[k], jacobian=True) | |
traj.append((valid, coords, coords_gt, Gs[:,:n], Ps[:,:n], kl)) | |
return traj | |