File size: 2,541 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Wrapper for DUSt3R model to estimate focal length.

DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132
"""

import sys

sys.path.append("third_party/dust3r")

import torch
from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
from dust3r.image_pairs import make_pairs
from dust3r.inference import inference, load_model
from dust3r.utils.image import load_images

from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.models import BaseModel

# mypy: ignore-errors


class Dust3R(BaseModel):
    """DUSt3R model for focal length estimation."""

    default_conf = {
        "model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
        "device": "cuda",
        "batch_size": 1,
        "schedule": "cosine",
        "lr": 0.01,
        "niter": 300,
        "show_scene": False,
    }

    required_data_keys = ["path"]

    def _init(self, conf):
        """Initialize the DUSt3R model."""
        self.model = load_model(conf["model_path"], conf["device"])

    def _forward(self, data):
        """Forward pass of the DUSt."""
        assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}"

        path = data["path"][0]
        images = [path] * 2

        with torch.enable_grad():
            images = load_images(images, size=512)
            pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
            output = inference(
                pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"]
            )
            scene = global_aligner(
                output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer
            )
            _ = scene.compute_global_alignment(
                init="mst",
                niter=self.conf["niter"],
                schedule=self.conf["schedule"],
                lr=self.conf["lr"],
            )

        # retrieve useful values from scene:
        focals = scene.get_focals().mean(dim=0)

        h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1]
        h, w = focals.new_tensor(h), focals.new_tensor(w)

        camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals})
        gravity = Gravity.from_rp([0.0], [0.0])

        if self.conf["show_scene"]:
            scene.show()

        return {"camera": camera, "gravity": gravity}

    def loss(self, pred, data):
        """Loss function for DUSt3R model."""
        return {}, {}