File size: 3,401 Bytes
a930e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import sys
from pathlib import Path

from .. import MODEL_REPO_ID, logger
from ..utils.base_model import BaseModel

lightglue_path = Path(__file__).parent / "../../third_party/LightGlue"
sys.path.append(str(lightglue_path))
from lightglue import LightGlue as LG
import torch
import os


class LightGlue(BaseModel):
    default_conf = {
        "match_threshold": 0.,
        "filter_threshold": 0.1,
        "width_confidence": 0.99,  # for point pruning
        "depth_confidence": 0.95,  # for early stopping,
        "features": "superpoint",
        "model_name": "superpoint_lightglue.pth",
        "flash": True,  # enable FlashAttention if available.
        "mp": False,  # enable mixed precision
        "add_scale_ori": False,
    }
    required_inputs = [
        "image0",
        "keypoints0",
        "scores0",
        "descriptors0",
        "image1",
        "keypoints1",
        "scores1",
        "descriptors1",
    ]

    def _init(self, conf):
        logger.info("Loading lightglue model, {}".format(conf["model_name"]))
        print("Loading lightglue model, {}".format(conf["model_name"]))
        if conf["model_name"] == 'superpoint_minima_lightglue.pth':
            model_web_path = 'https://github.com/LSXI7/storage/releases/download/MINIMA/minima_lightglue.pth'
            weight_path = torch.hub.load_state_dict_from_url(model_web_path, map_location=torch.device('cpu'))
            cache_dir = torch.hub.get_dir()
            filename = "minima_lightglue.pth"
            print('cache_dir', cache_dir)
            print('filename', filename)
            print('os.path.join(cache_dir, filename)', os.path.join(cache_dir, filename))
            # torch.hub.download_url_to_file(model_web_path, os.path.join(cache_dir, filename))

            model_path = os.path.join(cache_dir, 'checkpoints', filename)
            conf['MINIMA'] = True
            conf['MINIMA_path'] = model_path
        else:
            model_path = self._download_model(
                repo_id=MODEL_REPO_ID,
                filename="{}/{}".format(
                    Path(__file__).stem, self.conf["model_name"]
                ),
            )
            # print("model_path:", model_path)
        conf["weights"] = str(model_path)
        print("conf:", conf["weights"])
        conf["filter_threshold"] = conf["match_threshold"]
        self.net = LG(**conf)
        logger.info("Load lightglue model done.")

    def _forward(self, data):
        input = {}
        input["image0"] = {
            "image": data["image0"],
            "keypoints": data["keypoints0"],
            "descriptors": data["descriptors0"].permute(0, 2, 1),
        }
        if "scales0" in data:
            input["image0"] = {**input["image0"], "scales": data["scales0"]}
        if "oris0" in data:
            input["image0"] = {**input["image0"], "oris": data["oris0"]}

        input["image1"] = {
            "image": data["image1"],
            "keypoints": data["keypoints1"],
            "descriptors": data["descriptors1"].permute(0, 2, 1),
        }
        print('data["image0"]',data["image0"].shape)
        print('data["image1"]',data["image1"].shape)
        if "scales1" in data:
            input["image1"] = {**input["image1"], "scales": data["scales1"]}
        if "oris1" in data:
            input["image1"] = {**input["image1"], "oris": data["oris1"]}
        return self.net(input)