Realcat commited on
Commit
13760e8
·
1 Parent(s): 7fd062e

add: liftfeat

Browse files
Files changed (36) hide show
  1. README.md +1 -0
  2. config/config.yaml +11 -0
  3. imcui/hloc/extract_features.py +11 -0
  4. imcui/hloc/extractors/liftfeat.py +57 -0
  5. imcui/third_party/LiftFeat/.gitignore +4 -0
  6. imcui/third_party/LiftFeat/README.md +141 -0
  7. imcui/third_party/LiftFeat/assert/achitecture.png +3 -0
  8. imcui/third_party/LiftFeat/assert/demo_liftfeat.gif +3 -0
  9. imcui/third_party/LiftFeat/assert/demo_sp.gif +3 -0
  10. imcui/third_party/LiftFeat/assert/query.jpg +3 -0
  11. imcui/third_party/LiftFeat/assert/ref.jpg +3 -0
  12. imcui/third_party/LiftFeat/data/megadepth_1500.json +0 -0
  13. imcui/third_party/LiftFeat/dataset/__init__.py +0 -0
  14. imcui/third_party/LiftFeat/dataset/coco_augmentor.py +298 -0
  15. imcui/third_party/LiftFeat/dataset/coco_wrapper.py +175 -0
  16. imcui/third_party/LiftFeat/dataset/dataset_utils.py +183 -0
  17. imcui/third_party/LiftFeat/dataset/megadepth.py +177 -0
  18. imcui/third_party/LiftFeat/dataset/megadepth_wrapper.py +167 -0
  19. imcui/third_party/LiftFeat/demo.py +68 -0
  20. imcui/third_party/LiftFeat/evaluation/HPatch_evaluation.py +182 -0
  21. imcui/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py +105 -0
  22. imcui/third_party/LiftFeat/evaluation/eval_utils.py +127 -0
  23. imcui/third_party/LiftFeat/loss/loss.py +291 -0
  24. imcui/third_party/LiftFeat/models/interpolator.py +34 -0
  25. imcui/third_party/LiftFeat/models/liftfeat.py +190 -0
  26. imcui/third_party/LiftFeat/models/liftfeat_wrapper.py +173 -0
  27. imcui/third_party/LiftFeat/models/model.py +419 -0
  28. imcui/third_party/LiftFeat/requirements.txt +18 -0
  29. imcui/third_party/LiftFeat/train.py +365 -0
  30. imcui/third_party/LiftFeat/train.sh +11 -0
  31. imcui/third_party/LiftFeat/utils/__init__.py +0 -0
  32. imcui/third_party/LiftFeat/utils/alike_wrapper.py +45 -0
  33. imcui/third_party/LiftFeat/utils/config.py +16 -0
  34. imcui/third_party/LiftFeat/utils/depth_anything_wrapper.py +150 -0
  35. imcui/third_party/LiftFeat/utils/featurebooster.py +247 -0
  36. imcui/third_party/LiftFeat/weights/LiftFeat.pth +3 -0
README.md CHANGED
@@ -45,6 +45,7 @@ The tool currently supports various popular image matching algorithms, namely:
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
  | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
 
48
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
49
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
50
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
  | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
48
+ | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
49
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
50
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
51
  | EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
config/config.yaml CHANGED
@@ -256,6 +256,17 @@ matcher_zoo:
256
  paper: https://arxiv.org/abs/2404.19174
257
  project: null
258
  display: false
 
 
 
 
 
 
 
 
 
 
 
259
  dedode:
260
  matcher: Dual-Softmax
261
  feature: dedode
 
256
  paper: https://arxiv.org/abs/2404.19174
257
  project: null
258
  display: false
259
+ liftfeat(sparse):
260
+ matcher: NN-mutual
261
+ feature: liftfeat
262
+ dense: false
263
+ info:
264
+ name: LiftFeat #dispaly name
265
+ source: "ICRA 2025"
266
+ github: https://github.com/lyp-deeplearning/LiftFeat
267
+ paper: https://arxiv.org/abs/2505.0342
268
+ project: null
269
+ display: true
270
  dedode:
271
  matcher: Dual-Softmax
272
  feature: dedode
imcui/hloc/extract_features.py CHANGED
@@ -214,6 +214,17 @@ confs = {
214
  "resize_max": 1600,
215
  },
216
  },
 
 
 
 
 
 
 
 
 
 
 
217
  "aliked-n16-rot": {
218
  "output": "feats-aliked-n16-rot",
219
  "model": {
 
214
  "resize_max": 1600,
215
  },
216
  },
217
+ "liftfeat": {
218
+ "output": "feats-liftfeat-n5000-r1600",
219
+ "model": {
220
+ "name": "liftfeat",
221
+ "max_keypoints": 5000,
222
+ },
223
+ "preprocessing": {
224
+ "grayscale": False,
225
+ "resize_max": 1600,
226
+ },
227
+ },
228
  "aliked-n16-rot": {
229
  "output": "feats-aliked-n16-rot",
230
  "model": {
imcui/hloc/extractors/liftfeat.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import sys
3
+ from pathlib import Path
4
+ import torch
5
+ import random
6
+ from ..utils.base_model import BaseModel
7
+ from .. import logger
8
+
9
+ fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
10
+ sys.path.append(str(fire_path))
11
+
12
+ from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
13
+
14
+
15
+ def select_idx(N, M):
16
+ numbers = list(range(1, N + 1))
17
+ selected = random.sample(numbers, M)
18
+ return selected
19
+
20
+
21
+ class Liftfeat(BaseModel):
22
+ default_conf = {
23
+ "keypoint_threshold": 0.05,
24
+ "max_keypoints": 5000,
25
+ }
26
+
27
+ required_inputs = ["image"]
28
+
29
+ def _init(self, conf):
30
+ logger.info("Loading LiftFeat model...")
31
+ self.net = LiftFeat(
32
+ weight=MODEL_PATH,
33
+ detect_threshold=self.conf["keypoint_threshold"],
34
+ top_k=self.conf["max_keypoints"],
35
+ )
36
+ logger.info("Loading LiftFeat model done!")
37
+
38
+ def _forward(self, data):
39
+ image = data["image"].cpu().numpy().squeeze() * 255
40
+ image = image.transpose(1, 2, 0)
41
+ pred = self.net.extract(image)
42
+
43
+ keypoints = pred["keypoints"]
44
+ descriptors = pred["descriptors"]
45
+ scores = torch.ones_like(pred["keypoints"][:, 0])
46
+ if self.conf["max_keypoints"] < len(keypoints):
47
+ idxs = select_idx(len(keypoints), self.conf["max_keypoints"])
48
+ keypoints = keypoints[idxs, :2]
49
+ descriptors = descriptors[idxs]
50
+ scores = scores[idxs]
51
+
52
+ pred = {
53
+ "keypoints": keypoints[None],
54
+ "descriptors": descriptors[None].permute(0, 2, 1),
55
+ "scores": scores[None],
56
+ }
57
+ return pred
imcui/third_party/LiftFeat/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ visualize
2
+ trained_weights
3
+ data/HPatch
4
+ data/megadepth_test_1500
imcui/third_party/LiftFeat/README.md ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## LiftFeat: 3D Geometry-Aware Local Feature Matching
2
+ <div align="center" style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
3
+ <div style="display: flex; justify-content: space-around; width: 100%;">
4
+ <img src='./assert/demo_sp.gif' width="400"/>
5
+ <img src='./assert/demo_liftfeat.gif' width="400"/>
6
+ </div>
7
+
8
+ Real-time SuperPoint demonstration (left) compared to LiftFeat (right) on a textureless scene.
9
+
10
+ </div>
11
+
12
+ - 🎉 **New!** Training code is now available 🚀
13
+ - 🎉 **New!** The test code and pretrained model have been released. 🚀
14
+
15
+ ## Table of Contents
16
+ - [Introduction](#introduction)
17
+ - [Installation](#installation)
18
+ - [Usage](#usage)
19
+ - [Inference](#inference)
20
+ - [Training](#training)
21
+ - [Evaluation](#evaluation)
22
+ - [Citation](#citation)
23
+ - [License](#license)
24
+
25
+ ## Introduction
26
+ This repository contains the official implementation of the paper:
27
+ **[LiftFeat: 3D Geometry-Aware Local Feature Matching](https://www.arxiv.org/abs/2505.03422)**, to be presented at *ICRA 2025*.
28
+
29
+ **Overview of LiftFeat's achitecture**
30
+ <div style="background-color:white">
31
+ <img align="center" src="./assert/achitecture.png" width=1000 />
32
+ </div>
33
+
34
+ LiftFeat is a lightweight and robust local feature matching network designed to handle challenging scenarios such as drastic lighting changes, low-texture regions, and repetitive patterns. By incorporating 3D geometric cues through surface normals predicted from monocular depth, LiftFeat enhances the discriminative power of 2D descriptors. Our proposed 3D geometry-aware feature lifting module effectively fuses these cues, leading to significant improvements in tasks like relative pose estimation, homography estimation, and visual localization.
35
+
36
+ ## Installation
37
+ If you use conda as virtual environment,you can create a new env with:
38
+ ```bash
39
+ git clone https://github.com/lyp-deeplearning/LiftFeat.git
40
+ cd LiftFeat
41
+ conda create -n LiftFeat python=3.8
42
+ conda activate LiftFeat
43
+ pip install -r requirements.txt
44
+ ```
45
+
46
+ ## Usage
47
+ ### Inference
48
+ To run LiftFeat on an image,you can simply run with:
49
+ ```bash
50
+ python demo.py --img1=<reference image> --img2=<query image>
51
+ ```
52
+
53
+ ### Training
54
+ To train LiftFeat as described in the paper, you will need MegaDepth & COCO_20k subset of COCO2017 dataset as described in the paper *[XFeat: Accelerated Features for Lightweight Image Matching](https://arxiv.org/abs/2404.19174)*
55
+ You can obtain the full COCO2017 train data at https://cocodataset.org/.
56
+ However, we [make available](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) a subset of COCO for convenience. We simply selected a subset of 20k images according to image resolution. Please check COCO [terms of use](https://cocodataset.org/#termsofuse) before using the data.
57
+
58
+ To reproduce the training setup from the paper, please follow the steps:
59
+ 1. Download [COCO_20k](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) containing a subset of COCO2017;
60
+ 2. Download MegaDepth dataset. You can follow [LoFTR instructions](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets), we use the same standard as LoFTR. Then put the megadepth indices inside the MegaDepth root folder following the standard below:
61
+ ```bash
62
+ {megadepth_root_path}/train_data/megadepth_indices #indices
63
+ {megadepth_root_path}/MegaDepth_v1 #images & depth maps & poses
64
+ ```
65
+ 3. Finally you can call training
66
+ ```bash
67
+ python train.py --megadepth_root_path <path_to>/MegaDepth --synthetic_root_path <path_to>/coco_20k --ckpt_save_path /path/to/ckpts
68
+ ```
69
+
70
+ ### Evaluation
71
+ All evaluation code are in *evaluation*, you can download **HPatch** dataset following [D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master) and download **MegaDepth** test dataset following [LoFTR](https://github.com/zju3dv/LoFTR/tree/master).
72
+
73
+ **Download and process HPatch**
74
+ ```bash
75
+ cd /data
76
+
77
+ # Download the dataset
78
+ wget https://huggingface.co/datasets/vbalnt/hpatches/resolve/main/hpatches-sequences-release.zip
79
+
80
+ # Extract the dataset
81
+ unzip hpatches-sequences-release.zip
82
+
83
+ # Remove the high-resolution sequences
84
+ cd hpatches-sequences-release
85
+ rm -rf i_contruction i_crownnight i_dc i_pencils i_whitebuilding v_artisans v_astronautis v_talent
86
+
87
+ cd <LiftFeat>/data
88
+
89
+ ln -s /data/hpatches-sequences-release ./HPatch
90
+ ```
91
+
92
+ **Download and process MegaDepth1500**
93
+ We provide download link to [megadepth_test_1500](https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc)
94
+ ```bash
95
+ tar xvf <path to megadepth_test_1500.tar>
96
+
97
+ cd <LiftFeat>/data
98
+
99
+ ln -s <path to megadepth_test_1500> ./megadepth_test_1500
100
+ ```
101
+
102
+
103
+ **Homography Estimation**
104
+ ```bash
105
+ python evaluation/HPatch_evaluation.py
106
+ ```
107
+
108
+ **Relative Pose Estimation**
109
+
110
+ For *Megadepth1500* dataset:
111
+ ```bash
112
+ python evaluation/MegaDepth1500_evaluation.py
113
+ ```
114
+
115
+
116
+ ## Citation
117
+ If you find this code useful for your research, please cite the paper:
118
+ ```bibtex
119
+ @misc{liu2025liftfeat3dgeometryawarelocal,
120
+ title={LiftFeat: 3D Geometry-Aware Local Feature Matching},
121
+ author={Yepeng Liu and Wenpeng Lai and Zhou Zhao and Yuxuan Xiong and Jinchi Zhu and Jun Cheng and Yongchao Xu},
122
+ year={2025},
123
+ eprint={2505.03422},
124
+ archivePrefix={arXiv},
125
+ primaryClass={cs.CV},
126
+ url={https://arxiv.org/abs/2505.03422},
127
+ }
128
+ ```
129
+
130
+ ## License
131
+ [![License](https://img.shields.io/badge/License-Apache_2.0-blue.svg)](LICENSE)
132
+
133
+
134
+ ## Acknowledgements
135
+ We would like to thank the authors of the following open-source repositories for their valuable contributions, which have inspired or supported this work:
136
+
137
+ - [verlab/accelerated_features](https://github.com/verlab/accelerated_features)
138
+ - [zju3dv/LoFTR](https://github.com/zju3dv/LoFTR)
139
+ - [rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint)
140
+
141
+ We deeply appreciate the efforts of the research community in releasing high-quality codebases.
imcui/third_party/LiftFeat/assert/achitecture.png ADDED

Git LFS Details

  • SHA256: a00df3202b47a4dfb12a8b57e40ac36cea77fe4c1fe671c12ba7d785db85da1b
  • Pointer size: 131 Bytes
  • Size of remote file: 529 kB
imcui/third_party/LiftFeat/assert/demo_liftfeat.gif ADDED

Git LFS Details

  • SHA256: b2370fceb92f3f4cc8cd1def7af870469c7a0345c6e5502d618f80b6aa7322d8
  • Pointer size: 132 Bytes
  • Size of remote file: 4.11 MB
imcui/third_party/LiftFeat/assert/demo_sp.gif ADDED

Git LFS Details

  • SHA256: d5f6613eb69830ed1a6c4a09d1bf548f8acdb3d500c743bb73ba939347017892
  • Pointer size: 132 Bytes
  • Size of remote file: 4.7 MB
imcui/third_party/LiftFeat/assert/query.jpg ADDED

Git LFS Details

  • SHA256: d49dd4628d36baaa8f47eedace5c6c45fd67e47bbf5aecf8ff8427fd82e5e463
  • Pointer size: 132 Bytes
  • Size of remote file: 1.22 MB
imcui/third_party/LiftFeat/assert/ref.jpg ADDED

Git LFS Details

  • SHA256: 140a4bb3b353c215e20d1abfa18e264669df3950e1933df8a26448c0d9900838
  • Pointer size: 132 Bytes
  • Size of remote file: 1.09 MB
imcui/third_party/LiftFeat/data/megadepth_1500.json ADDED
The diff for this file is too large to render. See raw diff
 
imcui/third_party/LiftFeat/dataset/__init__.py ADDED
File without changes
imcui/third_party/LiftFeat/dataset/coco_augmentor.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+ COCO_20k image augmentor
4
+ """
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.utils.data import Dataset
9
+ import torch.utils.data as data
10
+ from torchvision import transforms
11
+ import torch.nn.functional as F
12
+
13
+ import cv2
14
+ import kornia
15
+ import kornia.augmentation as K
16
+ from kornia.geometry.transform import get_tps_transform as findTPS
17
+ from kornia.geometry.transform import warp_points_tps, warp_image_tps
18
+
19
+ import glob
20
+ import random
21
+ import tqdm
22
+
23
+ import numpy as np
24
+ import pdb
25
+ import time
26
+
27
+ random.seed(0)
28
+ torch.manual_seed(0)
29
+
30
+ def generateRandomTPS(shape,grid=(8,6),GLOBAL_MULTIPLIER=0.3,prob=0.5):
31
+
32
+ h, w = shape
33
+ sh, sw = h/grid[0], w/grid[1]
34
+ src = torch.dstack(torch.meshgrid(torch.arange(0, h + sh , sh), torch.arange(0, w + sw , sw), indexing='ij'))
35
+
36
+ offsets = torch.rand(grid[0]+1, grid[1]+1, 2) - 0.5
37
+ offsets *= torch.tensor([ sh/2, sw/2 ]).view(1, 1, 2) * min(0.97, 2.0 * GLOBAL_MULTIPLIER)
38
+ dst = src + offsets if np.random.uniform() < prob else src
39
+
40
+ src, dst = src.view(1, -1, 2), dst.view(1, -1, 2)
41
+ src = (src / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
42
+ dst = (dst / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
43
+ weights, A = findTPS(dst, src)
44
+
45
+ return src, weights, A
46
+
47
+
48
+ def generateRandomHomography(shape,GLOBAL_MULTIPLIER=0.3):
49
+ #Generate random in-plane rotation [-theta,+theta]
50
+ theta = np.radians(np.random.uniform(-30, 30))
51
+
52
+ #Generate random scale in both x and y
53
+ scale_x, scale_y = np.random.uniform(0.35, 1.2, 2)
54
+
55
+ #Generate random translation shift
56
+ tx , ty = -shape[1]/2.0 , -shape[0]/2.0
57
+ txn, tyn = np.random.normal(0, 120.0*GLOBAL_MULTIPLIER, 2)
58
+
59
+ c, s = np.cos(theta), np.sin(theta)
60
+
61
+ #Affine coeffs
62
+ sx , sy = np.random.normal(0,0.6*GLOBAL_MULTIPLIER,2)
63
+
64
+ #Projective coeffs
65
+ p1 , p2 = np.random.normal(0,0.006*GLOBAL_MULTIPLIER,2)
66
+
67
+
68
+ # Build Homography from parmeterizations
69
+ H_t = np.array(((1,0, tx), (0, 1, ty), (0,0,1))) #t
70
+ H_r = np.array(((c,-s, 0), (s, c, 0), (0,0,1))) #rotation,
71
+ H_a = np.array(((1,sy, 0), (sx, 1, 0), (0,0,1))) # affine
72
+ H_p = np.array(((1, 0, 0), (0 , 1, 0), (p1,p2,1))) # projective
73
+ H_s = np.array(((scale_x,0, 0), (0, scale_y, 0), (0,0,1))) #scale
74
+ H_b = np.array(((1.0,0,-tx +txn), (0, 1, -ty + tyn), (0,0,1))) #t_back,
75
+
76
+ #H = H_e * H_s * H_a * H_p
77
+ H = np.dot(np.dot(np.dot(np.dot(np.dot(H_b,H_s),H_p),H_a),H_r),H_t)
78
+
79
+ return H
80
+
81
+
82
+ class COCOAugmentor(nn.Module):
83
+
84
+ def __init__(self,device,load_dataset=True,
85
+ img_dir="/home/yepeng_liu/code_python/dataset/coco_20k",
86
+ warp_resolution=(1200, 900),
87
+ out_resolution=(400, 300),
88
+ sides_crop=0.2,
89
+ max_num_imgs=50,
90
+ num_test_imgs=10,
91
+ batch_size=1,
92
+ photometric=True,
93
+ geometric=True,
94
+ reload_step=1_000
95
+ ):
96
+ super(COCOAugmentor,self).__init__()
97
+ self.half=16
98
+ self.device=device
99
+
100
+ self.dims=warp_resolution
101
+ self.batch_size=batch_size
102
+ self.out_resolution=out_resolution
103
+ self.sides_crop=sides_crop
104
+ self.max_num_imgs=max_num_imgs
105
+ self.num_test_imgs=num_test_imgs
106
+ self.dims_t=torch.tensor([int(self.dims[0]*(1. - self.sides_crop)) - int(self.dims[0]*self.sides_crop) -1,
107
+ int(self.dims[1]*(1. - self.sides_crop)) - int(self.dims[1]*self.sides_crop) -1]).float().to(device).view(1,1,2)
108
+ self.dims_s=torch.tensor([self.dims_t[0,0,0] / out_resolution[0],
109
+ self.dims_t[0,0,1] / out_resolution[1]]).float().to(device).view(1,1,2)
110
+
111
+ self.all_imgs=glob.glob(img_dir+'/*.jpg')+glob.glob(img_dir+'/*.png')
112
+
113
+ self.photometric=photometric
114
+ self.geometric=geometric
115
+ self.cnt=1
116
+ self.reload_step=reload_step
117
+
118
+ list_augmentation=[
119
+ kornia.augmentation.ColorJitter(0.15,0.15,0.15,0.15,p=1.),
120
+ kornia.augmentation.RandomEqualize(p=0.4),
121
+ kornia.augmentation.RandomGaussianBlur(p=0.3,sigma=(2.0,2.0),kernel_size=(7,7))
122
+ ]
123
+
124
+ if photometric is False:
125
+ list_augmentation = []
126
+
127
+ self.aug_list=kornia.augmentation.ImageSequential(*list_augmentation)
128
+
129
+ if len(self.all_imgs)<10:
130
+ raise RuntimeError('Couldnt find enough images to train. Please check the path: ',img_dir)
131
+
132
+ if load_dataset:
133
+ print('[COCO]: ',len(self.all_imgs),' images for training..')
134
+ if len(self.all_imgs) - num_test_imgs < max_num_imgs:
135
+ raise RuntimeError('Error: test set overlaps with training set! Decrease number of test imgs')
136
+
137
+ self.load_imgs()
138
+
139
+ self.TPS = True
140
+
141
+
142
+ def load_imgs(self):
143
+ random.shuffle(self.all_imgs)
144
+ train = []
145
+ for p in tqdm.tqdm(self.all_imgs[:self.max_num_imgs],desc='loading train'):
146
+ im=cv2.imread(p)
147
+ halfH,halfW=im.shape[0]//2,im.shape[1]//2
148
+ if halfH>halfW:
149
+ im=np.rot90(im)
150
+ halfH,halfW=halfW,halfH
151
+
152
+ if im.shape[0]!=self.dims[1] or im.shape[1]!=self.dims[0]:
153
+ im = cv2.resize(im, self.dims)
154
+
155
+ train.append(np.copy(im))
156
+
157
+ self.train=train
158
+ self.test=[
159
+ cv2.resize(cv2.imread(p),self.dims)
160
+ for p in tqdm.tqdm(self.all_imgs[-self.num_test_imgs:],desc='loading test')
161
+ ]
162
+
163
+ def norm_pts_grid(self, x):
164
+ if len(x.size()) == 2:
165
+ return (x.view(1,-1,2) * self.dims_s / self.dims_t) * 2. - 1
166
+ return (x * self.dims_s / self.dims_t) * 2. - 1
167
+
168
+ def denorm_pts_grid(self, x):
169
+ if len(x.size()) == 2:
170
+ return ((x.view(1,-1,2) + 1) / 2.) / self.dims_s * self.dims_t
171
+ return ((x+1) / 2.) / self.dims_s * self.dims_t
172
+
173
+ def rnd_kps(self, shape, n = 256):
174
+ h, w = shape
175
+ kps = torch.rand(size = (3,n)).to(self.device)
176
+ kps[0,:]*=w
177
+ kps[1,:]*=h
178
+ kps[2,:] = 1.0
179
+
180
+ return kps
181
+
182
+ def warp_points(self, H, pts):
183
+ scale = self.dims_s.view(-1,2)
184
+ offset = torch.tensor([int(self.dims[0]*self.sides_crop), int(self.dims[1]*self.sides_crop)], device = pts.device).float()
185
+ pts = pts*scale + offset
186
+ pts = torch.vstack( [pts.t(), torch.ones(1, pts.shape[0], device = pts.device)])
187
+ warped = torch.matmul(H, pts)
188
+ warped = warped / warped[2,...]
189
+ warped = warped.t()[:, :2]
190
+ return (warped - offset) / scale
191
+
192
+ @torch.inference_mode()
193
+ def forward(self, x, difficulty = 0.3, TPS = False, prob_deformation = 0.5, test = False):
194
+ """
195
+ Perform augmentation to a batch of images.
196
+
197
+ input:
198
+ x -> torch.Tensor(B, C, H, W): rgb images
199
+ difficulty -> float: level of difficulty, 0.1 is medium, 0.3 is already pretty hard
200
+ tps -> bool: Wether to apply non-rigid deformations in images
201
+ prob_deformation -> float: probability to apply a deformation
202
+
203
+ return:
204
+ 'output' -> torch.Tensor(B, C, H, W): rgb images
205
+ Tuple:
206
+ 'H' -> torch.Tensor(3,3): homography matrix
207
+ 'mask' -> torch.Tensor(B, H, W): mask of valid pixels after warp
208
+ (deformation only)
209
+ src, weights, A are parameters from a TPS warp (all torch.Tensors)
210
+
211
+ """
212
+
213
+ if self.cnt % self.reload_step == 0:
214
+ self.load_imgs()
215
+
216
+ if self.geometric is False:
217
+ difficulty = 0.
218
+
219
+ with torch.no_grad():
220
+ x = (x/255.).to(self.device)
221
+ b, c, h, w = x.shape
222
+ shape = (h, w)
223
+
224
+ ######## Geometric Transformations
225
+
226
+ H = torch.tensor(np.array([generateRandomHomography(shape,difficulty) for b in range(self.batch_size)]),dtype=torch.float32).to(self.device)
227
+
228
+ output = kornia.geometry.transform.warp_perspective(x,H,dsize=shape,padding_mode='zeros')
229
+
230
+ #crop % of image boundaries each side to reduce invalid pixels after warps
231
+ low_h = int(h * self.sides_crop); low_w = int(w*self.sides_crop)
232
+ high_h = int(h*(1. - self.sides_crop)); high_w= int(w * (1. - self.sides_crop))
233
+ output = output[..., low_h:high_h, low_w:high_w]
234
+ x = x[..., low_h:high_h, low_w:high_w]
235
+
236
+ #apply TPS if desired:
237
+ if TPS:
238
+ src, weights, A = None, None, None
239
+ for b in range(self.batch_size):
240
+ b_src, b_weights, b_A = generateRandomTPS(shape, (8,6), difficulty, prob = prob_deformation)
241
+ b_src, b_weights, b_A = b_src.to(self.device), b_weights.to(self.device), b_A.to(self.device)
242
+
243
+ if src is None:
244
+ src, weights, A = b_src, b_weights, b_A
245
+ else:
246
+ src = torch.cat((b_src, src))
247
+ weights = torch.cat((b_weights, weights))
248
+ A = torch.cat((b_A, A))
249
+
250
+ output = warp_image_tps(output, src, weights, A)
251
+
252
+ output = F.interpolate(output, self.out_resolution[::-1], mode = 'nearest')
253
+ x = F.interpolate(x, self.out_resolution[::-1], mode = 'nearest')
254
+
255
+ mask = ~torch.all(output == 0, dim=1, keepdim=True)
256
+ mask = mask.expand(-1,3,-1,-1)
257
+
258
+ # Make-up invalid regions with texture from the batch
259
+ rv = 1 if not TPS else 2
260
+ output_shifted = torch.roll(x, rv, 0)
261
+ output[~mask] = output_shifted[~mask]
262
+ mask = mask[:, 0, :, :]
263
+
264
+ ######## Photometric Transformations
265
+ output = self.aug_list(output)
266
+
267
+ b, c, h, w = output.shape
268
+ #Correlated Gaussian Noise
269
+ if np.random.uniform() > 0.5 and self.photometric:
270
+ noise = F.interpolate(torch.randn_like(output)*(10/255), (h//2, w//2))
271
+ noise = F.interpolate(noise, (h, w), mode = 'bicubic')
272
+ output = torch.clip( output + noise, 0., 1.)
273
+
274
+ #Random shadows
275
+ if np.random.uniform() > 0.6 and self.photometric:
276
+ noise = torch.rand((b, 1, h//64, w//64), device = self.device) * 1.3
277
+ noise = torch.clip(noise, 0.25, 1.0)
278
+ noise = F.interpolate(noise, (h, w), mode = 'bicubic')
279
+ noise = noise.expand(-1, 3, -1, -1)
280
+ output *= noise
281
+ output = torch.clip( output, 0., 1.)
282
+
283
+ self.cnt+=1
284
+
285
+ if TPS:
286
+ return output, (H, src, weights, A, mask)
287
+ else:
288
+ return output, (H, mask)
289
+
290
+ def get_correspondences(self, kps_target, T):
291
+ H, H2, src, W, A = T
292
+ undeformed = self.denorm_pts_grid(
293
+ warp_points_tps(self.norm_pts_grid(kps_target),
294
+ src, W, A) ).view(-1,2)
295
+
296
+ warped_to_src = self.warp_points([email protected](H2), undeformed)
297
+
298
+ return warped_to_src
imcui/third_party/LiftFeat/dataset/coco_wrapper.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import pdb
4
+
5
+ debug_cnt = -1
6
+
7
+ def make_batch(augmentor, difficulty = 0.3, train = True):
8
+ Hs = []
9
+ img_list = augmentor.train if train else augmentor.test
10
+ dev = augmentor.device
11
+ batch_images = []
12
+
13
+ with torch.no_grad(): # we dont require grads in the augmentation
14
+ for b in range(augmentor.batch_size):
15
+ rdidx = np.random.randint(len(img_list))
16
+ img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0)
17
+ batch_images.append(img)
18
+
19
+ batch_images = torch.cat(batch_images)
20
+
21
+ p1, H1 = augmentor(batch_images, difficulty)
22
+ p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7)
23
+ # p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7)
24
+
25
+ return p1, p2, H1, H2
26
+
27
+
28
+ def plot_corrs(p1, p2, src_pts, tgt_pts):
29
+ import matplotlib.pyplot as plt
30
+ p1 = p1.cpu()
31
+ p2 = p2.cpu()
32
+ src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu()
33
+ rnd_idx = np.random.randint(len(src_pts), size=200)
34
+ src_pts = src_pts[rnd_idx, ...]
35
+ tgt_pts = tgt_pts[rnd_idx, ...]
36
+
37
+ #Plot ground-truth correspondences
38
+ fig, ax = plt.subplots(1,2,figsize=(18, 12))
39
+ colors = np.random.uniform(size=(len(tgt_pts),3))
40
+ #Src image
41
+ img = p1
42
+ for i, p in enumerate(src_pts):
43
+ ax[0].scatter(p[0],p[1],color=colors[i])
44
+ ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1])
45
+
46
+ #Target img
47
+ img2 = p2
48
+ for i, p in enumerate(tgt_pts):
49
+ ax[1].scatter(p[0],p[1],color=colors[i])
50
+ ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1])
51
+ plt.show()
52
+
53
+
54
+ def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None):
55
+ '''
56
+ Get dense corresponding points
57
+ '''
58
+ global debug_cnt
59
+ negatives, positives = [], []
60
+
61
+ with torch.no_grad():
62
+ #real input res of samples
63
+ rh, rw = p1.shape[-2:]
64
+ ratio = torch.tensor([rw/w, rh/h], device = p1.device)
65
+
66
+ (H, mask1) = H
67
+ (H2, src, W, A, mask2) = H2
68
+
69
+ #Generate meshgrid of target pts
70
+ x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy')
71
+ mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
72
+ target_pts = mesh.view(-1, 2) * ratio
73
+
74
+ #Pack all transformations into T
75
+ for batch_idx in range(len(p1)):
76
+ with torch.no_grad():
77
+ T = (H[batch_idx], H2[batch_idx],
78
+ src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0))
79
+ #We now warp the target points to src image
80
+ src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src
81
+ tgt_pts = (target_pts)
82
+
83
+ #Check out of bounds points
84
+ mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \
85
+ (src_pts[:, 0] < rw) & (src_pts[:, 1] < rh)
86
+
87
+ negatives.append( tgt_pts[~mask_valid] )
88
+ tgt_pts = tgt_pts[mask_valid]
89
+ src_pts = src_pts[mask_valid]
90
+
91
+
92
+ #Remove invalid pixels
93
+ mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \
94
+ mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()]
95
+ tgt_pts = tgt_pts[mask_valid]
96
+ src_pts = src_pts[mask_valid]
97
+
98
+ # limit nb of matches if desired
99
+ if crop is not None:
100
+ rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop]
101
+ src_pts = src_pts[rnd_idx]
102
+ tgt_pts = tgt_pts[rnd_idx]
103
+
104
+ if debug_cnt >=0 and debug_cnt < 4:
105
+ plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts )
106
+ debug_cnt +=1
107
+
108
+ src_pts = (src_pts / ratio)
109
+ tgt_pts = (tgt_pts / ratio)
110
+
111
+ #Check out of bounds points
112
+ padto = 10 if crop is not None else 2
113
+ mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \
114
+ (src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto))
115
+ mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \
116
+ (tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto))
117
+ mask_valid = mask_valid1 & mask_valid2
118
+ tgt_pts = tgt_pts[mask_valid]
119
+ src_pts = src_pts[mask_valid]
120
+
121
+ #Remove repeated correspondences
122
+ lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1
123
+ # src_pts_np = src_pts.cpu().numpy()
124
+ # tgt_pts_np = tgt_pts.cpu().numpy()
125
+ try:
126
+ lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
127
+ mask_valid = torch.all(lut_mat >= 0, dim=-1)
128
+ points = lut_mat[mask_valid]
129
+ positives.append(points)
130
+ except:
131
+ pdb.set_trace()
132
+ print('..')
133
+
134
+ return negatives, positives
135
+
136
+
137
+ def crop_patches(tensor, coords, size = 7):
138
+ '''
139
+ Crop [size x size] patches around 2D coordinates from a tensor.
140
+ '''
141
+ B, C, H, W = tensor.shape
142
+
143
+ x, y = coords[:, 0], coords[:, 1]
144
+ y = y.view(-1, 1, 1)
145
+ x = x.view(-1, 1, 1)
146
+ halfsize = size // 2
147
+ # Create meshgrid for indexing
148
+ x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy')
149
+ y_offset = y_offset.to(tensor.device)
150
+ x_offset = x_offset.to(tensor.device)
151
+
152
+ # Compute indices around each coordinate
153
+ y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize
154
+ x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize
155
+
156
+ # Handle out-of-boundary indices with padding
157
+ tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant')
158
+
159
+ # Index tensor to get patches
160
+ patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W]
161
+ return patches
162
+
163
+ def subpix_softmax2d(heatmaps, temp = 0.25):
164
+ N, H, W = heatmaps.shape
165
+ heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
166
+ x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
167
+ x = x - (W//2)
168
+ y = y - (H//2)
169
+ #pdb.set_trace()
170
+ coords_x = (x[None, ...] * heatmaps)
171
+ coords_y = (y[None, ...] * heatmaps)
172
+ coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
173
+ coords = coords.sum(1)
174
+
175
+ return coords
imcui/third_party/LiftFeat/dataset/dataset_utils.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+
4
+ MegaDepth data handling was adapted from
5
+ LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6
+ """
7
+
8
+ import io
9
+ import cv2
10
+ import numpy as np
11
+ import h5py
12
+ import torch
13
+ from numpy.linalg import inv
14
+
15
+
16
+ try:
17
+ # for internel use only
18
+ from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
19
+ except Exception:
20
+ MEGADEPTH_CLIENT = SCANNET_CLIENT = None
21
+
22
+ # --- DATA IO ---
23
+
24
+ def load_array_from_s3(
25
+ path, client, cv_type,
26
+ use_h5py=False,
27
+ ):
28
+ byte_str = client.Get(path)
29
+ try:
30
+ if not use_h5py:
31
+ raw_array = np.fromstring(byte_str, np.uint8)
32
+ data = cv2.imdecode(raw_array, cv_type)
33
+ else:
34
+ f = io.BytesIO(byte_str)
35
+ data = np.array(h5py.File(f, 'r')['/depth'])
36
+ except Exception as ex:
37
+ print(f"==> Data loading failure: {path}")
38
+ raise ex
39
+
40
+ assert data is not None
41
+ return data
42
+
43
+
44
+ def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
45
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
46
+ else cv2.IMREAD_COLOR
47
+ if str(path).startswith('s3://'):
48
+ image = load_array_from_s3(str(path), client, cv_type)
49
+ else:
50
+ image = cv2.imread(str(path), 1)
51
+
52
+ if augment_fn is not None:
53
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
54
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
55
+ image = augment_fn(image)
56
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
57
+ return image # (h, w)
58
+
59
+
60
+ def get_resized_wh(w, h, resize=None):
61
+ if resize is not None: # resize the longer edge
62
+ scale = resize / max(h, w)
63
+ w_new, h_new = int(round(w*scale)), int(round(h*scale))
64
+ else:
65
+ w_new, h_new = w, h
66
+ return w_new, h_new
67
+
68
+
69
+ def get_divisible_wh(w, h, df=None):
70
+ if df is not None:
71
+ w_new, h_new = map(lambda x: int(x // df * df), [w, h])
72
+ else:
73
+ w_new, h_new = w, h
74
+ return w_new, h_new
75
+
76
+
77
+ def pad_bottom_right(inp, pad_size, ret_mask=False):
78
+ assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
79
+ mask = None
80
+ if inp.ndim == 2:
81
+ padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
82
+ padded[:inp.shape[0], :inp.shape[1]] = inp
83
+ if ret_mask:
84
+ mask = np.zeros((pad_size, pad_size), dtype=bool)
85
+ mask[:inp.shape[0], :inp.shape[1]] = True
86
+ elif inp.ndim == 3:
87
+ padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
88
+ padded[:, :inp.shape[1], :inp.shape[2]] = inp
89
+ if ret_mask:
90
+ mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
91
+ mask[:, :inp.shape[1], :inp.shape[2]] = True
92
+ else:
93
+ raise NotImplementedError()
94
+ return padded, mask
95
+
96
+
97
+ # --- MEGADEPTH ---
98
+
99
+ def fix_path_from_d2net(path):
100
+ if not path:
101
+ return None
102
+
103
+ path = path.replace('Undistorted_SfM/', '')
104
+ path = path.replace('images', 'dense0/imgs')
105
+ path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '')
106
+
107
+ return path
108
+
109
+ def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
110
+ """
111
+ Args:
112
+ resize (int, optional): the longer edge of resized images. None for no resize.
113
+ padding (bool): If set to 'True', zero-pad resized images to squared size.
114
+ augment_fn (callable, optional): augments images with pre-defined visual effects
115
+ Returns:
116
+ image (torch.tensor): (1, h, w)
117
+ mask (torch.tensor): (h, w)
118
+ scale (torch.tensor): [w/w_new, h/h_new]
119
+ """
120
+ # read image
121
+ image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
122
+
123
+ # resize image
124
+ w, h = image.shape[1], image.shape[0]
125
+
126
+ if resize is not None:
127
+ if len(resize) == 2:
128
+ w_new, h_new = resize
129
+ else:
130
+ resize = resize[0]
131
+ w_new, h_new = get_resized_wh(w, h, resize)
132
+ w_new, h_new = get_divisible_wh(w_new, h_new, df)
133
+
134
+
135
+ image = cv2.resize(image, (w_new, h_new))
136
+ scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
137
+
138
+ if padding: # padding
139
+ pad_to = max(h_new, w_new)
140
+ image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
141
+ else:
142
+ mask = None
143
+ else:
144
+ scale=torch.tensor([1.0,1.0],dtype=torch.float)
145
+
146
+ if padding:
147
+ pad_to=max(w,h)
148
+ image,mask=pad_bottom_right(image,pad_to,ret_mask=True)
149
+ else:
150
+ mask=None
151
+
152
+ #image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
153
+ image_t = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized
154
+ mask = torch.from_numpy(mask) if mask is not None else None
155
+
156
+ return image, image_t, mask, scale
157
+
158
+
159
+ def read_megadepth_depth(path, pad_to=None):
160
+
161
+ if str(path).startswith('s3://'):
162
+ depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
163
+ else:
164
+ depth = np.array(h5py.File(path, 'r')['depth'])
165
+ if pad_to is not None:
166
+ depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
167
+ depth = torch.from_numpy(depth).float() # (h, w)
168
+ return depth
169
+
170
+
171
+ def imread_bgr(path, augment_fn=None, client=SCANNET_CLIENT):
172
+ cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
173
+ if str(path).startswith('s3://'):
174
+ image = load_array_from_s3(str(path), client, cv_type)
175
+ else:
176
+ image = cv2.imread(str(path), 1)
177
+
178
+ if augment_fn is not None:
179
+ image = cv2.imread(str(path), cv2.IMREAD_COLOR)
180
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
181
+ image = augment_fn(image)
182
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
183
+ return image # (h, w)
imcui/third_party/LiftFeat/dataset/megadepth.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+
4
+ MegaDepth data handling was adapted from
5
+ LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6
+ """
7
+
8
+ import os.path as osp
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch.utils.data import Dataset
13
+ import glob
14
+ import numpy.random as rnd
15
+
16
+ import os
17
+ import sys
18
+ sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
19
+ from dataset.dataset_utils import read_megadepth_gray, read_megadepth_depth, fix_path_from_d2net
20
+
21
+ import pdb, tqdm, os
22
+
23
+
24
+ class MegaDepthDataset(Dataset):
25
+ def __init__(self,
26
+ root_dir,
27
+ npz_path,
28
+ mode='train',
29
+ min_overlap_score = 0.3, #0.3,
30
+ max_overlap_score = 1.0, #1,
31
+ load_depth = True,
32
+ img_resize = (800,608), #or None
33
+ df=32,
34
+ img_padding=False,
35
+ depth_padding=True,
36
+ augment_fn=None,
37
+ **kwargs):
38
+ """
39
+ Manage one scene(npz_path) of MegaDepth dataset.
40
+
41
+ Args:
42
+ root_dir (str): megadepth root directory that has `phoenix`.
43
+ npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
44
+ mode (str): options are ['train', 'val', 'test']
45
+ min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
46
+ img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
47
+ This is useful during training with batches and testing with memory intensive algorithms.
48
+ df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
49
+ img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
50
+ depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
51
+ augment_fn (callable, optional): augments images with pre-defined visual effects.
52
+ """
53
+ super().__init__()
54
+ self.root_dir = root_dir
55
+ self.mode = mode
56
+ self.scene_id = npz_path.split('.')[0]
57
+ self.load_depth = load_depth
58
+ # prepare scene_info and pair_info
59
+ if mode == 'test' and min_overlap_score != 0:
60
+ min_overlap_score = 0
61
+ self.scene_info = np.load(npz_path, allow_pickle=True)
62
+ self.pair_infos = self.scene_info['pair_infos'].copy()
63
+ del self.scene_info['pair_infos']
64
+ self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score]
65
+
66
+ # parameters for image resizing, padding and depthmap padding
67
+ if mode == 'train':
68
+ assert img_resize is not None #and img_padding and depth_padding
69
+
70
+ self.img_resize = img_resize
71
+ self.df = df
72
+ self.img_padding = img_padding
73
+ self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
74
+
75
+ # for training LoFTR
76
+ self.augment_fn = augment_fn if mode == 'train' else None
77
+ self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
78
+ #pdb.set_trace()
79
+ for idx in range(len(self.scene_info['image_paths'])):
80
+ self.scene_info['image_paths'][idx] = fix_path_from_d2net(self.scene_info['image_paths'][idx])
81
+
82
+ for idx in range(len(self.scene_info['depth_paths'])):
83
+ self.scene_info['depth_paths'][idx] = fix_path_from_d2net(self.scene_info['depth_paths'][idx])
84
+
85
+
86
+ def __len__(self):
87
+ return len(self.pair_infos)
88
+
89
+ def __getitem__(self, idx):
90
+ (idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)]
91
+
92
+ # read grayscale image and mask. (1, h, w) and (h, w)
93
+ img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
94
+ img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
95
+
96
+ # TODO: Support augmentation & handle seeds for each worker correctly.
97
+ image0, image0_t, mask0, scale0 = read_megadepth_gray(img_name0, self.img_resize, self.df, self.img_padding, None)
98
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
99
+ image1, image1_t, mask1, scale1 = read_megadepth_gray(img_name1, self.img_resize, self.df, self.img_padding, None)
100
+ # np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
101
+
102
+ if self.load_depth:
103
+ # read depth. shape: (h, w)
104
+ if self.mode in ['train', 'val']:
105
+ depth0 = read_megadepth_depth(
106
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
107
+ depth1 = read_megadepth_depth(
108
+ osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
109
+ else:
110
+ depth0 = depth1 = torch.tensor([])
111
+
112
+ # read intrinsics of original size
113
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
114
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
115
+
116
+ # read and compute relative poses
117
+ T0 = self.scene_info['poses'][idx0]
118
+ T1 = self.scene_info['poses'][idx1]
119
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
120
+ T_1to0 = T_0to1.inverse()
121
+
122
+ data = {
123
+ 'image0': image0_t, # (1, h, w)
124
+ 'image0_np': image0,
125
+ 'depth0': depth0, # (h, w)
126
+ 'image1': image1_t,
127
+ 'image1_np': image1,
128
+ 'depth1': depth1,
129
+ 'T_0to1': T_0to1, # (4, 4)
130
+ 'T_1to0': T_1to0,
131
+ 'K0': K_0, # (3, 3)
132
+ 'K1': K_1,
133
+ 'scale0': scale0, # [scale_w, scale_h]
134
+ 'scale1': scale1,
135
+ 'dataset_name': 'MegaDepth',
136
+ 'scene_id': self.scene_id,
137
+ 'pair_id': idx,
138
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
139
+ }
140
+
141
+ # for LoFTR training
142
+ if mask0 is not None: # img_padding is True
143
+ if self.coarse_scale:
144
+ [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
145
+ scale_factor=self.coarse_scale,
146
+ mode='nearest',
147
+ recompute_scale_factor=False)[0].bool()
148
+ data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
149
+
150
+ else:
151
+
152
+ # read intrinsics of original size
153
+ K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
154
+ K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
155
+
156
+ # read and compute relative poses
157
+ T0 = self.scene_info['poses'][idx0]
158
+ T1 = self.scene_info['poses'][idx1]
159
+ T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
160
+ T_1to0 = T_0to1.inverse()
161
+
162
+ data = {
163
+ 'image0': image0, # (1, h, w)
164
+ 'image1': image1,
165
+ 'T_0to1': T_0to1, # (4, 4)
166
+ 'T_1to0': T_1to0,
167
+ 'K0': K_0, # (3, 3)
168
+ 'K1': K_1,
169
+ 'scale0': scale0, # [scale_w, scale_h]
170
+ 'scale1': scale1,
171
+ 'dataset_name': 'MegaDepth',
172
+ 'scene_id': self.scene_id,
173
+ 'pair_id': idx,
174
+ 'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
175
+ }
176
+
177
+ return data
imcui/third_party/LiftFeat/dataset/megadepth_wrapper.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+
4
+ MegaDepth data handling was adapted from
5
+ LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
6
+ """
7
+
8
+ import torch
9
+ from kornia.utils import create_meshgrid
10
+ import matplotlib.pyplot as plt
11
+ import pdb
12
+ import cv2
13
+
14
+ @torch.no_grad()
15
+ def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
16
+ """ Warp kpts0 from I0 to I1 with depth, K and Rt
17
+ Also check covisibility and depth consistency.
18
+ Depth is consistent if relative error < 0.2 (hard-coded).
19
+
20
+ Args:
21
+ kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
22
+ depth0 (torch.Tensor): [N, H, W],
23
+ depth1 (torch.Tensor): [N, H, W],
24
+ T_0to1 (torch.Tensor): [N, 3, 4],
25
+ K0 (torch.Tensor): [N, 3, 3],
26
+ K1 (torch.Tensor): [N, 3, 3],
27
+ Returns:
28
+ calculable_mask (torch.Tensor): [N, L]
29
+ warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
30
+ """
31
+ kpts0_long = kpts0.round().long().clip(0, 2000-1)
32
+
33
+ depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0
34
+ depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0
35
+
36
+ # Sample depth, get calculable_mask on depth != 0
37
+ kpts0_depth = torch.stack(
38
+ [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
39
+ ) # (N, L)
40
+ nonzero_mask = kpts0_depth > 0
41
+
42
+ # Draw cross marks on the image for each keypoint
43
+ # for b in range(len(kpts0)):
44
+ # fig, ax = plt.subplots(1,2)
45
+ # depth_np = depth0.numpy()[b]
46
+ # depth_np_plot = depth_np.copy()
47
+ # for x, y in kpts0_long[b, nonzero_mask[b], :].numpy():
48
+ # cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2)
49
+ # ax[0].imshow(depth_np)
50
+ # ax[1].imshow(depth_np_plot)
51
+
52
+ # Unproject
53
+ kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
54
+ kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
55
+
56
+ # Rigid Transform
57
+ w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
58
+ w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
59
+
60
+ # Project
61
+ w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
62
+ w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth
63
+
64
+ # Covisible Check
65
+ # h, w = depth1.shape[1:3]
66
+ # covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
67
+ # (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
68
+ # w_kpts0_long = w_kpts0.long()
69
+ # w_kpts0_long[~covisible_mask, :] = 0
70
+
71
+ # w_kpts0_depth = torch.stack(
72
+ # [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
73
+ # ) # (N, L)
74
+ # consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
75
+
76
+
77
+ valid_mask = nonzero_mask #* consistent_mask* covisible_mask
78
+
79
+ return valid_mask, w_kpts0
80
+
81
+
82
+ @torch.no_grad()
83
+ def spvs_coarse(data, scale = 8):
84
+ """
85
+ Supervise corresp with dense depth & camera poses
86
+ """
87
+
88
+ # 1. misc
89
+ device = data['image0'].device
90
+ N, _, H0, W0 = data['image0'].shape
91
+ _, _, H1, W1 = data['image1'].shape
92
+ #scale = 8
93
+ scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
94
+ scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
95
+ h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
96
+
97
+ # 2. warp grids
98
+ # create kpts in meshgrid and resize them to image resolution
99
+ grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2]
100
+ grid_pt1_i = scale1 * grid_pt1_c
101
+
102
+ # warp kpts bi-directionally and check reproj error
103
+ nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
104
+ nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
105
+
106
+
107
+ dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1)
108
+ mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2
109
+
110
+ #_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
111
+ batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i],
112
+ grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))]
113
+
114
+
115
+ #Remove repeated correspondences - this is important for network convergence
116
+ corrs = []
117
+ for pts in batched_corrs:
118
+ lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
119
+ lut_mat21 = torch.clone(lut_mat12)
120
+ src_pts = pts[:, :2] / scale
121
+ tgt_pts = pts[:, 2:] / scale
122
+ try:
123
+ lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
124
+ mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
125
+ points = lut_mat12[mask_valid12]
126
+
127
+ #Target-src check
128
+ src_pts, tgt_pts = points[:, :2], points[:, 2:]
129
+ lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
130
+ mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
131
+ points = lut_mat21[mask_valid21]
132
+
133
+ corrs.append(points)
134
+ except:
135
+ pdb.set_trace()
136
+ print('..')
137
+
138
+ #Plot for debug purposes
139
+ # for i in range(len(corrs)):
140
+ # plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
141
+
142
+ return corrs
143
+
144
+ @torch.no_grad()
145
+ def get_correspondences(pts2, data, idx):
146
+ device = data['image0'].device
147
+ N, _, H0, W0 = data['image0'].shape
148
+ _, _, H1, W1 = data['image1'].shape
149
+
150
+ pts2 = pts2[None, ...]
151
+
152
+ scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1
153
+ scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1
154
+
155
+ pts2 = scale1 * pts2 * 8
156
+
157
+ # warp kpts bi-directionally and check reproj error
158
+ nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...],
159
+ data['K1'][idx][None, ...], data['K0'][idx][None, ...])
160
+
161
+ corrs = torch.cat([pts1[0, :] / data['scale0'][idx],
162
+ pts2[0, :] / data['scale1'][idx]],dim=-1)
163
+
164
+ #plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:])
165
+
166
+ return corrs
167
+
imcui/third_party/LiftFeat/demo.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import math
6
+ import cv2
7
+
8
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
9
+
10
+ from models.liftfeat_wrapper import LiftFeat,MODEL_PATH
11
+
12
+ import argparse
13
+
14
+ parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
15
+ parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
16
+ parser.add_argument('--img1',type=str,default='./assert/ref.jpg',help='reference image path')
17
+ parser.add_argument('--img2',type=str,default='./assert/query.jpg',help='query image path')
18
+ parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
19
+ args=parser.parse_args()
20
+
21
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
22
+
23
+
24
+ def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
25
+ # Calculate the Homography matrix
26
+ H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
27
+ mask = mask.flatten()
28
+
29
+ # Get corners of the first image (image1)
30
+ h, w = img1.shape[:2]
31
+ corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)
32
+
33
+ # Warp corners to the second image (image2) space
34
+ warped_corners = cv2.perspectiveTransform(corners_img1, H)
35
+
36
+ # Draw the warped corners in image2
37
+ img2_with_corners = img2.copy()
38
+
39
+ # Prepare keypoints and matches for drawMatches function
40
+ keypoints1 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in ref_points]
41
+ keypoints2 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in dst_points]
42
+ matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]
43
+
44
+ # Draw inlier matches
45
+ img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
46
+ matchColor=(0, 255, 0), flags=2)
47
+
48
+ return img_matches
49
+
50
+
51
+ if __name__=="__main__":
52
+ liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.05)
53
+
54
+ img1=cv2.imread(args.img1)
55
+ img2=cv2.imread(args.img2)
56
+
57
+ # import pdb;pdb.set_trace()
58
+ mkpts1,mkpts2=liftfeat.match_liftfeat(img1,img2)
59
+ canvas=warp_corners_and_draw_matches(mkpts1,mkpts2,img1,img2)
60
+
61
+ import matplotlib.pyplot as plt
62
+ plt.figure(figsize=[12,12])
63
+ plt.imshow(canvas[...,::-1])
64
+
65
+ plt.savefig(os.path.join(os.path.dirname(__file__),'match.jpg'), dpi=300, bbox_inches='tight')
66
+
67
+ plt.show()
68
+
imcui/third_party/LiftFeat/evaluation/HPatch_evaluation.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import os
3
+ from tqdm import tqdm
4
+ import torch
5
+ import numpy as np
6
+ import sys
7
+ import poselib
8
+
9
+ sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
10
+
11
+ import argparse
12
+ import datetime
13
+
14
+ parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
15
+ parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
16
+ parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
17
+ args=parser.parse_args()
18
+
19
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
20
+
21
+ use_cuda = torch.cuda.is_available()
22
+ device = torch.device("cuda" if use_cuda else "cpu")
23
+
24
+ top_k = None
25
+ n_i = 52
26
+ n_v = 56
27
+
28
+ DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/HPatch')
29
+
30
+ from evaluation.eval_utils import *
31
+ from models.liftfeat_wrapper import LiftFeat
32
+
33
+
34
+ poselib_config = {"ransac_th": 3.0, "options": {}}
35
+
36
+ class PoseLibHomographyEstimator:
37
+ def __init__(self, conf):
38
+ self.conf = conf
39
+
40
+ def estimate(self, mkpts0,mkpts1):
41
+ M, info = poselib.estimate_homography(
42
+ mkpts0,
43
+ mkpts1,
44
+ {
45
+ "max_reproj_error": self.conf["ransac_th"],
46
+ **self.conf["options"],
47
+ },
48
+ )
49
+ success = M is not None
50
+ if not success:
51
+ M = np.eye(3,dtype=np.float32)
52
+ inl = np.zeros(mkpts0.shape[0],dtype=np.bool_)
53
+ else:
54
+ inl = info["inliers"]
55
+
56
+ estimation = {
57
+ "success": success,
58
+ "M_0to1": M,
59
+ "inliers": inl,
60
+ }
61
+
62
+ return estimation
63
+
64
+
65
+ estimator=PoseLibHomographyEstimator(poselib_config)
66
+
67
+
68
+ def poselib_homography_estimate(mkpts0,mkpts1):
69
+ data=estimator.estimate(mkpts0,mkpts1)
70
+ return data
71
+
72
+
73
+ def generate_standard_image(img,target_size=(1920,1080)):
74
+ sh,sw=img.shape[0],img.shape[1]
75
+ rh,rw=float(target_size[1])/float(sh),float(target_size[0])/float(sw)
76
+ ratio=min(rh,rw)
77
+ nh,nw=int(ratio*sh),int(ratio*sw)
78
+ ph,pw=target_size[1]-nh,target_size[0]-nw
79
+ nimg=cv2.resize(img,(nw,nh))
80
+ nimg=cv2.copyMakeBorder(nimg,0,ph,0,pw,cv2.BORDER_CONSTANT,value=(0,0,0))
81
+
82
+ return nimg,ratio,ph,pw
83
+
84
+
85
+ def benchmark_features(match_fn):
86
+ lim = [1, 9]
87
+ rng = np.arange(lim[0], lim[1] + 1)
88
+
89
+ seq_names = sorted(os.listdir(DATASET_ROOT))
90
+
91
+ n_feats = []
92
+ n_matches = []
93
+ seq_type = []
94
+ i_err = {thr: 0 for thr in rng}
95
+ v_err = {thr: 0 for thr in rng}
96
+
97
+ i_err_homo = {thr: 0 for thr in rng}
98
+ v_err_homo = {thr: 0 for thr in rng}
99
+
100
+ for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)):
101
+ # load reference image
102
+ ref_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, "1.ppm"))
103
+ ref_img_shape=ref_img.shape
104
+
105
+ # load query images
106
+ for im_idx in range(2, 7):
107
+ # read ground-truth homography
108
+ homography = np.loadtxt(os.path.join(DATASET_ROOT, seq_name, "H_1_" + str(im_idx)))
109
+ query_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, f"{im_idx}.ppm"))
110
+
111
+ mkpts_a,mkpts_b=match_fn(ref_img,query_img)
112
+
113
+ pos_a = mkpts_a
114
+ pos_a_h = np.concatenate([pos_a, np.ones([pos_a.shape[0], 1])], axis=1)
115
+ pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
116
+ pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
117
+
118
+ pos_b = mkpts_b
119
+
120
+ dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
121
+
122
+ n_matches.append(pos_a.shape[0])
123
+ seq_type.append(seq_name[0])
124
+
125
+ if dist.shape[0] == 0:
126
+ dist = np.array([float("inf")])
127
+
128
+ for thr in rng:
129
+ if seq_name[0] == "i":
130
+ i_err[thr] += np.mean(dist <= thr)
131
+ else:
132
+ v_err[thr] += np.mean(dist <= thr)
133
+
134
+ # estimate homography
135
+ gt_homo = homography
136
+ pred_homo, _ = cv2.findHomography(mkpts_a,mkpts_b,cv2.USAC_MAGSAC)
137
+ if pred_homo is None:
138
+ homo_dist = np.array([float("inf")])
139
+ else:
140
+ corners = np.array(
141
+ [
142
+ [0, 0],
143
+ [ref_img_shape[1] - 1, 0],
144
+ [0, ref_img_shape[0] - 1],
145
+ [ref_img_shape[1] - 1, ref_img_shape[0] - 1],
146
+ ]
147
+ )
148
+ real_warped_corners = homo_trans(corners, gt_homo)
149
+ warped_corners = homo_trans(corners, pred_homo)
150
+ homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
151
+
152
+ for thr in rng:
153
+ if seq_name[0] == "i":
154
+ i_err_homo[thr] += np.mean(homo_dist <= thr)
155
+ else:
156
+ v_err_homo[thr] += np.mean(homo_dist <= thr)
157
+
158
+ seq_type = np.array(seq_type)
159
+ n_feats = np.array(n_feats)
160
+ n_matches = np.array(n_matches)
161
+
162
+ return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
163
+
164
+
165
+ if __name__ == "__main__":
166
+ errors = {}
167
+
168
+ weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
169
+ liftfeat=LiftFeat(weight=weights)
170
+
171
+ errors = benchmark_features(liftfeat.match_liftfeat)
172
+
173
+ i_err, v_err, i_err_hom, v_err_hom, _ = errors
174
+
175
+ cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
176
+
177
+ print(f'\n==={cur_time}==={args.name}===')
178
+ print(f"MHA@3 MHA@5 MHA@7")
179
+ for thr in [3, 5, 7]:
180
+ ill_err_hom = i_err_hom[thr] / (n_i * 5)
181
+ view_err_hom = v_err_hom[thr] / (n_v * 5)
182
+ print(f"{ill_err_hom * 100:.2f}%-{view_err_hom * 100:.2f}%")
imcui/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import cv2
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import torch
7
+ import torch.utils.data as data
8
+ import tqdm
9
+ from copy import deepcopy
10
+ from torchvision.transforms import ToTensor
11
+ import torch.nn.functional as F
12
+ import json
13
+
14
+ import scipy.io as scio
15
+ import poselib
16
+
17
+ import argparse
18
+ import datetime
19
+
20
+ parser=argparse.ArgumentParser(description='MegaDepth dataset evaluation script')
21
+ parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
22
+ parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
23
+ args=parser.parse_args()
24
+
25
+ os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
26
+
27
+ sys.path.append(os.path.join(os.path.dirname(__file__),'../'))
28
+ from models.liftfeat_wrapper import LiftFeat
29
+ from evaluation.eval_utils import *
30
+
31
+ from torch.utils.data import Dataset,DataLoader
32
+
33
+ use_cuda = torch.cuda.is_available()
34
+ device = "cuda" if use_cuda else "cpu"
35
+
36
+ DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/megadepth_test_1500')
37
+ DATASET_JSON = os.path.join(os.path.dirname(__file__),'../data/megadepth_1500.json')
38
+
39
+ class MegaDepth1500(Dataset):
40
+ """
41
+ Streamlined MegaDepth-1500 dataloader. The camera poses & metadata are stored in a formatted json for facilitating
42
+ the download of the dataset and to keep the setup as simple as possible.
43
+ """
44
+ def __init__(self, json_file, root_dir):
45
+ # Load the info & calibration from the JSON
46
+ with open(json_file, 'r') as f:
47
+ self.data = json.load(f)
48
+
49
+ self.root_dir = root_dir
50
+
51
+ if not os.path.exists(self.root_dir):
52
+ raise RuntimeError(
53
+ f"Dataset {self.root_dir} does not exist! \n \
54
+ > If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h")
55
+
56
+ def __len__(self):
57
+ return len(self.data)
58
+
59
+ def __getitem__(self, idx):
60
+ data = deepcopy(self.data[idx])
61
+
62
+ h1, w1 = data['size0_hw']
63
+ h2, w2 = data['size1_hw']
64
+
65
+ # Here we resize the images to max_dim = 1200, as described in the paper, and adjust the image such that it is divisible by 32
66
+ # following the protocol of the LoFTR's Dataloader (intrinsics are corrected accordingly).
67
+ # For adapting this with different resolution, you would need to re-scale intrinsics below.
68
+ image0 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][0]}"),(w1, h1))
69
+
70
+ image1 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][1]}"),(w2, h2))
71
+
72
+ data['image0'] = torch.tensor(image0.astype(np.float32)/255).permute(2,0,1)
73
+ data['image1'] = torch.tensor(image1.astype(np.float32)/255).permute(2,0,1)
74
+
75
+ for k,v in data.items():
76
+ if k not in ('dataset_name', 'scene_id', 'pair_id', 'pair_names', 'size0_hw', 'size1_hw', 'image0', 'image1'):
77
+ data[k] = torch.tensor(np.array(v, dtype=np.float32))
78
+
79
+ return data
80
+
81
+ if __name__ == "__main__":
82
+ weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
83
+ liftfeat=LiftFeat(weight=weights)
84
+
85
+ dataset = MegaDepth1500(json_file = DATASET_JSON, root_dir = DATASET_ROOT)
86
+
87
+ loader = DataLoader(dataset, batch_size=1, shuffle=False)
88
+
89
+ metrics = {}
90
+ R_errs = []
91
+ t_errs = []
92
+ inliers = []
93
+
94
+ results=[]
95
+
96
+ cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
97
+
98
+ for d in tqdm.tqdm(loader, desc="processing"):
99
+ error_infos = compute_pose_error(liftfeat.match_liftfeat,d)
100
+ results.append(error_infos)
101
+
102
+ print(f'\n==={cur_time}==={args.name}===')
103
+ d_err_auc,errors=compute_maa(results)
104
+ for s_k,s_v in d_err_auc.items():
105
+ print(f'{s_k}: {s_v*100}')
imcui/third_party/LiftFeat/evaluation/eval_utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import poselib
4
+
5
+
6
+ def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
7
+ # angle error between 2 vectors
8
+ t_gt = T_0to1[:3, 3]
9
+ n = np.linalg.norm(t) * np.linalg.norm(t_gt)
10
+ t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
11
+ t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
12
+ if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
13
+ t_err = 0
14
+
15
+ # angle error between 2 rotation matrices
16
+ R_gt = T_0to1[:3, :3]
17
+ cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
18
+ cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
19
+ R_err = np.rad2deg(np.abs(np.arccos(cos)))
20
+
21
+ return t_err, R_err
22
+
23
+ def intrinsics_to_camera(K):
24
+ px, py = K[0, 2], K[1, 2]
25
+ fx, fy = K[0, 0], K[1, 1]
26
+ return {
27
+ "model": "PINHOLE",
28
+ "width": int(2 * px),
29
+ "height": int(2 * py),
30
+ "params": [fx, fy, px, py],
31
+ }
32
+
33
+
34
+ def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
35
+ M, info = poselib.estimate_relative_pose(
36
+ kpts0, kpts1,
37
+ intrinsics_to_camera(K0),
38
+ intrinsics_to_camera(K1),
39
+ {"max_epipolar_error": thresh,
40
+ "success_prob": conf,
41
+ "min_iterations": 20,
42
+ "max_iterations": 1_000},
43
+ )
44
+
45
+ R, t, inl = M.R, M.t, info["inliers"]
46
+ inl = np.array(inl)
47
+ ret = (R, t, inl)
48
+
49
+ return ret
50
+
51
+ def tensor2bgr(t):
52
+ return (t.cpu()[0].permute(1,2,0).numpy()*255).astype(np.uint8)
53
+
54
+ def compute_pose_error(match_fn,data):
55
+ result = {}
56
+
57
+ with torch.no_grad():
58
+ mkpts0,mkpts1=match_fn(tensor2bgr(data["image0"]),tensor2bgr(data["image1"]))
59
+
60
+ mkpts0=mkpts0 * data["scale0"].numpy()
61
+ mkpts1=mkpts1 * data["scale1"].numpy()
62
+
63
+ K0, K1 = data["K0"][0].numpy(), data["K1"][0].numpy()
64
+ T_0to1 = data["T_0to1"][0].numpy()
65
+ T_1to0 = data["T_1to0"][0].numpy()
66
+
67
+ result={}
68
+ conf = 0.99999
69
+
70
+ ret = estimate_pose(mkpts0,mkpts1,K0,K1,4.0,conf)
71
+ if ret is not None:
72
+ R, t, inliers = ret
73
+ t_err, R_err = relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0)
74
+ result['R_err'] = R_err
75
+ result['t_err'] = t_err
76
+
77
+ return result
78
+
79
+
80
+ def error_auc(errors, thresholds=[5, 10, 20]):
81
+ """
82
+ Args:
83
+ errors (list): [N,]
84
+ thresholds (list)
85
+ """
86
+ errors = [0] + sorted(list(errors))
87
+ recall = list(np.linspace(0, 1, len(errors)))
88
+
89
+ aucs = []
90
+
91
+ for thr in thresholds:
92
+ last_index = np.searchsorted(errors, thr)
93
+ y = recall[:last_index] + [recall[last_index-1]]
94
+ x = errors[:last_index] + [thr]
95
+ aucs.append(np.trapz(y, x) / thr)
96
+
97
+ return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
98
+
99
+ def compute_maa(pairs, thresholds=[5, 10, 20]):
100
+ # print("auc / mAcc on %d pairs" % (len(pairs)))
101
+ errors = []
102
+
103
+ for p in pairs:
104
+ et = p['t_err']
105
+ er = p['R_err']
106
+ errors.append(max(et, er))
107
+
108
+ d_err_auc = error_auc(errors)
109
+
110
+ # for k,v in d_err_auc.items():
111
+ # print(k, ': ', '%.1f'%(v*100))
112
+
113
+ errors = np.array(errors)
114
+
115
+ for t in thresholds:
116
+ acc = (errors <= t).sum() / len(errors)
117
+ # print("mAcc@%d: %.1f "%(t, acc*100))
118
+
119
+ return d_err_auc,errors
120
+
121
+ def homo_trans(coord, H):
122
+ kpt_num = coord.shape[0]
123
+ homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1)
124
+ proj_coord = np.matmul(H, homo_coord.T).T
125
+ proj_coord = proj_coord / proj_coord[:, 2][..., None]
126
+ proj_coord = proj_coord[:, 0:2]
127
+ return proj_coord
imcui/third_party/LiftFeat/loss/loss.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import time
7
+
8
+
9
+ def dual_softmax_loss(X, Y, temp = 0.2):
10
+ if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
11
+ raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')
12
+
13
+ dist_mat = (X @ Y.t()) * temp
14
+ conf_matrix12 = F.log_softmax(dist_mat, dim=1)
15
+ conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1)
16
+
17
+ with torch.no_grad():
18
+ conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0]
19
+ conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0]
20
+ conf = conf12 * conf21
21
+
22
+ target = torch.arange(len(X), device = X.device)
23
+
24
+ loss = F.nll_loss(conf_matrix12, target) + \
25
+ F.nll_loss(conf_matrix21, target)
26
+
27
+ return loss, conf
28
+
29
+
30
+ class LiftFeatLoss(nn.Module):
31
+ def __init__(self,device,lam_descs=1,lam_fb_descs=1,lam_kpts=1,lam_heatmap=1,lam_normals=1,lam_coordinates=1,lam_fb_coordinates=1,depth_spvs=False):
32
+ super().__init__()
33
+
34
+ # loss parameters
35
+ self.lam_descs=lam_descs
36
+ self.lam_fb_descs=lam_fb_descs
37
+ self.lam_kpts=lam_kpts
38
+ self.lam_heatmap=lam_heatmap
39
+ self.lam_normals=lam_normals
40
+ self.lam_coordinates=lam_coordinates
41
+ self.lam_fb_coordinates=lam_fb_coordinates
42
+ self.depth_spvs=depth_spvs
43
+ self.running_descs_loss=0
44
+ self.running_kpts_loss=0
45
+ self.running_heatmaps_loss=0
46
+ self.loss_descs=0
47
+ self.loss_fb_descs=0
48
+ self.loss_kpts=0
49
+ self.loss_heatmaps=0
50
+ self.loss_normals=0
51
+ self.loss_coordinates=0
52
+ self.loss_fb_coordinates=0
53
+ self.acc_coarse=0
54
+ self.acc_fb_coarse=0
55
+ self.acc_kpt=0
56
+ self.acc_coordinates=0
57
+ self.acc_fb_coordinates=0
58
+
59
+ # device
60
+ self.dev=device
61
+
62
+
63
+ def check_accuracy(self,m1,m2,pts1=None,pts2=None,plot=False):
64
+ with torch.no_grad():
65
+ #dist_mat = torch.cdist(X,Y)
66
+ dist_mat = m1 @ m2.t()
67
+ nn = torch.argmax(dist_mat, dim=1)
68
+ #nn = torch.argmin(dist_mat, dim=1)
69
+ correct = nn == torch.arange(len(m1), device = m1.device)
70
+
71
+ if pts1 is not None and plot:
72
+ import matplotlib.pyplot as plt
73
+ canvas = torch.zeros((60, 80),device=m1.device)
74
+ pts1 = pts1[~correct]
75
+ canvas[pts1[:,1].long(), pts1[:,0].long()] = 1
76
+ canvas = canvas.cpu().numpy()
77
+ plt.imshow(canvas), plt.show()
78
+
79
+ acc = correct.sum().item() / len(m1)
80
+ return acc
81
+
82
+ def compute_descriptors_loss(self,descs1,descs2,pts):
83
+ loss=[]
84
+ acc=0
85
+ B,_,H,W=descs1.shape
86
+ conf_list=[]
87
+
88
+ for b in range(B):
89
+ pts1,pts2=pts[b][:,:2],pts[b][:,2:]
90
+ m1=descs1[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
91
+ m2=descs2[b,:,pts2[:,1].long(),pts2[:,0].long()].permute(1,0)
92
+
93
+ loss_per,conf_per=dual_softmax_loss(m1,m2)
94
+ loss.append(loss_per.unsqueeze(0))
95
+ conf_list.append(conf_per)
96
+
97
+ acc_coarse_per=self.check_accuracy(m1,m2)
98
+ acc += acc_coarse_per
99
+
100
+ loss=torch.cat(loss,dim=-1).mean()
101
+ acc /= B
102
+ return loss,acc,conf_list
103
+
104
+
105
+ def alike_distill_loss(self,kpts,alike_kpts):
106
+ C, H, W = kpts.shape
107
+ kpts = kpts.permute(1,2,0)
108
+ # get ALike keypoints
109
+ with torch.no_grad():
110
+ labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * 64 # -> Default is non-keypoint (bin 64)
111
+ offsets = (((alike_kpts/8) - (alike_kpts/8).long())*8).long()
112
+ offsets = offsets[:, 0] + 8*offsets[:, 1] # Linear IDX
113
+ labels[(alike_kpts[:,1]/8).long(), (alike_kpts[:,0]/8).long()] = offsets
114
+
115
+ kpts = kpts.view(-1,C)
116
+ labels = labels.view(-1)
117
+
118
+ mask = labels < 64
119
+ idxs_pos = mask.nonzero().flatten()
120
+ idxs_neg = (~mask).nonzero().flatten()
121
+ perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32]
122
+ idxs_neg = idxs_neg[perm]
123
+ idxs = torch.cat([idxs_pos, idxs_neg])
124
+
125
+ kpts = kpts[idxs]
126
+ labels = labels[idxs]
127
+
128
+ with torch.no_grad():
129
+ predicted = kpts.max(dim=-1)[1]
130
+ acc = (labels == predicted)
131
+ acc = acc.sum() / len(acc)
132
+
133
+ kpts = F.log_softmax(kpts,dim=-1)
134
+ loss = F.nll_loss(kpts, labels, reduction = 'mean')
135
+
136
+ return loss, acc
137
+
138
+
139
+ def compute_keypoints_loss(self,kpts1,kpts2,alike_kpts1,alike_kpts2):
140
+ loss=[]
141
+ acc=0
142
+ B,_,H,W=kpts1.shape
143
+
144
+ for b in range(B):
145
+ loss_per1,acc_per1=self.alike_distill_loss(kpts1[b],alike_kpts1[b])
146
+ loss_per2,acc_per2=self.alike_distill_loss(kpts2[b],alike_kpts2[b])
147
+ loss_per=(loss_per1+loss_per2)
148
+ acc_per=(acc_per1+acc_per2)/2
149
+ loss.append(loss_per.unsqueeze(0))
150
+ acc += acc_per
151
+
152
+ loss=torch.cat(loss,dim=-1).mean()
153
+ acc /= B
154
+ return loss,acc
155
+
156
+
157
+ def compute_heatmaps_loss(self,heatmaps1,heatmaps2,pts,conf_list):
158
+ loss=[]
159
+ B,_,H,W=heatmaps1.shape
160
+
161
+ for b in range(B):
162
+ pts1,pts2=pts[b][:,:2],pts[b][:,2:]
163
+ h1=heatmaps1[b,0,pts1[:,1].long(),pts1[:,0].long()]
164
+ h2=heatmaps2[b,0,pts2[:,1].long(),pts2[:,0].long()]
165
+
166
+ conf=conf_list[b]
167
+ loss_per1=F.l1_loss(h1,conf)
168
+ loss_per2=F.l1_loss(h2,conf)
169
+ loss_per=(loss_per1+loss_per2)
170
+ loss.append(loss_per.unsqueeze(0))
171
+
172
+ loss=torch.cat(loss,dim=-1).mean()
173
+ return loss
174
+
175
+
176
+ def normal_loss(self,normal,target_normal):
177
+ # import pdb;pdb.set_trace()
178
+ normal = normal.permute(1, 2, 0)
179
+ target_normal = target_normal.permute(1,2,0)
180
+ # loss = F.l1_loss(d_feat, depth_anything_normal_feat)
181
+ dot = torch.cosine_similarity(normal, target_normal, dim=2)
182
+ valid_mask = target_normal[:, :, 0].float() \
183
+ * (dot.detach() < 0.999).float() \
184
+ * (dot.detach() > -0.999).float()
185
+ valid_mask = valid_mask > 0.0
186
+ al = torch.acos(dot[valid_mask])
187
+ loss = torch.mean(al)
188
+ return loss
189
+
190
+
191
+ def compute_normals_loss(self,normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size):
192
+ loss=[]
193
+
194
+ # import pdb;pdb.set_trace()
195
+
196
+ # only MegaDepth image need depth-normal
197
+ normals1=normals1[coco_batch_size:,...]
198
+ normals2=normals2[coco_batch_size:,...]
199
+ for b in range(len(DA_normals1)):
200
+ normal1,normal2=normals1[b],normals2[b]
201
+ loss_per1=self.normal_loss(normal1,DA_normals1[b].permute(2,0,1))
202
+ loss_per2=self.normal_loss(normal2,DA_normals2[b].permute(2,0,1))
203
+ loss_per=(loss_per1+loss_per2)
204
+ loss.append(loss_per.unsqueeze(0))
205
+
206
+ loss=torch.cat(loss,dim=-1).mean()
207
+ return loss
208
+
209
+
210
+ def coordinate_loss(self,coordinate,conf,pts1):
211
+ with torch.no_grad():
212
+ coordinate_detached = pts1 * 8
213
+ offset_detached = (coordinate_detached/8) - (coordinate_detached/8).long()
214
+ offset_detached = (offset_detached * 8).long()
215
+ label = offset_detached[:, 0] + 8*offset_detached[:, 1]
216
+
217
+ #pdb.set_trace()
218
+ coordinate_log = F.log_softmax(coordinate, dim=-1)
219
+
220
+ predicted = coordinate.max(dim=-1)[1]
221
+ acc = (label == predicted)
222
+ acc = acc[conf > 0.1]
223
+ acc = acc.sum() / len(acc)
224
+
225
+ loss = F.nll_loss(coordinate_log, label, reduction = 'none')
226
+
227
+ #Weight loss by confidence, giving more emphasis on reliable matches
228
+ conf = conf / conf.sum()
229
+ loss = (loss * conf).sum()
230
+
231
+ return loss*2., acc
232
+
233
+ def compute_coordinates_loss(self,coordinates,pts,conf_list):
234
+ loss=[]
235
+ acc=0
236
+ B,_,H,W=coordinates.shape
237
+
238
+ for b in range(B):
239
+ pts1,pts2=pts[b][:,:2],pts[b][:,2:]
240
+ coordinate=coordinates[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
241
+ conf=conf_list[b]
242
+
243
+ loss_per,acc_per=self.coordinate_loss(coordinate,conf,pts1)
244
+ loss.append(loss_per.unsqueeze(0))
245
+ acc += acc_per
246
+
247
+ loss=torch.cat(loss,dim=-1).mean()
248
+ acc /= B
249
+
250
+ return loss,acc
251
+
252
+
253
+ def forward(self,
254
+ descs1,fb_descs1,kpts1,normals1,
255
+ descs2,fb_descs2,kpts2,normals2,
256
+ pts,coordinates,fb_coordinates,
257
+ alike_kpts1,alike_kpts2,
258
+ DA_normals1,DA_normals2,
259
+ megadepth_batch_size,coco_batch_size
260
+ ):
261
+ # import pdb;pdb.set_trace()
262
+ self.loss_descs,self.acc_coarse,conf_list=self.compute_descriptors_loss(descs1,descs2,pts)
263
+ self.loss_fb_descs,self.acc_fb_coarse,fb_conf_list=self.compute_descriptors_loss(fb_descs1,fb_descs2,pts)
264
+
265
+ # start=time.perf_counter()
266
+ self.loss_kpts,self.acc_kpt=self.compute_keypoints_loss(kpts1,kpts2,alike_kpts1,alike_kpts2)
267
+ # end=time.perf_counter()
268
+ # print(f"kpts loss cost {end-start} seconds")
269
+
270
+ # start=time.perf_counter()
271
+ self.loss_normals=self.compute_normals_loss(normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size)
272
+ # end=time.perf_counter()
273
+ # print(f"normal loss cost {end-start} seconds")
274
+
275
+ self.loss_coordinates,self.acc_coordinates=self.compute_coordinates_loss(coordinates,pts,conf_list)
276
+ self.loss_fb_coordinates,self.acc_fb_coordinates=self.compute_coordinates_loss(fb_coordinates,pts,fb_conf_list)
277
+
278
+ return {
279
+ 'loss_descs':self.lam_descs*self.loss_descs,
280
+ 'acc_coarse':self.acc_coarse,
281
+ 'loss_coordinates':self.lam_coordinates*self.loss_coordinates,
282
+ 'acc_coordinates':self.acc_coordinates,
283
+ 'loss_fb_descs':self.lam_fb_descs*self.loss_fb_descs,
284
+ 'acc_fb_coarse':self.acc_fb_coarse,
285
+ 'loss_fb_coordinates':self.lam_fb_coordinates*self.loss_fb_coordinates,
286
+ 'acc_fb_coordinates':self.acc_fb_coordinates,
287
+ 'loss_kpts':self.lam_kpts*self.loss_kpts,
288
+ 'acc_kpt':self.acc_kpt,
289
+ 'loss_normals':self.lam_normals*self.loss_normals,
290
+ }
291
+
imcui/third_party/LiftFeat/models/interpolator.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+
4
+ This script is used to interpolate rough descriptors from LiftFeat
5
+ """
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ class InterpolateSparse2d(nn.Module):
12
+ """ Efficiently interpolate tensor at given sparse 2D positions. """
13
+ def __init__(self, mode = 'bicubic', align_corners = False):
14
+ super().__init__()
15
+ self.mode = mode
16
+ self.align_corners = align_corners
17
+
18
+ def normgrid(self, x, H, W):
19
+ """ Normalize coords to [-1,1]. """
20
+ return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
21
+
22
+ def forward(self, x, pos, H, W):
23
+ """
24
+ Input
25
+ x: [B, C, H, W] feature tensor
26
+ pos: [B, N, 2] tensor of positions
27
+ H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
28
+
29
+ Returns
30
+ [B, N, C] sampled channels at 2d positions
31
+ """
32
+ grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
33
+ x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
34
+ return x.permute(0,2,3,1).squeeze(-2)
imcui/third_party/LiftFeat/models/liftfeat.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import math
7
+ import cv2
8
+
9
+ os.environ['CUDA_VISIBLE_DEVICES']='1'
10
+
11
+ import kornia as K
12
+
13
+ sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
14
+
15
+ from models.model import LiftFeatSPModel
16
+ from models.interpolator import InterpolateSparse2d
17
+ from utils.config import featureboost_config
18
+
19
+
20
+ class NonMaxSuppression(torch.nn.Module):
21
+ def __init__(self, rep_thr=0.1, top_k=4096):
22
+ super(NonMaxSuppression,self).__init__()
23
+ self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
24
+ self.rep_thr = rep_thr
25
+ self.top_k=top_k
26
+
27
+
28
+ def NMS(self, x, threshold = 0.05, kernel_size = 5):
29
+ B, _, H, W = x.shape
30
+ pad=kernel_size//2
31
+ local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
32
+ pos = (x == local_max) & (x > threshold)
33
+ pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
34
+
35
+ pad_val = max([len(x) for x in pos_batched])
36
+ pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
37
+
38
+ #Pad kpts and build (B, N, 2) tensor
39
+ for b in range(len(pos_batched)):
40
+ pos[b, :len(pos_batched[b]), :] = pos_batched[b]
41
+
42
+ return pos
43
+
44
+ def forward(self, score):
45
+ pos = self.NMS(score,self.rep_thr)
46
+
47
+ return pos
48
+
49
+ def load_model(model, weight_path):
50
+ pretrained_weights = torch.load(weight_path)
51
+
52
+ model_keys = set(model.state_dict().keys())
53
+ pretrained_keys = set(pretrained_weights.keys())
54
+
55
+ missing_keys = model_keys - pretrained_keys
56
+ unexpected_keys = pretrained_keys - model_keys
57
+
58
+ if missing_keys:
59
+ print("Missing keys in pretrained weights:", missing_keys)
60
+ else:
61
+ print("No missing keys in pretrained weights.")
62
+
63
+ if unexpected_keys:
64
+ print("Unexpected keys in pretrained weights:", unexpected_keys)
65
+ else:
66
+ print("No unexpected keys in pretrained weights.")
67
+
68
+ if not missing_keys and not unexpected_keys:
69
+ model.load_state_dict(pretrained_weights)
70
+ print("Pretrained weights loaded successfully.")
71
+ else:
72
+ model.load_state_dict(pretrained_weights, strict=False)
73
+ print("There were issues with the keys.")
74
+ return model
75
+
76
+
77
+ def load_torch_image(fname, device=torch.device('cpu')):
78
+ img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.
79
+ img = K.color.bgr_to_rgb(img.to(device))
80
+
81
+ image=cv2.imread(fname)
82
+ H,W,C=image.shape[0],image.shape[1],image.shape[2]
83
+
84
+ _H=math.ceil(H/32)*32
85
+ _W=math.ceil(W/32)*32
86
+
87
+ pad_h=_H-H
88
+ pad_w=_W-W
89
+
90
+ image=cv2.copyMakeBorder(image,0,pad_h,0,pad_w,cv2.BORDER_CONSTANT,None,(0, 0, 0))
91
+
92
+ pad_info=[0,pad_h,0,pad_w]
93
+
94
+ image = K.image_to_tensor(image, False).float() / 255.
95
+ image = image.to(device)
96
+
97
+ return image,pad_info
98
+
99
+
100
+ class LiftFeat(nn.Module):
101
+ def __init__(self,weight,top_k=4096,detect_threshold=0.1):
102
+ super().__init__()
103
+ self.net=LiftFeatSPModel(featureboost_config)
104
+ self.top_k=top_k
105
+ self.sampler=InterpolateSparse2d('bicubic')
106
+ self.net=load_model(self.net,weight)
107
+ self.detector=NonMaxSuppression(rep_thr=detect_threshold)
108
+
109
+ @torch.inference_mode()
110
+ def extract(self,image,pad_info):
111
+ B,_,_H1,_W1=image.shape
112
+ M1,K1,D1=self.net.forward1(image)
113
+ refine_M=self.net.forward2(M1,K1,D1)
114
+
115
+ refine_M=refine_M.reshape(M1.shape[0],M1.shape[2],M1.shape[3],-1).permute(0,3,1,2)
116
+ refine_M=torch.nn.functional.normalize(refine_M,2,dim=1)
117
+
118
+ descs_map=refine_M
119
+ # descs_map=M1
120
+
121
+ scores=torch.softmax(K1,dim=1)[:,:64]
122
+ heatmap=scores.permute(0,2,3,1).reshape(scores.shape[0],scores.shape[2],scores.shape[3],8,8)
123
+ heatmap=heatmap.permute(0,1,3,2,4).reshape(scores.shape[0],1,scores.shape[2]*8,scores.shape[3]*8)
124
+
125
+ pos=self.detector(heatmap)
126
+ kpts=pos.squeeze(0)
127
+ mask_w=kpts[...,0]<(_W1-pad_info[-1])
128
+ kpts=kpts[mask_w]
129
+ mask_h=kpts[..., 1]<(_H1-pad_info[1])
130
+ kpts=kpts[mask_h]
131
+
132
+ descs=self.sampler(descs_map,kpts.unsqueeze(0),_H1,_W1)
133
+ descs=torch.nn.functional.normalize(descs,p=2,dim=1)
134
+ descs=descs.squeeze(0)
135
+
136
+ return {
137
+ 'descriptors':descs,
138
+ 'keypoints':kpts
139
+ }
140
+
141
+ def match_liftfeat(self, img1, pad_info1, img2, pad_info2, min_cossim=-1):
142
+ # import pdb;pdb.set_trace()
143
+ data1=self.extract(img1, pad_info1)
144
+ data2=self.extract(img2, pad_info2)
145
+
146
+ kpts1,feats1=data1['keypoints'],data1['descriptors']
147
+ kpts2,feats2=data2['keypoints'],data2['descriptors']
148
+
149
+ cossim = feats1 @ feats2.t()
150
+ cossim_t = feats2 @ feats1.t()
151
+
152
+ _, match12 = cossim.max(dim=1)
153
+ _, match21 = cossim_t.max(dim=1)
154
+
155
+ idx0 = torch.arange(len(match12), device=match12.device)
156
+ mutual = match21[match12] == idx0
157
+
158
+ if min_cossim > 0:
159
+ cossim, _ = cossim.max(dim=1)
160
+ good = cossim > min_cossim
161
+ idx0 = idx0[mutual & good]
162
+ idx1 = match12[mutual & good]
163
+ else:
164
+ idx0 = idx0[mutual]
165
+ idx1 = match12[mutual]
166
+
167
+ mkpts1,mkpts2=kpts1[idx0],kpts2[idx1]
168
+
169
+ return mkpts1, mkpts2
170
+
171
+ weight=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
172
+
173
+ liftfeat=LiftFeat(weight)
174
+
175
+ save_file=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pt')
176
+
177
+ liftfeat_script=torch.jit.script(liftfeat)
178
+ liftfeat_script.save(save_file)
179
+
180
+ # checkpoint = {
181
+ # 'model_name': 'LiftFeat',
182
+ # 'model_args': {
183
+ # 'top_k': 4096,
184
+ # 'detect_threshold': 0.1
185
+ # },
186
+ # 'state_dict': liftfeat.state_dict()
187
+ # }
188
+
189
+ # torch.save(checkpoint,os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.ckpt'))
190
+
imcui/third_party/LiftFeat/models/liftfeat_wrapper.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import torch
4
+ import numpy as np
5
+ import math
6
+ import cv2
7
+
8
+ from models.model import LiftFeatSPModel
9
+ from models.interpolator import InterpolateSparse2d
10
+ from utils.config import featureboost_config
11
+
12
+ device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
13
+
14
+ MODEL_PATH=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
15
+
16
+
17
+ class NonMaxSuppression(torch.nn.Module):
18
+ def __init__(self, rep_thr=0.1, top_k=4096):
19
+ super(NonMaxSuppression,self).__init__()
20
+ self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
21
+ self.rep_thr = rep_thr
22
+ self.top_k=top_k
23
+
24
+
25
+ def NMS(self, x, threshold = 0.05, kernel_size = 5):
26
+ B, _, H, W = x.shape
27
+ pad=kernel_size//2
28
+ local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
29
+ pos = (x == local_max) & (x > threshold)
30
+ pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
31
+
32
+ pad_val = max([len(x) for x in pos_batched])
33
+ pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
34
+
35
+ #Pad kpts and build (B, N, 2) tensor
36
+ for b in range(len(pos_batched)):
37
+ pos[b, :len(pos_batched[b]), :] = pos_batched[b]
38
+
39
+ return pos
40
+
41
+ def forward(self, score):
42
+ pos = self.NMS(score,self.rep_thr)
43
+
44
+ return pos
45
+
46
+ def load_model(model, weight_path):
47
+ pretrained_weights = torch.load(weight_path, map_location="cpu")
48
+
49
+ model_keys = set(model.state_dict().keys())
50
+ pretrained_keys = set(pretrained_weights.keys())
51
+
52
+ missing_keys = model_keys - pretrained_keys
53
+ unexpected_keys = pretrained_keys - model_keys
54
+
55
+ # if missing_keys:
56
+ # print("Missing keys in pretrained weights:", missing_keys)
57
+ # else:
58
+ # print("No missing keys in pretrained weights.")
59
+
60
+ # if unexpected_keys:
61
+ # print("Unexpected keys in pretrained weights:", unexpected_keys)
62
+ # else:
63
+ # print("No unexpected keys in pretrained weights.")
64
+
65
+ if not missing_keys and not unexpected_keys:
66
+ model.load_state_dict(pretrained_weights)
67
+ print("load weight successfully.")
68
+ else:
69
+ model.load_state_dict(pretrained_weights, strict=False)
70
+ # print("There were issues with the keys.")
71
+ return model
72
+
73
+
74
+ import torch.nn as nn
75
+ class LiftFeat(nn.Module):
76
+ def __init__(self,weight=MODEL_PATH,top_k=4096,detect_threshold=0.1):
77
+ super().__init__()
78
+ self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
79
+ self.net=LiftFeatSPModel(featureboost_config).to(self.device).eval()
80
+ self.top_k=top_k
81
+ self.sampler=InterpolateSparse2d('bicubic')
82
+ self.net=load_model(self.net,weight)
83
+ self.detector=NonMaxSuppression(rep_thr=detect_threshold)
84
+ self.net=self.net.to(self.device)
85
+ self.detector=self.detector.to(self.device)
86
+ self.sampler=self.sampler.to(self.device)
87
+
88
+ def image_preprocess(self,image: np.ndarray):
89
+ H,W,C=image.shape[0],image.shape[1],image.shape[2]
90
+
91
+ _H=math.ceil(H/32)*32
92
+ _W=math.ceil(W/32)*32
93
+
94
+ pad_h=_H-H
95
+ pad_w=_W-W
96
+
97
+ image=cv2.copyMakeBorder(image,0,pad_h,0,pad_w,cv2.BORDER_CONSTANT,None,(0, 0, 0))
98
+
99
+ pad_info=[0,pad_h,0,pad_w]
100
+
101
+ if len(image.shape)==3:
102
+ image=image[None,...]
103
+
104
+ image=torch.tensor(image).permute(0,3,1,2)/255
105
+ image=image.to(device)
106
+
107
+ return image, pad_info
108
+
109
+ @torch.inference_mode()
110
+ def extract(self,image: np.ndarray):
111
+ image,pad_info=self.image_preprocess(image)
112
+ B,_,_H1,_W1=image.shape
113
+
114
+ M1,K1,D1=self.net.forward1(image)
115
+ refine_M=self.net.forward2(M1,K1,D1)
116
+
117
+ refine_M=refine_M.reshape(M1.shape[0],M1.shape[2],M1.shape[3],-1).permute(0,3,1,2)
118
+ refine_M=torch.nn.functional.normalize(refine_M,2,dim=1)
119
+
120
+ descs_map=refine_M
121
+ # descs_map=M1
122
+
123
+ scores=torch.softmax(K1,dim=1)[:,:64]
124
+ heatmap=scores.permute(0,2,3,1).reshape(scores.shape[0],scores.shape[2],scores.shape[3],8,8)
125
+ heatmap=heatmap.permute(0,1,3,2,4).reshape(scores.shape[0],1,scores.shape[2]*8,scores.shape[3]*8)
126
+
127
+ pos=self.detector(heatmap)
128
+ kpts=pos.squeeze(0)
129
+ mask_w=kpts[...,0]<(_W1-pad_info[-1])
130
+ kpts=kpts[mask_w]
131
+ mask_h=kpts[..., 1]<(_H1-pad_info[1])
132
+ kpts=kpts[mask_h]
133
+
134
+ descs=self.sampler(descs_map,kpts.unsqueeze(0),_H1,_W1)
135
+ descs=torch.nn.functional.normalize(descs,p=2,dim=1)
136
+ descs=descs.squeeze(0)
137
+
138
+ return {
139
+ 'descriptors':descs,
140
+ 'keypoints':kpts
141
+ }
142
+
143
+ def match_liftfeat(self, img1, img2, min_cossim=-1):
144
+ # import pdb;pdb.set_trace()
145
+ data1=self.extract(img1)
146
+ data2=self.extract(img2)
147
+
148
+ kpts1,feats1=data1['keypoints'],data1['descriptors']
149
+ kpts2,feats2=data2['keypoints'],data2['descriptors']
150
+
151
+ cossim = feats1 @ feats2.t()
152
+ cossim_t = feats2 @ feats1.t()
153
+
154
+ _, match12 = cossim.max(dim=1)
155
+ _, match21 = cossim_t.max(dim=1)
156
+
157
+ idx0 = torch.arange(len(match12), device=match12.device)
158
+ mutual = match21[match12] == idx0
159
+
160
+ if min_cossim > 0:
161
+ cossim, _ = cossim.max(dim=1)
162
+ good = cossim > min_cossim
163
+ idx0 = idx0[mutual & good]
164
+ idx1 = match12[mutual & good]
165
+ else:
166
+ idx0 = idx0[mutual]
167
+ idx1 = match12[mutual]
168
+
169
+ mkpts1,mkpts2=kpts1[idx0],kpts2[idx1]
170
+ mkpts1,mkpts2=mkpts1.cpu().numpy(),mkpts2.cpu().numpy()
171
+
172
+ return mkpts1, mkpts2
173
+
imcui/third_party/LiftFeat/models/model.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """
3
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
4
+ """
5
+
6
+ import numpy as np
7
+ import os
8
+ import torch
9
+ from torch import nn
10
+ import torch.nn.functional as F
11
+
12
+ import tqdm
13
+ import math
14
+ import cv2
15
+
16
+ import sys
17
+ sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat')
18
+ from utils.featurebooster import FeatureBooster
19
+ from utils.config import featureboost_config
20
+
21
+ # from models.model_dfb import LiftFeatModel
22
+ # from models.interpolator import InterpolateSparse2d
23
+ # from third_party.config import featureboost_config
24
+
25
+ """
26
+ foundational functions
27
+ """
28
+ def simple_nms(scores, radius):
29
+ """Perform non maximum suppression on the heatmap using max-pooling.
30
+ This method does not suppress contiguous points that have the same score.
31
+ Args:
32
+ scores: the score heatmap of size `(B, H, W)`.
33
+ radius: an integer scalar, the radius of the NMS window.
34
+ """
35
+
36
+ def max_pool(x):
37
+ return torch.nn.functional.max_pool2d(
38
+ x, kernel_size=radius * 2 + 1, stride=1, padding=radius
39
+ )
40
+
41
+ zeros = torch.zeros_like(scores)
42
+ max_mask = scores == max_pool(scores)
43
+ for _ in range(2):
44
+ supp_mask = max_pool(max_mask.float()) > 0
45
+ supp_scores = torch.where(supp_mask, zeros, scores)
46
+ new_max_mask = supp_scores == max_pool(supp_scores)
47
+ max_mask = max_mask | (new_max_mask & (~supp_mask))
48
+ return torch.where(max_mask, scores, zeros)
49
+
50
+
51
+ def top_k_keypoints(keypoints, scores, k):
52
+ if k >= len(keypoints):
53
+ return keypoints, scores
54
+ scores, indices = torch.topk(scores, k, dim=0, sorted=True)
55
+ return keypoints[indices], scores
56
+
57
+
58
+ def sample_k_keypoints(keypoints, scores, k):
59
+ if k >= len(keypoints):
60
+ return keypoints, scores
61
+ indices = torch.multinomial(scores, k, replacement=False)
62
+ return keypoints[indices], scores[indices]
63
+
64
+
65
+ def soft_argmax_refinement(keypoints, scores, radius: int):
66
+ width = 2 * radius + 1
67
+ sum_ = torch.nn.functional.avg_pool2d(
68
+ scores[:, None], width, 1, radius, divisor_override=1
69
+ )
70
+ ar = torch.arange(-radius, radius + 1).to(scores)
71
+ kernel_x = ar[None].expand(width, -1)[None, None]
72
+ dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius)
73
+ dy = torch.nn.functional.conv2d(
74
+ scores[:, None], kernel_x.transpose(2, 3), padding=radius
75
+ )
76
+ dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None]
77
+ refined_keypoints = []
78
+ for i, kpts in enumerate(keypoints):
79
+ delta = dydx[i][tuple(kpts.t())]
80
+ refined_keypoints.append(kpts.float() + delta)
81
+ return refined_keypoints
82
+
83
+
84
+ # Legacy (broken) sampling of the descriptors
85
+ def sample_descriptors(keypoints, descriptors, s):
86
+ b, c, h, w = descriptors.shape
87
+ keypoints = keypoints - s / 2 + 0.5
88
+ keypoints /= torch.tensor(
89
+ [(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
90
+ ).to(
91
+ keypoints
92
+ )[None]
93
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
94
+ args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
95
+ descriptors = torch.nn.functional.grid_sample(
96
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
97
+ )
98
+ descriptors = torch.nn.functional.normalize(
99
+ descriptors.reshape(b, c, -1), p=2, dim=1
100
+ )
101
+ return descriptors
102
+
103
+
104
+ # The original keypoint sampling is incorrect. We patch it here but
105
+ # keep the original one above for legacy.
106
+ def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
107
+ """Interpolate descriptors at keypoint locations"""
108
+ b, c, h, w = descriptors.shape
109
+ keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
110
+ keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
111
+ descriptors = torch.nn.functional.grid_sample(
112
+ descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
113
+ )
114
+ descriptors = torch.nn.functional.normalize(
115
+ descriptors.reshape(b, c, -1), p=2, dim=1
116
+ )
117
+ return descriptors
118
+
119
+
120
+ class UpsampleLayer(nn.Module):
121
+ def __init__(self, in_channels):
122
+ super().__init__()
123
+ # 定义特征提取层,减少通道数同时增加特征提取能力
124
+ self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
125
+ # 使用BN层
126
+ self.bn = nn.BatchNorm2d(in_channels//2)
127
+ # 使用LeakyReLU激活函数
128
+ self.leaky_relu = nn.LeakyReLU(0.1)
129
+
130
+ def forward(self, x):
131
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
132
+ x = self.leaky_relu(self.bn(self.conv(x)))
133
+
134
+ return x
135
+
136
+
137
+ class KeypointHead(nn.Module):
138
+ def __init__(self,in_channels,out_channels):
139
+ super().__init__()
140
+ self.layer1=BaseLayer(in_channels,32)
141
+ self.layer2=BaseLayer(32,32)
142
+ self.layer3=BaseLayer(32,64)
143
+ self.layer4=BaseLayer(64,64)
144
+ self.layer5=BaseLayer(64,128)
145
+
146
+ self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1)
147
+ self.bn=nn.BatchNorm2d(65)
148
+
149
+ def forward(self,x):
150
+ x=self.layer1(x)
151
+ x=self.layer2(x)
152
+ x=self.layer3(x)
153
+ x=self.layer4(x)
154
+ x=self.layer5(x)
155
+ x=self.bn(self.conv(x))
156
+ return x
157
+
158
+
159
+ class DescriptorHead(nn.Module):
160
+ def __init__(self,in_channels,out_channels):
161
+ super().__init__()
162
+ self.layer=nn.Sequential(
163
+ BaseLayer(in_channels,32),
164
+ BaseLayer(32,32,activation=False),
165
+ BaseLayer(32,64,activation=False),
166
+ BaseLayer(64,out_channels,activation=False)
167
+ )
168
+
169
+ def forward(self,x):
170
+ x=self.layer(x)
171
+ # x=nn.functional.softmax(x,dim=1)
172
+ return x
173
+
174
+
175
+ class HeatmapHead(nn.Module):
176
+ def __init__(self,in_channels,mid_channels,out_channels):
177
+ super().__init__()
178
+ self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
179
+ self.bnHa = nn.BatchNorm2d(mid_channels)
180
+ self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
181
+ self.bnHb = nn.BatchNorm2d(out_channels)
182
+ self.leaky_relu = nn.LeakyReLU(0.1)
183
+
184
+ def forward(self,x):
185
+ x = self.leaky_relu(self.bnHa(self.convHa(x)))
186
+ x = self.leaky_relu(self.bnHb(self.convHb(x)))
187
+
188
+ x = torch.sigmoid(x)
189
+ return x
190
+
191
+
192
+ class DepthHead(nn.Module):
193
+ def __init__(self, in_channels):
194
+ super().__init__()
195
+ self.upsampleDa = UpsampleLayer(in_channels)
196
+ self.upsampleDb = UpsampleLayer(in_channels//2)
197
+ self.upsampleDc = UpsampleLayer(in_channels//4)
198
+
199
+ self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
200
+ self.bnDepa = nn.BatchNorm2d(in_channels//2)
201
+ self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1)
202
+ self.bnDepb = nn.BatchNorm2d(in_channels//4)
203
+ self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1)
204
+ self.bnDepc = nn.BatchNorm2d(3)
205
+
206
+ self.leaky_relu = nn.LeakyReLU(0.1)
207
+
208
+ def forward(self, x):
209
+ x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False)
210
+ x1 = self.upsampleDa(x)
211
+ x1 = torch.cat([x0,x1],dim=1)
212
+ x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1)))
213
+
214
+ x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False)
215
+ x2 = self.upsampleDb(x1)
216
+ x2 = torch.cat([x1_0,x2],dim=1)
217
+ x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2)))
218
+
219
+ x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False)
220
+ x3 = self.upsampleDc(x2)
221
+ x3 = torch.cat([x2_0,x3],dim=1)
222
+ x = self.leaky_relu(self.bnDepc(self.convDepc(x3)))
223
+
224
+ x = F.normalize(x,p=2,dim=1)
225
+ return x
226
+
227
+
228
+ class BaseLayer(nn.Module):
229
+ def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True):
230
+ super().__init__()
231
+ if activation:
232
+ self.layer=nn.Sequential(
233
+ nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
234
+ nn.BatchNorm2d(out_channels,affine=False),
235
+ nn.ReLU(inplace=True)
236
+ )
237
+ else:
238
+ self.layer=nn.Sequential(
239
+ nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
240
+ nn.BatchNorm2d(out_channels,affine=False)
241
+ )
242
+
243
+ def forward(self,x):
244
+ return self.layer(x)
245
+
246
+
247
+ class LiftFeatSPModel(nn.Module):
248
+ default_conf = {
249
+ "has_detector": True,
250
+ "has_descriptor": True,
251
+ "descriptor_dim": 64,
252
+ # Inference
253
+ "sparse_outputs": True,
254
+ "dense_outputs": False,
255
+ "nms_radius": 4,
256
+ "refinement_radius": 0,
257
+ "detection_threshold": 0.005,
258
+ "max_num_keypoints": -1,
259
+ "max_num_keypoints_val": None,
260
+ "force_num_keypoints": False,
261
+ "randomize_keypoints_training": False,
262
+ "remove_borders": 4,
263
+ "legacy_sampling": True, # True to use the old broken sampling
264
+ }
265
+
266
+ def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True):
267
+ super().__init__()
268
+ self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
269
+ self.descriptor_dim = 64
270
+
271
+ self.norm = nn.InstanceNorm2d(1)
272
+
273
+ self.relu = nn.ReLU(inplace=True)
274
+ self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
275
+ c1,c2,c3,c4,c5 = 24,24,64,64,128
276
+
277
+ self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
278
+ self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
279
+ self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
280
+ self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
281
+ self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
282
+ self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
283
+ self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
284
+ self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
285
+ self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
286
+ self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1)
287
+
288
+ self.upsample4 = UpsampleLayer(c4)
289
+ self.upsample5 = UpsampleLayer(c5)
290
+ self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1)
291
+ self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1)
292
+
293
+ # detector
294
+ self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65)
295
+ # descriptor
296
+ self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim)
297
+ # # heatmap
298
+ # self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1)
299
+ # depth
300
+ self.depth_head = DepthHead(c3)
301
+
302
+ self.fine_matcher = nn.Sequential(
303
+ nn.Linear(128, 512),
304
+ nn.BatchNorm1d(512, affine=False),
305
+ nn.ReLU(inplace = True),
306
+ nn.Linear(512, 512),
307
+ nn.BatchNorm1d(512, affine=False),
308
+ nn.ReLU(inplace = True),
309
+ nn.Linear(512, 512),
310
+ nn.BatchNorm1d(512, affine=False),
311
+ nn.ReLU(inplace = True),
312
+ nn.Linear(512, 512),
313
+ nn.BatchNorm1d(512, affine=False),
314
+ nn.ReLU(inplace = True),
315
+ nn.Linear(512, 64),
316
+ )
317
+
318
+ # feature_booster
319
+ self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal)
320
+
321
+ def feature_extract(self, x):
322
+ x1 = self.relu(self.conv1a(x))
323
+ x1 = self.relu(self.conv1b(x1))
324
+ x1 = self.pool(x1)
325
+ x2 = self.relu(self.conv2a(x1))
326
+ x2 = self.relu(self.conv2b(x2))
327
+ x2 = self.pool(x2)
328
+ x3 = self.relu(self.conv3a(x2))
329
+ x3 = self.relu(self.conv3b(x3))
330
+ x3 = self.pool(x3)
331
+ x4 = self.relu(self.conv4a(x3))
332
+ x4 = self.relu(self.conv4b(x4))
333
+ x4 = self.pool(x4)
334
+ x5 = self.relu(self.conv5a(x4))
335
+ x5 = self.relu(self.conv5b(x5))
336
+ x5 = self.pool(x5)
337
+ return x3,x4,x5
338
+
339
+ def fuse_multi_features(self,x3,x4,x5):
340
+ # upsample x5 feature
341
+ x5 = self.upsample5(x5)
342
+ x4 = torch.cat([x4,x5],dim=1)
343
+ x4 = self.conv_fusion45(x4)
344
+
345
+ # upsample x4 feature
346
+ x4 = self.upsample4(x4)
347
+ x3 = torch.cat([x3,x4],dim=1)
348
+ x = self.conv_fusion34(x3)
349
+ return x
350
+
351
+ def _unfold2d(self, x, ws = 2):
352
+ """
353
+ Unfolds tensor in 2D with desired ws (window size) and concat the channels
354
+ """
355
+ B, C, H, W = x.shape
356
+ x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2)
357
+ return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)
358
+
359
+
360
+ def forward1(self, x):
361
+ """
362
+ input:
363
+ x -> torch.Tensor(B, C, H, W) grayscale or rgb images
364
+ return:
365
+ feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
366
+ keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
367
+ heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map
368
+
369
+ """
370
+ with torch.no_grad():
371
+ x = x.mean(dim=1, keepdim = True)
372
+ x = self.norm(x)
373
+
374
+ x3,x4,x5 = self.feature_extract(x)
375
+
376
+ # features fusion
377
+ x = self.fuse_multi_features(x3,x4,x5)
378
+
379
+ # keypoint
380
+ keypoint_map = self.keypoint_head(x)
381
+ # descriptor
382
+ des_map = self.descriptor_head(x)
383
+ # # heatmap
384
+ # heatmap = self.heatmap_head(x)
385
+
386
+ # import pdb;pdb.set_trace()
387
+ # depth
388
+ d_feats = self.depth_head(x)
389
+
390
+ return des_map, keypoint_map, d_feats
391
+ # return des_map, keypoint_map, heatmap, d_feats
392
+
393
+ def forward2(self, descs, kpts, normals):
394
+ # import pdb;pdb.set_trace()
395
+ normals_feat=self._unfold2d(normals, ws=8)
396
+ normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1])
397
+ descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1])
398
+ kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1])
399
+ descs_refine = self.feature_boost(descs_v, kpts_v, normals_v)
400
+ return descs_refine
401
+
402
+ def forward(self,x):
403
+ M1,K1,D1=self.forward1(x)
404
+ descs_refine=self.forward2(M1,K1,D1)
405
+ return descs_refine,M1,K1,D1
406
+
407
+
408
+ if __name__ == "__main__":
409
+ img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
410
+ img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
411
+ img=cv2.resize(img,(800,608))
412
+ import pdb;pdb.set_trace()
413
+ img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0
414
+ img=img.cuda() if torch.cuda.is_available() else img
415
+ liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
416
+ des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img)
417
+ des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats)
418
+ print(des_map.shape)
419
+
imcui/third_party/LiftFeat/requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchvision==0.14.1
3
+ einops==0.8.0
4
+ kornia==0.7.3
5
+ timm==1.0.15
6
+ albumentations==1.4.12
7
+ imgaug==0.4.0
8
+ opencv-python==4.10.0.84
9
+ matplotlib==3.7.5
10
+ numpy==1.24.4
11
+ scikit-image==0.21.0
12
+ scipy==1.10.1
13
+ pillow==10.3.0
14
+ tensorboard==2.14.0
15
+ tqdm==4.66.4
16
+ omegaconf==2.3.0
17
+ thop==0.1.1.post2209072238
18
+ poselib
imcui/third_party/LiftFeat/train.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+ training script
4
+ """
5
+
6
+ import argparse
7
+ import os
8
+ import time
9
+ import sys
10
+ sys.path.append(os.path.dirname(__file__))
11
+
12
+ def parse_arguments():
13
+ parser = argparse.ArgumentParser(description="LiftFeat training script.")
14
+ parser.add_argument('--name',type=str,default='LiftFeat',help='set process name')
15
+
16
+ # MegaDepth dataset setting
17
+ parser.add_argument('--use_megadepth',action='store_true')
18
+ parser.add_argument('--megadepth_root_path', type=str,
19
+ default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548',
20
+ help='Path to the MegaDepth dataset root directory.')
21
+ parser.add_argument('--megadepth_batch_size', type=int, default=6)
22
+
23
+ # COCO20k dataset setting
24
+ parser.add_argument('--use_coco',action='store_true')
25
+ parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k',
26
+ help='Path to the COCO20k dataset root directory.')
27
+ parser.add_argument('--coco_batch_size',type=int,default=4)
28
+
29
+ parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test',
30
+ help='Path to save the checkpoints.')
31
+ parser.add_argument('--n_steps', type=int, default=160_000,
32
+ help='Number of training steps. Default is 160000.')
33
+ parser.add_argument('--lr', type=float, default=3e-4,
34
+ help='Learning rate. Default is 0.0003.')
35
+ parser.add_argument('--gamma_steplr', type=float, default=0.5,
36
+ help='Gamma value for StepLR scheduler. Default is 0.5.')
37
+ parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
38
+ default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
39
+ parser.add_argument('--device_num', type=str, default='0',
40
+ help='Device number to use for training. Default is "0".')
41
+ parser.add_argument('--dry_run', action='store_true',
42
+ help='If set, perform a dry run training with a mini-batch for sanity check.')
43
+ parser.add_argument('--save_ckpt_every', type=int, default=500,
44
+ help='Save checkpoints every N steps. Default is 500.')
45
+
46
+ args = parser.parse_args()
47
+
48
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
49
+
50
+ return args
51
+
52
+ args = parse_arguments()
53
+
54
+ import torch
55
+ from torch import nn
56
+ from torch import optim
57
+ import torch.nn.functional as F
58
+ from torch.utils.tensorboard import SummaryWriter
59
+ from torch.utils.data import Dataset, DataLoader
60
+
61
+ import numpy as np
62
+ import tqdm
63
+ import glob
64
+
65
+ from models.model import LiftFeatSPModel
66
+ from loss.loss import LiftFeatLoss
67
+ from utils.config import featureboost_config
68
+ from models.interpolator import InterpolateSparse2d
69
+ from utils.depth_anything_wrapper import DepthAnythingExtractor
70
+ from utils.alike_wrapper import ALikeExtractor
71
+
72
+ from dataset import megadepth_wrapper
73
+ from dataset import coco_wrapper
74
+ from dataset.megadepth import MegaDepthDataset
75
+ from dataset.coco_augmentor import COCOAugmentor
76
+
77
+ import setproctitle
78
+
79
+
80
+ class Trainer():
81
+ def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size,
82
+ coco_root_path,use_coco,coco_batch_size,
83
+ ckpt_save_path,
84
+ model_name = 'LiftFeat',
85
+ n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
86
+ training_res = (800, 608), device_num="0", dry_run = False,
87
+ save_ckpt_every = 500):
88
+ print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}')
89
+ print(f'COCO20k: {use_coco}-{coco_batch_size}')
90
+ self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
91
+
92
+ # training model
93
+ self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev)
94
+ self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1)
95
+
96
+ # depth-anything model
97
+ self.depth_net=DepthAnythingExtractor('vits',self.dev,256)
98
+
99
+ # alike model
100
+ self.alike_net=ALikeExtractor('alike-t',self.dev)
101
+
102
+ #Setup optimizer
103
+ self.steps = n_steps
104
+ self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
105
+ self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr)
106
+
107
+ ##################### COCO INIT ##########################
108
+ self.use_coco=use_coco
109
+ self.coco_batch_size=coco_batch_size
110
+ if self.use_coco:
111
+ self.augmentor=COCOAugmentor(
112
+ img_dir=coco_root_path,
113
+ device=self.dev,load_dataset=True,
114
+ batch_size=self.coco_batch_size,
115
+ out_resolution=training_res,
116
+ warp_resolution=training_res,
117
+ sides_crop=0.1,
118
+ max_num_imgs=3000,
119
+ num_test_imgs=5,
120
+ photometric=True,
121
+ geometric=True,
122
+ reload_step=4000
123
+ )
124
+ ##################### COCO END #######################
125
+
126
+
127
+ ##################### MEGADEPTH INIT ##########################
128
+ self.use_megadepth=use_megadepth
129
+ self.megadepth_batch_size=megadepth_batch_size
130
+ if self.use_megadepth:
131
+ TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
132
+ TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
133
+
134
+ TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
135
+
136
+ npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
137
+ megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
138
+ npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
139
+
140
+ self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True)
141
+ self.megadepth_data_iter = iter(self.megadepth_dataloader)
142
+ ##################### MEGADEPTH INIT END #######################
143
+
144
+ os.makedirs(ckpt_save_path, exist_ok=True)
145
+ os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
146
+
147
+ self.dry_run = dry_run
148
+ self.save_ckpt_every = save_ckpt_every
149
+ self.ckpt_save_path = ckpt_save_path
150
+ self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
151
+ self.model_name = model_name
152
+
153
+
154
+ def generate_train_data(self):
155
+ imgs1_t,imgs2_t=[],[]
156
+ imgs1_np,imgs2_np=[],[]
157
+ # norms0,norms1=[],[]
158
+ positives_coarse=[]
159
+
160
+ if self.use_coco:
161
+ coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1)
162
+ h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8
163
+ _ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse)
164
+ coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True)
165
+ imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2)
166
+ positives_coarse += positives_coco_coarse
167
+
168
+ if self.use_megadepth:
169
+ try:
170
+ megadepth_data=next(self.megadepth_data_iter)
171
+ except StopIteration:
172
+ print('End of MD DATASET')
173
+ self.megadepth_data_iter=iter(self.megadepth_dataloader)
174
+ megadepth_data=next(self.megadepth_data_iter)
175
+ if megadepth_data is not None:
176
+ for k in megadepth_data.keys():
177
+ if isinstance(megadepth_data[k],torch.Tensor):
178
+ megadepth_data[k]=megadepth_data[k].to(self.dev)
179
+ megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1']
180
+ megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True)
181
+ imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t)
182
+ megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np']
183
+ for np_idx in range(megadepth_imgs1_np.shape[0]):
184
+ img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy()
185
+ imgs1_np.append(img1_np);imgs2_np.append(img2_np)
186
+ positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8)
187
+ positives_coarse += positives_megadepth_coarse
188
+
189
+ with torch.no_grad():
190
+ imgs1_t=torch.cat(imgs1_t,dim=0)
191
+ imgs2_t=torch.cat(imgs2_t,dim=0)
192
+
193
+ return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse
194
+
195
+
196
+ def train(self):
197
+ self.net.train()
198
+
199
+ with tqdm.tqdm(total=self.steps) as pbar:
200
+ for i in range(self.steps):
201
+ # import pdb;pdb.set_trace()
202
+ imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data()
203
+
204
+ #Check if batch is corrupted with too few correspondences
205
+ is_corrupted = False
206
+ for p in positives_coarse:
207
+ if len(p) < 30:
208
+ is_corrupted = True
209
+
210
+ if is_corrupted:
211
+ continue
212
+
213
+ # import pdb;pdb.set_trace()
214
+ #Forward pass
215
+ # start=time.perf_counter()
216
+ feats1,kpts1,normals1 = self.net.forward1(imgs1_t)
217
+ feats2,kpts2,normals2 = self.net.forward1(imgs2_t)
218
+
219
+ coordinates,fb_coordinates=[],[]
220
+ alike_kpts1,alike_kpts2=[],[]
221
+ DA_normals1,DA_normals2=[],[]
222
+
223
+ # import pdb;pdb.set_trace()
224
+
225
+ fb_feats1,fb_feats2=[],[]
226
+ for b in range(feats1.shape[0]):
227
+ feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1])
228
+ feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1])
229
+
230
+ coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1))
231
+ coordinates.append(coordinate)
232
+
233
+ fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0))
234
+ fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0))
235
+
236
+ fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1))
237
+ fb_coordinates.append(fb_coordinate)
238
+
239
+ fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0))
240
+
241
+ img1,img2=imgs1_t[b],imgs2_t[b]
242
+ img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
243
+ img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
244
+ alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev)
245
+ alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev)
246
+ alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2)
247
+
248
+ # import pdb;pdb.set_trace()
249
+ for b in range(len(imgs1_np)):
250
+ megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b])
251
+ megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b])
252
+ DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2)
253
+
254
+ # import pdb;pdb.set_trace()
255
+ fb_feats1=torch.cat(fb_feats1,dim=0)
256
+ fb_feats2=torch.cat(fb_feats2,dim=0)
257
+ fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
258
+ fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2)
259
+
260
+ coordinates=torch.cat(coordinates,dim=0)
261
+ coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
262
+
263
+ fb_coordinates=torch.cat(fb_coordinates,dim=0)
264
+ fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
265
+
266
+ # end=time.perf_counter()
267
+ # print(f"forward1 cost {end-start} seconds")
268
+
269
+ loss_items = []
270
+
271
+ # import pdb;pdb.set_trace()
272
+ loss_info=self.loss_fn(
273
+ feats1,fb_feats1,kpts1,normals1,
274
+ feats2,fb_feats2,kpts2,normals2,
275
+ positives_coarse,
276
+ coordinates,fb_coordinates,
277
+ alike_kpts1,alike_kpts2,
278
+ DA_normals1,DA_normals2,
279
+ self.megadepth_batch_size,self.coco_batch_size)
280
+
281
+ loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse']
282
+ loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates']
283
+ loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse']
284
+ loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates']
285
+ loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt']
286
+ loss_normals=loss_info['loss_normals']
287
+
288
+ # loss_items.append(loss_descs.unsqueeze(0))
289
+ # loss_items.append(loss_coordinates.unsqueeze(0))
290
+ loss_items.append(loss_fb_descs.unsqueeze(0))
291
+ loss_items.append(loss_fb_coordinates.unsqueeze(0))
292
+ loss_items.append(loss_kpts.unsqueeze(0))
293
+ loss_items.append(loss_normals.unsqueeze(0))
294
+
295
+ # nb_coarse = len(m1)
296
+ # nb_coarse = len(fb_m1)
297
+ loss = torch.cat(loss_items, -1).mean()
298
+
299
+ # Compute Backward Pass
300
+ loss.backward()
301
+ torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
302
+ self.opt.step()
303
+ self.opt.zero_grad()
304
+ self.scheduler.step()
305
+
306
+ # import pdb;pdb.set_trace()
307
+ if (i+1) % self.save_ckpt_every == 0:
308
+ print('saving iter ', i+1)
309
+ torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
310
+
311
+ pbar.set_description(
312
+ 'Loss: {:.4f} \
313
+ loss_descs: {:.3f} acc_coarse: {:.3f} \
314
+ loss_coordinates: {:.3f} acc_coordinates: {:.3f} \
315
+ loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \
316
+ loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \
317
+ loss_kpts: {:.3f} acc_kpts: {:.3f} \
318
+ loss_normals: {:.3f}'.format( \
319
+ loss.item(), \
320
+ loss_descs.item(), acc_coarse, \
321
+ loss_coordinates.item(), acc_coordinates, \
322
+ loss_fb_descs.item(), acc_fb_coarse, \
323
+ loss_fb_coordinates.item(), acc_fb_coordinates, \
324
+ loss_kpts.item(), acc_kpt, \
325
+ loss_normals.item()) )
326
+ pbar.update(1)
327
+
328
+ # Log metrics
329
+ self.writer.add_scalar('Loss/total', loss.item(), i)
330
+ self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i)
331
+ self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i)
332
+ self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i)
333
+ self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i)
334
+ self.writer.add_scalar('Loss/descs', loss_descs.item(), i)
335
+ self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i)
336
+ self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i)
337
+ self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i)
338
+ self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i)
339
+ self.writer.add_scalar('Loss/normals', loss_normals.item(), i)
340
+
341
+
342
+
343
+ if __name__ == '__main__':
344
+
345
+ setproctitle.setproctitle(args.name)
346
+
347
+ trainer = Trainer(
348
+ megadepth_root_path=args.megadepth_root_path,
349
+ use_megadepth=args.use_megadepth,
350
+ megadepth_batch_size=args.megadepth_batch_size,
351
+ coco_root_path=args.coco_root_path,
352
+ use_coco=args.use_coco,
353
+ coco_batch_size=args.coco_batch_size,
354
+ ckpt_save_path=args.ckpt_save_path,
355
+ n_steps=args.n_steps,
356
+ lr=args.lr,
357
+ gamma_steplr=args.gamma_steplr,
358
+ training_res=args.training_res,
359
+ device_num=args.device_num,
360
+ dry_run=args.dry_run,
361
+ save_ckpt_every=args.save_ckpt_every
362
+ )
363
+
364
+ #The most fun part
365
+ trainer.train()
imcui/third_party/LiftFeat/train.sh ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # default training
2
+ nohup python /home/yepeng_liu/code_python/LiftFeat/train.py \
3
+ --name LiftFeat_test \
4
+ --ckpt_save_path /home/yepeng_liu/code_python/LiftFeat/trained_weights/test \
5
+ --device_num 1 \
6
+ --use_megadepth \
7
+ --megadepth_batch_size 8 \
8
+ --use_coco \
9
+ --coco_batch_size 4 \
10
+ --save_ckpt_every 1000 \
11
+ > /home/yepeng_liu/code_python/LiftFeat/trained_weights/test/training.log 2>&1 &
imcui/third_party/LiftFeat/utils/__init__.py ADDED
File without changes
imcui/third_party/LiftFeat/utils/alike_wrapper.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ "LiftFeat: 3D Geometry-Aware Local Feature Matching"
3
+ """
4
+
5
+
6
+ import sys
7
+ import os
8
+
9
+ ALIKE_PATH = '/home/yepeng_liu/code_python/multimodal_remote/ALIKE'
10
+ sys.path.append(ALIKE_PATH)
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ from alike import ALike
15
+ import cv2
16
+ import numpy as np
17
+
18
+ import pdb
19
+
20
+ dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
+
22
+ configs = {
23
+ 'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
24
+ 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-t.pth')},
25
+ 'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
26
+ 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-s.pth')},
27
+ 'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
28
+ 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-n.pth')},
29
+ 'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
30
+ 'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-l.pth')},
31
+ }
32
+
33
+
34
+ class ALikeExtractor(nn.Module):
35
+ def __init__(self,model_type,device) -> None:
36
+ super().__init__()
37
+ self.net=ALike(**configs[model_type],device=device,top_k=4096,scores_th=0.1,n_limit=8000)
38
+
39
+
40
+ @torch.inference_mode()
41
+ def extract_alike_kpts(self,img):
42
+ pred0=self.net(img,sub_pixel=True)
43
+ return pred0['keypoints']
44
+
45
+
imcui/third_party/LiftFeat/utils/config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import numpy as np
4
+
5
+ featureboost_config = {
6
+ "keypoint_dim": 65,
7
+ "keypoint_encoder": [128, 64, 64],
8
+ "normal_dim": 192,
9
+ "normal_encoder": [128, 64, 64],
10
+ "descriptor_encoder": [64, 64],
11
+ "descriptor_dim": 64,
12
+ "Attentional_layers": 3,
13
+ "last_activation": None,
14
+ "l2_normalization": None,
15
+ "output_dim": 64,
16
+ }
imcui/third_party/LiftFeat/utils/depth_anything_wrapper.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import cv2
3
+ import glob
4
+ import matplotlib
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision.transforms import Compose
11
+ import sys
12
+
13
+ sys.path.append("/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2")
14
+ from depth_anything_v2.dpt_opt import DepthAnythingV2
15
+ from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
16
+
17
+ import time
18
+
19
+ VITS_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vits.pth"
20
+ VITB_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitb.pth"
21
+ VITL_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth"
22
+
23
+ model_configs = {
24
+ "vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
25
+ "vitb": {
26
+ "encoder": "vitb",
27
+ "features": 128,
28
+ "out_channels": [96, 192, 384, 768],
29
+ },
30
+ "vitl": {
31
+ "encoder": "vitl",
32
+ "features": 256,
33
+ "out_channels": [256, 512, 1024, 1024],
34
+ },
35
+ "vitg": {
36
+ "encoder": "vitg",
37
+ "features": 384,
38
+ "out_channels": [1536, 1536, 1536, 1536],
39
+ },
40
+ }
41
+
42
+ class DepthAnythingExtractor(nn.Module):
43
+ def __init__(self, encoder_type, device, input_size, process_size=(608,800)):
44
+ super().__init__()
45
+ self.net = DepthAnythingV2(**model_configs[encoder_type])
46
+ self.device = device
47
+ if encoder_type == "vits":
48
+ print(f"loading {VITS_MODEL_PATH}")
49
+ self.net.load_state_dict(torch.load(VITS_MODEL_PATH, map_location="cpu"))
50
+ elif encoder_type == "vitb":
51
+ print(f"loading {VITB_MODEL_PATH}")
52
+ self.net.load_state_dict(torch.load(VITB_MODEL_PATH, map_location="cpu"))
53
+ elif encoder_type == "vitl":
54
+ print(f"loading {VITL_MODEL_PATH}")
55
+ self.net.load_state_dict(torch.load(VITL_MODEL_PATH, map_location="cpu"))
56
+ else:
57
+ raise RuntimeError("unsupport encoder type")
58
+ self.net.to(self.device).eval()
59
+ self.tranform = Compose([
60
+ Resize(
61
+ width=input_size,
62
+ height=input_size,
63
+ resize_target=False,
64
+ keep_aspect_ratio=True,
65
+ ensure_multiple_of=14,
66
+ resize_method='lower_bound',
67
+ image_interpolation_method=cv2.INTER_CUBIC,
68
+ ),
69
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
70
+ PrepareForNet(),
71
+ ])
72
+ self.process_size=process_size
73
+ self.input_size=input_size
74
+
75
+ @torch.inference_mode()
76
+ def infer_image(self,img):
77
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
78
+
79
+ img = self.tranform({'image': img})['image']
80
+
81
+ img = torch.from_numpy(img).unsqueeze(0)
82
+
83
+ img = img.to(self.device)
84
+
85
+ with torch.no_grad():
86
+ depth = self.net.forward(img)
87
+
88
+ depth = F.interpolate(depth[:, None], self.process_size, mode="bilinear", align_corners=True)[0, 0]
89
+
90
+ return depth.cpu().numpy()
91
+
92
+ @torch.inference_mode()
93
+ def compute_normal_map_torch(self, depth_map, scale=1.0):
94
+ """
95
+ 通过深度图计算法向量 (PyTorch 实现)
96
+
97
+ 参数:
98
+ depth_map (torch.Tensor): 深度图,形状为 (H, W)
99
+ scale (float): 深度值的比例因子,用于调整深度图中的梯度计算
100
+
101
+ 返回:
102
+ torch.Tensor: 法向量图,形状为 (H, W, 3)
103
+ """
104
+ if depth_map.ndim != 2:
105
+ raise ValueError("输入 depth_map 必须是二维张量。")
106
+
107
+ # 计算深度图的梯度
108
+ dzdx = torch.diff(depth_map, dim=1, append=depth_map[:, -1:]) * scale
109
+ dzdy = torch.diff(depth_map, dim=0, append=depth_map[-1:, :]) * scale
110
+
111
+ # 初始化法向量图
112
+ H, W = depth_map.shape
113
+ normal_map = torch.zeros((H, W, 3), dtype=depth_map.dtype, device=depth_map.device)
114
+ normal_map[:, :, 0] = -dzdx # x 分量
115
+ normal_map[:, :, 1] = -dzdy # y 分量
116
+ normal_map[:, :, 2] = 1.0 # z 分量
117
+
118
+ # 归一化法向量
119
+ norm = torch.linalg.norm(normal_map, dim=2, keepdim=True)
120
+ norm = torch.where(norm == 0, torch.tensor(1.0, device=depth_map.device), norm) # 避免除以零
121
+ normal_map /= norm
122
+
123
+ return normal_map
124
+
125
+ @torch.inference_mode()
126
+ def extract(self, img):
127
+ depth = self.infer_image(img)
128
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
129
+ depth_t=torch.from_numpy(depth).float().to(self.device)
130
+ normal_map = self.compute_normal_map_torch(depth_t,1.0)
131
+ return depth_t,normal_map
132
+
133
+
134
+ if __name__=="__main__":
135
+ img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
136
+ img=cv2.imread(img_path)
137
+ img=cv2.resize(img,(800,608))
138
+ import pdb;pdb.set_trace()
139
+ DAExtractor=DepthAnythingExtractor('vitb',torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),256)
140
+ depth_t,norm=DAExtractor.extract(img)
141
+ norm=norm.cpu().numpy()
142
+ norm=(norm+1)/2*255
143
+ norm=norm.astype(np.uint8)
144
+ cv2.imwrite(os.path.join(os.path.dirname(__file__),"norm.png"),norm)
145
+ start=time.perf_counter()
146
+ for i in range(20):
147
+ depth_t,norm=DAExtractor.extract(img)
148
+ end=time.perf_counter()
149
+ print(f"cost {end-start} seconds")
150
+
imcui/third_party/LiftFeat/utils/featurebooster.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def MLP(channels: List[int], do_bn: bool = False) -> nn.Module:
9
+ """ Multi-layer perceptron """
10
+ n = len(channels)
11
+ layers = []
12
+ for i in range(1, n):
13
+ layers.append(nn.Linear(channels[i - 1], channels[i]))
14
+ if i < (n-1):
15
+ if do_bn:
16
+ layers.append(nn.BatchNorm1d(channels[i]))
17
+ layers.append(nn.ReLU())
18
+ return nn.Sequential(*layers)
19
+
20
+ def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module:
21
+ """ Multi-layer perceptron """
22
+ n = len(channels)
23
+ layers = []
24
+ for i in range(1, n):
25
+ layers.append(nn.Linear(channels[i - 1], channels[i]))
26
+ if i < (n-1):
27
+ if do_bn:
28
+ layers.append(nn.BatchNorm1d(channels[i]))
29
+ return nn.Sequential(*layers)
30
+
31
+
32
+ class KeypointEncoder(nn.Module):
33
+ """ Encoding of geometric properties using MLP """
34
+ def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
35
+ super().__init__()
36
+ self.encoder = MLP([keypoint_dim] + layers + [feature_dim])
37
+ self.use_dropout = dropout
38
+ self.dropout = nn.Dropout(p=p)
39
+
40
+ def forward(self, kpts):
41
+ if self.use_dropout:
42
+ return self.dropout(self.encoder(kpts))
43
+ return self.encoder(kpts)
44
+
45
+ class NormalEncoder(nn.Module):
46
+ """ Encoding of geometric properties using MLP """
47
+ def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
48
+ super().__init__()
49
+ self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim])
50
+ self.use_dropout = dropout
51
+ self.dropout = nn.Dropout(p=p)
52
+
53
+ def forward(self, kpts):
54
+ if self.use_dropout:
55
+ return self.dropout(self.encoder(kpts))
56
+ return self.encoder(kpts)
57
+
58
+
59
+ class DescriptorEncoder(nn.Module):
60
+ """ Encoding of visual descriptor using MLP """
61
+ def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
62
+ super().__init__()
63
+ self.encoder = MLP([feature_dim] + layers + [feature_dim])
64
+ self.use_dropout = dropout
65
+ self.dropout = nn.Dropout(p=p)
66
+
67
+ def forward(self, descs):
68
+ residual = descs
69
+ if self.use_dropout:
70
+ return residual + self.dropout(self.encoder(descs))
71
+ return residual + self.encoder(descs)
72
+
73
+
74
+ class AFTAttention(nn.Module):
75
+ """ Attention-free attention """
76
+ def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None:
77
+ super().__init__()
78
+ self.dim = d_model
79
+ self.query = nn.Linear(d_model, d_model)
80
+ self.key = nn.Linear(d_model, d_model)
81
+ self.value = nn.Linear(d_model, d_model)
82
+ self.proj = nn.Linear(d_model, d_model)
83
+ # self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
84
+ self.use_dropout = dropout
85
+ self.dropout = nn.Dropout(p=p)
86
+
87
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ residual = x
89
+ q = self.query(x)
90
+ k = self.key(x)
91
+ v = self.value(x)
92
+ # q = torch.sigmoid(q)
93
+ k = k.T
94
+ k = torch.softmax(k, dim=-1)
95
+ k = k.T
96
+ kv = (k * v).sum(dim=-2, keepdim=True)
97
+ x = q * kv
98
+ x = self.proj(x)
99
+ if self.use_dropout:
100
+ x = self.dropout(x)
101
+ x += residual
102
+ # x = self.layer_norm(x)
103
+ return x
104
+
105
+
106
+ class PositionwiseFeedForward(nn.Module):
107
+ def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None:
108
+ super().__init__()
109
+ self.mlp = MLP([feature_dim, feature_dim*2, feature_dim])
110
+ # self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6)
111
+ self.use_dropout = dropout
112
+ self.dropout = nn.Dropout(p=p)
113
+
114
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
115
+ residual = x
116
+ x = self.mlp(x)
117
+ if self.use_dropout:
118
+ x = self.dropout(x)
119
+ x += residual
120
+ # x = self.layer_norm(x)
121
+ return x
122
+
123
+
124
+ class AttentionalLayer(nn.Module):
125
+ def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1):
126
+ super().__init__()
127
+ self.attn = AFTAttention(feature_dim, dropout=dropout, p=p)
128
+ self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p)
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ # import pdb;pdb.set_trace()
132
+ x = self.attn(x)
133
+ x = self.ffn(x)
134
+ return x
135
+
136
+
137
+ class AttentionalNN(nn.Module):
138
+ def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None:
139
+ super().__init__()
140
+ self.layers = nn.ModuleList([
141
+ AttentionalLayer(feature_dim, dropout=dropout, p=p)
142
+ for _ in range(layer_num)])
143
+
144
+ def forward(self, desc: torch.Tensor) -> torch.Tensor:
145
+ for layer in self.layers:
146
+ desc = layer(desc)
147
+ return desc
148
+
149
+
150
+ class FeatureBooster(nn.Module):
151
+ default_config = {
152
+ 'descriptor_dim': 128,
153
+ 'keypoint_encoder': [32, 64, 128],
154
+ 'Attentional_layers': 3,
155
+ 'last_activation': 'relu',
156
+ 'l2_normalization': True,
157
+ 'output_dim': 128
158
+ }
159
+
160
+ def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True):
161
+ super().__init__()
162
+ self.config = {**self.default_config, **config}
163
+ self.use_kenc = use_kenc
164
+ self.use_cross = use_cross
165
+ self.use_normal = use_normal
166
+
167
+ if use_kenc:
168
+ self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout)
169
+
170
+ if use_normal:
171
+ self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout)
172
+
173
+ if self.config.get('descriptor_encoder', False):
174
+ self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout)
175
+ else:
176
+ self.denc = None
177
+
178
+ if self.use_cross:
179
+ self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout)
180
+
181
+ # self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim'])
182
+
183
+ self.use_dropout = dropout
184
+ self.dropout = nn.Dropout(p=p)
185
+
186
+ # self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6)
187
+
188
+ if self.config.get('last_activation', False):
189
+ if self.config['last_activation'].lower() == 'relu':
190
+ self.last_activation = nn.ReLU()
191
+ elif self.config['last_activation'].lower() == 'sigmoid':
192
+ self.last_activation = nn.Sigmoid()
193
+ elif self.config['last_activation'].lower() == 'tanh':
194
+ self.last_activation = nn.Tanh()
195
+ else:
196
+ raise Exception('Not supported activation "%s".' % self.config['last_activation'])
197
+ else:
198
+ self.last_activation = None
199
+
200
+ def forward(self, desc, kpts, normals):
201
+ # import pdb;pdb.set_trace()
202
+ ## Self boosting
203
+ # Descriptor MLP encoder
204
+ if self.denc is not None:
205
+ desc = self.denc(desc)
206
+ # Geometric MLP encoder
207
+ if self.use_kenc:
208
+ desc = desc + self.kenc(kpts)
209
+ if self.use_dropout:
210
+ desc = self.dropout(desc)
211
+
212
+ # 法向量特征 encoder
213
+ if self.use_normal:
214
+ desc = desc + self.nenc(normals)
215
+ if self.use_dropout:
216
+ desc = self.dropout(desc)
217
+
218
+ ## Cross boosting
219
+ # Multi-layer Transformer network.
220
+ if self.use_cross:
221
+ # desc = self.attn_proj(self.layer_norm(desc))
222
+ desc = self.attn_proj(desc)
223
+
224
+ ## Post processing
225
+ # Final MLP projection
226
+ # desc = self.final_proj(desc)
227
+ if self.last_activation is not None:
228
+ desc = self.last_activation(desc)
229
+ # L2 normalization
230
+ if self.config['l2_normalization']:
231
+ desc = F.normalize(desc, dim=-1)
232
+
233
+ return desc
234
+
235
+ if __name__ == "__main__":
236
+ from config import t1_featureboost_config
237
+ fb_net = FeatureBooster(t1_featureboost_config)
238
+
239
+ descs=torch.randn([1900,64])
240
+ kpts=torch.randn([1900,65])
241
+ normals=torch.randn([1900,3])
242
+
243
+ import pdb;pdb.set_trace()
244
+
245
+ descs_refine=fb_net(descs,kpts,normals)
246
+
247
+ print(descs_refine.shape)
imcui/third_party/LiftFeat/weights/LiftFeat.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0da33b2216bde964989f3d13e9b9c9cbdd65c98fb05fb4d4771b7d2f3a807c8b
3
+ size 8086947