import torch import lietorch import numpy as np import matplotlib.pyplot as plt from lietorch import SE3 from modules.corr import CorrBlock, AltCorrBlock import geom.projective_ops as pops from glob import glob class FactorGraph: def __init__(self, video, update_op, device="cuda:0", corr_impl="volume", max_factors=-1, upsample=False): self.video = video self.update_op = update_op self.device = device self.max_factors = max_factors self.corr_impl = corr_impl self.upsample = upsample # operator at 1/8 resolution self.ht = ht = video.ht // 8 self.wd = wd = video.wd // 8 self.coords0 = pops.coords_grid(ht, wd, device=device) self.ii = torch.as_tensor([], dtype=torch.long, device=device) self.jj = torch.as_tensor([], dtype=torch.long, device=device) self.age = torch.as_tensor([], dtype=torch.long, device=device) self.corr, self.net, self.inp = None, None, None self.damping = 1e-6 * torch.ones_like(self.video.disps) self.target = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) self.weight = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) # inactive factors self.ii_inac = torch.as_tensor([], dtype=torch.long, device=device) self.jj_inac = torch.as_tensor([], dtype=torch.long, device=device) self.ii_bad = torch.as_tensor([], dtype=torch.long, device=device) self.jj_bad = torch.as_tensor([], dtype=torch.long, device=device) self.target_inac = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) self.weight_inac = torch.zeros([1, 0, ht, wd, 2], device=device, dtype=torch.float) def __filter_repeated_edges(self, ii, jj): """ remove duplicate edges """ keep = torch.zeros(ii.shape[0], dtype=torch.bool, device=ii.device) eset = set( [(i.item(), j.item()) for i, j in zip(self.ii, self.jj)] + [(i.item(), j.item()) for i, j in zip(self.ii_inac, self.jj_inac)]) for k, (i, j) in enumerate(zip(ii, jj)): keep[k] = (i.item(), j.item()) not in eset return ii[keep], jj[keep] def print_edges(self): ii = self.ii.cpu().numpy() jj = self.jj.cpu().numpy() ix = np.argsort(ii) ii = ii[ix] jj = jj[ix] w = torch.mean(self.weight, dim=[0,2,3,4]).cpu().numpy() w = w[ix] for e in zip(ii, jj, w): print(e) print() def filter_edges(self): """ remove bad edges """ conf = torch.mean(self.weight, dim=[0,2,3,4]) mask = (torch.abs(self.ii-self.jj) > 2) & (conf < 0.001) self.ii_bad = torch.cat([self.ii_bad, self.ii[mask]]) self.jj_bad = torch.cat([self.jj_bad, self.jj[mask]]) self.rm_factors(mask, store=False) def clear_edges(self): self.rm_factors(self.ii >= 0) self.net = None self.inp = None @torch.cuda.amp.autocast(enabled=True) def add_factors(self, ii, jj, remove=False): """ add edges to factor graph """ if not isinstance(ii, torch.Tensor): ii = torch.as_tensor(ii, dtype=torch.long, device=self.device) if not isinstance(jj, torch.Tensor): jj = torch.as_tensor(jj, dtype=torch.long, device=self.device) # remove duplicate edges ii, jj = self.__filter_repeated_edges(ii, jj) if ii.shape[0] == 0: return # place limit on number of factors if self.max_factors > 0 and self.ii.shape[0] + ii.shape[0] > self.max_factors \ and self.corr is not None and remove: ix = torch.arange(len(self.age))[torch.argsort(self.age).cpu()] self.rm_factors(ix >= self.max_factors - ii.shape[0], store=True) net = self.video.nets[ii].to(self.device).unsqueeze(0) # correlation volume for new edges if self.corr_impl == "volume": c = (ii == jj).long() fmap1 = self.video.fmaps[ii,0].to(self.device).unsqueeze(0) fmap2 = self.video.fmaps[jj,c].to(self.device).unsqueeze(0) corr = CorrBlock(fmap1, fmap2) self.corr = corr if self.corr is None else self.corr.cat(corr) inp = self.video.inps[ii].to(self.device).unsqueeze(0) self.inp = inp if self.inp is None else torch.cat([self.inp, inp], 1) with torch.cuda.amp.autocast(enabled=False): target, _ = self.video.reproject(ii, jj) weight = torch.zeros_like(target) self.ii = torch.cat([self.ii, ii], 0) self.jj = torch.cat([self.jj, jj], 0) self.age = torch.cat([self.age, torch.zeros_like(ii)], 0) # reprojection factors self.net = net if self.net is None else torch.cat([self.net, net], 1) self.target = torch.cat([self.target, target], 1) self.weight = torch.cat([self.weight, weight], 1) @torch.cuda.amp.autocast(enabled=True) def rm_factors(self, mask, store=False): """ drop edges from factor graph """ # store estimated factors if store: self.ii_inac = torch.cat([self.ii_inac, self.ii[mask]], 0) self.jj_inac = torch.cat([self.jj_inac, self.jj[mask]], 0) self.target_inac = torch.cat([self.target_inac, self.target[:,mask]], 1) self.weight_inac = torch.cat([self.weight_inac, self.weight[:,mask]], 1) self.ii = self.ii[~mask] self.jj = self.jj[~mask] self.age = self.age[~mask] if self.corr_impl == "volume": self.corr = self.corr[~mask] if self.net is not None: self.net = self.net[:,~mask] if self.inp is not None: self.inp = self.inp[:,~mask] self.target = self.target[:,~mask] self.weight = self.weight[:,~mask] @torch.cuda.amp.autocast(enabled=True) def rm_keyframe(self, ix): """ drop edges from factor graph """ with self.video.get_lock(): self.video.images[ix] = self.video.images[ix+1] self.video.poses[ix] = self.video.poses[ix+1] self.video.disps[ix] = self.video.disps[ix+1] self.video.disps_sens[ix] = self.video.disps_sens[ix+1] self.video.intrinsics[ix] = self.video.intrinsics[ix+1] self.video.nets[ix] = self.video.nets[ix+1] self.video.inps[ix] = self.video.inps[ix+1] self.video.fmaps[ix] = self.video.fmaps[ix+1] self.video.tstamp[ix] = self.video.tstamp[ix+1] self.video.masks[ix] = self.video.masks[ix+1] m = (self.ii_inac == ix) | (self.jj_inac == ix) self.ii_inac[self.ii_inac >= ix] -= 1 self.jj_inac[self.jj_inac >= ix] -= 1 if torch.any(m): self.ii_inac = self.ii_inac[~m] self.jj_inac = self.jj_inac[~m] self.target_inac = self.target_inac[:,~m] self.weight_inac = self.weight_inac[:,~m] m = (self.ii == ix) | (self.jj == ix) self.ii[self.ii >= ix] -= 1 self.jj[self.jj >= ix] -= 1 self.rm_factors(m, store=False) @torch.cuda.amp.autocast(enabled=True) def update(self, t0=None, t1=None, itrs=3, use_inactive=False, EP=1e-7, motion_only=False): """ run update operator on factor graph """ # motion features with torch.cuda.amp.autocast(enabled=False): coords1, mask = self.video.reproject(self.ii, self.jj) motn = torch.cat([coords1 - self.coords0, self.target - coords1], dim=-1) motn = motn.permute(0,1,4,2,3).clamp(-64.0, 64.0) # correlation features corr = self.corr(coords1) self.net, delta, weight, damping, upmask = \ self.update_op(self.net, self.inp, corr, motn, self.ii, self.jj) ##### save confidecnce weight for vis ##### # for k in range(len(self.ii)): # w = weight[:, k].detach().cpu().numpy() # idx_i = self.ii[k] # idx_j = self.jj[k] # np.save(f'pred_conf/{idx_i:04d}_{idx_j:04d}.npy', w) ############################################# # Shapes: # weight: [1, k, h//8, w//8, 2] # self.ii: [k]; self.jj: [k] msk = self.video.masks[self.ii] > 0 weight[:,msk] = 0.0 if t0 is None: t0 = max(1, self.ii.min().item()+1) with torch.cuda.amp.autocast(enabled=False): self.target = coords1 + delta.to(dtype=torch.float) self.weight = weight.to(dtype=torch.float) ht, wd = self.coords0.shape[0:2] self.damping[torch.unique(self.ii)] = damping if use_inactive: m = (self.ii_inac >= t0 - 3) & (self.jj_inac >= t0 - 3) ii = torch.cat([self.ii_inac[m], self.ii], 0) jj = torch.cat([self.jj_inac[m], self.jj], 0) target = torch.cat([self.target_inac[:,m], self.target], 1) weight = torch.cat([self.weight_inac[:,m], self.weight], 1) else: ii, jj, target, weight = self.ii, self.jj, self.target, self.weight damping = .2 * self.damping[torch.unique(ii)].contiguous() + EP target = target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() weight = weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() # dense bundle adjustment self.video.ba(target, weight, damping, ii, jj, t0, t1, itrs=itrs, lm=1e-4, ep=0.1, motion_only=motion_only) if self.upsample: self.video.upsample(torch.unique(self.ii), upmask) self.age += 1 @torch.cuda.amp.autocast(enabled=False) def update_lowmem(self, t0=None, t1=None, itrs=2, use_inactive=False, EP=1e-7, steps=8): """ run update operator on factor graph - reduced memory implementation """ # alternate corr implementation t = self.video.counter.value num, rig, ch, ht, wd = self.video.fmaps.shape corr_op = AltCorrBlock(self.video.fmaps.view(1, num*rig, ch, ht, wd)) print("Global BA Iteration with {} steps".format(steps)) for step in range(steps): # print("Global BA Iteration #{}".format(step+1)) with torch.cuda.amp.autocast(enabled=False): coords1, mask = self.video.reproject(self.ii, self.jj) motn = torch.cat([coords1 - self.coords0, self.target - coords1], dim=-1) motn = motn.permute(0,1,4,2,3).clamp(-64.0, 64.0) s = 8 for i in range(0, self.jj.max()+1, s): v = (self.ii >= i) & (self.ii < i + s) iis = self.ii[v] jjs = self.jj[v] ht, wd = self.coords0.shape[0:2] corr1 = corr_op(coords1[:,v], rig * iis, rig * jjs + (iis == jjs).long()) with torch.cuda.amp.autocast(enabled=True): net, delta, weight, damping, upmask = \ self.update_op(self.net[:,v], self.video.inps[None,iis], corr1, motn[:,v], iis, jjs) if self.upsample: self.video.upsample(torch.unique(iis), upmask) # Shapes: # weight: [1, k, h//8, w//8, 2] # self.ii: [k]; self.jj: [k] msk = self.video.masks[iis] > 0 weight[:,msk] = 0.0 self.net[:,v] = net self.target[:,v] = coords1[:,v] + delta.float() self.weight[:,v] = weight.float() self.damping[torch.unique(iis)] = damping damping = .2 * self.damping[torch.unique(self.ii)].contiguous() + EP target = self.target.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() weight = self.weight.view(-1, ht, wd, 2).permute(0,3,1,2).contiguous() # dense bundle adjustment self.video.ba(target, weight, damping, self.ii, self.jj, 1, t, itrs=itrs, lm=1e-5, ep=1e-2, motion_only=False) self.video.dirty[:t] = True def add_neighborhood_factors(self, t0, t1, r=3): """ add edges between neighboring frames within radius r """ ii, jj = torch.meshgrid(torch.arange(t0,t1), torch.arange(t0,t1), indexing='ij') ii = ii.reshape(-1).to(dtype=torch.long, device=self.device) jj = jj.reshape(-1).to(dtype=torch.long, device=self.device) c = 1 if self.video.stereo else 0 keep = ((ii - jj).abs() > c) & ((ii - jj).abs() <= r) self.add_factors(ii[keep], jj[keep]) def add_proximity_factors(self, t0=0, t1=0, rad=2, nms=2, beta=0.25, thresh=16.0, remove=False): """ add edges to the factor graph based on distance """ t = self.video.counter.value ix = torch.arange(t0, t) jx = torch.arange(t1, t) ii, jj = torch.meshgrid(ix, jx, indexing='ij') ii = ii.reshape(-1) jj = jj.reshape(-1) d = self.video.distance(ii, jj, beta=beta) d[ii - rad < jj] = np.inf d[d > 100] = np.inf ii1 = torch.cat([self.ii, self.ii_bad, self.ii_inac], 0) jj1 = torch.cat([self.jj, self.jj_bad, self.jj_inac], 0) for i, j in zip(ii1.cpu().numpy(), jj1.cpu().numpy()): for di in range(-nms, nms+1): for dj in range(-nms, nms+1): if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0): i1 = i + di j1 = j + dj if (t0 <= i1 < t) and (t1 <= j1 < t): d[(i1-t0)*(t-t1) + (j1-t1)] = np.inf es = [] for i in range(t0, t): if self.video.stereo: es.append((i, i)) d[(i-t0)*(t-t1) + (i-t1)] = np.inf for j in range(max(i-rad-1,0), i): es.append((i,j)) es.append((j,i)) d[(i-t0)*(t-t1) + (j-t1)] = np.inf ix = torch.argsort(d) for k in ix: if d[k].item() > thresh: continue if len(es) > self.max_factors: break i = ii[k] j = jj[k] # bidirectional es.append((i, j)) es.append((j, i)) for di in range(-nms, nms+1): for dj in range(-nms, nms+1): if abs(di) + abs(dj) <= max(min(abs(i-j)-2, nms), 0): i1 = i + di j1 = j + dj if (t0 <= i1 < t) and (t1 <= j1 < t): d[(i1-t0)*(t-t1) + (j1-t1)] = np.inf ii, jj = torch.as_tensor(es, device=self.device).unbind(dim=-1) self.add_factors(ii, jj, remove)