|
"""Wrapper for DUSt3R model to estimate focal length. |
|
|
|
DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132 |
|
""" |
|
|
|
import sys |
|
|
|
sys.path.append("third_party/dust3r") |
|
|
|
import torch |
|
from dust3r.cloud_opt import GlobalAlignerMode, global_aligner |
|
from dust3r.image_pairs import make_pairs |
|
from dust3r.inference import inference, load_model |
|
from dust3r.utils.image import load_images |
|
|
|
from siclib.geometry.base_camera import BaseCamera |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.models import BaseModel |
|
|
|
|
|
|
|
|
|
class Dust3R(BaseModel): |
|
"""DUSt3R model for focal length estimation.""" |
|
|
|
default_conf = { |
|
"model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", |
|
"device": "cuda", |
|
"batch_size": 1, |
|
"schedule": "cosine", |
|
"lr": 0.01, |
|
"niter": 300, |
|
"show_scene": False, |
|
} |
|
|
|
required_data_keys = ["path"] |
|
|
|
def _init(self, conf): |
|
"""Initialize the DUSt3R model.""" |
|
self.model = load_model(conf["model_path"], conf["device"]) |
|
|
|
def _forward(self, data): |
|
"""Forward pass of the DUSt.""" |
|
assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}" |
|
|
|
path = data["path"][0] |
|
images = [path] * 2 |
|
|
|
with torch.enable_grad(): |
|
images = load_images(images, size=512) |
|
pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True) |
|
output = inference( |
|
pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"] |
|
) |
|
scene = global_aligner( |
|
output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer |
|
) |
|
_ = scene.compute_global_alignment( |
|
init="mst", |
|
niter=self.conf["niter"], |
|
schedule=self.conf["schedule"], |
|
lr=self.conf["lr"], |
|
) |
|
|
|
|
|
focals = scene.get_focals().mean(dim=0) |
|
|
|
h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1] |
|
h, w = focals.new_tensor(h), focals.new_tensor(w) |
|
|
|
camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals}) |
|
gravity = Gravity.from_rp([0.0], [0.0]) |
|
|
|
if self.conf["show_scene"]: |
|
scene.show() |
|
|
|
return {"camera": camera, "gravity": gravity} |
|
|
|
def loss(self, pred, data): |
|
"""Loss function for DUSt3R model.""" |
|
return {}, {} |
|
|