Spaces:
Running
Running
File size: 3,289 Bytes
b7eedf7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
import torch
import lietorch
import numpy as np
from droid_net import DroidNet
from depth_video import DepthVideo
from motion_filter import MotionFilter
from droid_frontend import DroidFrontend
from droid_backend import DroidBackend
from trajectory_filler import PoseTrajectoryFiller
from collections import OrderedDict
from torch.multiprocessing import Process
class Droid:
def __init__(self, args):
super(Droid, self).__init__()
self.load_weights(args.weights)
self.args = args
self.disable_vis = args.disable_vis
# store images, depth, poses, intrinsics (shared between processes)
self.video = DepthVideo(args.image_size, args.buffer, stereo=args.stereo)
# filter incoming frames so that there is enough motion
self.filterx = MotionFilter(self.net, self.video, thresh=args.filter_thresh)
# frontend process
self.frontend = DroidFrontend(self.net, self.video, self.args)
# backend process
self.backend = DroidBackend(self.net, self.video, self.args)
# visualizer
if not self.disable_vis:
# from visualization import droid_visualization
from vis_headless import droid_visualization
print('Using headless ...')
self.visualizer = Process(target=droid_visualization, args=(self.video, '.'))
self.visualizer.start()
# post processor - fill in poses for non-keyframes
self.traj_filler = PoseTrajectoryFiller(self.net, self.video)
def load_weights(self, weights):
""" load trained model weights """
self.net = DroidNet()
state_dict = OrderedDict([
(k.replace("module.", ""), v) for (k, v) in torch.load(weights).items()])
state_dict["update.weight.2.weight"] = state_dict["update.weight.2.weight"][:2]
state_dict["update.weight.2.bias"] = state_dict["update.weight.2.bias"][:2]
state_dict["update.delta.2.weight"] = state_dict["update.delta.2.weight"][:2]
state_dict["update.delta.2.bias"] = state_dict["update.delta.2.bias"][:2]
self.net.load_state_dict(state_dict)
self.net.to("cuda:0").eval()
def track(self, tstamp, image, depth=None, intrinsics=None, mask=None):
""" main thread - update map """
with torch.no_grad():
# check there is enough motion
self.filterx.track(tstamp, image, depth, intrinsics, mask)
# local bundle adjustment
self.frontend()
# global bundle adjustment
# self.backend()
def terminate(self, stream=None, backend=True):
""" terminate the visualization process, return poses [t, q] """
del self.frontend
if backend:
torch.cuda.empty_cache()
# print("#" * 32)
self.backend(7)
torch.cuda.empty_cache()
# print("#" * 32)
self.backend(12)
camera_trajectory = self.traj_filler(stream)
return camera_trajectory.inv().data.cpu().numpy()
def compute_error(self):
""" compute slam reprojection error """
del self.frontend
torch.cuda.empty_cache()
self.backend(12)
return self.backend.errors[-1]
|