Spaces:
Running
Running
import torch | |
import lietorch | |
import numpy as np | |
import matplotlib.pyplot as plt | |
from lietorch import SE3 | |
from modules.corr import CorrBlock, AltCorrBlock | |
import geom.projective_ops as pops | |
from glob import glob | |
class FactorGraph: | |
def __init__(self, video, update_op, device="cuda:0", corr_impl="volume", max_factors=-1, upsample=False): | |
self.video = video | |
self.update_op = update_op | |
self.device = device | |
self.max_factors = max_factors | |
self.corr_impl = corr_impl | |
self.upsample = upsample | |
# operator at 1/8 resolution | |
self.ht = ht = video.ht // 8 | |
self.wd = wd = video.wd // 8 | |
self.coords0 = pops.coords_grid(ht, wd, device=device) | |
self.ii = torch.as_tensor([], dtype=torch.long, device=device) | |
self.jj = torch.as_tensor([], dtype=torch.long, device=device) | |
self.age = torch.as_tensor([], dtype=torch.long, device=device) | |
self.corr, self.net, self.inp = None, None, None | |
self.damping = 1e-6 * torch.ones_like(self.video.disps) | |
self.target = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) | |
self.weight = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) | |
# inactive factors | |
self.ii_inac = torch.as_tensor([], dtype=torch.long, device=device) | |
self.jj_inac = torch.as_tensor([], dtype=torch.long, device=device) | |
self.ii_bad = torch.as_tensor([], dtype=torch.long, device=device) | |
self.jj_bad = torch.as_tensor([], dtype=torch.long, device=device) | |
self.target_inac = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) | |
self.weight_inac = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) | |
def __filter_repeated_edges(self, ii, jj): | |
""" remove duplicate edges """ | |
keep = torch.zeros(ii.shape[0], dtype=torch.bool, device=ii.device) | |
eset = set( | |
[(i.item(), j.item()) for i, j in zip(self.ii, self.jj)] + | |
[(i.item(), j.item()) for i, j in zip(self.ii_inac, self.jj_inac)]) | |
for k, (i, j) in enumerate(zip(ii, jj)): | |
keep[k] = (i.item(), j.item()) not in eset | |
return ii[keep], jj[keep] | |
def print_edges(self): | |
ii = self.ii.cpu().numpy() | |
jj = self.jj.cpu().numpy() | |
ix = np.argsort(ii) | |
ii = ii[ix] | |
jj = jj[ix] | |
w = torch.mean(self.weight, dim=[0,2,3,4]).cpu().numpy() | |
w = w[ix] | |
for e in zip(ii, jj, w): | |
print(e) | |
print() | |
def filter_edges(self): | |
""" remove bad edges """ | |
conf = torch.mean(self.weight, dim=[0,2,3,4]) | |
mask = (torch.abs(self.ii-self.jj) > 2) & (conf < 0.001) | |
self.ii_bad = torch.cat([self.ii_bad, self.ii[mask]]) | |
self.jj_bad = torch.cat([self.jj_bad, self.jj[mask]]) | |
self.rm_factors(mask, store=False) | |
def clear_edges(self): | |
self.rm_factors(self.ii >= 0) | |
self.net = None | |
self.inp = None | |
def add_factors(self, ii, jj, remove=False): | |
""" add edges to factor graph """ | |
if not isinstance(ii, torch.Tensor): | |
ii = torch.as_tensor(ii, dtype=torch.long, device=self.device) | |
if not isinstance(jj, torch.Tensor): | |
jj = torch.as_tensor(jj, dtype=torch.long, device=self.device) | |
# remove duplicate edges | |
ii, jj = self.__filter_repeated_edges(ii, jj) | |
if ii.shape[0] == 0: | |
return | |
# place limit on number of factors | |
if self.max_factors > 0 and self.ii.shape[0] + ii.shape[0] > self.max_factors \ | |
and self.corr is not None and remove: | |
ix = torch.arange(len(self.age))[torch.argsort(self.age).cpu()] | |
self.rm_factors(ix >= self.max_factors - ii.shape[0], store=True) | |
net = self.video.nets[ii].to(self.device).unsqueeze(0) | |
# correlation volume for new edges | |
if self.corr_impl == "volume": | |
c = (ii == jj).long() | |
fmap1 = self.video.fmaps[ii,0].to(self.device).unsqueeze(0) | |
fmap2 = self.video.fmaps[jj,c].to(self.device).unsqueeze(0) | |
corr = CorrBlock(fmap1, fmap2) | |
self.corr = corr if self.corr is None else self.corr.cat(corr) | |
inp = self.video.inps[ii].to(self.device).unsqueeze(0) | |
self.inp = inp if self.inp is None else torch.cat([self.inp, inp], 1) | |
with torch.cuda.amp.autocast(enabled=False): | |
target, _ = self.video.reproject(ii, jj) | |
weight = torch.zeros_like(target) | |
self.ii = torch.cat([self.ii, ii], 0) | |
self.jj = torch.cat([self.jj, jj], 0) | |
self.age = torch.cat([self.age, torch.zeros_like(ii)], 0) | |
# reprojection factors | |
self.net = net if self.net is None else torch.cat([self.net, net], 1) | |
self.target = torch.cat([self.target, target], 1) | |
self.weight = torch.cat([self.weight, weight], 1) | |
def rm_factors(self, mask, store=False): | |
""" drop edges from factor graph """ | |
# store estimated factors | |
if store: | |
self.ii_inac = torch.cat([self.ii_inac, self.ii[mask]], 0) | |
self.jj_inac = torch.cat([self.jj_inac, self.jj[mask]], 0) | |
self.target_inac = torch.cat([self.target_inac, self.target[:,mask]], 1) | |
self.weight_inac = torch.cat([self.weight_inac, self.weight[:,mask]], 1) | |
self.ii = self.ii[~mask] | |
self.jj = self.jj[~mask] | |
self.age = self.age[~mask] | |
if self.corr_impl == "volume": | |
self.corr = self.corr[~mask] | |
if self.net is not None: | |
self.net = self.net[:,~mask] | |
if self.inp is not None: | |
self.inp = self.inp[:,~mask] | |
self.target = self.target[:,~mask] | |
self.weight = self.weight[:,~mask] | |
def rm_keyframe(self, ix): | |
""" drop edges from factor graph """ | |
with self.video.get_lock(): | |
self.video.images[ix] = self.video.images[ix+1] | |
self.video.poses[ix] = self.video.poses[ix+1] | |
self.video.disps[ix] = self.video.disps[ix+1] | |
self.video.disps_sens[ix] = self.video.disps_sens[ix+1] | |
self.video.intrinsics[ix] = self.video.intrinsics[ix+1] | |
self.video.nets[ix] = self.video.nets[ix+1] | |
self.video.inps[ix] = self.video.inps[ix+1] | |
self.video.fmaps[ix] = self.video.fmaps[ix+1] | |
self.video.tstamp[ix] = self.video.tstamp[ix+1] | |
self.video.masks[ix] = self.video.masks[ix+1] | |
m = (self.ii_inac == ix) | (self.jj_inac == ix) | |
self.ii_inac[self.ii_inac >= ix] -= 1 | |
self.jj_inac[self.jj_inac >= ix] -= 1 | |
if torch.any(m): | |
self.ii_inac = self.ii_inac[~m] | |
self.jj_inac = self.jj_inac[~m] | |
self.target_inac = self.target_inac[:,~m] | |
self.weight_inac = self.weight_inac[:,~m] | |
m = (self.ii == ix) | (self.jj == ix) | |
self.ii[self.ii >= ix] -= 1 | |
self.jj[self.jj >= ix] -= 1 | |
self.rm_factors(m, store=False) | |
def update(self, t0=None, t1=None, itrs=3, use_inactive=False, EP=1e-7, motion_only=False): | |
""" run update operator on factor graph """ | |
# motion features | |
with torch.cuda.amp.autocast(enabled=False): | |
coords1, mask = self.video.reproject(self.ii, self.jj) | |
motn = torch.cat([coords1 - self.coords0, self.target - coords1], dim=-1) | |
motn = motn.permute(0,1,4,2,3).clamp(-64.0, 64.0) | |
# correlation features | |
corr = self.corr(coords1) | |
self.net, delta, weight, damping, upmask = \ | |
self.update_op(self.net, self.inp, corr, motn, self.ii, self.jj) | |
##### save confidecnce weight for vis ##### | |
# for k in range(len(self.ii)): | |
# w = weight[:, k].detach().cpu().numpy() | |
# idx_i = self.ii[k] | |
# idx_j = self.jj[k] | |
# np.save(f'pred_conf/{idx_i:04d}_{idx_j:04d}.npy', w) | |
############################################# | |
# Shapes: | |
# weight: [1, k, h//8, w//8, 2] | |
# self.ii: [k]; self.jj: [k] | |
msk = self.video.masks[self.ii] > 0 | |
weight[:,msk] = 0.0 | |
if t0 is None: | |
t0 = max(1, self.ii.min().item()+1) | |
with torch.cuda.amp.autocast(enabled=False): | |
self.target = coords1 + delta.to(dtype=torch.float) | |
self.weight = weight.to(dtype=torch.float) | |
ht, wd = self.coords0.shape[0:2] | |
self.damping[torch.unique(self.ii)] = damping | |
if use_inactive: | |
m = (self.ii_inac >= t0 - 3) & (self.jj_inac >= t0 - 3) | |
ii = torch.cat([self.ii_inac[m], self.ii], 0) | |
jj = torch.cat([self.jj_inac[m], self.jj], 0) | |
target = torch.cat([self.target_inac[:,m], self.target], 1) | |
weight = torch.cat([self.weight_inac[:,m], self.weight], 1) | |
else: | |
ii, jj, target, weight = self.ii, self.jj, self.target, self.weight | |
damping = .2 * self.damping[torch.unique(ii)].contiguous() + EP | |
target = target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() | |
weight = weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() | |
# dense bundle adjustment | |
self.video.ba(target, weight, damping, ii, jj, t0, t1, | |
itrs=itrs, lm=1e-4, ep=0.1, motion_only=motion_only) | |
if self.upsample: | |
self.video.upsample(torch.unique(self.ii), upmask) | |
self.age += 1 | |
def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, steps=8): | |
""" run update operator on factor graph - reduced memory implementation """ | |
# alternate corr implementation | |
t = self.video.counter.value | |
num, rig, ch, ht, wd = self.video.fmaps.shape | |
corr_op = AltCorrBlock(self.video.fmaps.view(1, num*rig, ch, ht, wd)) | |
print("Global BA Iteration with {} steps".format(steps)) | |
for step in range(steps): | |
# print("Global BA Iteration #{}".format(step+1)) | |
with torch.cuda.amp.autocast(enabled=False): | |
coords1, mask = self.video.reproject(self.ii, self.jj) | |
motn = torch.cat([coords1 - self.coords0, self.target - coords1], dim=-1) | |
motn = motn.permute(0,1,4,2,3).clamp(-64.0, 64.0) | |
s = 8 | |
for i in range(0, self.jj.max()+1, s): | |
v = (self.ii >= i) & (self.ii < i + s) | |
iis = self.ii[v] | |
jjs = self.jj[v] | |
ht, wd = self.coords0.shape[0:2] | |
corr1 = corr_op(coords1[:,v], rig * iis, rig * jjs + (iis == jjs).long()) | |
with torch.cuda.amp.autocast(enabled=True): | |
net, delta, weight, damping, upmask = \ | |
self.update_op(self.net[:,v], self.video.inps[None,iis], corr1, motn[:,v], iis, jjs) | |
if self.upsample: | |
self.video.upsample(torch.unique(iis), upmask) | |
# Shapes: | |
# weight: [1, k, h//8, w//8, 2] | |
# self.ii: [k]; self.jj: [k] | |
msk = self.video.masks[iis] > 0 | |
weight[:,msk] = 0.0 | |
self.net[:,v] = net | |
self.target[:,v] = coords1[:,v] + delta.float() | |
self.weight[:,v] = weight.float() | |
self.damping[torch.unique(iis)] = damping | |
damping = .2 * self.damping[torch.unique(self.ii)].contiguous() + EP | |
target = self.target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() | |
weight = self.weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() | |
# dense bundle adjustment | |
self.video.ba(target, weight, damping, self.ii, self.jj, 1, t, | |
itrs=itrs, lm=1e-5, ep=1e-2, motion_only=False) | |
self.video.dirty[:t] = True | |
def add_neighborhood_factors(self, t0, t1, r=3): | |
""" add edges between neighboring frames within radius r """ | |
ii, jj = torch.meshgrid(torch.arange(t0,t1), torch.arange(t0,t1), indexing='ij') | |
ii = ii.reshape(-1).to(dtype=torch.long, device=self.device) | |
jj = jj.reshape(-1).to(dtype=torch.long, device=self.device) | |
c = 1 if self.video.stereo else 0 | |
keep = ((ii - jj).abs() > c) & ((ii - jj).abs() <= r) | |
self.add_factors(ii[keep], jj[keep]) | |
def add_proximity_factors(self, t0=0, t1=0, rad=2, nms=2, beta=0.25, thresh=16.0, remove=False): | |
""" add edges to the factor graph based on distance """ | |
t = self.video.counter.value | |
ix = torch.arange(t0, t) | |
jx = torch.arange(t1, t) | |
ii, jj = torch.meshgrid(ix, jx, indexing='ij') | |
ii = ii.reshape(-1) | |
jj = jj.reshape(-1) | |
d = self.video.distance(ii, jj, beta=beta) | |
d[ii - rad < jj] = np.inf | |
d[d > 100] = np.inf | |
ii1 = torch.cat([self.ii, self.ii_bad, self.ii_inac], 0) | |
jj1 = torch.cat([self.jj, self.jj_bad, self.jj_inac], 0) | |
for i, j in zip(ii1.cpu().numpy(), jj1.cpu().numpy()): | |
for di in range(-nms, nms+1): | |
for dj in range(-nms, nms+1): | |
if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0): | |
i1 = i + di | |
j1 = j + dj | |
if (t0 <= i1 < t) and (t1 <= j1 < t): | |
d[(i1-t0)*(t-t1) + (j1-t1)] = np.inf | |
es = [] | |
for i in range(t0, t): | |
if self.video.stereo: | |
es.append((i, i)) | |
d[(i-t0)*(t-t1) + (i-t1)] = np.inf | |
for j in range(max(i-rad-1,0), i): | |
es.append((i,j)) | |
es.append((j,i)) | |
d[(i-t0)*(t-t1) + (j-t1)] = np.inf | |
ix = torch.argsort(d) | |
for k in ix: | |
if d[k].item() > thresh: | |
continue | |
if len(es) > self.max_factors: | |
break | |
i = ii[k] | |
j = jj[k] | |
# bidirectional | |
es.append((i, j)) | |
es.append((j, i)) | |
for di in range(-nms, nms+1): | |
for dj in range(-nms, nms+1): | |
if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0): | |
i1 = i + di | |
j1 = j + dj | |
if (t0 <= i1 < t) and (t1 <= j1 < t): | |
d[(i1-t0)*(t-t1) + (j1-t1)] = np.inf | |
ii, jj = torch.as_tensor(es, device=self.device).unbind(dim=-1) | |
self.add_factors(ii, jj, remove) | |