veichta's picture
Upload folder using huggingface_hub
205a7af verified
"""Wrapper for DUSt3R model to estimate focal length.
DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132
"""
import sys
sys.path.append("third_party/dust3r")
import torch
from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference, load_model
from dust3r.utils.image import load_images
from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.models import BaseModel
# mypy: ignore-errors
class Dust3R(BaseModel):
"""DUSt3R model for focal length estimation."""
default_conf = {
"model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
"device": "cuda",
"batch_size": 1,
"schedule": "cosine",
"lr": 0.01,
"niter": 300,
"show_scene": False,
}
required_data_keys = ["path"]
def _init(self, conf):
"""Initialize the DUSt3R model."""
self.model = load_model(conf["model_path"], conf["device"])
def _forward(self, data):
"""Forward pass of the DUSt."""
assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}"
path = data["path"][0]
images = [path] * 2
with torch.enable_grad():
images = load_images(images, size=512)
pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
output = inference(
pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"]
)
scene = global_aligner(
output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer
)
_ = scene.compute_global_alignment(
init="mst",
niter=self.conf["niter"],
schedule=self.conf["schedule"],
lr=self.conf["lr"],
)
# retrieve useful values from scene:
focals = scene.get_focals().mean(dim=0)
h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1]
h, w = focals.new_tensor(h), focals.new_tensor(w)
camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals})
gravity = Gravity.from_rp([0.0], [0.0])
if self.conf["show_scene"]:
scene.show()
return {"camera": camera, "gravity": gravity}
def loss(self, pred, data):
"""Loss function for DUSt3R model."""
return {}, {}