Spaces:
Running
Running
add: liftfeat
Browse files- README.md +1 -0
- config/config.yaml +11 -0
- imcui/hloc/extract_features.py +11 -0
- imcui/hloc/extractors/liftfeat.py +57 -0
- imcui/third_party/LiftFeat/.gitignore +4 -0
- imcui/third_party/LiftFeat/README.md +141 -0
- imcui/third_party/LiftFeat/assert/achitecture.png +3 -0
- imcui/third_party/LiftFeat/assert/demo_liftfeat.gif +3 -0
- imcui/third_party/LiftFeat/assert/demo_sp.gif +3 -0
- imcui/third_party/LiftFeat/assert/query.jpg +3 -0
- imcui/third_party/LiftFeat/assert/ref.jpg +3 -0
- imcui/third_party/LiftFeat/data/megadepth_1500.json +0 -0
- imcui/third_party/LiftFeat/dataset/__init__.py +0 -0
- imcui/third_party/LiftFeat/dataset/coco_augmentor.py +298 -0
- imcui/third_party/LiftFeat/dataset/coco_wrapper.py +175 -0
- imcui/third_party/LiftFeat/dataset/dataset_utils.py +183 -0
- imcui/third_party/LiftFeat/dataset/megadepth.py +177 -0
- imcui/third_party/LiftFeat/dataset/megadepth_wrapper.py +167 -0
- imcui/third_party/LiftFeat/demo.py +68 -0
- imcui/third_party/LiftFeat/evaluation/HPatch_evaluation.py +182 -0
- imcui/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py +105 -0
- imcui/third_party/LiftFeat/evaluation/eval_utils.py +127 -0
- imcui/third_party/LiftFeat/loss/loss.py +291 -0
- imcui/third_party/LiftFeat/models/interpolator.py +34 -0
- imcui/third_party/LiftFeat/models/liftfeat.py +190 -0
- imcui/third_party/LiftFeat/models/liftfeat_wrapper.py +173 -0
- imcui/third_party/LiftFeat/models/model.py +419 -0
- imcui/third_party/LiftFeat/requirements.txt +18 -0
- imcui/third_party/LiftFeat/train.py +365 -0
- imcui/third_party/LiftFeat/train.sh +11 -0
- imcui/third_party/LiftFeat/utils/__init__.py +0 -0
- imcui/third_party/LiftFeat/utils/alike_wrapper.py +45 -0
- imcui/third_party/LiftFeat/utils/config.py +16 -0
- imcui/third_party/LiftFeat/utils/depth_anything_wrapper.py +150 -0
- imcui/third_party/LiftFeat/utils/featurebooster.py +247 -0
- imcui/third_party/LiftFeat/weights/LiftFeat.pth +3 -0
README.md
CHANGED
@@ -45,6 +45,7 @@ The tool currently supports various popular image matching algorithms, namely:
|
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
47 |
| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
|
|
|
48 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
49 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
50 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
|
|
45 |
| Algorithm | Supported | Conference/Journal | Year | GitHub Link |
|
46 |
|------------------|-----------|--------------------|------|-------------|
|
47 |
| DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
|
48 |
+
| LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
|
49 |
| MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
|
50 |
| XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
|
51 |
| EfficientLoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/zju3dv/EfficientLoFTR) |
|
config/config.yaml
CHANGED
@@ -256,6 +256,17 @@ matcher_zoo:
|
|
256 |
paper: https://arxiv.org/abs/2404.19174
|
257 |
project: null
|
258 |
display: false
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
259 |
dedode:
|
260 |
matcher: Dual-Softmax
|
261 |
feature: dedode
|
|
|
256 |
paper: https://arxiv.org/abs/2404.19174
|
257 |
project: null
|
258 |
display: false
|
259 |
+
liftfeat(sparse):
|
260 |
+
matcher: NN-mutual
|
261 |
+
feature: liftfeat
|
262 |
+
dense: false
|
263 |
+
info:
|
264 |
+
name: LiftFeat #dispaly name
|
265 |
+
source: "ICRA 2025"
|
266 |
+
github: https://github.com/lyp-deeplearning/LiftFeat
|
267 |
+
paper: https://arxiv.org/abs/2505.0342
|
268 |
+
project: null
|
269 |
+
display: true
|
270 |
dedode:
|
271 |
matcher: Dual-Softmax
|
272 |
feature: dedode
|
imcui/hloc/extract_features.py
CHANGED
@@ -214,6 +214,17 @@ confs = {
|
|
214 |
"resize_max": 1600,
|
215 |
},
|
216 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
"aliked-n16-rot": {
|
218 |
"output": "feats-aliked-n16-rot",
|
219 |
"model": {
|
|
|
214 |
"resize_max": 1600,
|
215 |
},
|
216 |
},
|
217 |
+
"liftfeat": {
|
218 |
+
"output": "feats-liftfeat-n5000-r1600",
|
219 |
+
"model": {
|
220 |
+
"name": "liftfeat",
|
221 |
+
"max_keypoints": 5000,
|
222 |
+
},
|
223 |
+
"preprocessing": {
|
224 |
+
"grayscale": False,
|
225 |
+
"resize_max": 1600,
|
226 |
+
},
|
227 |
+
},
|
228 |
"aliked-n16-rot": {
|
229 |
"output": "feats-aliked-n16-rot",
|
230 |
"model": {
|
imcui/hloc/extractors/liftfeat.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import sys
|
3 |
+
from pathlib import Path
|
4 |
+
import torch
|
5 |
+
import random
|
6 |
+
from ..utils.base_model import BaseModel
|
7 |
+
from .. import logger
|
8 |
+
|
9 |
+
fire_path = Path(__file__).parent / "../../third_party/LiftFeat"
|
10 |
+
sys.path.append(str(fire_path))
|
11 |
+
|
12 |
+
from models.liftfeat_wrapper import LiftFeat, MODEL_PATH
|
13 |
+
|
14 |
+
|
15 |
+
def select_idx(N, M):
|
16 |
+
numbers = list(range(1, N + 1))
|
17 |
+
selected = random.sample(numbers, M)
|
18 |
+
return selected
|
19 |
+
|
20 |
+
|
21 |
+
class Liftfeat(BaseModel):
|
22 |
+
default_conf = {
|
23 |
+
"keypoint_threshold": 0.05,
|
24 |
+
"max_keypoints": 5000,
|
25 |
+
}
|
26 |
+
|
27 |
+
required_inputs = ["image"]
|
28 |
+
|
29 |
+
def _init(self, conf):
|
30 |
+
logger.info("Loading LiftFeat model...")
|
31 |
+
self.net = LiftFeat(
|
32 |
+
weight=MODEL_PATH,
|
33 |
+
detect_threshold=self.conf["keypoint_threshold"],
|
34 |
+
top_k=self.conf["max_keypoints"],
|
35 |
+
)
|
36 |
+
logger.info("Loading LiftFeat model done!")
|
37 |
+
|
38 |
+
def _forward(self, data):
|
39 |
+
image = data["image"].cpu().numpy().squeeze() * 255
|
40 |
+
image = image.transpose(1, 2, 0)
|
41 |
+
pred = self.net.extract(image)
|
42 |
+
|
43 |
+
keypoints = pred["keypoints"]
|
44 |
+
descriptors = pred["descriptors"]
|
45 |
+
scores = torch.ones_like(pred["keypoints"][:, 0])
|
46 |
+
if self.conf["max_keypoints"] < len(keypoints):
|
47 |
+
idxs = select_idx(len(keypoints), self.conf["max_keypoints"])
|
48 |
+
keypoints = keypoints[idxs, :2]
|
49 |
+
descriptors = descriptors[idxs]
|
50 |
+
scores = scores[idxs]
|
51 |
+
|
52 |
+
pred = {
|
53 |
+
"keypoints": keypoints[None],
|
54 |
+
"descriptors": descriptors[None].permute(0, 2, 1),
|
55 |
+
"scores": scores[None],
|
56 |
+
}
|
57 |
+
return pred
|
imcui/third_party/LiftFeat/.gitignore
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
visualize
|
2 |
+
trained_weights
|
3 |
+
data/HPatch
|
4 |
+
data/megadepth_test_1500
|
imcui/third_party/LiftFeat/README.md
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## LiftFeat: 3D Geometry-Aware Local Feature Matching
|
2 |
+
<div align="center" style="display: flex; justify-content: center; align-items: center; flex-direction: column;">
|
3 |
+
<div style="display: flex; justify-content: space-around; width: 100%;">
|
4 |
+
<img src='./assert/demo_sp.gif' width="400"/>
|
5 |
+
<img src='./assert/demo_liftfeat.gif' width="400"/>
|
6 |
+
</div>
|
7 |
+
|
8 |
+
Real-time SuperPoint demonstration (left) compared to LiftFeat (right) on a textureless scene.
|
9 |
+
|
10 |
+
</div>
|
11 |
+
|
12 |
+
- 🎉 **New!** Training code is now available 🚀
|
13 |
+
- 🎉 **New!** The test code and pretrained model have been released. 🚀
|
14 |
+
|
15 |
+
## Table of Contents
|
16 |
+
- [Introduction](#introduction)
|
17 |
+
- [Installation](#installation)
|
18 |
+
- [Usage](#usage)
|
19 |
+
- [Inference](#inference)
|
20 |
+
- [Training](#training)
|
21 |
+
- [Evaluation](#evaluation)
|
22 |
+
- [Citation](#citation)
|
23 |
+
- [License](#license)
|
24 |
+
|
25 |
+
## Introduction
|
26 |
+
This repository contains the official implementation of the paper:
|
27 |
+
**[LiftFeat: 3D Geometry-Aware Local Feature Matching](https://www.arxiv.org/abs/2505.03422)**, to be presented at *ICRA 2025*.
|
28 |
+
|
29 |
+
**Overview of LiftFeat's achitecture**
|
30 |
+
<div style="background-color:white">
|
31 |
+
<img align="center" src="./assert/achitecture.png" width=1000 />
|
32 |
+
</div>
|
33 |
+
|
34 |
+
LiftFeat is a lightweight and robust local feature matching network designed to handle challenging scenarios such as drastic lighting changes, low-texture regions, and repetitive patterns. By incorporating 3D geometric cues through surface normals predicted from monocular depth, LiftFeat enhances the discriminative power of 2D descriptors. Our proposed 3D geometry-aware feature lifting module effectively fuses these cues, leading to significant improvements in tasks like relative pose estimation, homography estimation, and visual localization.
|
35 |
+
|
36 |
+
## Installation
|
37 |
+
If you use conda as virtual environment,you can create a new env with:
|
38 |
+
```bash
|
39 |
+
git clone https://github.com/lyp-deeplearning/LiftFeat.git
|
40 |
+
cd LiftFeat
|
41 |
+
conda create -n LiftFeat python=3.8
|
42 |
+
conda activate LiftFeat
|
43 |
+
pip install -r requirements.txt
|
44 |
+
```
|
45 |
+
|
46 |
+
## Usage
|
47 |
+
### Inference
|
48 |
+
To run LiftFeat on an image,you can simply run with:
|
49 |
+
```bash
|
50 |
+
python demo.py --img1=<reference image> --img2=<query image>
|
51 |
+
```
|
52 |
+
|
53 |
+
### Training
|
54 |
+
To train LiftFeat as described in the paper, you will need MegaDepth & COCO_20k subset of COCO2017 dataset as described in the paper *[XFeat: Accelerated Features for Lightweight Image Matching](https://arxiv.org/abs/2404.19174)*
|
55 |
+
You can obtain the full COCO2017 train data at https://cocodataset.org/.
|
56 |
+
However, we [make available](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) a subset of COCO for convenience. We simply selected a subset of 20k images according to image resolution. Please check COCO [terms of use](https://cocodataset.org/#termsofuse) before using the data.
|
57 |
+
|
58 |
+
To reproduce the training setup from the paper, please follow the steps:
|
59 |
+
1. Download [COCO_20k](https://drive.google.com/file/d/1ijYsPq7dtLQSl-oEsUOGH1fAy21YLc7H/view?usp=drive_link) containing a subset of COCO2017;
|
60 |
+
2. Download MegaDepth dataset. You can follow [LoFTR instructions](https://github.com/zju3dv/LoFTR/blob/master/docs/TRAINING.md#download-datasets), we use the same standard as LoFTR. Then put the megadepth indices inside the MegaDepth root folder following the standard below:
|
61 |
+
```bash
|
62 |
+
{megadepth_root_path}/train_data/megadepth_indices #indices
|
63 |
+
{megadepth_root_path}/MegaDepth_v1 #images & depth maps & poses
|
64 |
+
```
|
65 |
+
3. Finally you can call training
|
66 |
+
```bash
|
67 |
+
python train.py --megadepth_root_path <path_to>/MegaDepth --synthetic_root_path <path_to>/coco_20k --ckpt_save_path /path/to/ckpts
|
68 |
+
```
|
69 |
+
|
70 |
+
### Evaluation
|
71 |
+
All evaluation code are in *evaluation*, you can download **HPatch** dataset following [D2-Net](https://github.com/mihaidusmanu/d2-net/tree/master) and download **MegaDepth** test dataset following [LoFTR](https://github.com/zju3dv/LoFTR/tree/master).
|
72 |
+
|
73 |
+
**Download and process HPatch**
|
74 |
+
```bash
|
75 |
+
cd /data
|
76 |
+
|
77 |
+
# Download the dataset
|
78 |
+
wget https://huggingface.co/datasets/vbalnt/hpatches/resolve/main/hpatches-sequences-release.zip
|
79 |
+
|
80 |
+
# Extract the dataset
|
81 |
+
unzip hpatches-sequences-release.zip
|
82 |
+
|
83 |
+
# Remove the high-resolution sequences
|
84 |
+
cd hpatches-sequences-release
|
85 |
+
rm -rf i_contruction i_crownnight i_dc i_pencils i_whitebuilding v_artisans v_astronautis v_talent
|
86 |
+
|
87 |
+
cd <LiftFeat>/data
|
88 |
+
|
89 |
+
ln -s /data/hpatches-sequences-release ./HPatch
|
90 |
+
```
|
91 |
+
|
92 |
+
**Download and process MegaDepth1500**
|
93 |
+
We provide download link to [megadepth_test_1500](https://drive.google.com/drive/folders/1nTkK1485FuwqA0DbZrK2Cl0WnXadUZdc)
|
94 |
+
```bash
|
95 |
+
tar xvf <path to megadepth_test_1500.tar>
|
96 |
+
|
97 |
+
cd <LiftFeat>/data
|
98 |
+
|
99 |
+
ln -s <path to megadepth_test_1500> ./megadepth_test_1500
|
100 |
+
```
|
101 |
+
|
102 |
+
|
103 |
+
**Homography Estimation**
|
104 |
+
```bash
|
105 |
+
python evaluation/HPatch_evaluation.py
|
106 |
+
```
|
107 |
+
|
108 |
+
**Relative Pose Estimation**
|
109 |
+
|
110 |
+
For *Megadepth1500* dataset:
|
111 |
+
```bash
|
112 |
+
python evaluation/MegaDepth1500_evaluation.py
|
113 |
+
```
|
114 |
+
|
115 |
+
|
116 |
+
## Citation
|
117 |
+
If you find this code useful for your research, please cite the paper:
|
118 |
+
```bibtex
|
119 |
+
@misc{liu2025liftfeat3dgeometryawarelocal,
|
120 |
+
title={LiftFeat: 3D Geometry-Aware Local Feature Matching},
|
121 |
+
author={Yepeng Liu and Wenpeng Lai and Zhou Zhao and Yuxuan Xiong and Jinchi Zhu and Jun Cheng and Yongchao Xu},
|
122 |
+
year={2025},
|
123 |
+
eprint={2505.03422},
|
124 |
+
archivePrefix={arXiv},
|
125 |
+
primaryClass={cs.CV},
|
126 |
+
url={https://arxiv.org/abs/2505.03422},
|
127 |
+
}
|
128 |
+
```
|
129 |
+
|
130 |
+
## License
|
131 |
+
[](LICENSE)
|
132 |
+
|
133 |
+
|
134 |
+
## Acknowledgements
|
135 |
+
We would like to thank the authors of the following open-source repositories for their valuable contributions, which have inspired or supported this work:
|
136 |
+
|
137 |
+
- [verlab/accelerated_features](https://github.com/verlab/accelerated_features)
|
138 |
+
- [zju3dv/LoFTR](https://github.com/zju3dv/LoFTR)
|
139 |
+
- [rpautrat/SuperPoint](https://github.com/rpautrat/SuperPoint)
|
140 |
+
|
141 |
+
We deeply appreciate the efforts of the research community in releasing high-quality codebases.
|
imcui/third_party/LiftFeat/assert/achitecture.png
ADDED
![]() |
Git LFS Details
|
imcui/third_party/LiftFeat/assert/demo_liftfeat.gif
ADDED
![]() |
Git LFS Details
|
imcui/third_party/LiftFeat/assert/demo_sp.gif
ADDED
![]() |
Git LFS Details
|
imcui/third_party/LiftFeat/assert/query.jpg
ADDED
![]() |
Git LFS Details
|
imcui/third_party/LiftFeat/assert/ref.jpg
ADDED
![]() |
Git LFS Details
|
imcui/third_party/LiftFeat/data/megadepth_1500.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
imcui/third_party/LiftFeat/dataset/__init__.py
ADDED
File without changes
|
imcui/third_party/LiftFeat/dataset/coco_augmentor.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
COCO_20k image augmentor
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
from torch import nn
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import torch.utils.data as data
|
10 |
+
from torchvision import transforms
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
import cv2
|
14 |
+
import kornia
|
15 |
+
import kornia.augmentation as K
|
16 |
+
from kornia.geometry.transform import get_tps_transform as findTPS
|
17 |
+
from kornia.geometry.transform import warp_points_tps, warp_image_tps
|
18 |
+
|
19 |
+
import glob
|
20 |
+
import random
|
21 |
+
import tqdm
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import pdb
|
25 |
+
import time
|
26 |
+
|
27 |
+
random.seed(0)
|
28 |
+
torch.manual_seed(0)
|
29 |
+
|
30 |
+
def generateRandomTPS(shape,grid=(8,6),GLOBAL_MULTIPLIER=0.3,prob=0.5):
|
31 |
+
|
32 |
+
h, w = shape
|
33 |
+
sh, sw = h/grid[0], w/grid[1]
|
34 |
+
src = torch.dstack(torch.meshgrid(torch.arange(0, h + sh , sh), torch.arange(0, w + sw , sw), indexing='ij'))
|
35 |
+
|
36 |
+
offsets = torch.rand(grid[0]+1, grid[1]+1, 2) - 0.5
|
37 |
+
offsets *= torch.tensor([ sh/2, sw/2 ]).view(1, 1, 2) * min(0.97, 2.0 * GLOBAL_MULTIPLIER)
|
38 |
+
dst = src + offsets if np.random.uniform() < prob else src
|
39 |
+
|
40 |
+
src, dst = src.view(1, -1, 2), dst.view(1, -1, 2)
|
41 |
+
src = (src / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
|
42 |
+
dst = (dst / torch.tensor([h,w]).view(1,1,2) ) * 2 - 1.
|
43 |
+
weights, A = findTPS(dst, src)
|
44 |
+
|
45 |
+
return src, weights, A
|
46 |
+
|
47 |
+
|
48 |
+
def generateRandomHomography(shape,GLOBAL_MULTIPLIER=0.3):
|
49 |
+
#Generate random in-plane rotation [-theta,+theta]
|
50 |
+
theta = np.radians(np.random.uniform(-30, 30))
|
51 |
+
|
52 |
+
#Generate random scale in both x and y
|
53 |
+
scale_x, scale_y = np.random.uniform(0.35, 1.2, 2)
|
54 |
+
|
55 |
+
#Generate random translation shift
|
56 |
+
tx , ty = -shape[1]/2.0 , -shape[0]/2.0
|
57 |
+
txn, tyn = np.random.normal(0, 120.0*GLOBAL_MULTIPLIER, 2)
|
58 |
+
|
59 |
+
c, s = np.cos(theta), np.sin(theta)
|
60 |
+
|
61 |
+
#Affine coeffs
|
62 |
+
sx , sy = np.random.normal(0,0.6*GLOBAL_MULTIPLIER,2)
|
63 |
+
|
64 |
+
#Projective coeffs
|
65 |
+
p1 , p2 = np.random.normal(0,0.006*GLOBAL_MULTIPLIER,2)
|
66 |
+
|
67 |
+
|
68 |
+
# Build Homography from parmeterizations
|
69 |
+
H_t = np.array(((1,0, tx), (0, 1, ty), (0,0,1))) #t
|
70 |
+
H_r = np.array(((c,-s, 0), (s, c, 0), (0,0,1))) #rotation,
|
71 |
+
H_a = np.array(((1,sy, 0), (sx, 1, 0), (0,0,1))) # affine
|
72 |
+
H_p = np.array(((1, 0, 0), (0 , 1, 0), (p1,p2,1))) # projective
|
73 |
+
H_s = np.array(((scale_x,0, 0), (0, scale_y, 0), (0,0,1))) #scale
|
74 |
+
H_b = np.array(((1.0,0,-tx +txn), (0, 1, -ty + tyn), (0,0,1))) #t_back,
|
75 |
+
|
76 |
+
#H = H_e * H_s * H_a * H_p
|
77 |
+
H = np.dot(np.dot(np.dot(np.dot(np.dot(H_b,H_s),H_p),H_a),H_r),H_t)
|
78 |
+
|
79 |
+
return H
|
80 |
+
|
81 |
+
|
82 |
+
class COCOAugmentor(nn.Module):
|
83 |
+
|
84 |
+
def __init__(self,device,load_dataset=True,
|
85 |
+
img_dir="/home/yepeng_liu/code_python/dataset/coco_20k",
|
86 |
+
warp_resolution=(1200, 900),
|
87 |
+
out_resolution=(400, 300),
|
88 |
+
sides_crop=0.2,
|
89 |
+
max_num_imgs=50,
|
90 |
+
num_test_imgs=10,
|
91 |
+
batch_size=1,
|
92 |
+
photometric=True,
|
93 |
+
geometric=True,
|
94 |
+
reload_step=1_000
|
95 |
+
):
|
96 |
+
super(COCOAugmentor,self).__init__()
|
97 |
+
self.half=16
|
98 |
+
self.device=device
|
99 |
+
|
100 |
+
self.dims=warp_resolution
|
101 |
+
self.batch_size=batch_size
|
102 |
+
self.out_resolution=out_resolution
|
103 |
+
self.sides_crop=sides_crop
|
104 |
+
self.max_num_imgs=max_num_imgs
|
105 |
+
self.num_test_imgs=num_test_imgs
|
106 |
+
self.dims_t=torch.tensor([int(self.dims[0]*(1. - self.sides_crop)) - int(self.dims[0]*self.sides_crop) -1,
|
107 |
+
int(self.dims[1]*(1. - self.sides_crop)) - int(self.dims[1]*self.sides_crop) -1]).float().to(device).view(1,1,2)
|
108 |
+
self.dims_s=torch.tensor([self.dims_t[0,0,0] / out_resolution[0],
|
109 |
+
self.dims_t[0,0,1] / out_resolution[1]]).float().to(device).view(1,1,2)
|
110 |
+
|
111 |
+
self.all_imgs=glob.glob(img_dir+'/*.jpg')+glob.glob(img_dir+'/*.png')
|
112 |
+
|
113 |
+
self.photometric=photometric
|
114 |
+
self.geometric=geometric
|
115 |
+
self.cnt=1
|
116 |
+
self.reload_step=reload_step
|
117 |
+
|
118 |
+
list_augmentation=[
|
119 |
+
kornia.augmentation.ColorJitter(0.15,0.15,0.15,0.15,p=1.),
|
120 |
+
kornia.augmentation.RandomEqualize(p=0.4),
|
121 |
+
kornia.augmentation.RandomGaussianBlur(p=0.3,sigma=(2.0,2.0),kernel_size=(7,7))
|
122 |
+
]
|
123 |
+
|
124 |
+
if photometric is False:
|
125 |
+
list_augmentation = []
|
126 |
+
|
127 |
+
self.aug_list=kornia.augmentation.ImageSequential(*list_augmentation)
|
128 |
+
|
129 |
+
if len(self.all_imgs)<10:
|
130 |
+
raise RuntimeError('Couldnt find enough images to train. Please check the path: ',img_dir)
|
131 |
+
|
132 |
+
if load_dataset:
|
133 |
+
print('[COCO]: ',len(self.all_imgs),' images for training..')
|
134 |
+
if len(self.all_imgs) - num_test_imgs < max_num_imgs:
|
135 |
+
raise RuntimeError('Error: test set overlaps with training set! Decrease number of test imgs')
|
136 |
+
|
137 |
+
self.load_imgs()
|
138 |
+
|
139 |
+
self.TPS = True
|
140 |
+
|
141 |
+
|
142 |
+
def load_imgs(self):
|
143 |
+
random.shuffle(self.all_imgs)
|
144 |
+
train = []
|
145 |
+
for p in tqdm.tqdm(self.all_imgs[:self.max_num_imgs],desc='loading train'):
|
146 |
+
im=cv2.imread(p)
|
147 |
+
halfH,halfW=im.shape[0]//2,im.shape[1]//2
|
148 |
+
if halfH>halfW:
|
149 |
+
im=np.rot90(im)
|
150 |
+
halfH,halfW=halfW,halfH
|
151 |
+
|
152 |
+
if im.shape[0]!=self.dims[1] or im.shape[1]!=self.dims[0]:
|
153 |
+
im = cv2.resize(im, self.dims)
|
154 |
+
|
155 |
+
train.append(np.copy(im))
|
156 |
+
|
157 |
+
self.train=train
|
158 |
+
self.test=[
|
159 |
+
cv2.resize(cv2.imread(p),self.dims)
|
160 |
+
for p in tqdm.tqdm(self.all_imgs[-self.num_test_imgs:],desc='loading test')
|
161 |
+
]
|
162 |
+
|
163 |
+
def norm_pts_grid(self, x):
|
164 |
+
if len(x.size()) == 2:
|
165 |
+
return (x.view(1,-1,2) * self.dims_s / self.dims_t) * 2. - 1
|
166 |
+
return (x * self.dims_s / self.dims_t) * 2. - 1
|
167 |
+
|
168 |
+
def denorm_pts_grid(self, x):
|
169 |
+
if len(x.size()) == 2:
|
170 |
+
return ((x.view(1,-1,2) + 1) / 2.) / self.dims_s * self.dims_t
|
171 |
+
return ((x+1) / 2.) / self.dims_s * self.dims_t
|
172 |
+
|
173 |
+
def rnd_kps(self, shape, n = 256):
|
174 |
+
h, w = shape
|
175 |
+
kps = torch.rand(size = (3,n)).to(self.device)
|
176 |
+
kps[0,:]*=w
|
177 |
+
kps[1,:]*=h
|
178 |
+
kps[2,:] = 1.0
|
179 |
+
|
180 |
+
return kps
|
181 |
+
|
182 |
+
def warp_points(self, H, pts):
|
183 |
+
scale = self.dims_s.view(-1,2)
|
184 |
+
offset = torch.tensor([int(self.dims[0]*self.sides_crop), int(self.dims[1]*self.sides_crop)], device = pts.device).float()
|
185 |
+
pts = pts*scale + offset
|
186 |
+
pts = torch.vstack( [pts.t(), torch.ones(1, pts.shape[0], device = pts.device)])
|
187 |
+
warped = torch.matmul(H, pts)
|
188 |
+
warped = warped / warped[2,...]
|
189 |
+
warped = warped.t()[:, :2]
|
190 |
+
return (warped - offset) / scale
|
191 |
+
|
192 |
+
@torch.inference_mode()
|
193 |
+
def forward(self, x, difficulty = 0.3, TPS = False, prob_deformation = 0.5, test = False):
|
194 |
+
"""
|
195 |
+
Perform augmentation to a batch of images.
|
196 |
+
|
197 |
+
input:
|
198 |
+
x -> torch.Tensor(B, C, H, W): rgb images
|
199 |
+
difficulty -> float: level of difficulty, 0.1 is medium, 0.3 is already pretty hard
|
200 |
+
tps -> bool: Wether to apply non-rigid deformations in images
|
201 |
+
prob_deformation -> float: probability to apply a deformation
|
202 |
+
|
203 |
+
return:
|
204 |
+
'output' -> torch.Tensor(B, C, H, W): rgb images
|
205 |
+
Tuple:
|
206 |
+
'H' -> torch.Tensor(3,3): homography matrix
|
207 |
+
'mask' -> torch.Tensor(B, H, W): mask of valid pixels after warp
|
208 |
+
(deformation only)
|
209 |
+
src, weights, A are parameters from a TPS warp (all torch.Tensors)
|
210 |
+
|
211 |
+
"""
|
212 |
+
|
213 |
+
if self.cnt % self.reload_step == 0:
|
214 |
+
self.load_imgs()
|
215 |
+
|
216 |
+
if self.geometric is False:
|
217 |
+
difficulty = 0.
|
218 |
+
|
219 |
+
with torch.no_grad():
|
220 |
+
x = (x/255.).to(self.device)
|
221 |
+
b, c, h, w = x.shape
|
222 |
+
shape = (h, w)
|
223 |
+
|
224 |
+
######## Geometric Transformations
|
225 |
+
|
226 |
+
H = torch.tensor(np.array([generateRandomHomography(shape,difficulty) for b in range(self.batch_size)]),dtype=torch.float32).to(self.device)
|
227 |
+
|
228 |
+
output = kornia.geometry.transform.warp_perspective(x,H,dsize=shape,padding_mode='zeros')
|
229 |
+
|
230 |
+
#crop % of image boundaries each side to reduce invalid pixels after warps
|
231 |
+
low_h = int(h * self.sides_crop); low_w = int(w*self.sides_crop)
|
232 |
+
high_h = int(h*(1. - self.sides_crop)); high_w= int(w * (1. - self.sides_crop))
|
233 |
+
output = output[..., low_h:high_h, low_w:high_w]
|
234 |
+
x = x[..., low_h:high_h, low_w:high_w]
|
235 |
+
|
236 |
+
#apply TPS if desired:
|
237 |
+
if TPS:
|
238 |
+
src, weights, A = None, None, None
|
239 |
+
for b in range(self.batch_size):
|
240 |
+
b_src, b_weights, b_A = generateRandomTPS(shape, (8,6), difficulty, prob = prob_deformation)
|
241 |
+
b_src, b_weights, b_A = b_src.to(self.device), b_weights.to(self.device), b_A.to(self.device)
|
242 |
+
|
243 |
+
if src is None:
|
244 |
+
src, weights, A = b_src, b_weights, b_A
|
245 |
+
else:
|
246 |
+
src = torch.cat((b_src, src))
|
247 |
+
weights = torch.cat((b_weights, weights))
|
248 |
+
A = torch.cat((b_A, A))
|
249 |
+
|
250 |
+
output = warp_image_tps(output, src, weights, A)
|
251 |
+
|
252 |
+
output = F.interpolate(output, self.out_resolution[::-1], mode = 'nearest')
|
253 |
+
x = F.interpolate(x, self.out_resolution[::-1], mode = 'nearest')
|
254 |
+
|
255 |
+
mask = ~torch.all(output == 0, dim=1, keepdim=True)
|
256 |
+
mask = mask.expand(-1,3,-1,-1)
|
257 |
+
|
258 |
+
# Make-up invalid regions with texture from the batch
|
259 |
+
rv = 1 if not TPS else 2
|
260 |
+
output_shifted = torch.roll(x, rv, 0)
|
261 |
+
output[~mask] = output_shifted[~mask]
|
262 |
+
mask = mask[:, 0, :, :]
|
263 |
+
|
264 |
+
######## Photometric Transformations
|
265 |
+
output = self.aug_list(output)
|
266 |
+
|
267 |
+
b, c, h, w = output.shape
|
268 |
+
#Correlated Gaussian Noise
|
269 |
+
if np.random.uniform() > 0.5 and self.photometric:
|
270 |
+
noise = F.interpolate(torch.randn_like(output)*(10/255), (h//2, w//2))
|
271 |
+
noise = F.interpolate(noise, (h, w), mode = 'bicubic')
|
272 |
+
output = torch.clip( output + noise, 0., 1.)
|
273 |
+
|
274 |
+
#Random shadows
|
275 |
+
if np.random.uniform() > 0.6 and self.photometric:
|
276 |
+
noise = torch.rand((b, 1, h//64, w//64), device = self.device) * 1.3
|
277 |
+
noise = torch.clip(noise, 0.25, 1.0)
|
278 |
+
noise = F.interpolate(noise, (h, w), mode = 'bicubic')
|
279 |
+
noise = noise.expand(-1, 3, -1, -1)
|
280 |
+
output *= noise
|
281 |
+
output = torch.clip( output, 0., 1.)
|
282 |
+
|
283 |
+
self.cnt+=1
|
284 |
+
|
285 |
+
if TPS:
|
286 |
+
return output, (H, src, weights, A, mask)
|
287 |
+
else:
|
288 |
+
return output, (H, mask)
|
289 |
+
|
290 |
+
def get_correspondences(self, kps_target, T):
|
291 |
+
H, H2, src, W, A = T
|
292 |
+
undeformed = self.denorm_pts_grid(
|
293 |
+
warp_points_tps(self.norm_pts_grid(kps_target),
|
294 |
+
src, W, A) ).view(-1,2)
|
295 |
+
|
296 |
+
warped_to_src = self.warp_points([email protected](H2), undeformed)
|
297 |
+
|
298 |
+
return warped_to_src
|
imcui/third_party/LiftFeat/dataset/coco_wrapper.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import pdb
|
4 |
+
|
5 |
+
debug_cnt = -1
|
6 |
+
|
7 |
+
def make_batch(augmentor, difficulty = 0.3, train = True):
|
8 |
+
Hs = []
|
9 |
+
img_list = augmentor.train if train else augmentor.test
|
10 |
+
dev = augmentor.device
|
11 |
+
batch_images = []
|
12 |
+
|
13 |
+
with torch.no_grad(): # we dont require grads in the augmentation
|
14 |
+
for b in range(augmentor.batch_size):
|
15 |
+
rdidx = np.random.randint(len(img_list))
|
16 |
+
img = torch.tensor(img_list[rdidx], dtype=torch.float32).permute(2,0,1).to(augmentor.device).unsqueeze(0)
|
17 |
+
batch_images.append(img)
|
18 |
+
|
19 |
+
batch_images = torch.cat(batch_images)
|
20 |
+
|
21 |
+
p1, H1 = augmentor(batch_images, difficulty)
|
22 |
+
p2, H2 = augmentor(batch_images, difficulty, TPS = True, prob_deformation = 0.7)
|
23 |
+
# p2, H2 = augmentor(batch_images, difficulty, TPS = False, prob_deformation = 0.7)
|
24 |
+
|
25 |
+
return p1, p2, H1, H2
|
26 |
+
|
27 |
+
|
28 |
+
def plot_corrs(p1, p2, src_pts, tgt_pts):
|
29 |
+
import matplotlib.pyplot as plt
|
30 |
+
p1 = p1.cpu()
|
31 |
+
p2 = p2.cpu()
|
32 |
+
src_pts = src_pts.cpu() ; tgt_pts = tgt_pts.cpu()
|
33 |
+
rnd_idx = np.random.randint(len(src_pts), size=200)
|
34 |
+
src_pts = src_pts[rnd_idx, ...]
|
35 |
+
tgt_pts = tgt_pts[rnd_idx, ...]
|
36 |
+
|
37 |
+
#Plot ground-truth correspondences
|
38 |
+
fig, ax = plt.subplots(1,2,figsize=(18, 12))
|
39 |
+
colors = np.random.uniform(size=(len(tgt_pts),3))
|
40 |
+
#Src image
|
41 |
+
img = p1
|
42 |
+
for i, p in enumerate(src_pts):
|
43 |
+
ax[0].scatter(p[0],p[1],color=colors[i])
|
44 |
+
ax[0].imshow(img.permute(1,2,0).numpy()[...,::-1])
|
45 |
+
|
46 |
+
#Target img
|
47 |
+
img2 = p2
|
48 |
+
for i, p in enumerate(tgt_pts):
|
49 |
+
ax[1].scatter(p[0],p[1],color=colors[i])
|
50 |
+
ax[1].imshow(img2.permute(1,2,0).numpy()[...,::-1])
|
51 |
+
plt.show()
|
52 |
+
|
53 |
+
|
54 |
+
def get_corresponding_pts(p1, p2, H, H2, augmentor, h, w, crop = None):
|
55 |
+
'''
|
56 |
+
Get dense corresponding points
|
57 |
+
'''
|
58 |
+
global debug_cnt
|
59 |
+
negatives, positives = [], []
|
60 |
+
|
61 |
+
with torch.no_grad():
|
62 |
+
#real input res of samples
|
63 |
+
rh, rw = p1.shape[-2:]
|
64 |
+
ratio = torch.tensor([rw/w, rh/h], device = p1.device)
|
65 |
+
|
66 |
+
(H, mask1) = H
|
67 |
+
(H2, src, W, A, mask2) = H2
|
68 |
+
|
69 |
+
#Generate meshgrid of target pts
|
70 |
+
x, y = torch.meshgrid(torch.arange(w, device=p1.device), torch.arange(h, device=p1.device), indexing ='xy')
|
71 |
+
mesh = torch.cat([x.unsqueeze(-1), y.unsqueeze(-1)], dim=-1)
|
72 |
+
target_pts = mesh.view(-1, 2) * ratio
|
73 |
+
|
74 |
+
#Pack all transformations into T
|
75 |
+
for batch_idx in range(len(p1)):
|
76 |
+
with torch.no_grad():
|
77 |
+
T = (H[batch_idx], H2[batch_idx],
|
78 |
+
src[batch_idx].unsqueeze(0), W[batch_idx].unsqueeze(0), A[batch_idx].unsqueeze(0))
|
79 |
+
#We now warp the target points to src image
|
80 |
+
src_pts = (augmentor.get_correspondences(target_pts, T) ) #target to src
|
81 |
+
tgt_pts = (target_pts)
|
82 |
+
|
83 |
+
#Check out of bounds points
|
84 |
+
mask_valid = (src_pts[:, 0] >=0) & (src_pts[:, 1] >=0) & \
|
85 |
+
(src_pts[:, 0] < rw) & (src_pts[:, 1] < rh)
|
86 |
+
|
87 |
+
negatives.append( tgt_pts[~mask_valid] )
|
88 |
+
tgt_pts = tgt_pts[mask_valid]
|
89 |
+
src_pts = src_pts[mask_valid]
|
90 |
+
|
91 |
+
|
92 |
+
#Remove invalid pixels
|
93 |
+
mask_valid = mask1[batch_idx, src_pts[:,1].long(), src_pts[:,0].long()] & \
|
94 |
+
mask2[batch_idx, tgt_pts[:,1].long(), tgt_pts[:,0].long()]
|
95 |
+
tgt_pts = tgt_pts[mask_valid]
|
96 |
+
src_pts = src_pts[mask_valid]
|
97 |
+
|
98 |
+
# limit nb of matches if desired
|
99 |
+
if crop is not None:
|
100 |
+
rnd_idx = torch.randperm(len(src_pts), device=src_pts.device)[:crop]
|
101 |
+
src_pts = src_pts[rnd_idx]
|
102 |
+
tgt_pts = tgt_pts[rnd_idx]
|
103 |
+
|
104 |
+
if debug_cnt >=0 and debug_cnt < 4:
|
105 |
+
plot_corrs(p1[batch_idx], p2[batch_idx], src_pts , tgt_pts )
|
106 |
+
debug_cnt +=1
|
107 |
+
|
108 |
+
src_pts = (src_pts / ratio)
|
109 |
+
tgt_pts = (tgt_pts / ratio)
|
110 |
+
|
111 |
+
#Check out of bounds points
|
112 |
+
padto = 10 if crop is not None else 2
|
113 |
+
mask_valid1 = (src_pts[:, 0] >= (0 + padto)) & (src_pts[:, 1] >= (0 + padto)) & \
|
114 |
+
(src_pts[:, 0] < (w - padto)) & (src_pts[:, 1] < (h - padto))
|
115 |
+
mask_valid2 = (tgt_pts[:, 0] >= (0 + padto)) & (tgt_pts[:, 1] >= (0 + padto)) & \
|
116 |
+
(tgt_pts[:, 0] < (w - padto)) & (tgt_pts[:, 1] < (h - padto))
|
117 |
+
mask_valid = mask_valid1 & mask_valid2
|
118 |
+
tgt_pts = tgt_pts[mask_valid]
|
119 |
+
src_pts = src_pts[mask_valid]
|
120 |
+
|
121 |
+
#Remove repeated correspondences
|
122 |
+
lut_mat = torch.ones((h, w, 4), device = src_pts.device, dtype = src_pts.dtype) * -1
|
123 |
+
# src_pts_np = src_pts.cpu().numpy()
|
124 |
+
# tgt_pts_np = tgt_pts.cpu().numpy()
|
125 |
+
try:
|
126 |
+
lut_mat[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
|
127 |
+
mask_valid = torch.all(lut_mat >= 0, dim=-1)
|
128 |
+
points = lut_mat[mask_valid]
|
129 |
+
positives.append(points)
|
130 |
+
except:
|
131 |
+
pdb.set_trace()
|
132 |
+
print('..')
|
133 |
+
|
134 |
+
return negatives, positives
|
135 |
+
|
136 |
+
|
137 |
+
def crop_patches(tensor, coords, size = 7):
|
138 |
+
'''
|
139 |
+
Crop [size x size] patches around 2D coordinates from a tensor.
|
140 |
+
'''
|
141 |
+
B, C, H, W = tensor.shape
|
142 |
+
|
143 |
+
x, y = coords[:, 0], coords[:, 1]
|
144 |
+
y = y.view(-1, 1, 1)
|
145 |
+
x = x.view(-1, 1, 1)
|
146 |
+
halfsize = size // 2
|
147 |
+
# Create meshgrid for indexing
|
148 |
+
x_offset, y_offset = torch.meshgrid(torch.arange(-halfsize, halfsize+1), torch.arange(-halfsize, halfsize+1), indexing='xy')
|
149 |
+
y_offset = y_offset.to(tensor.device)
|
150 |
+
x_offset = x_offset.to(tensor.device)
|
151 |
+
|
152 |
+
# Compute indices around each coordinate
|
153 |
+
y_indices = (y + y_offset.view(1, size, size)).squeeze(0) + halfsize
|
154 |
+
x_indices = (x + x_offset.view(1, size, size)).squeeze(0) + halfsize
|
155 |
+
|
156 |
+
# Handle out-of-boundary indices with padding
|
157 |
+
tensor_padded = torch.nn.functional.pad(tensor, (halfsize, halfsize, halfsize, halfsize), mode='constant')
|
158 |
+
|
159 |
+
# Index tensor to get patches
|
160 |
+
patches = tensor_padded[:, :, y_indices, x_indices] # [B, C, N, H, W]
|
161 |
+
return patches
|
162 |
+
|
163 |
+
def subpix_softmax2d(heatmaps, temp = 0.25):
|
164 |
+
N, H, W = heatmaps.shape
|
165 |
+
heatmaps = torch.softmax(temp * heatmaps.view(-1, H*W), -1).view(-1, H, W)
|
166 |
+
x, y = torch.meshgrid(torch.arange(W, device = heatmaps.device ), torch.arange(H, device = heatmaps.device ), indexing = 'xy')
|
167 |
+
x = x - (W//2)
|
168 |
+
y = y - (H//2)
|
169 |
+
#pdb.set_trace()
|
170 |
+
coords_x = (x[None, ...] * heatmaps)
|
171 |
+
coords_y = (y[None, ...] * heatmaps)
|
172 |
+
coords = torch.cat([coords_x[..., None], coords_y[..., None]], -1).view(N, H*W, 2)
|
173 |
+
coords = coords.sum(1)
|
174 |
+
|
175 |
+
return coords
|
imcui/third_party/LiftFeat/dataset/dataset_utils.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
|
4 |
+
MegaDepth data handling was adapted from
|
5 |
+
LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
import cv2
|
10 |
+
import numpy as np
|
11 |
+
import h5py
|
12 |
+
import torch
|
13 |
+
from numpy.linalg import inv
|
14 |
+
|
15 |
+
|
16 |
+
try:
|
17 |
+
# for internel use only
|
18 |
+
from .client import MEGADEPTH_CLIENT, SCANNET_CLIENT
|
19 |
+
except Exception:
|
20 |
+
MEGADEPTH_CLIENT = SCANNET_CLIENT = None
|
21 |
+
|
22 |
+
# --- DATA IO ---
|
23 |
+
|
24 |
+
def load_array_from_s3(
|
25 |
+
path, client, cv_type,
|
26 |
+
use_h5py=False,
|
27 |
+
):
|
28 |
+
byte_str = client.Get(path)
|
29 |
+
try:
|
30 |
+
if not use_h5py:
|
31 |
+
raw_array = np.fromstring(byte_str, np.uint8)
|
32 |
+
data = cv2.imdecode(raw_array, cv_type)
|
33 |
+
else:
|
34 |
+
f = io.BytesIO(byte_str)
|
35 |
+
data = np.array(h5py.File(f, 'r')['/depth'])
|
36 |
+
except Exception as ex:
|
37 |
+
print(f"==> Data loading failure: {path}")
|
38 |
+
raise ex
|
39 |
+
|
40 |
+
assert data is not None
|
41 |
+
return data
|
42 |
+
|
43 |
+
|
44 |
+
def imread_gray(path, augment_fn=None, client=SCANNET_CLIENT):
|
45 |
+
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None \
|
46 |
+
else cv2.IMREAD_COLOR
|
47 |
+
if str(path).startswith('s3://'):
|
48 |
+
image = load_array_from_s3(str(path), client, cv_type)
|
49 |
+
else:
|
50 |
+
image = cv2.imread(str(path), 1)
|
51 |
+
|
52 |
+
if augment_fn is not None:
|
53 |
+
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
|
54 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
55 |
+
image = augment_fn(image)
|
56 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
57 |
+
return image # (h, w)
|
58 |
+
|
59 |
+
|
60 |
+
def get_resized_wh(w, h, resize=None):
|
61 |
+
if resize is not None: # resize the longer edge
|
62 |
+
scale = resize / max(h, w)
|
63 |
+
w_new, h_new = int(round(w*scale)), int(round(h*scale))
|
64 |
+
else:
|
65 |
+
w_new, h_new = w, h
|
66 |
+
return w_new, h_new
|
67 |
+
|
68 |
+
|
69 |
+
def get_divisible_wh(w, h, df=None):
|
70 |
+
if df is not None:
|
71 |
+
w_new, h_new = map(lambda x: int(x // df * df), [w, h])
|
72 |
+
else:
|
73 |
+
w_new, h_new = w, h
|
74 |
+
return w_new, h_new
|
75 |
+
|
76 |
+
|
77 |
+
def pad_bottom_right(inp, pad_size, ret_mask=False):
|
78 |
+
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
|
79 |
+
mask = None
|
80 |
+
if inp.ndim == 2:
|
81 |
+
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
|
82 |
+
padded[:inp.shape[0], :inp.shape[1]] = inp
|
83 |
+
if ret_mask:
|
84 |
+
mask = np.zeros((pad_size, pad_size), dtype=bool)
|
85 |
+
mask[:inp.shape[0], :inp.shape[1]] = True
|
86 |
+
elif inp.ndim == 3:
|
87 |
+
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
|
88 |
+
padded[:, :inp.shape[1], :inp.shape[2]] = inp
|
89 |
+
if ret_mask:
|
90 |
+
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
|
91 |
+
mask[:, :inp.shape[1], :inp.shape[2]] = True
|
92 |
+
else:
|
93 |
+
raise NotImplementedError()
|
94 |
+
return padded, mask
|
95 |
+
|
96 |
+
|
97 |
+
# --- MEGADEPTH ---
|
98 |
+
|
99 |
+
def fix_path_from_d2net(path):
|
100 |
+
if not path:
|
101 |
+
return None
|
102 |
+
|
103 |
+
path = path.replace('Undistorted_SfM/', '')
|
104 |
+
path = path.replace('images', 'dense0/imgs')
|
105 |
+
path = path.replace('phoenix/S6/zl548/MegaDepth_v1/', '')
|
106 |
+
|
107 |
+
return path
|
108 |
+
|
109 |
+
def read_megadepth_gray(path, resize=None, df=None, padding=False, augment_fn=None):
|
110 |
+
"""
|
111 |
+
Args:
|
112 |
+
resize (int, optional): the longer edge of resized images. None for no resize.
|
113 |
+
padding (bool): If set to 'True', zero-pad resized images to squared size.
|
114 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects
|
115 |
+
Returns:
|
116 |
+
image (torch.tensor): (1, h, w)
|
117 |
+
mask (torch.tensor): (h, w)
|
118 |
+
scale (torch.tensor): [w/w_new, h/h_new]
|
119 |
+
"""
|
120 |
+
# read image
|
121 |
+
image = imread_gray(path, augment_fn, client=MEGADEPTH_CLIENT)
|
122 |
+
|
123 |
+
# resize image
|
124 |
+
w, h = image.shape[1], image.shape[0]
|
125 |
+
|
126 |
+
if resize is not None:
|
127 |
+
if len(resize) == 2:
|
128 |
+
w_new, h_new = resize
|
129 |
+
else:
|
130 |
+
resize = resize[0]
|
131 |
+
w_new, h_new = get_resized_wh(w, h, resize)
|
132 |
+
w_new, h_new = get_divisible_wh(w_new, h_new, df)
|
133 |
+
|
134 |
+
|
135 |
+
image = cv2.resize(image, (w_new, h_new))
|
136 |
+
scale = torch.tensor([w/w_new, h/h_new], dtype=torch.float)
|
137 |
+
|
138 |
+
if padding: # padding
|
139 |
+
pad_to = max(h_new, w_new)
|
140 |
+
image, mask = pad_bottom_right(image, pad_to, ret_mask=True)
|
141 |
+
else:
|
142 |
+
mask = None
|
143 |
+
else:
|
144 |
+
scale=torch.tensor([1.0,1.0],dtype=torch.float)
|
145 |
+
|
146 |
+
if padding:
|
147 |
+
pad_to=max(w,h)
|
148 |
+
image,mask=pad_bottom_right(image,pad_to,ret_mask=True)
|
149 |
+
else:
|
150 |
+
mask=None
|
151 |
+
|
152 |
+
#image = torch.from_numpy(image).float()[None] / 255 # (h, w) -> (1, h, w) and normalized
|
153 |
+
image_t = torch.from_numpy(image).float().permute(2,0,1) / 255 # (h, w) -> (1, h, w) and normalized
|
154 |
+
mask = torch.from_numpy(mask) if mask is not None else None
|
155 |
+
|
156 |
+
return image, image_t, mask, scale
|
157 |
+
|
158 |
+
|
159 |
+
def read_megadepth_depth(path, pad_to=None):
|
160 |
+
|
161 |
+
if str(path).startswith('s3://'):
|
162 |
+
depth = load_array_from_s3(path, MEGADEPTH_CLIENT, None, use_h5py=True)
|
163 |
+
else:
|
164 |
+
depth = np.array(h5py.File(path, 'r')['depth'])
|
165 |
+
if pad_to is not None:
|
166 |
+
depth, _ = pad_bottom_right(depth, pad_to, ret_mask=False)
|
167 |
+
depth = torch.from_numpy(depth).float() # (h, w)
|
168 |
+
return depth
|
169 |
+
|
170 |
+
|
171 |
+
def imread_bgr(path, augment_fn=None, client=SCANNET_CLIENT):
|
172 |
+
cv_type = cv2.IMREAD_GRAYSCALE if augment_fn is None else cv2.IMREAD_COLOR
|
173 |
+
if str(path).startswith('s3://'):
|
174 |
+
image = load_array_from_s3(str(path), client, cv_type)
|
175 |
+
else:
|
176 |
+
image = cv2.imread(str(path), 1)
|
177 |
+
|
178 |
+
if augment_fn is not None:
|
179 |
+
image = cv2.imread(str(path), cv2.IMREAD_COLOR)
|
180 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
181 |
+
image = augment_fn(image)
|
182 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
|
183 |
+
return image # (h, w)
|
imcui/third_party/LiftFeat/dataset/megadepth.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
|
4 |
+
MegaDepth data handling was adapted from
|
5 |
+
LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os.path as osp
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import glob
|
14 |
+
import numpy.random as rnd
|
15 |
+
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
|
19 |
+
from dataset.dataset_utils import read_megadepth_gray, read_megadepth_depth, fix_path_from_d2net
|
20 |
+
|
21 |
+
import pdb, tqdm, os
|
22 |
+
|
23 |
+
|
24 |
+
class MegaDepthDataset(Dataset):
|
25 |
+
def __init__(self,
|
26 |
+
root_dir,
|
27 |
+
npz_path,
|
28 |
+
mode='train',
|
29 |
+
min_overlap_score = 0.3, #0.3,
|
30 |
+
max_overlap_score = 1.0, #1,
|
31 |
+
load_depth = True,
|
32 |
+
img_resize = (800,608), #or None
|
33 |
+
df=32,
|
34 |
+
img_padding=False,
|
35 |
+
depth_padding=True,
|
36 |
+
augment_fn=None,
|
37 |
+
**kwargs):
|
38 |
+
"""
|
39 |
+
Manage one scene(npz_path) of MegaDepth dataset.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
root_dir (str): megadepth root directory that has `phoenix`.
|
43 |
+
npz_path (str): {scene_id}.npz path. This contains image pair information of a scene.
|
44 |
+
mode (str): options are ['train', 'val', 'test']
|
45 |
+
min_overlap_score (float): how much a pair should have in common. In range of [0, 1]. Set to 0 when testing.
|
46 |
+
img_resize (int, optional): the longer edge of resized images. None for no resize. 640 is recommended.
|
47 |
+
This is useful during training with batches and testing with memory intensive algorithms.
|
48 |
+
df (int, optional): image size division factor. NOTE: this will change the final image size after img_resize.
|
49 |
+
img_padding (bool): If set to 'True', zero-pad the image to squared size. This is useful during training.
|
50 |
+
depth_padding (bool): If set to 'True', zero-pad depthmap to (2000, 2000). This is useful during training.
|
51 |
+
augment_fn (callable, optional): augments images with pre-defined visual effects.
|
52 |
+
"""
|
53 |
+
super().__init__()
|
54 |
+
self.root_dir = root_dir
|
55 |
+
self.mode = mode
|
56 |
+
self.scene_id = npz_path.split('.')[0]
|
57 |
+
self.load_depth = load_depth
|
58 |
+
# prepare scene_info and pair_info
|
59 |
+
if mode == 'test' and min_overlap_score != 0:
|
60 |
+
min_overlap_score = 0
|
61 |
+
self.scene_info = np.load(npz_path, allow_pickle=True)
|
62 |
+
self.pair_infos = self.scene_info['pair_infos'].copy()
|
63 |
+
del self.scene_info['pair_infos']
|
64 |
+
self.pair_infos = [pair_info for pair_info in self.pair_infos if pair_info[1] > min_overlap_score and pair_info[1] < max_overlap_score]
|
65 |
+
|
66 |
+
# parameters for image resizing, padding and depthmap padding
|
67 |
+
if mode == 'train':
|
68 |
+
assert img_resize is not None #and img_padding and depth_padding
|
69 |
+
|
70 |
+
self.img_resize = img_resize
|
71 |
+
self.df = df
|
72 |
+
self.img_padding = img_padding
|
73 |
+
self.depth_max_size = 2000 if depth_padding else None # the upperbound of depthmaps size in megadepth.
|
74 |
+
|
75 |
+
# for training LoFTR
|
76 |
+
self.augment_fn = augment_fn if mode == 'train' else None
|
77 |
+
self.coarse_scale = getattr(kwargs, 'coarse_scale', 0.125)
|
78 |
+
#pdb.set_trace()
|
79 |
+
for idx in range(len(self.scene_info['image_paths'])):
|
80 |
+
self.scene_info['image_paths'][idx] = fix_path_from_d2net(self.scene_info['image_paths'][idx])
|
81 |
+
|
82 |
+
for idx in range(len(self.scene_info['depth_paths'])):
|
83 |
+
self.scene_info['depth_paths'][idx] = fix_path_from_d2net(self.scene_info['depth_paths'][idx])
|
84 |
+
|
85 |
+
|
86 |
+
def __len__(self):
|
87 |
+
return len(self.pair_infos)
|
88 |
+
|
89 |
+
def __getitem__(self, idx):
|
90 |
+
(idx0, idx1), overlap_score, central_matches = self.pair_infos[idx % len(self)]
|
91 |
+
|
92 |
+
# read grayscale image and mask. (1, h, w) and (h, w)
|
93 |
+
img_name0 = osp.join(self.root_dir, self.scene_info['image_paths'][idx0])
|
94 |
+
img_name1 = osp.join(self.root_dir, self.scene_info['image_paths'][idx1])
|
95 |
+
|
96 |
+
# TODO: Support augmentation & handle seeds for each worker correctly.
|
97 |
+
image0, image0_t, mask0, scale0 = read_megadepth_gray(img_name0, self.img_resize, self.df, self.img_padding, None)
|
98 |
+
# np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
99 |
+
image1, image1_t, mask1, scale1 = read_megadepth_gray(img_name1, self.img_resize, self.df, self.img_padding, None)
|
100 |
+
# np.random.choice([self.augment_fn, None], p=[0.5, 0.5]))
|
101 |
+
|
102 |
+
if self.load_depth:
|
103 |
+
# read depth. shape: (h, w)
|
104 |
+
if self.mode in ['train', 'val']:
|
105 |
+
depth0 = read_megadepth_depth(
|
106 |
+
osp.join(self.root_dir, self.scene_info['depth_paths'][idx0]), pad_to=self.depth_max_size)
|
107 |
+
depth1 = read_megadepth_depth(
|
108 |
+
osp.join(self.root_dir, self.scene_info['depth_paths'][idx1]), pad_to=self.depth_max_size)
|
109 |
+
else:
|
110 |
+
depth0 = depth1 = torch.tensor([])
|
111 |
+
|
112 |
+
# read intrinsics of original size
|
113 |
+
K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
|
114 |
+
K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
|
115 |
+
|
116 |
+
# read and compute relative poses
|
117 |
+
T0 = self.scene_info['poses'][idx0]
|
118 |
+
T1 = self.scene_info['poses'][idx1]
|
119 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
|
120 |
+
T_1to0 = T_0to1.inverse()
|
121 |
+
|
122 |
+
data = {
|
123 |
+
'image0': image0_t, # (1, h, w)
|
124 |
+
'image0_np': image0,
|
125 |
+
'depth0': depth0, # (h, w)
|
126 |
+
'image1': image1_t,
|
127 |
+
'image1_np': image1,
|
128 |
+
'depth1': depth1,
|
129 |
+
'T_0to1': T_0to1, # (4, 4)
|
130 |
+
'T_1to0': T_1to0,
|
131 |
+
'K0': K_0, # (3, 3)
|
132 |
+
'K1': K_1,
|
133 |
+
'scale0': scale0, # [scale_w, scale_h]
|
134 |
+
'scale1': scale1,
|
135 |
+
'dataset_name': 'MegaDepth',
|
136 |
+
'scene_id': self.scene_id,
|
137 |
+
'pair_id': idx,
|
138 |
+
'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
|
139 |
+
}
|
140 |
+
|
141 |
+
# for LoFTR training
|
142 |
+
if mask0 is not None: # img_padding is True
|
143 |
+
if self.coarse_scale:
|
144 |
+
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
|
145 |
+
scale_factor=self.coarse_scale,
|
146 |
+
mode='nearest',
|
147 |
+
recompute_scale_factor=False)[0].bool()
|
148 |
+
data.update({'mask0': ts_mask_0, 'mask1': ts_mask_1})
|
149 |
+
|
150 |
+
else:
|
151 |
+
|
152 |
+
# read intrinsics of original size
|
153 |
+
K_0 = torch.tensor(self.scene_info['intrinsics'][idx0].copy(), dtype=torch.float).reshape(3, 3)
|
154 |
+
K_1 = torch.tensor(self.scene_info['intrinsics'][idx1].copy(), dtype=torch.float).reshape(3, 3)
|
155 |
+
|
156 |
+
# read and compute relative poses
|
157 |
+
T0 = self.scene_info['poses'][idx0]
|
158 |
+
T1 = self.scene_info['poses'][idx1]
|
159 |
+
T_0to1 = torch.tensor(np.matmul(T1, np.linalg.inv(T0)), dtype=torch.float)[:4, :4] # (4, 4)
|
160 |
+
T_1to0 = T_0to1.inverse()
|
161 |
+
|
162 |
+
data = {
|
163 |
+
'image0': image0, # (1, h, w)
|
164 |
+
'image1': image1,
|
165 |
+
'T_0to1': T_0to1, # (4, 4)
|
166 |
+
'T_1to0': T_1to0,
|
167 |
+
'K0': K_0, # (3, 3)
|
168 |
+
'K1': K_1,
|
169 |
+
'scale0': scale0, # [scale_w, scale_h]
|
170 |
+
'scale1': scale1,
|
171 |
+
'dataset_name': 'MegaDepth',
|
172 |
+
'scene_id': self.scene_id,
|
173 |
+
'pair_id': idx,
|
174 |
+
'pair_names': (self.scene_info['image_paths'][idx0], self.scene_info['image_paths'][idx1]),
|
175 |
+
}
|
176 |
+
|
177 |
+
return data
|
imcui/third_party/LiftFeat/dataset/megadepth_wrapper.py
ADDED
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
|
4 |
+
MegaDepth data handling was adapted from
|
5 |
+
LoFTR official code: https://github.com/zju3dv/LoFTR/blob/master/src/datasets/megadepth.py
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from kornia.utils import create_meshgrid
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import pdb
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
@torch.no_grad()
|
15 |
+
def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1):
|
16 |
+
""" Warp kpts0 from I0 to I1 with depth, K and Rt
|
17 |
+
Also check covisibility and depth consistency.
|
18 |
+
Depth is consistent if relative error < 0.2 (hard-coded).
|
19 |
+
|
20 |
+
Args:
|
21 |
+
kpts0 (torch.Tensor): [N, L, 2] - <x, y>,
|
22 |
+
depth0 (torch.Tensor): [N, H, W],
|
23 |
+
depth1 (torch.Tensor): [N, H, W],
|
24 |
+
T_0to1 (torch.Tensor): [N, 3, 4],
|
25 |
+
K0 (torch.Tensor): [N, 3, 3],
|
26 |
+
K1 (torch.Tensor): [N, 3, 3],
|
27 |
+
Returns:
|
28 |
+
calculable_mask (torch.Tensor): [N, L]
|
29 |
+
warped_keypoints0 (torch.Tensor): [N, L, 2] <x0_hat, y1_hat>
|
30 |
+
"""
|
31 |
+
kpts0_long = kpts0.round().long().clip(0, 2000-1)
|
32 |
+
|
33 |
+
depth0[:, 0, :] = 0 ; depth1[:, 0, :] = 0
|
34 |
+
depth0[:, :, 0] = 0 ; depth1[:, :, 0] = 0
|
35 |
+
|
36 |
+
# Sample depth, get calculable_mask on depth != 0
|
37 |
+
kpts0_depth = torch.stack(
|
38 |
+
[depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0
|
39 |
+
) # (N, L)
|
40 |
+
nonzero_mask = kpts0_depth > 0
|
41 |
+
|
42 |
+
# Draw cross marks on the image for each keypoint
|
43 |
+
# for b in range(len(kpts0)):
|
44 |
+
# fig, ax = plt.subplots(1,2)
|
45 |
+
# depth_np = depth0.numpy()[b]
|
46 |
+
# depth_np_plot = depth_np.copy()
|
47 |
+
# for x, y in kpts0_long[b, nonzero_mask[b], :].numpy():
|
48 |
+
# cv2.drawMarker(depth_np_plot, (x, y), (255), cv2.MARKER_CROSS, markerSize=10, thickness=2)
|
49 |
+
# ax[0].imshow(depth_np)
|
50 |
+
# ax[1].imshow(depth_np_plot)
|
51 |
+
|
52 |
+
# Unproject
|
53 |
+
kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3)
|
54 |
+
kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L)
|
55 |
+
|
56 |
+
# Rigid Transform
|
57 |
+
w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L)
|
58 |
+
w_kpts0_depth_computed = w_kpts0_cam[:, 2, :]
|
59 |
+
|
60 |
+
# Project
|
61 |
+
w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3)
|
62 |
+
w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-5) # (N, L, 2), +1e-4 to avoid zero depth
|
63 |
+
|
64 |
+
# Covisible Check
|
65 |
+
# h, w = depth1.shape[1:3]
|
66 |
+
# covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \
|
67 |
+
# (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1)
|
68 |
+
# w_kpts0_long = w_kpts0.long()
|
69 |
+
# w_kpts0_long[~covisible_mask, :] = 0
|
70 |
+
|
71 |
+
# w_kpts0_depth = torch.stack(
|
72 |
+
# [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0
|
73 |
+
# ) # (N, L)
|
74 |
+
# consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2
|
75 |
+
|
76 |
+
|
77 |
+
valid_mask = nonzero_mask #* consistent_mask* covisible_mask
|
78 |
+
|
79 |
+
return valid_mask, w_kpts0
|
80 |
+
|
81 |
+
|
82 |
+
@torch.no_grad()
|
83 |
+
def spvs_coarse(data, scale = 8):
|
84 |
+
"""
|
85 |
+
Supervise corresp with dense depth & camera poses
|
86 |
+
"""
|
87 |
+
|
88 |
+
# 1. misc
|
89 |
+
device = data['image0'].device
|
90 |
+
N, _, H0, W0 = data['image0'].shape
|
91 |
+
_, _, H1, W1 = data['image1'].shape
|
92 |
+
#scale = 8
|
93 |
+
scale0 = scale * data['scale0'][:, None] if 'scale0' in data else scale
|
94 |
+
scale1 = scale * data['scale1'][:, None] if 'scale1' in data else scale
|
95 |
+
h0, w0, h1, w1 = map(lambda x: x // scale, [H0, W0, H1, W1])
|
96 |
+
|
97 |
+
# 2. warp grids
|
98 |
+
# create kpts in meshgrid and resize them to image resolution
|
99 |
+
grid_pt1_c = create_meshgrid(h1, w1, False, device).reshape(1, h1*w1, 2).repeat(N, 1, 1) # [N, hw, 2]
|
100 |
+
grid_pt1_i = scale1 * grid_pt1_c
|
101 |
+
|
102 |
+
# warp kpts bi-directionally and check reproj error
|
103 |
+
nonzero_m1, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
|
104 |
+
nonzero_m2, w_pt1_og = warp_kpts( w_pt1_i, data['depth0'], data['depth1'], data['T_0to1'], data['K0'], data['K1'])
|
105 |
+
|
106 |
+
|
107 |
+
dist = torch.linalg.norm( grid_pt1_i - w_pt1_og, dim=-1)
|
108 |
+
mask_mutual = (dist < 1.5) & nonzero_m1 & nonzero_m2
|
109 |
+
|
110 |
+
#_, w_pt1_i = warp_kpts(grid_pt1_i, data['depth1'], data['depth0'], data['T_1to0'], data['K1'], data['K0'])
|
111 |
+
batched_corrs = [ torch.cat([w_pt1_i[i, mask_mutual[i]] / data['scale0'][i],
|
112 |
+
grid_pt1_i[i, mask_mutual[i]] / data['scale1'][i]],dim=-1) for i in range(len(mask_mutual))]
|
113 |
+
|
114 |
+
|
115 |
+
#Remove repeated correspondences - this is important for network convergence
|
116 |
+
corrs = []
|
117 |
+
for pts in batched_corrs:
|
118 |
+
lut_mat12 = torch.ones((h1, w1, 4), device = device, dtype = torch.float32) * -1
|
119 |
+
lut_mat21 = torch.clone(lut_mat12)
|
120 |
+
src_pts = pts[:, :2] / scale
|
121 |
+
tgt_pts = pts[:, 2:] / scale
|
122 |
+
try:
|
123 |
+
lut_mat12[src_pts[:,1].long(), src_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
|
124 |
+
mask_valid12 = torch.all(lut_mat12 >= 0, dim=-1)
|
125 |
+
points = lut_mat12[mask_valid12]
|
126 |
+
|
127 |
+
#Target-src check
|
128 |
+
src_pts, tgt_pts = points[:, :2], points[:, 2:]
|
129 |
+
lut_mat21[tgt_pts[:,1].long(), tgt_pts[:,0].long()] = torch.cat([src_pts, tgt_pts], dim=1)
|
130 |
+
mask_valid21 = torch.all(lut_mat21 >= 0, dim=-1)
|
131 |
+
points = lut_mat21[mask_valid21]
|
132 |
+
|
133 |
+
corrs.append(points)
|
134 |
+
except:
|
135 |
+
pdb.set_trace()
|
136 |
+
print('..')
|
137 |
+
|
138 |
+
#Plot for debug purposes
|
139 |
+
# for i in range(len(corrs)):
|
140 |
+
# plot_corrs(data['image0'][i], data['image1'][i], corrs[i][:, :2]*8, corrs[i][:, 2:]*8)
|
141 |
+
|
142 |
+
return corrs
|
143 |
+
|
144 |
+
@torch.no_grad()
|
145 |
+
def get_correspondences(pts2, data, idx):
|
146 |
+
device = data['image0'].device
|
147 |
+
N, _, H0, W0 = data['image0'].shape
|
148 |
+
_, _, H1, W1 = data['image1'].shape
|
149 |
+
|
150 |
+
pts2 = pts2[None, ...]
|
151 |
+
|
152 |
+
scale0 = data['scale0'][idx, None][None, ...] if 'scale0' in data else 1
|
153 |
+
scale1 = data['scale1'][idx, None][None, ...] if 'scale1' in data else 1
|
154 |
+
|
155 |
+
pts2 = scale1 * pts2 * 8
|
156 |
+
|
157 |
+
# warp kpts bi-directionally and check reproj error
|
158 |
+
nonzero_m1, pts1 = warp_kpts(pts2, data['depth1'][idx][None, ...], data['depth0'][idx][None, ...], data['T_1to0'][idx][None, ...],
|
159 |
+
data['K1'][idx][None, ...], data['K0'][idx][None, ...])
|
160 |
+
|
161 |
+
corrs = torch.cat([pts1[0, :] / data['scale0'][idx],
|
162 |
+
pts2[0, :] / data['scale1'][idx]],dim=-1)
|
163 |
+
|
164 |
+
#plot_corrs(data['image0'][idx], data['image1'][idx], corrs[:, :2], corrs[:, 2:])
|
165 |
+
|
166 |
+
return corrs
|
167 |
+
|
imcui/third_party/LiftFeat/demo.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
os.environ['CUDA_VISIBLE_DEVICES']='1'
|
9 |
+
|
10 |
+
from models.liftfeat_wrapper import LiftFeat,MODEL_PATH
|
11 |
+
|
12 |
+
import argparse
|
13 |
+
|
14 |
+
parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
|
15 |
+
parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
|
16 |
+
parser.add_argument('--img1',type=str,default='./assert/ref.jpg',help='reference image path')
|
17 |
+
parser.add_argument('--img2',type=str,default='./assert/query.jpg',help='query image path')
|
18 |
+
parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
|
19 |
+
args=parser.parse_args()
|
20 |
+
|
21 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
22 |
+
|
23 |
+
|
24 |
+
def warp_corners_and_draw_matches(ref_points, dst_points, img1, img2):
|
25 |
+
# Calculate the Homography matrix
|
26 |
+
H, mask = cv2.findHomography(ref_points, dst_points, cv2.USAC_MAGSAC, 3.5, maxIters=1_000, confidence=0.999)
|
27 |
+
mask = mask.flatten()
|
28 |
+
|
29 |
+
# Get corners of the first image (image1)
|
30 |
+
h, w = img1.shape[:2]
|
31 |
+
corners_img1 = np.array([[0, 0], [w-1, 0], [w-1, h-1], [0, h-1]], dtype=np.float32).reshape(-1, 1, 2)
|
32 |
+
|
33 |
+
# Warp corners to the second image (image2) space
|
34 |
+
warped_corners = cv2.perspectiveTransform(corners_img1, H)
|
35 |
+
|
36 |
+
# Draw the warped corners in image2
|
37 |
+
img2_with_corners = img2.copy()
|
38 |
+
|
39 |
+
# Prepare keypoints and matches for drawMatches function
|
40 |
+
keypoints1 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in ref_points]
|
41 |
+
keypoints2 = [cv2.KeyPoint(float(p[0]), float(p[1]), 5) for p in dst_points]
|
42 |
+
matches = [cv2.DMatch(i,i,0) for i in range(len(mask)) if mask[i]]
|
43 |
+
|
44 |
+
# Draw inlier matches
|
45 |
+
img_matches = cv2.drawMatches(img1, keypoints1, img2_with_corners, keypoints2, matches, None,
|
46 |
+
matchColor=(0, 255, 0), flags=2)
|
47 |
+
|
48 |
+
return img_matches
|
49 |
+
|
50 |
+
|
51 |
+
if __name__=="__main__":
|
52 |
+
liftfeat=LiftFeat(weight=MODEL_PATH,detect_threshold=0.05)
|
53 |
+
|
54 |
+
img1=cv2.imread(args.img1)
|
55 |
+
img2=cv2.imread(args.img2)
|
56 |
+
|
57 |
+
# import pdb;pdb.set_trace()
|
58 |
+
mkpts1,mkpts2=liftfeat.match_liftfeat(img1,img2)
|
59 |
+
canvas=warp_corners_and_draw_matches(mkpts1,mkpts2,img1,img2)
|
60 |
+
|
61 |
+
import matplotlib.pyplot as plt
|
62 |
+
plt.figure(figsize=[12,12])
|
63 |
+
plt.imshow(canvas[...,::-1])
|
64 |
+
|
65 |
+
plt.savefig(os.path.join(os.path.dirname(__file__),'match.jpg'), dpi=300, bbox_inches='tight')
|
66 |
+
|
67 |
+
plt.show()
|
68 |
+
|
imcui/third_party/LiftFeat/evaluation/HPatch_evaluation.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import os
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
import numpy as np
|
6 |
+
import sys
|
7 |
+
import poselib
|
8 |
+
|
9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
|
10 |
+
|
11 |
+
import argparse
|
12 |
+
import datetime
|
13 |
+
|
14 |
+
parser=argparse.ArgumentParser(description='HPatch dataset evaluation script')
|
15 |
+
parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
|
16 |
+
parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
|
17 |
+
args=parser.parse_args()
|
18 |
+
|
19 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
20 |
+
|
21 |
+
use_cuda = torch.cuda.is_available()
|
22 |
+
device = torch.device("cuda" if use_cuda else "cpu")
|
23 |
+
|
24 |
+
top_k = None
|
25 |
+
n_i = 52
|
26 |
+
n_v = 56
|
27 |
+
|
28 |
+
DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/HPatch')
|
29 |
+
|
30 |
+
from evaluation.eval_utils import *
|
31 |
+
from models.liftfeat_wrapper import LiftFeat
|
32 |
+
|
33 |
+
|
34 |
+
poselib_config = {"ransac_th": 3.0, "options": {}}
|
35 |
+
|
36 |
+
class PoseLibHomographyEstimator:
|
37 |
+
def __init__(self, conf):
|
38 |
+
self.conf = conf
|
39 |
+
|
40 |
+
def estimate(self, mkpts0,mkpts1):
|
41 |
+
M, info = poselib.estimate_homography(
|
42 |
+
mkpts0,
|
43 |
+
mkpts1,
|
44 |
+
{
|
45 |
+
"max_reproj_error": self.conf["ransac_th"],
|
46 |
+
**self.conf["options"],
|
47 |
+
},
|
48 |
+
)
|
49 |
+
success = M is not None
|
50 |
+
if not success:
|
51 |
+
M = np.eye(3,dtype=np.float32)
|
52 |
+
inl = np.zeros(mkpts0.shape[0],dtype=np.bool_)
|
53 |
+
else:
|
54 |
+
inl = info["inliers"]
|
55 |
+
|
56 |
+
estimation = {
|
57 |
+
"success": success,
|
58 |
+
"M_0to1": M,
|
59 |
+
"inliers": inl,
|
60 |
+
}
|
61 |
+
|
62 |
+
return estimation
|
63 |
+
|
64 |
+
|
65 |
+
estimator=PoseLibHomographyEstimator(poselib_config)
|
66 |
+
|
67 |
+
|
68 |
+
def poselib_homography_estimate(mkpts0,mkpts1):
|
69 |
+
data=estimator.estimate(mkpts0,mkpts1)
|
70 |
+
return data
|
71 |
+
|
72 |
+
|
73 |
+
def generate_standard_image(img,target_size=(1920,1080)):
|
74 |
+
sh,sw=img.shape[0],img.shape[1]
|
75 |
+
rh,rw=float(target_size[1])/float(sh),float(target_size[0])/float(sw)
|
76 |
+
ratio=min(rh,rw)
|
77 |
+
nh,nw=int(ratio*sh),int(ratio*sw)
|
78 |
+
ph,pw=target_size[1]-nh,target_size[0]-nw
|
79 |
+
nimg=cv2.resize(img,(nw,nh))
|
80 |
+
nimg=cv2.copyMakeBorder(nimg,0,ph,0,pw,cv2.BORDER_CONSTANT,value=(0,0,0))
|
81 |
+
|
82 |
+
return nimg,ratio,ph,pw
|
83 |
+
|
84 |
+
|
85 |
+
def benchmark_features(match_fn):
|
86 |
+
lim = [1, 9]
|
87 |
+
rng = np.arange(lim[0], lim[1] + 1)
|
88 |
+
|
89 |
+
seq_names = sorted(os.listdir(DATASET_ROOT))
|
90 |
+
|
91 |
+
n_feats = []
|
92 |
+
n_matches = []
|
93 |
+
seq_type = []
|
94 |
+
i_err = {thr: 0 for thr in rng}
|
95 |
+
v_err = {thr: 0 for thr in rng}
|
96 |
+
|
97 |
+
i_err_homo = {thr: 0 for thr in rng}
|
98 |
+
v_err_homo = {thr: 0 for thr in rng}
|
99 |
+
|
100 |
+
for seq_idx, seq_name in tqdm(enumerate(seq_names), total=len(seq_names)):
|
101 |
+
# load reference image
|
102 |
+
ref_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, "1.ppm"))
|
103 |
+
ref_img_shape=ref_img.shape
|
104 |
+
|
105 |
+
# load query images
|
106 |
+
for im_idx in range(2, 7):
|
107 |
+
# read ground-truth homography
|
108 |
+
homography = np.loadtxt(os.path.join(DATASET_ROOT, seq_name, "H_1_" + str(im_idx)))
|
109 |
+
query_img = cv2.imread(os.path.join(DATASET_ROOT, seq_name, f"{im_idx}.ppm"))
|
110 |
+
|
111 |
+
mkpts_a,mkpts_b=match_fn(ref_img,query_img)
|
112 |
+
|
113 |
+
pos_a = mkpts_a
|
114 |
+
pos_a_h = np.concatenate([pos_a, np.ones([pos_a.shape[0], 1])], axis=1)
|
115 |
+
pos_b_proj_h = np.transpose(np.dot(homography, np.transpose(pos_a_h)))
|
116 |
+
pos_b_proj = pos_b_proj_h[:, :2] / pos_b_proj_h[:, 2:]
|
117 |
+
|
118 |
+
pos_b = mkpts_b
|
119 |
+
|
120 |
+
dist = np.sqrt(np.sum((pos_b - pos_b_proj) ** 2, axis=1))
|
121 |
+
|
122 |
+
n_matches.append(pos_a.shape[0])
|
123 |
+
seq_type.append(seq_name[0])
|
124 |
+
|
125 |
+
if dist.shape[0] == 0:
|
126 |
+
dist = np.array([float("inf")])
|
127 |
+
|
128 |
+
for thr in rng:
|
129 |
+
if seq_name[0] == "i":
|
130 |
+
i_err[thr] += np.mean(dist <= thr)
|
131 |
+
else:
|
132 |
+
v_err[thr] += np.mean(dist <= thr)
|
133 |
+
|
134 |
+
# estimate homography
|
135 |
+
gt_homo = homography
|
136 |
+
pred_homo, _ = cv2.findHomography(mkpts_a,mkpts_b,cv2.USAC_MAGSAC)
|
137 |
+
if pred_homo is None:
|
138 |
+
homo_dist = np.array([float("inf")])
|
139 |
+
else:
|
140 |
+
corners = np.array(
|
141 |
+
[
|
142 |
+
[0, 0],
|
143 |
+
[ref_img_shape[1] - 1, 0],
|
144 |
+
[0, ref_img_shape[0] - 1],
|
145 |
+
[ref_img_shape[1] - 1, ref_img_shape[0] - 1],
|
146 |
+
]
|
147 |
+
)
|
148 |
+
real_warped_corners = homo_trans(corners, gt_homo)
|
149 |
+
warped_corners = homo_trans(corners, pred_homo)
|
150 |
+
homo_dist = np.mean(np.linalg.norm(real_warped_corners - warped_corners, axis=1))
|
151 |
+
|
152 |
+
for thr in rng:
|
153 |
+
if seq_name[0] == "i":
|
154 |
+
i_err_homo[thr] += np.mean(homo_dist <= thr)
|
155 |
+
else:
|
156 |
+
v_err_homo[thr] += np.mean(homo_dist <= thr)
|
157 |
+
|
158 |
+
seq_type = np.array(seq_type)
|
159 |
+
n_feats = np.array(n_feats)
|
160 |
+
n_matches = np.array(n_matches)
|
161 |
+
|
162 |
+
return i_err, v_err, i_err_homo, v_err_homo, [seq_type, n_feats, n_matches]
|
163 |
+
|
164 |
+
|
165 |
+
if __name__ == "__main__":
|
166 |
+
errors = {}
|
167 |
+
|
168 |
+
weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
|
169 |
+
liftfeat=LiftFeat(weight=weights)
|
170 |
+
|
171 |
+
errors = benchmark_features(liftfeat.match_liftfeat)
|
172 |
+
|
173 |
+
i_err, v_err, i_err_hom, v_err_hom, _ = errors
|
174 |
+
|
175 |
+
cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
176 |
+
|
177 |
+
print(f'\n==={cur_time}==={args.name}===')
|
178 |
+
print(f"MHA@3 MHA@5 MHA@7")
|
179 |
+
for thr in [3, 5, 7]:
|
180 |
+
ill_err_hom = i_err_hom[thr] / (n_i * 5)
|
181 |
+
view_err_hom = v_err_hom[thr] / (n_v * 5)
|
182 |
+
print(f"{ill_err_hom * 100:.2f}%-{view_err_hom * 100:.2f}%")
|
imcui/third_party/LiftFeat/evaluation/MegaDepth1500_evaluation.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import cv2
|
4 |
+
from pathlib import Path
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.utils.data as data
|
8 |
+
import tqdm
|
9 |
+
from copy import deepcopy
|
10 |
+
from torchvision.transforms import ToTensor
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import json
|
13 |
+
|
14 |
+
import scipy.io as scio
|
15 |
+
import poselib
|
16 |
+
|
17 |
+
import argparse
|
18 |
+
import datetime
|
19 |
+
|
20 |
+
parser=argparse.ArgumentParser(description='MegaDepth dataset evaluation script')
|
21 |
+
parser.add_argument('--name',type=str,default='LiftFeat',help='experiment name')
|
22 |
+
parser.add_argument('--gpu',type=str,default='0',help='GPU ID')
|
23 |
+
args=parser.parse_args()
|
24 |
+
|
25 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
|
26 |
+
|
27 |
+
sys.path.append(os.path.join(os.path.dirname(__file__),'../'))
|
28 |
+
from models.liftfeat_wrapper import LiftFeat
|
29 |
+
from evaluation.eval_utils import *
|
30 |
+
|
31 |
+
from torch.utils.data import Dataset,DataLoader
|
32 |
+
|
33 |
+
use_cuda = torch.cuda.is_available()
|
34 |
+
device = "cuda" if use_cuda else "cpu"
|
35 |
+
|
36 |
+
DATASET_ROOT = os.path.join(os.path.dirname(__file__),'../data/megadepth_test_1500')
|
37 |
+
DATASET_JSON = os.path.join(os.path.dirname(__file__),'../data/megadepth_1500.json')
|
38 |
+
|
39 |
+
class MegaDepth1500(Dataset):
|
40 |
+
"""
|
41 |
+
Streamlined MegaDepth-1500 dataloader. The camera poses & metadata are stored in a formatted json for facilitating
|
42 |
+
the download of the dataset and to keep the setup as simple as possible.
|
43 |
+
"""
|
44 |
+
def __init__(self, json_file, root_dir):
|
45 |
+
# Load the info & calibration from the JSON
|
46 |
+
with open(json_file, 'r') as f:
|
47 |
+
self.data = json.load(f)
|
48 |
+
|
49 |
+
self.root_dir = root_dir
|
50 |
+
|
51 |
+
if not os.path.exists(self.root_dir):
|
52 |
+
raise RuntimeError(
|
53 |
+
f"Dataset {self.root_dir} does not exist! \n \
|
54 |
+
> If you didn't download the dataset, use the downloader tool: python3 -m modules.dataset.download -h")
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.data)
|
58 |
+
|
59 |
+
def __getitem__(self, idx):
|
60 |
+
data = deepcopy(self.data[idx])
|
61 |
+
|
62 |
+
h1, w1 = data['size0_hw']
|
63 |
+
h2, w2 = data['size1_hw']
|
64 |
+
|
65 |
+
# Here we resize the images to max_dim = 1200, as described in the paper, and adjust the image such that it is divisible by 32
|
66 |
+
# following the protocol of the LoFTR's Dataloader (intrinsics are corrected accordingly).
|
67 |
+
# For adapting this with different resolution, you would need to re-scale intrinsics below.
|
68 |
+
image0 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][0]}"),(w1, h1))
|
69 |
+
|
70 |
+
image1 = cv2.resize(cv2.imread(f"{self.root_dir}/{data['pair_names'][1]}"),(w2, h2))
|
71 |
+
|
72 |
+
data['image0'] = torch.tensor(image0.astype(np.float32)/255).permute(2,0,1)
|
73 |
+
data['image1'] = torch.tensor(image1.astype(np.float32)/255).permute(2,0,1)
|
74 |
+
|
75 |
+
for k,v in data.items():
|
76 |
+
if k not in ('dataset_name', 'scene_id', 'pair_id', 'pair_names', 'size0_hw', 'size1_hw', 'image0', 'image1'):
|
77 |
+
data[k] = torch.tensor(np.array(v, dtype=np.float32))
|
78 |
+
|
79 |
+
return data
|
80 |
+
|
81 |
+
if __name__ == "__main__":
|
82 |
+
weights=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
|
83 |
+
liftfeat=LiftFeat(weight=weights)
|
84 |
+
|
85 |
+
dataset = MegaDepth1500(json_file = DATASET_JSON, root_dir = DATASET_ROOT)
|
86 |
+
|
87 |
+
loader = DataLoader(dataset, batch_size=1, shuffle=False)
|
88 |
+
|
89 |
+
metrics = {}
|
90 |
+
R_errs = []
|
91 |
+
t_errs = []
|
92 |
+
inliers = []
|
93 |
+
|
94 |
+
results=[]
|
95 |
+
|
96 |
+
cur_time = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
97 |
+
|
98 |
+
for d in tqdm.tqdm(loader, desc="processing"):
|
99 |
+
error_infos = compute_pose_error(liftfeat.match_liftfeat,d)
|
100 |
+
results.append(error_infos)
|
101 |
+
|
102 |
+
print(f'\n==={cur_time}==={args.name}===')
|
103 |
+
d_err_auc,errors=compute_maa(results)
|
104 |
+
for s_k,s_v in d_err_auc.items():
|
105 |
+
print(f'{s_k}: {s_v*100}')
|
imcui/third_party/LiftFeat/evaluation/eval_utils.py
ADDED
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import poselib
|
4 |
+
|
5 |
+
|
6 |
+
def relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0):
|
7 |
+
# angle error between 2 vectors
|
8 |
+
t_gt = T_0to1[:3, 3]
|
9 |
+
n = np.linalg.norm(t) * np.linalg.norm(t_gt)
|
10 |
+
t_err = np.rad2deg(np.arccos(np.clip(np.dot(t, t_gt) / n, -1.0, 1.0)))
|
11 |
+
t_err = np.minimum(t_err, 180 - t_err) # handle E ambiguity
|
12 |
+
if np.linalg.norm(t_gt) < ignore_gt_t_thr: # pure rotation is challenging
|
13 |
+
t_err = 0
|
14 |
+
|
15 |
+
# angle error between 2 rotation matrices
|
16 |
+
R_gt = T_0to1[:3, :3]
|
17 |
+
cos = (np.trace(np.dot(R.T, R_gt)) - 1) / 2
|
18 |
+
cos = np.clip(cos, -1.0, 1.0) # handle numercial errors
|
19 |
+
R_err = np.rad2deg(np.abs(np.arccos(cos)))
|
20 |
+
|
21 |
+
return t_err, R_err
|
22 |
+
|
23 |
+
def intrinsics_to_camera(K):
|
24 |
+
px, py = K[0, 2], K[1, 2]
|
25 |
+
fx, fy = K[0, 0], K[1, 1]
|
26 |
+
return {
|
27 |
+
"model": "PINHOLE",
|
28 |
+
"width": int(2 * px),
|
29 |
+
"height": int(2 * py),
|
30 |
+
"params": [fx, fy, px, py],
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def estimate_pose(kpts0, kpts1, K0, K1, thresh, conf=0.99999):
|
35 |
+
M, info = poselib.estimate_relative_pose(
|
36 |
+
kpts0, kpts1,
|
37 |
+
intrinsics_to_camera(K0),
|
38 |
+
intrinsics_to_camera(K1),
|
39 |
+
{"max_epipolar_error": thresh,
|
40 |
+
"success_prob": conf,
|
41 |
+
"min_iterations": 20,
|
42 |
+
"max_iterations": 1_000},
|
43 |
+
)
|
44 |
+
|
45 |
+
R, t, inl = M.R, M.t, info["inliers"]
|
46 |
+
inl = np.array(inl)
|
47 |
+
ret = (R, t, inl)
|
48 |
+
|
49 |
+
return ret
|
50 |
+
|
51 |
+
def tensor2bgr(t):
|
52 |
+
return (t.cpu()[0].permute(1,2,0).numpy()*255).astype(np.uint8)
|
53 |
+
|
54 |
+
def compute_pose_error(match_fn,data):
|
55 |
+
result = {}
|
56 |
+
|
57 |
+
with torch.no_grad():
|
58 |
+
mkpts0,mkpts1=match_fn(tensor2bgr(data["image0"]),tensor2bgr(data["image1"]))
|
59 |
+
|
60 |
+
mkpts0=mkpts0 * data["scale0"].numpy()
|
61 |
+
mkpts1=mkpts1 * data["scale1"].numpy()
|
62 |
+
|
63 |
+
K0, K1 = data["K0"][0].numpy(), data["K1"][0].numpy()
|
64 |
+
T_0to1 = data["T_0to1"][0].numpy()
|
65 |
+
T_1to0 = data["T_1to0"][0].numpy()
|
66 |
+
|
67 |
+
result={}
|
68 |
+
conf = 0.99999
|
69 |
+
|
70 |
+
ret = estimate_pose(mkpts0,mkpts1,K0,K1,4.0,conf)
|
71 |
+
if ret is not None:
|
72 |
+
R, t, inliers = ret
|
73 |
+
t_err, R_err = relative_pose_error(T_0to1, R, t, ignore_gt_t_thr=0.0)
|
74 |
+
result['R_err'] = R_err
|
75 |
+
result['t_err'] = t_err
|
76 |
+
|
77 |
+
return result
|
78 |
+
|
79 |
+
|
80 |
+
def error_auc(errors, thresholds=[5, 10, 20]):
|
81 |
+
"""
|
82 |
+
Args:
|
83 |
+
errors (list): [N,]
|
84 |
+
thresholds (list)
|
85 |
+
"""
|
86 |
+
errors = [0] + sorted(list(errors))
|
87 |
+
recall = list(np.linspace(0, 1, len(errors)))
|
88 |
+
|
89 |
+
aucs = []
|
90 |
+
|
91 |
+
for thr in thresholds:
|
92 |
+
last_index = np.searchsorted(errors, thr)
|
93 |
+
y = recall[:last_index] + [recall[last_index-1]]
|
94 |
+
x = errors[:last_index] + [thr]
|
95 |
+
aucs.append(np.trapz(y, x) / thr)
|
96 |
+
|
97 |
+
return {f'auc@{t}': auc for t, auc in zip(thresholds, aucs)}
|
98 |
+
|
99 |
+
def compute_maa(pairs, thresholds=[5, 10, 20]):
|
100 |
+
# print("auc / mAcc on %d pairs" % (len(pairs)))
|
101 |
+
errors = []
|
102 |
+
|
103 |
+
for p in pairs:
|
104 |
+
et = p['t_err']
|
105 |
+
er = p['R_err']
|
106 |
+
errors.append(max(et, er))
|
107 |
+
|
108 |
+
d_err_auc = error_auc(errors)
|
109 |
+
|
110 |
+
# for k,v in d_err_auc.items():
|
111 |
+
# print(k, ': ', '%.1f'%(v*100))
|
112 |
+
|
113 |
+
errors = np.array(errors)
|
114 |
+
|
115 |
+
for t in thresholds:
|
116 |
+
acc = (errors <= t).sum() / len(errors)
|
117 |
+
# print("mAcc@%d: %.1f "%(t, acc*100))
|
118 |
+
|
119 |
+
return d_err_auc,errors
|
120 |
+
|
121 |
+
def homo_trans(coord, H):
|
122 |
+
kpt_num = coord.shape[0]
|
123 |
+
homo_coord = np.concatenate((coord, np.ones((kpt_num, 1))), axis=-1)
|
124 |
+
proj_coord = np.matmul(H, homo_coord.T).T
|
125 |
+
proj_coord = proj_coord / proj_coord[:, 2][..., None]
|
126 |
+
proj_coord = proj_coord[:, 0:2]
|
127 |
+
return proj_coord
|
imcui/third_party/LiftFeat/loss/loss.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import time
|
7 |
+
|
8 |
+
|
9 |
+
def dual_softmax_loss(X, Y, temp = 0.2):
|
10 |
+
if X.size() != Y.size() or X.dim() != 2 or Y.dim() != 2:
|
11 |
+
raise RuntimeError('Error: X and Y shapes must match and be 2D matrices')
|
12 |
+
|
13 |
+
dist_mat = (X @ Y.t()) * temp
|
14 |
+
conf_matrix12 = F.log_softmax(dist_mat, dim=1)
|
15 |
+
conf_matrix21 = F.log_softmax(dist_mat.t(), dim=1)
|
16 |
+
|
17 |
+
with torch.no_grad():
|
18 |
+
conf12 = torch.exp( conf_matrix12 ).max(dim=-1)[0]
|
19 |
+
conf21 = torch.exp( conf_matrix21 ).max(dim=-1)[0]
|
20 |
+
conf = conf12 * conf21
|
21 |
+
|
22 |
+
target = torch.arange(len(X), device = X.device)
|
23 |
+
|
24 |
+
loss = F.nll_loss(conf_matrix12, target) + \
|
25 |
+
F.nll_loss(conf_matrix21, target)
|
26 |
+
|
27 |
+
return loss, conf
|
28 |
+
|
29 |
+
|
30 |
+
class LiftFeatLoss(nn.Module):
|
31 |
+
def __init__(self,device,lam_descs=1,lam_fb_descs=1,lam_kpts=1,lam_heatmap=1,lam_normals=1,lam_coordinates=1,lam_fb_coordinates=1,depth_spvs=False):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
# loss parameters
|
35 |
+
self.lam_descs=lam_descs
|
36 |
+
self.lam_fb_descs=lam_fb_descs
|
37 |
+
self.lam_kpts=lam_kpts
|
38 |
+
self.lam_heatmap=lam_heatmap
|
39 |
+
self.lam_normals=lam_normals
|
40 |
+
self.lam_coordinates=lam_coordinates
|
41 |
+
self.lam_fb_coordinates=lam_fb_coordinates
|
42 |
+
self.depth_spvs=depth_spvs
|
43 |
+
self.running_descs_loss=0
|
44 |
+
self.running_kpts_loss=0
|
45 |
+
self.running_heatmaps_loss=0
|
46 |
+
self.loss_descs=0
|
47 |
+
self.loss_fb_descs=0
|
48 |
+
self.loss_kpts=0
|
49 |
+
self.loss_heatmaps=0
|
50 |
+
self.loss_normals=0
|
51 |
+
self.loss_coordinates=0
|
52 |
+
self.loss_fb_coordinates=0
|
53 |
+
self.acc_coarse=0
|
54 |
+
self.acc_fb_coarse=0
|
55 |
+
self.acc_kpt=0
|
56 |
+
self.acc_coordinates=0
|
57 |
+
self.acc_fb_coordinates=0
|
58 |
+
|
59 |
+
# device
|
60 |
+
self.dev=device
|
61 |
+
|
62 |
+
|
63 |
+
def check_accuracy(self,m1,m2,pts1=None,pts2=None,plot=False):
|
64 |
+
with torch.no_grad():
|
65 |
+
#dist_mat = torch.cdist(X,Y)
|
66 |
+
dist_mat = m1 @ m2.t()
|
67 |
+
nn = torch.argmax(dist_mat, dim=1)
|
68 |
+
#nn = torch.argmin(dist_mat, dim=1)
|
69 |
+
correct = nn == torch.arange(len(m1), device = m1.device)
|
70 |
+
|
71 |
+
if pts1 is not None and plot:
|
72 |
+
import matplotlib.pyplot as plt
|
73 |
+
canvas = torch.zeros((60, 80),device=m1.device)
|
74 |
+
pts1 = pts1[~correct]
|
75 |
+
canvas[pts1[:,1].long(), pts1[:,0].long()] = 1
|
76 |
+
canvas = canvas.cpu().numpy()
|
77 |
+
plt.imshow(canvas), plt.show()
|
78 |
+
|
79 |
+
acc = correct.sum().item() / len(m1)
|
80 |
+
return acc
|
81 |
+
|
82 |
+
def compute_descriptors_loss(self,descs1,descs2,pts):
|
83 |
+
loss=[]
|
84 |
+
acc=0
|
85 |
+
B,_,H,W=descs1.shape
|
86 |
+
conf_list=[]
|
87 |
+
|
88 |
+
for b in range(B):
|
89 |
+
pts1,pts2=pts[b][:,:2],pts[b][:,2:]
|
90 |
+
m1=descs1[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
|
91 |
+
m2=descs2[b,:,pts2[:,1].long(),pts2[:,0].long()].permute(1,0)
|
92 |
+
|
93 |
+
loss_per,conf_per=dual_softmax_loss(m1,m2)
|
94 |
+
loss.append(loss_per.unsqueeze(0))
|
95 |
+
conf_list.append(conf_per)
|
96 |
+
|
97 |
+
acc_coarse_per=self.check_accuracy(m1,m2)
|
98 |
+
acc += acc_coarse_per
|
99 |
+
|
100 |
+
loss=torch.cat(loss,dim=-1).mean()
|
101 |
+
acc /= B
|
102 |
+
return loss,acc,conf_list
|
103 |
+
|
104 |
+
|
105 |
+
def alike_distill_loss(self,kpts,alike_kpts):
|
106 |
+
C, H, W = kpts.shape
|
107 |
+
kpts = kpts.permute(1,2,0)
|
108 |
+
# get ALike keypoints
|
109 |
+
with torch.no_grad():
|
110 |
+
labels = torch.ones((H, W), dtype = torch.long, device = kpts.device) * 64 # -> Default is non-keypoint (bin 64)
|
111 |
+
offsets = (((alike_kpts/8) - (alike_kpts/8).long())*8).long()
|
112 |
+
offsets = offsets[:, 0] + 8*offsets[:, 1] # Linear IDX
|
113 |
+
labels[(alike_kpts[:,1]/8).long(), (alike_kpts[:,0]/8).long()] = offsets
|
114 |
+
|
115 |
+
kpts = kpts.view(-1,C)
|
116 |
+
labels = labels.view(-1)
|
117 |
+
|
118 |
+
mask = labels < 64
|
119 |
+
idxs_pos = mask.nonzero().flatten()
|
120 |
+
idxs_neg = (~mask).nonzero().flatten()
|
121 |
+
perm = torch.randperm(idxs_neg.size(0))[:len(idxs_pos)//32]
|
122 |
+
idxs_neg = idxs_neg[perm]
|
123 |
+
idxs = torch.cat([idxs_pos, idxs_neg])
|
124 |
+
|
125 |
+
kpts = kpts[idxs]
|
126 |
+
labels = labels[idxs]
|
127 |
+
|
128 |
+
with torch.no_grad():
|
129 |
+
predicted = kpts.max(dim=-1)[1]
|
130 |
+
acc = (labels == predicted)
|
131 |
+
acc = acc.sum() / len(acc)
|
132 |
+
|
133 |
+
kpts = F.log_softmax(kpts,dim=-1)
|
134 |
+
loss = F.nll_loss(kpts, labels, reduction = 'mean')
|
135 |
+
|
136 |
+
return loss, acc
|
137 |
+
|
138 |
+
|
139 |
+
def compute_keypoints_loss(self,kpts1,kpts2,alike_kpts1,alike_kpts2):
|
140 |
+
loss=[]
|
141 |
+
acc=0
|
142 |
+
B,_,H,W=kpts1.shape
|
143 |
+
|
144 |
+
for b in range(B):
|
145 |
+
loss_per1,acc_per1=self.alike_distill_loss(kpts1[b],alike_kpts1[b])
|
146 |
+
loss_per2,acc_per2=self.alike_distill_loss(kpts2[b],alike_kpts2[b])
|
147 |
+
loss_per=(loss_per1+loss_per2)
|
148 |
+
acc_per=(acc_per1+acc_per2)/2
|
149 |
+
loss.append(loss_per.unsqueeze(0))
|
150 |
+
acc += acc_per
|
151 |
+
|
152 |
+
loss=torch.cat(loss,dim=-1).mean()
|
153 |
+
acc /= B
|
154 |
+
return loss,acc
|
155 |
+
|
156 |
+
|
157 |
+
def compute_heatmaps_loss(self,heatmaps1,heatmaps2,pts,conf_list):
|
158 |
+
loss=[]
|
159 |
+
B,_,H,W=heatmaps1.shape
|
160 |
+
|
161 |
+
for b in range(B):
|
162 |
+
pts1,pts2=pts[b][:,:2],pts[b][:,2:]
|
163 |
+
h1=heatmaps1[b,0,pts1[:,1].long(),pts1[:,0].long()]
|
164 |
+
h2=heatmaps2[b,0,pts2[:,1].long(),pts2[:,0].long()]
|
165 |
+
|
166 |
+
conf=conf_list[b]
|
167 |
+
loss_per1=F.l1_loss(h1,conf)
|
168 |
+
loss_per2=F.l1_loss(h2,conf)
|
169 |
+
loss_per=(loss_per1+loss_per2)
|
170 |
+
loss.append(loss_per.unsqueeze(0))
|
171 |
+
|
172 |
+
loss=torch.cat(loss,dim=-1).mean()
|
173 |
+
return loss
|
174 |
+
|
175 |
+
|
176 |
+
def normal_loss(self,normal,target_normal):
|
177 |
+
# import pdb;pdb.set_trace()
|
178 |
+
normal = normal.permute(1, 2, 0)
|
179 |
+
target_normal = target_normal.permute(1,2,0)
|
180 |
+
# loss = F.l1_loss(d_feat, depth_anything_normal_feat)
|
181 |
+
dot = torch.cosine_similarity(normal, target_normal, dim=2)
|
182 |
+
valid_mask = target_normal[:, :, 0].float() \
|
183 |
+
* (dot.detach() < 0.999).float() \
|
184 |
+
* (dot.detach() > -0.999).float()
|
185 |
+
valid_mask = valid_mask > 0.0
|
186 |
+
al = torch.acos(dot[valid_mask])
|
187 |
+
loss = torch.mean(al)
|
188 |
+
return loss
|
189 |
+
|
190 |
+
|
191 |
+
def compute_normals_loss(self,normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size):
|
192 |
+
loss=[]
|
193 |
+
|
194 |
+
# import pdb;pdb.set_trace()
|
195 |
+
|
196 |
+
# only MegaDepth image need depth-normal
|
197 |
+
normals1=normals1[coco_batch_size:,...]
|
198 |
+
normals2=normals2[coco_batch_size:,...]
|
199 |
+
for b in range(len(DA_normals1)):
|
200 |
+
normal1,normal2=normals1[b],normals2[b]
|
201 |
+
loss_per1=self.normal_loss(normal1,DA_normals1[b].permute(2,0,1))
|
202 |
+
loss_per2=self.normal_loss(normal2,DA_normals2[b].permute(2,0,1))
|
203 |
+
loss_per=(loss_per1+loss_per2)
|
204 |
+
loss.append(loss_per.unsqueeze(0))
|
205 |
+
|
206 |
+
loss=torch.cat(loss,dim=-1).mean()
|
207 |
+
return loss
|
208 |
+
|
209 |
+
|
210 |
+
def coordinate_loss(self,coordinate,conf,pts1):
|
211 |
+
with torch.no_grad():
|
212 |
+
coordinate_detached = pts1 * 8
|
213 |
+
offset_detached = (coordinate_detached/8) - (coordinate_detached/8).long()
|
214 |
+
offset_detached = (offset_detached * 8).long()
|
215 |
+
label = offset_detached[:, 0] + 8*offset_detached[:, 1]
|
216 |
+
|
217 |
+
#pdb.set_trace()
|
218 |
+
coordinate_log = F.log_softmax(coordinate, dim=-1)
|
219 |
+
|
220 |
+
predicted = coordinate.max(dim=-1)[1]
|
221 |
+
acc = (label == predicted)
|
222 |
+
acc = acc[conf > 0.1]
|
223 |
+
acc = acc.sum() / len(acc)
|
224 |
+
|
225 |
+
loss = F.nll_loss(coordinate_log, label, reduction = 'none')
|
226 |
+
|
227 |
+
#Weight loss by confidence, giving more emphasis on reliable matches
|
228 |
+
conf = conf / conf.sum()
|
229 |
+
loss = (loss * conf).sum()
|
230 |
+
|
231 |
+
return loss*2., acc
|
232 |
+
|
233 |
+
def compute_coordinates_loss(self,coordinates,pts,conf_list):
|
234 |
+
loss=[]
|
235 |
+
acc=0
|
236 |
+
B,_,H,W=coordinates.shape
|
237 |
+
|
238 |
+
for b in range(B):
|
239 |
+
pts1,pts2=pts[b][:,:2],pts[b][:,2:]
|
240 |
+
coordinate=coordinates[b,:,pts1[:,1].long(),pts1[:,0].long()].permute(1,0)
|
241 |
+
conf=conf_list[b]
|
242 |
+
|
243 |
+
loss_per,acc_per=self.coordinate_loss(coordinate,conf,pts1)
|
244 |
+
loss.append(loss_per.unsqueeze(0))
|
245 |
+
acc += acc_per
|
246 |
+
|
247 |
+
loss=torch.cat(loss,dim=-1).mean()
|
248 |
+
acc /= B
|
249 |
+
|
250 |
+
return loss,acc
|
251 |
+
|
252 |
+
|
253 |
+
def forward(self,
|
254 |
+
descs1,fb_descs1,kpts1,normals1,
|
255 |
+
descs2,fb_descs2,kpts2,normals2,
|
256 |
+
pts,coordinates,fb_coordinates,
|
257 |
+
alike_kpts1,alike_kpts2,
|
258 |
+
DA_normals1,DA_normals2,
|
259 |
+
megadepth_batch_size,coco_batch_size
|
260 |
+
):
|
261 |
+
# import pdb;pdb.set_trace()
|
262 |
+
self.loss_descs,self.acc_coarse,conf_list=self.compute_descriptors_loss(descs1,descs2,pts)
|
263 |
+
self.loss_fb_descs,self.acc_fb_coarse,fb_conf_list=self.compute_descriptors_loss(fb_descs1,fb_descs2,pts)
|
264 |
+
|
265 |
+
# start=time.perf_counter()
|
266 |
+
self.loss_kpts,self.acc_kpt=self.compute_keypoints_loss(kpts1,kpts2,alike_kpts1,alike_kpts2)
|
267 |
+
# end=time.perf_counter()
|
268 |
+
# print(f"kpts loss cost {end-start} seconds")
|
269 |
+
|
270 |
+
# start=time.perf_counter()
|
271 |
+
self.loss_normals=self.compute_normals_loss(normals1,normals2,DA_normals1,DA_normals2,megadepth_batch_size,coco_batch_size)
|
272 |
+
# end=time.perf_counter()
|
273 |
+
# print(f"normal loss cost {end-start} seconds")
|
274 |
+
|
275 |
+
self.loss_coordinates,self.acc_coordinates=self.compute_coordinates_loss(coordinates,pts,conf_list)
|
276 |
+
self.loss_fb_coordinates,self.acc_fb_coordinates=self.compute_coordinates_loss(fb_coordinates,pts,fb_conf_list)
|
277 |
+
|
278 |
+
return {
|
279 |
+
'loss_descs':self.lam_descs*self.loss_descs,
|
280 |
+
'acc_coarse':self.acc_coarse,
|
281 |
+
'loss_coordinates':self.lam_coordinates*self.loss_coordinates,
|
282 |
+
'acc_coordinates':self.acc_coordinates,
|
283 |
+
'loss_fb_descs':self.lam_fb_descs*self.loss_fb_descs,
|
284 |
+
'acc_fb_coarse':self.acc_fb_coarse,
|
285 |
+
'loss_fb_coordinates':self.lam_fb_coordinates*self.loss_fb_coordinates,
|
286 |
+
'acc_fb_coordinates':self.acc_fb_coordinates,
|
287 |
+
'loss_kpts':self.lam_kpts*self.loss_kpts,
|
288 |
+
'acc_kpt':self.acc_kpt,
|
289 |
+
'loss_normals':self.lam_normals*self.loss_normals,
|
290 |
+
}
|
291 |
+
|
imcui/third_party/LiftFeat/models/interpolator.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
|
4 |
+
This script is used to interpolate rough descriptors from LiftFeat
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
class InterpolateSparse2d(nn.Module):
|
12 |
+
""" Efficiently interpolate tensor at given sparse 2D positions. """
|
13 |
+
def __init__(self, mode = 'bicubic', align_corners = False):
|
14 |
+
super().__init__()
|
15 |
+
self.mode = mode
|
16 |
+
self.align_corners = align_corners
|
17 |
+
|
18 |
+
def normgrid(self, x, H, W):
|
19 |
+
""" Normalize coords to [-1,1]. """
|
20 |
+
return 2. * (x/(torch.tensor([W-1, H-1], device = x.device, dtype = x.dtype))) - 1.
|
21 |
+
|
22 |
+
def forward(self, x, pos, H, W):
|
23 |
+
"""
|
24 |
+
Input
|
25 |
+
x: [B, C, H, W] feature tensor
|
26 |
+
pos: [B, N, 2] tensor of positions
|
27 |
+
H, W: int, original resolution of input 2d positions -- used in normalization [-1,1]
|
28 |
+
|
29 |
+
Returns
|
30 |
+
[B, N, C] sampled channels at 2d positions
|
31 |
+
"""
|
32 |
+
grid = self.normgrid(pos, H, W).unsqueeze(-2).to(x.dtype)
|
33 |
+
x = F.grid_sample(x, grid, mode = self.mode , align_corners = False)
|
34 |
+
return x.permute(0,2,3,1).squeeze(-2)
|
imcui/third_party/LiftFeat/models/liftfeat.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import math
|
7 |
+
import cv2
|
8 |
+
|
9 |
+
os.environ['CUDA_VISIBLE_DEVICES']='1'
|
10 |
+
|
11 |
+
import kornia as K
|
12 |
+
|
13 |
+
sys.path.append(os.path.join(os.path.dirname(__file__),'..'))
|
14 |
+
|
15 |
+
from models.model import LiftFeatSPModel
|
16 |
+
from models.interpolator import InterpolateSparse2d
|
17 |
+
from utils.config import featureboost_config
|
18 |
+
|
19 |
+
|
20 |
+
class NonMaxSuppression(torch.nn.Module):
|
21 |
+
def __init__(self, rep_thr=0.1, top_k=4096):
|
22 |
+
super(NonMaxSuppression,self).__init__()
|
23 |
+
self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
24 |
+
self.rep_thr = rep_thr
|
25 |
+
self.top_k=top_k
|
26 |
+
|
27 |
+
|
28 |
+
def NMS(self, x, threshold = 0.05, kernel_size = 5):
|
29 |
+
B, _, H, W = x.shape
|
30 |
+
pad=kernel_size//2
|
31 |
+
local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
|
32 |
+
pos = (x == local_max) & (x > threshold)
|
33 |
+
pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
|
34 |
+
|
35 |
+
pad_val = max([len(x) for x in pos_batched])
|
36 |
+
pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
|
37 |
+
|
38 |
+
#Pad kpts and build (B, N, 2) tensor
|
39 |
+
for b in range(len(pos_batched)):
|
40 |
+
pos[b, :len(pos_batched[b]), :] = pos_batched[b]
|
41 |
+
|
42 |
+
return pos
|
43 |
+
|
44 |
+
def forward(self, score):
|
45 |
+
pos = self.NMS(score,self.rep_thr)
|
46 |
+
|
47 |
+
return pos
|
48 |
+
|
49 |
+
def load_model(model, weight_path):
|
50 |
+
pretrained_weights = torch.load(weight_path)
|
51 |
+
|
52 |
+
model_keys = set(model.state_dict().keys())
|
53 |
+
pretrained_keys = set(pretrained_weights.keys())
|
54 |
+
|
55 |
+
missing_keys = model_keys - pretrained_keys
|
56 |
+
unexpected_keys = pretrained_keys - model_keys
|
57 |
+
|
58 |
+
if missing_keys:
|
59 |
+
print("Missing keys in pretrained weights:", missing_keys)
|
60 |
+
else:
|
61 |
+
print("No missing keys in pretrained weights.")
|
62 |
+
|
63 |
+
if unexpected_keys:
|
64 |
+
print("Unexpected keys in pretrained weights:", unexpected_keys)
|
65 |
+
else:
|
66 |
+
print("No unexpected keys in pretrained weights.")
|
67 |
+
|
68 |
+
if not missing_keys and not unexpected_keys:
|
69 |
+
model.load_state_dict(pretrained_weights)
|
70 |
+
print("Pretrained weights loaded successfully.")
|
71 |
+
else:
|
72 |
+
model.load_state_dict(pretrained_weights, strict=False)
|
73 |
+
print("There were issues with the keys.")
|
74 |
+
return model
|
75 |
+
|
76 |
+
|
77 |
+
def load_torch_image(fname, device=torch.device('cpu')):
|
78 |
+
img = K.image_to_tensor(cv2.imread(fname), False).float() / 255.
|
79 |
+
img = K.color.bgr_to_rgb(img.to(device))
|
80 |
+
|
81 |
+
image=cv2.imread(fname)
|
82 |
+
H,W,C=image.shape[0],image.shape[1],image.shape[2]
|
83 |
+
|
84 |
+
_H=math.ceil(H/32)*32
|
85 |
+
_W=math.ceil(W/32)*32
|
86 |
+
|
87 |
+
pad_h=_H-H
|
88 |
+
pad_w=_W-W
|
89 |
+
|
90 |
+
image=cv2.copyMakeBorder(image,0,pad_h,0,pad_w,cv2.BORDER_CONSTANT,None,(0, 0, 0))
|
91 |
+
|
92 |
+
pad_info=[0,pad_h,0,pad_w]
|
93 |
+
|
94 |
+
image = K.image_to_tensor(image, False).float() / 255.
|
95 |
+
image = image.to(device)
|
96 |
+
|
97 |
+
return image,pad_info
|
98 |
+
|
99 |
+
|
100 |
+
class LiftFeat(nn.Module):
|
101 |
+
def __init__(self,weight,top_k=4096,detect_threshold=0.1):
|
102 |
+
super().__init__()
|
103 |
+
self.net=LiftFeatSPModel(featureboost_config)
|
104 |
+
self.top_k=top_k
|
105 |
+
self.sampler=InterpolateSparse2d('bicubic')
|
106 |
+
self.net=load_model(self.net,weight)
|
107 |
+
self.detector=NonMaxSuppression(rep_thr=detect_threshold)
|
108 |
+
|
109 |
+
@torch.inference_mode()
|
110 |
+
def extract(self,image,pad_info):
|
111 |
+
B,_,_H1,_W1=image.shape
|
112 |
+
M1,K1,D1=self.net.forward1(image)
|
113 |
+
refine_M=self.net.forward2(M1,K1,D1)
|
114 |
+
|
115 |
+
refine_M=refine_M.reshape(M1.shape[0],M1.shape[2],M1.shape[3],-1).permute(0,3,1,2)
|
116 |
+
refine_M=torch.nn.functional.normalize(refine_M,2,dim=1)
|
117 |
+
|
118 |
+
descs_map=refine_M
|
119 |
+
# descs_map=M1
|
120 |
+
|
121 |
+
scores=torch.softmax(K1,dim=1)[:,:64]
|
122 |
+
heatmap=scores.permute(0,2,3,1).reshape(scores.shape[0],scores.shape[2],scores.shape[3],8,8)
|
123 |
+
heatmap=heatmap.permute(0,1,3,2,4).reshape(scores.shape[0],1,scores.shape[2]*8,scores.shape[3]*8)
|
124 |
+
|
125 |
+
pos=self.detector(heatmap)
|
126 |
+
kpts=pos.squeeze(0)
|
127 |
+
mask_w=kpts[...,0]<(_W1-pad_info[-1])
|
128 |
+
kpts=kpts[mask_w]
|
129 |
+
mask_h=kpts[..., 1]<(_H1-pad_info[1])
|
130 |
+
kpts=kpts[mask_h]
|
131 |
+
|
132 |
+
descs=self.sampler(descs_map,kpts.unsqueeze(0),_H1,_W1)
|
133 |
+
descs=torch.nn.functional.normalize(descs,p=2,dim=1)
|
134 |
+
descs=descs.squeeze(0)
|
135 |
+
|
136 |
+
return {
|
137 |
+
'descriptors':descs,
|
138 |
+
'keypoints':kpts
|
139 |
+
}
|
140 |
+
|
141 |
+
def match_liftfeat(self, img1, pad_info1, img2, pad_info2, min_cossim=-1):
|
142 |
+
# import pdb;pdb.set_trace()
|
143 |
+
data1=self.extract(img1, pad_info1)
|
144 |
+
data2=self.extract(img2, pad_info2)
|
145 |
+
|
146 |
+
kpts1,feats1=data1['keypoints'],data1['descriptors']
|
147 |
+
kpts2,feats2=data2['keypoints'],data2['descriptors']
|
148 |
+
|
149 |
+
cossim = feats1 @ feats2.t()
|
150 |
+
cossim_t = feats2 @ feats1.t()
|
151 |
+
|
152 |
+
_, match12 = cossim.max(dim=1)
|
153 |
+
_, match21 = cossim_t.max(dim=1)
|
154 |
+
|
155 |
+
idx0 = torch.arange(len(match12), device=match12.device)
|
156 |
+
mutual = match21[match12] == idx0
|
157 |
+
|
158 |
+
if min_cossim > 0:
|
159 |
+
cossim, _ = cossim.max(dim=1)
|
160 |
+
good = cossim > min_cossim
|
161 |
+
idx0 = idx0[mutual & good]
|
162 |
+
idx1 = match12[mutual & good]
|
163 |
+
else:
|
164 |
+
idx0 = idx0[mutual]
|
165 |
+
idx1 = match12[mutual]
|
166 |
+
|
167 |
+
mkpts1,mkpts2=kpts1[idx0],kpts2[idx1]
|
168 |
+
|
169 |
+
return mkpts1, mkpts2
|
170 |
+
|
171 |
+
weight=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
|
172 |
+
|
173 |
+
liftfeat=LiftFeat(weight)
|
174 |
+
|
175 |
+
save_file=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pt')
|
176 |
+
|
177 |
+
liftfeat_script=torch.jit.script(liftfeat)
|
178 |
+
liftfeat_script.save(save_file)
|
179 |
+
|
180 |
+
# checkpoint = {
|
181 |
+
# 'model_name': 'LiftFeat',
|
182 |
+
# 'model_args': {
|
183 |
+
# 'top_k': 4096,
|
184 |
+
# 'detect_threshold': 0.1
|
185 |
+
# },
|
186 |
+
# 'state_dict': liftfeat.state_dict()
|
187 |
+
# }
|
188 |
+
|
189 |
+
# torch.save(checkpoint,os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.ckpt'))
|
190 |
+
|
imcui/third_party/LiftFeat/models/liftfeat_wrapper.py
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
import math
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
from models.model import LiftFeatSPModel
|
9 |
+
from models.interpolator import InterpolateSparse2d
|
10 |
+
from utils.config import featureboost_config
|
11 |
+
|
12 |
+
device=torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
13 |
+
|
14 |
+
MODEL_PATH=os.path.join(os.path.dirname(__file__),'../weights/LiftFeat.pth')
|
15 |
+
|
16 |
+
|
17 |
+
class NonMaxSuppression(torch.nn.Module):
|
18 |
+
def __init__(self, rep_thr=0.1, top_k=4096):
|
19 |
+
super(NonMaxSuppression,self).__init__()
|
20 |
+
self.max_filter = torch.nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
|
21 |
+
self.rep_thr = rep_thr
|
22 |
+
self.top_k=top_k
|
23 |
+
|
24 |
+
|
25 |
+
def NMS(self, x, threshold = 0.05, kernel_size = 5):
|
26 |
+
B, _, H, W = x.shape
|
27 |
+
pad=kernel_size//2
|
28 |
+
local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
|
29 |
+
pos = (x == local_max) & (x > threshold)
|
30 |
+
pos_batched = [k.nonzero()[..., 1:].flip(-1) for k in pos]
|
31 |
+
|
32 |
+
pad_val = max([len(x) for x in pos_batched])
|
33 |
+
pos = torch.zeros((B, pad_val, 2), dtype=torch.long, device=x.device)
|
34 |
+
|
35 |
+
#Pad kpts and build (B, N, 2) tensor
|
36 |
+
for b in range(len(pos_batched)):
|
37 |
+
pos[b, :len(pos_batched[b]), :] = pos_batched[b]
|
38 |
+
|
39 |
+
return pos
|
40 |
+
|
41 |
+
def forward(self, score):
|
42 |
+
pos = self.NMS(score,self.rep_thr)
|
43 |
+
|
44 |
+
return pos
|
45 |
+
|
46 |
+
def load_model(model, weight_path):
|
47 |
+
pretrained_weights = torch.load(weight_path, map_location="cpu")
|
48 |
+
|
49 |
+
model_keys = set(model.state_dict().keys())
|
50 |
+
pretrained_keys = set(pretrained_weights.keys())
|
51 |
+
|
52 |
+
missing_keys = model_keys - pretrained_keys
|
53 |
+
unexpected_keys = pretrained_keys - model_keys
|
54 |
+
|
55 |
+
# if missing_keys:
|
56 |
+
# print("Missing keys in pretrained weights:", missing_keys)
|
57 |
+
# else:
|
58 |
+
# print("No missing keys in pretrained weights.")
|
59 |
+
|
60 |
+
# if unexpected_keys:
|
61 |
+
# print("Unexpected keys in pretrained weights:", unexpected_keys)
|
62 |
+
# else:
|
63 |
+
# print("No unexpected keys in pretrained weights.")
|
64 |
+
|
65 |
+
if not missing_keys and not unexpected_keys:
|
66 |
+
model.load_state_dict(pretrained_weights)
|
67 |
+
print("load weight successfully.")
|
68 |
+
else:
|
69 |
+
model.load_state_dict(pretrained_weights, strict=False)
|
70 |
+
# print("There were issues with the keys.")
|
71 |
+
return model
|
72 |
+
|
73 |
+
|
74 |
+
import torch.nn as nn
|
75 |
+
class LiftFeat(nn.Module):
|
76 |
+
def __init__(self,weight=MODEL_PATH,top_k=4096,detect_threshold=0.1):
|
77 |
+
super().__init__()
|
78 |
+
self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
79 |
+
self.net=LiftFeatSPModel(featureboost_config).to(self.device).eval()
|
80 |
+
self.top_k=top_k
|
81 |
+
self.sampler=InterpolateSparse2d('bicubic')
|
82 |
+
self.net=load_model(self.net,weight)
|
83 |
+
self.detector=NonMaxSuppression(rep_thr=detect_threshold)
|
84 |
+
self.net=self.net.to(self.device)
|
85 |
+
self.detector=self.detector.to(self.device)
|
86 |
+
self.sampler=self.sampler.to(self.device)
|
87 |
+
|
88 |
+
def image_preprocess(self,image: np.ndarray):
|
89 |
+
H,W,C=image.shape[0],image.shape[1],image.shape[2]
|
90 |
+
|
91 |
+
_H=math.ceil(H/32)*32
|
92 |
+
_W=math.ceil(W/32)*32
|
93 |
+
|
94 |
+
pad_h=_H-H
|
95 |
+
pad_w=_W-W
|
96 |
+
|
97 |
+
image=cv2.copyMakeBorder(image,0,pad_h,0,pad_w,cv2.BORDER_CONSTANT,None,(0, 0, 0))
|
98 |
+
|
99 |
+
pad_info=[0,pad_h,0,pad_w]
|
100 |
+
|
101 |
+
if len(image.shape)==3:
|
102 |
+
image=image[None,...]
|
103 |
+
|
104 |
+
image=torch.tensor(image).permute(0,3,1,2)/255
|
105 |
+
image=image.to(device)
|
106 |
+
|
107 |
+
return image, pad_info
|
108 |
+
|
109 |
+
@torch.inference_mode()
|
110 |
+
def extract(self,image: np.ndarray):
|
111 |
+
image,pad_info=self.image_preprocess(image)
|
112 |
+
B,_,_H1,_W1=image.shape
|
113 |
+
|
114 |
+
M1,K1,D1=self.net.forward1(image)
|
115 |
+
refine_M=self.net.forward2(M1,K1,D1)
|
116 |
+
|
117 |
+
refine_M=refine_M.reshape(M1.shape[0],M1.shape[2],M1.shape[3],-1).permute(0,3,1,2)
|
118 |
+
refine_M=torch.nn.functional.normalize(refine_M,2,dim=1)
|
119 |
+
|
120 |
+
descs_map=refine_M
|
121 |
+
# descs_map=M1
|
122 |
+
|
123 |
+
scores=torch.softmax(K1,dim=1)[:,:64]
|
124 |
+
heatmap=scores.permute(0,2,3,1).reshape(scores.shape[0],scores.shape[2],scores.shape[3],8,8)
|
125 |
+
heatmap=heatmap.permute(0,1,3,2,4).reshape(scores.shape[0],1,scores.shape[2]*8,scores.shape[3]*8)
|
126 |
+
|
127 |
+
pos=self.detector(heatmap)
|
128 |
+
kpts=pos.squeeze(0)
|
129 |
+
mask_w=kpts[...,0]<(_W1-pad_info[-1])
|
130 |
+
kpts=kpts[mask_w]
|
131 |
+
mask_h=kpts[..., 1]<(_H1-pad_info[1])
|
132 |
+
kpts=kpts[mask_h]
|
133 |
+
|
134 |
+
descs=self.sampler(descs_map,kpts.unsqueeze(0),_H1,_W1)
|
135 |
+
descs=torch.nn.functional.normalize(descs,p=2,dim=1)
|
136 |
+
descs=descs.squeeze(0)
|
137 |
+
|
138 |
+
return {
|
139 |
+
'descriptors':descs,
|
140 |
+
'keypoints':kpts
|
141 |
+
}
|
142 |
+
|
143 |
+
def match_liftfeat(self, img1, img2, min_cossim=-1):
|
144 |
+
# import pdb;pdb.set_trace()
|
145 |
+
data1=self.extract(img1)
|
146 |
+
data2=self.extract(img2)
|
147 |
+
|
148 |
+
kpts1,feats1=data1['keypoints'],data1['descriptors']
|
149 |
+
kpts2,feats2=data2['keypoints'],data2['descriptors']
|
150 |
+
|
151 |
+
cossim = feats1 @ feats2.t()
|
152 |
+
cossim_t = feats2 @ feats1.t()
|
153 |
+
|
154 |
+
_, match12 = cossim.max(dim=1)
|
155 |
+
_, match21 = cossim_t.max(dim=1)
|
156 |
+
|
157 |
+
idx0 = torch.arange(len(match12), device=match12.device)
|
158 |
+
mutual = match21[match12] == idx0
|
159 |
+
|
160 |
+
if min_cossim > 0:
|
161 |
+
cossim, _ = cossim.max(dim=1)
|
162 |
+
good = cossim > min_cossim
|
163 |
+
idx0 = idx0[mutual & good]
|
164 |
+
idx1 = match12[mutual & good]
|
165 |
+
else:
|
166 |
+
idx0 = idx0[mutual]
|
167 |
+
idx1 = match12[mutual]
|
168 |
+
|
169 |
+
mkpts1,mkpts2=kpts1[idx0],kpts2[idx1]
|
170 |
+
mkpts1,mkpts2=mkpts1.cpu().numpy(),mkpts2.cpu().numpy()
|
171 |
+
|
172 |
+
return mkpts1, mkpts2
|
173 |
+
|
imcui/third_party/LiftFeat/models/model.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
"""
|
3 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
4 |
+
"""
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
import tqdm
|
13 |
+
import math
|
14 |
+
import cv2
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append('/home/yepeng_liu/code_python/laiwenpeng/LiftFeat')
|
18 |
+
from utils.featurebooster import FeatureBooster
|
19 |
+
from utils.config import featureboost_config
|
20 |
+
|
21 |
+
# from models.model_dfb import LiftFeatModel
|
22 |
+
# from models.interpolator import InterpolateSparse2d
|
23 |
+
# from third_party.config import featureboost_config
|
24 |
+
|
25 |
+
"""
|
26 |
+
foundational functions
|
27 |
+
"""
|
28 |
+
def simple_nms(scores, radius):
|
29 |
+
"""Perform non maximum suppression on the heatmap using max-pooling.
|
30 |
+
This method does not suppress contiguous points that have the same score.
|
31 |
+
Args:
|
32 |
+
scores: the score heatmap of size `(B, H, W)`.
|
33 |
+
radius: an integer scalar, the radius of the NMS window.
|
34 |
+
"""
|
35 |
+
|
36 |
+
def max_pool(x):
|
37 |
+
return torch.nn.functional.max_pool2d(
|
38 |
+
x, kernel_size=radius * 2 + 1, stride=1, padding=radius
|
39 |
+
)
|
40 |
+
|
41 |
+
zeros = torch.zeros_like(scores)
|
42 |
+
max_mask = scores == max_pool(scores)
|
43 |
+
for _ in range(2):
|
44 |
+
supp_mask = max_pool(max_mask.float()) > 0
|
45 |
+
supp_scores = torch.where(supp_mask, zeros, scores)
|
46 |
+
new_max_mask = supp_scores == max_pool(supp_scores)
|
47 |
+
max_mask = max_mask | (new_max_mask & (~supp_mask))
|
48 |
+
return torch.where(max_mask, scores, zeros)
|
49 |
+
|
50 |
+
|
51 |
+
def top_k_keypoints(keypoints, scores, k):
|
52 |
+
if k >= len(keypoints):
|
53 |
+
return keypoints, scores
|
54 |
+
scores, indices = torch.topk(scores, k, dim=0, sorted=True)
|
55 |
+
return keypoints[indices], scores
|
56 |
+
|
57 |
+
|
58 |
+
def sample_k_keypoints(keypoints, scores, k):
|
59 |
+
if k >= len(keypoints):
|
60 |
+
return keypoints, scores
|
61 |
+
indices = torch.multinomial(scores, k, replacement=False)
|
62 |
+
return keypoints[indices], scores[indices]
|
63 |
+
|
64 |
+
|
65 |
+
def soft_argmax_refinement(keypoints, scores, radius: int):
|
66 |
+
width = 2 * radius + 1
|
67 |
+
sum_ = torch.nn.functional.avg_pool2d(
|
68 |
+
scores[:, None], width, 1, radius, divisor_override=1
|
69 |
+
)
|
70 |
+
ar = torch.arange(-radius, radius + 1).to(scores)
|
71 |
+
kernel_x = ar[None].expand(width, -1)[None, None]
|
72 |
+
dx = torch.nn.functional.conv2d(scores[:, None], kernel_x, padding=radius)
|
73 |
+
dy = torch.nn.functional.conv2d(
|
74 |
+
scores[:, None], kernel_x.transpose(2, 3), padding=radius
|
75 |
+
)
|
76 |
+
dydx = torch.stack([dy[:, 0], dx[:, 0]], -1) / sum_[:, 0, :, :, None]
|
77 |
+
refined_keypoints = []
|
78 |
+
for i, kpts in enumerate(keypoints):
|
79 |
+
delta = dydx[i][tuple(kpts.t())]
|
80 |
+
refined_keypoints.append(kpts.float() + delta)
|
81 |
+
return refined_keypoints
|
82 |
+
|
83 |
+
|
84 |
+
# Legacy (broken) sampling of the descriptors
|
85 |
+
def sample_descriptors(keypoints, descriptors, s):
|
86 |
+
b, c, h, w = descriptors.shape
|
87 |
+
keypoints = keypoints - s / 2 + 0.5
|
88 |
+
keypoints /= torch.tensor(
|
89 |
+
[(w * s - s / 2 - 0.5), (h * s - s / 2 - 0.5)],
|
90 |
+
).to(
|
91 |
+
keypoints
|
92 |
+
)[None]
|
93 |
+
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
94 |
+
args = {"align_corners": True} if torch.__version__ >= "1.3" else {}
|
95 |
+
descriptors = torch.nn.functional.grid_sample(
|
96 |
+
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", **args
|
97 |
+
)
|
98 |
+
descriptors = torch.nn.functional.normalize(
|
99 |
+
descriptors.reshape(b, c, -1), p=2, dim=1
|
100 |
+
)
|
101 |
+
return descriptors
|
102 |
+
|
103 |
+
|
104 |
+
# The original keypoint sampling is incorrect. We patch it here but
|
105 |
+
# keep the original one above for legacy.
|
106 |
+
def sample_descriptors_fix_sampling(keypoints, descriptors, s: int = 8):
|
107 |
+
"""Interpolate descriptors at keypoint locations"""
|
108 |
+
b, c, h, w = descriptors.shape
|
109 |
+
keypoints = keypoints / (keypoints.new_tensor([w, h]) * s)
|
110 |
+
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
|
111 |
+
descriptors = torch.nn.functional.grid_sample(
|
112 |
+
descriptors, keypoints.view(b, 1, -1, 2), mode="bilinear", align_corners=False
|
113 |
+
)
|
114 |
+
descriptors = torch.nn.functional.normalize(
|
115 |
+
descriptors.reshape(b, c, -1), p=2, dim=1
|
116 |
+
)
|
117 |
+
return descriptors
|
118 |
+
|
119 |
+
|
120 |
+
class UpsampleLayer(nn.Module):
|
121 |
+
def __init__(self, in_channels):
|
122 |
+
super().__init__()
|
123 |
+
# 定义特征提取层,减少通道数同时增加特征提取能力
|
124 |
+
self.conv = nn.Conv2d(in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
|
125 |
+
# 使用BN层
|
126 |
+
self.bn = nn.BatchNorm2d(in_channels//2)
|
127 |
+
# 使用LeakyReLU激活函数
|
128 |
+
self.leaky_relu = nn.LeakyReLU(0.1)
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
|
132 |
+
x = self.leaky_relu(self.bn(self.conv(x)))
|
133 |
+
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class KeypointHead(nn.Module):
|
138 |
+
def __init__(self,in_channels,out_channels):
|
139 |
+
super().__init__()
|
140 |
+
self.layer1=BaseLayer(in_channels,32)
|
141 |
+
self.layer2=BaseLayer(32,32)
|
142 |
+
self.layer3=BaseLayer(32,64)
|
143 |
+
self.layer4=BaseLayer(64,64)
|
144 |
+
self.layer5=BaseLayer(64,128)
|
145 |
+
|
146 |
+
self.conv=nn.Conv2d(128,out_channels,kernel_size=3,stride=1,padding=1)
|
147 |
+
self.bn=nn.BatchNorm2d(65)
|
148 |
+
|
149 |
+
def forward(self,x):
|
150 |
+
x=self.layer1(x)
|
151 |
+
x=self.layer2(x)
|
152 |
+
x=self.layer3(x)
|
153 |
+
x=self.layer4(x)
|
154 |
+
x=self.layer5(x)
|
155 |
+
x=self.bn(self.conv(x))
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class DescriptorHead(nn.Module):
|
160 |
+
def __init__(self,in_channels,out_channels):
|
161 |
+
super().__init__()
|
162 |
+
self.layer=nn.Sequential(
|
163 |
+
BaseLayer(in_channels,32),
|
164 |
+
BaseLayer(32,32,activation=False),
|
165 |
+
BaseLayer(32,64,activation=False),
|
166 |
+
BaseLayer(64,out_channels,activation=False)
|
167 |
+
)
|
168 |
+
|
169 |
+
def forward(self,x):
|
170 |
+
x=self.layer(x)
|
171 |
+
# x=nn.functional.softmax(x,dim=1)
|
172 |
+
return x
|
173 |
+
|
174 |
+
|
175 |
+
class HeatmapHead(nn.Module):
|
176 |
+
def __init__(self,in_channels,mid_channels,out_channels):
|
177 |
+
super().__init__()
|
178 |
+
self.convHa = nn.Conv2d(in_channels, mid_channels, kernel_size=3, stride=1, padding=1)
|
179 |
+
self.bnHa = nn.BatchNorm2d(mid_channels)
|
180 |
+
self.convHb = nn.Conv2d(mid_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
181 |
+
self.bnHb = nn.BatchNorm2d(out_channels)
|
182 |
+
self.leaky_relu = nn.LeakyReLU(0.1)
|
183 |
+
|
184 |
+
def forward(self,x):
|
185 |
+
x = self.leaky_relu(self.bnHa(self.convHa(x)))
|
186 |
+
x = self.leaky_relu(self.bnHb(self.convHb(x)))
|
187 |
+
|
188 |
+
x = torch.sigmoid(x)
|
189 |
+
return x
|
190 |
+
|
191 |
+
|
192 |
+
class DepthHead(nn.Module):
|
193 |
+
def __init__(self, in_channels):
|
194 |
+
super().__init__()
|
195 |
+
self.upsampleDa = UpsampleLayer(in_channels)
|
196 |
+
self.upsampleDb = UpsampleLayer(in_channels//2)
|
197 |
+
self.upsampleDc = UpsampleLayer(in_channels//4)
|
198 |
+
|
199 |
+
self.convDepa = nn.Conv2d(in_channels//2+in_channels, in_channels//2, kernel_size=3, stride=1, padding=1)
|
200 |
+
self.bnDepa = nn.BatchNorm2d(in_channels//2)
|
201 |
+
self.convDepb = nn.Conv2d(in_channels//4+in_channels//2, in_channels//4, kernel_size=3, stride=1, padding=1)
|
202 |
+
self.bnDepb = nn.BatchNorm2d(in_channels//4)
|
203 |
+
self.convDepc = nn.Conv2d(in_channels//8+in_channels//4, 3, kernel_size=3, stride=1, padding=1)
|
204 |
+
self.bnDepc = nn.BatchNorm2d(3)
|
205 |
+
|
206 |
+
self.leaky_relu = nn.LeakyReLU(0.1)
|
207 |
+
|
208 |
+
def forward(self, x):
|
209 |
+
x0 = F.interpolate(x, scale_factor=2,mode='bilinear',align_corners=False)
|
210 |
+
x1 = self.upsampleDa(x)
|
211 |
+
x1 = torch.cat([x0,x1],dim=1)
|
212 |
+
x1 = self.leaky_relu(self.bnDepa(self.convDepa(x1)))
|
213 |
+
|
214 |
+
x1_0 = F.interpolate(x1,scale_factor=2,mode='bilinear',align_corners=False)
|
215 |
+
x2 = self.upsampleDb(x1)
|
216 |
+
x2 = torch.cat([x1_0,x2],dim=1)
|
217 |
+
x2 = self.leaky_relu(self.bnDepb(self.convDepb(x2)))
|
218 |
+
|
219 |
+
x2_0 = F.interpolate(x2,scale_factor=2,mode='bilinear',align_corners=False)
|
220 |
+
x3 = self.upsampleDc(x2)
|
221 |
+
x3 = torch.cat([x2_0,x3],dim=1)
|
222 |
+
x = self.leaky_relu(self.bnDepc(self.convDepc(x3)))
|
223 |
+
|
224 |
+
x = F.normalize(x,p=2,dim=1)
|
225 |
+
return x
|
226 |
+
|
227 |
+
|
228 |
+
class BaseLayer(nn.Module):
|
229 |
+
def __init__(self,in_channels,out_channels,kernel_size=3,stride=1,padding=1,bias=False,activation=True):
|
230 |
+
super().__init__()
|
231 |
+
if activation:
|
232 |
+
self.layer=nn.Sequential(
|
233 |
+
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
|
234 |
+
nn.BatchNorm2d(out_channels,affine=False),
|
235 |
+
nn.ReLU(inplace=True)
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
self.layer=nn.Sequential(
|
239 |
+
nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=bias),
|
240 |
+
nn.BatchNorm2d(out_channels,affine=False)
|
241 |
+
)
|
242 |
+
|
243 |
+
def forward(self,x):
|
244 |
+
return self.layer(x)
|
245 |
+
|
246 |
+
|
247 |
+
class LiftFeatSPModel(nn.Module):
|
248 |
+
default_conf = {
|
249 |
+
"has_detector": True,
|
250 |
+
"has_descriptor": True,
|
251 |
+
"descriptor_dim": 64,
|
252 |
+
# Inference
|
253 |
+
"sparse_outputs": True,
|
254 |
+
"dense_outputs": False,
|
255 |
+
"nms_radius": 4,
|
256 |
+
"refinement_radius": 0,
|
257 |
+
"detection_threshold": 0.005,
|
258 |
+
"max_num_keypoints": -1,
|
259 |
+
"max_num_keypoints_val": None,
|
260 |
+
"force_num_keypoints": False,
|
261 |
+
"randomize_keypoints_training": False,
|
262 |
+
"remove_borders": 4,
|
263 |
+
"legacy_sampling": True, # True to use the old broken sampling
|
264 |
+
}
|
265 |
+
|
266 |
+
def __init__(self, featureboost_config, use_kenc=False, use_normal=True, use_cross=True):
|
267 |
+
super().__init__()
|
268 |
+
self.device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
269 |
+
self.descriptor_dim = 64
|
270 |
+
|
271 |
+
self.norm = nn.InstanceNorm2d(1)
|
272 |
+
|
273 |
+
self.relu = nn.ReLU(inplace=True)
|
274 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
275 |
+
c1,c2,c3,c4,c5 = 24,24,64,64,128
|
276 |
+
|
277 |
+
self.conv1a = nn.Conv2d(1, c1, kernel_size=3, stride=1, padding=1)
|
278 |
+
self.conv1b = nn.Conv2d(c1, c1, kernel_size=3, stride=1, padding=1)
|
279 |
+
self.conv2a = nn.Conv2d(c1, c2, kernel_size=3, stride=1, padding=1)
|
280 |
+
self.conv2b = nn.Conv2d(c2, c2, kernel_size=3, stride=1, padding=1)
|
281 |
+
self.conv3a = nn.Conv2d(c2, c3, kernel_size=3, stride=1, padding=1)
|
282 |
+
self.conv3b = nn.Conv2d(c3, c3, kernel_size=3, stride=1, padding=1)
|
283 |
+
self.conv4a = nn.Conv2d(c3, c4, kernel_size=3, stride=1, padding=1)
|
284 |
+
self.conv4b = nn.Conv2d(c4, c4, kernel_size=3, stride=1, padding=1)
|
285 |
+
self.conv5a = nn.Conv2d(c4, c5, kernel_size=3, stride=1, padding=1)
|
286 |
+
self.conv5b = nn.Conv2d(c5, c5, kernel_size=3, stride=1, padding=1)
|
287 |
+
|
288 |
+
self.upsample4 = UpsampleLayer(c4)
|
289 |
+
self.upsample5 = UpsampleLayer(c5)
|
290 |
+
self.conv_fusion45 = nn.Conv2d(c5//2+c4,c4,kernel_size=3,stride=1,padding=1)
|
291 |
+
self.conv_fusion34 = nn.Conv2d(c4//2+c3,c3,kernel_size=3,stride=1,padding=1)
|
292 |
+
|
293 |
+
# detector
|
294 |
+
self.keypoint_head = KeypointHead(in_channels=c3,out_channels=65)
|
295 |
+
# descriptor
|
296 |
+
self.descriptor_head = DescriptorHead(in_channels=c3,out_channels=self.descriptor_dim)
|
297 |
+
# # heatmap
|
298 |
+
# self.heatmap_head = HeatmapHead(in_channels=c3,mid_channels=c3,out_channels=1)
|
299 |
+
# depth
|
300 |
+
self.depth_head = DepthHead(c3)
|
301 |
+
|
302 |
+
self.fine_matcher = nn.Sequential(
|
303 |
+
nn.Linear(128, 512),
|
304 |
+
nn.BatchNorm1d(512, affine=False),
|
305 |
+
nn.ReLU(inplace = True),
|
306 |
+
nn.Linear(512, 512),
|
307 |
+
nn.BatchNorm1d(512, affine=False),
|
308 |
+
nn.ReLU(inplace = True),
|
309 |
+
nn.Linear(512, 512),
|
310 |
+
nn.BatchNorm1d(512, affine=False),
|
311 |
+
nn.ReLU(inplace = True),
|
312 |
+
nn.Linear(512, 512),
|
313 |
+
nn.BatchNorm1d(512, affine=False),
|
314 |
+
nn.ReLU(inplace = True),
|
315 |
+
nn.Linear(512, 64),
|
316 |
+
)
|
317 |
+
|
318 |
+
# feature_booster
|
319 |
+
self.feature_boost = FeatureBooster(featureboost_config, use_kenc=use_kenc, use_cross=use_cross, use_normal=use_normal)
|
320 |
+
|
321 |
+
def feature_extract(self, x):
|
322 |
+
x1 = self.relu(self.conv1a(x))
|
323 |
+
x1 = self.relu(self.conv1b(x1))
|
324 |
+
x1 = self.pool(x1)
|
325 |
+
x2 = self.relu(self.conv2a(x1))
|
326 |
+
x2 = self.relu(self.conv2b(x2))
|
327 |
+
x2 = self.pool(x2)
|
328 |
+
x3 = self.relu(self.conv3a(x2))
|
329 |
+
x3 = self.relu(self.conv3b(x3))
|
330 |
+
x3 = self.pool(x3)
|
331 |
+
x4 = self.relu(self.conv4a(x3))
|
332 |
+
x4 = self.relu(self.conv4b(x4))
|
333 |
+
x4 = self.pool(x4)
|
334 |
+
x5 = self.relu(self.conv5a(x4))
|
335 |
+
x5 = self.relu(self.conv5b(x5))
|
336 |
+
x5 = self.pool(x5)
|
337 |
+
return x3,x4,x5
|
338 |
+
|
339 |
+
def fuse_multi_features(self,x3,x4,x5):
|
340 |
+
# upsample x5 feature
|
341 |
+
x5 = self.upsample5(x5)
|
342 |
+
x4 = torch.cat([x4,x5],dim=1)
|
343 |
+
x4 = self.conv_fusion45(x4)
|
344 |
+
|
345 |
+
# upsample x4 feature
|
346 |
+
x4 = self.upsample4(x4)
|
347 |
+
x3 = torch.cat([x3,x4],dim=1)
|
348 |
+
x = self.conv_fusion34(x3)
|
349 |
+
return x
|
350 |
+
|
351 |
+
def _unfold2d(self, x, ws = 2):
|
352 |
+
"""
|
353 |
+
Unfolds tensor in 2D with desired ws (window size) and concat the channels
|
354 |
+
"""
|
355 |
+
B, C, H, W = x.shape
|
356 |
+
x = x.unfold(2, ws , ws).unfold(3, ws,ws).reshape(B, C, H//ws, W//ws, ws**2)
|
357 |
+
return x.permute(0, 1, 4, 2, 3).reshape(B, -1, H//ws, W//ws)
|
358 |
+
|
359 |
+
|
360 |
+
def forward1(self, x):
|
361 |
+
"""
|
362 |
+
input:
|
363 |
+
x -> torch.Tensor(B, C, H, W) grayscale or rgb images
|
364 |
+
return:
|
365 |
+
feats -> torch.Tensor(B, 64, H/8, W/8) dense local features
|
366 |
+
keypoints -> torch.Tensor(B, 65, H/8, W/8) keypoint logit map
|
367 |
+
heatmap -> torch.Tensor(B, 1, H/8, W/8) reliability map
|
368 |
+
|
369 |
+
"""
|
370 |
+
with torch.no_grad():
|
371 |
+
x = x.mean(dim=1, keepdim = True)
|
372 |
+
x = self.norm(x)
|
373 |
+
|
374 |
+
x3,x4,x5 = self.feature_extract(x)
|
375 |
+
|
376 |
+
# features fusion
|
377 |
+
x = self.fuse_multi_features(x3,x4,x5)
|
378 |
+
|
379 |
+
# keypoint
|
380 |
+
keypoint_map = self.keypoint_head(x)
|
381 |
+
# descriptor
|
382 |
+
des_map = self.descriptor_head(x)
|
383 |
+
# # heatmap
|
384 |
+
# heatmap = self.heatmap_head(x)
|
385 |
+
|
386 |
+
# import pdb;pdb.set_trace()
|
387 |
+
# depth
|
388 |
+
d_feats = self.depth_head(x)
|
389 |
+
|
390 |
+
return des_map, keypoint_map, d_feats
|
391 |
+
# return des_map, keypoint_map, heatmap, d_feats
|
392 |
+
|
393 |
+
def forward2(self, descs, kpts, normals):
|
394 |
+
# import pdb;pdb.set_trace()
|
395 |
+
normals_feat=self._unfold2d(normals, ws=8)
|
396 |
+
normals_v=normals_feat.squeeze(0).permute(1,2,0).reshape(-1,normals_feat.shape[1])
|
397 |
+
descs_v=descs.squeeze(0).permute(1,2,0).reshape(-1,descs.shape[1])
|
398 |
+
kpts_v=kpts.squeeze(0).permute(1,2,0).reshape(-1,kpts.shape[1])
|
399 |
+
descs_refine = self.feature_boost(descs_v, kpts_v, normals_v)
|
400 |
+
return descs_refine
|
401 |
+
|
402 |
+
def forward(self,x):
|
403 |
+
M1,K1,D1=self.forward1(x)
|
404 |
+
descs_refine=self.forward2(M1,K1,D1)
|
405 |
+
return descs_refine,M1,K1,D1
|
406 |
+
|
407 |
+
|
408 |
+
if __name__ == "__main__":
|
409 |
+
img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
|
410 |
+
img=cv2.imread(img_path,cv2.IMREAD_GRAYSCALE)
|
411 |
+
img=cv2.resize(img,(800,608))
|
412 |
+
import pdb;pdb.set_trace()
|
413 |
+
img=torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()/255.0
|
414 |
+
img=img.cuda() if torch.cuda.is_available() else img
|
415 |
+
liftfeat_sp=LiftFeatSPModel(featureboost_config).to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))
|
416 |
+
des_map, keypoint_map, d_feats=liftfeat_sp.forward1(img)
|
417 |
+
des_fine=liftfeat_sp.forward2(des_map,keypoint_map,d_feats)
|
418 |
+
print(des_map.shape)
|
419 |
+
|
imcui/third_party/LiftFeat/requirements.txt
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchvision==0.14.1
|
3 |
+
einops==0.8.0
|
4 |
+
kornia==0.7.3
|
5 |
+
timm==1.0.15
|
6 |
+
albumentations==1.4.12
|
7 |
+
imgaug==0.4.0
|
8 |
+
opencv-python==4.10.0.84
|
9 |
+
matplotlib==3.7.5
|
10 |
+
numpy==1.24.4
|
11 |
+
scikit-image==0.21.0
|
12 |
+
scipy==1.10.1
|
13 |
+
pillow==10.3.0
|
14 |
+
tensorboard==2.14.0
|
15 |
+
tqdm==4.66.4
|
16 |
+
omegaconf==2.3.0
|
17 |
+
thop==0.1.1.post2209072238
|
18 |
+
poselib
|
imcui/third_party/LiftFeat/train.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
training script
|
4 |
+
"""
|
5 |
+
|
6 |
+
import argparse
|
7 |
+
import os
|
8 |
+
import time
|
9 |
+
import sys
|
10 |
+
sys.path.append(os.path.dirname(__file__))
|
11 |
+
|
12 |
+
def parse_arguments():
|
13 |
+
parser = argparse.ArgumentParser(description="LiftFeat training script.")
|
14 |
+
parser.add_argument('--name',type=str,default='LiftFeat',help='set process name')
|
15 |
+
|
16 |
+
# MegaDepth dataset setting
|
17 |
+
parser.add_argument('--use_megadepth',action='store_true')
|
18 |
+
parser.add_argument('--megadepth_root_path', type=str,
|
19 |
+
default='/home/yepeng_liu/code_python/dataset/MegaDepth/phoenix/S6/zl548',
|
20 |
+
help='Path to the MegaDepth dataset root directory.')
|
21 |
+
parser.add_argument('--megadepth_batch_size', type=int, default=6)
|
22 |
+
|
23 |
+
# COCO20k dataset setting
|
24 |
+
parser.add_argument('--use_coco',action='store_true')
|
25 |
+
parser.add_argument('--coco_root_path', type=str, default='/home/yepeng_liu/code_python/dataset/coco_20k',
|
26 |
+
help='Path to the COCO20k dataset root directory.')
|
27 |
+
parser.add_argument('--coco_batch_size',type=int,default=4)
|
28 |
+
|
29 |
+
parser.add_argument('--ckpt_save_path', type=str, default='/home/yepeng_liu/code_python/LiftFeat/trained_weights/test',
|
30 |
+
help='Path to save the checkpoints.')
|
31 |
+
parser.add_argument('--n_steps', type=int, default=160_000,
|
32 |
+
help='Number of training steps. Default is 160000.')
|
33 |
+
parser.add_argument('--lr', type=float, default=3e-4,
|
34 |
+
help='Learning rate. Default is 0.0003.')
|
35 |
+
parser.add_argument('--gamma_steplr', type=float, default=0.5,
|
36 |
+
help='Gamma value for StepLR scheduler. Default is 0.5.')
|
37 |
+
parser.add_argument('--training_res', type=lambda s: tuple(map(int, s.split(','))),
|
38 |
+
default=(800, 608), help='Training resolution as width,height. Default is (800, 608).')
|
39 |
+
parser.add_argument('--device_num', type=str, default='0',
|
40 |
+
help='Device number to use for training. Default is "0".')
|
41 |
+
parser.add_argument('--dry_run', action='store_true',
|
42 |
+
help='If set, perform a dry run training with a mini-batch for sanity check.')
|
43 |
+
parser.add_argument('--save_ckpt_every', type=int, default=500,
|
44 |
+
help='Save checkpoints every N steps. Default is 500.')
|
45 |
+
|
46 |
+
args = parser.parse_args()
|
47 |
+
|
48 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.device_num
|
49 |
+
|
50 |
+
return args
|
51 |
+
|
52 |
+
args = parse_arguments()
|
53 |
+
|
54 |
+
import torch
|
55 |
+
from torch import nn
|
56 |
+
from torch import optim
|
57 |
+
import torch.nn.functional as F
|
58 |
+
from torch.utils.tensorboard import SummaryWriter
|
59 |
+
from torch.utils.data import Dataset, DataLoader
|
60 |
+
|
61 |
+
import numpy as np
|
62 |
+
import tqdm
|
63 |
+
import glob
|
64 |
+
|
65 |
+
from models.model import LiftFeatSPModel
|
66 |
+
from loss.loss import LiftFeatLoss
|
67 |
+
from utils.config import featureboost_config
|
68 |
+
from models.interpolator import InterpolateSparse2d
|
69 |
+
from utils.depth_anything_wrapper import DepthAnythingExtractor
|
70 |
+
from utils.alike_wrapper import ALikeExtractor
|
71 |
+
|
72 |
+
from dataset import megadepth_wrapper
|
73 |
+
from dataset import coco_wrapper
|
74 |
+
from dataset.megadepth import MegaDepthDataset
|
75 |
+
from dataset.coco_augmentor import COCOAugmentor
|
76 |
+
|
77 |
+
import setproctitle
|
78 |
+
|
79 |
+
|
80 |
+
class Trainer():
|
81 |
+
def __init__(self, megadepth_root_path,use_megadepth,megadepth_batch_size,
|
82 |
+
coco_root_path,use_coco,coco_batch_size,
|
83 |
+
ckpt_save_path,
|
84 |
+
model_name = 'LiftFeat',
|
85 |
+
n_steps = 160_000, lr= 3e-4, gamma_steplr=0.5,
|
86 |
+
training_res = (800, 608), device_num="0", dry_run = False,
|
87 |
+
save_ckpt_every = 500):
|
88 |
+
print(f'MegeDepth: {use_megadepth}-{megadepth_batch_size}')
|
89 |
+
print(f'COCO20k: {use_coco}-{coco_batch_size}')
|
90 |
+
self.dev = torch.device ('cuda' if torch.cuda.is_available() else 'cpu')
|
91 |
+
|
92 |
+
# training model
|
93 |
+
self.net = LiftFeatSPModel(featureboost_config, use_kenc=False, use_normal=True, use_cross=True).to(self.dev)
|
94 |
+
self.loss_fn=LiftFeatLoss(self.dev,lam_descs=1,lam_kpts=2,lam_heatmap=1)
|
95 |
+
|
96 |
+
# depth-anything model
|
97 |
+
self.depth_net=DepthAnythingExtractor('vits',self.dev,256)
|
98 |
+
|
99 |
+
# alike model
|
100 |
+
self.alike_net=ALikeExtractor('alike-t',self.dev)
|
101 |
+
|
102 |
+
#Setup optimizer
|
103 |
+
self.steps = n_steps
|
104 |
+
self.opt = optim.Adam(filter(lambda x: x.requires_grad, self.net.parameters()) , lr = lr)
|
105 |
+
self.scheduler = torch.optim.lr_scheduler.StepLR(self.opt, step_size=10_000, gamma=gamma_steplr)
|
106 |
+
|
107 |
+
##################### COCO INIT ##########################
|
108 |
+
self.use_coco=use_coco
|
109 |
+
self.coco_batch_size=coco_batch_size
|
110 |
+
if self.use_coco:
|
111 |
+
self.augmentor=COCOAugmentor(
|
112 |
+
img_dir=coco_root_path,
|
113 |
+
device=self.dev,load_dataset=True,
|
114 |
+
batch_size=self.coco_batch_size,
|
115 |
+
out_resolution=training_res,
|
116 |
+
warp_resolution=training_res,
|
117 |
+
sides_crop=0.1,
|
118 |
+
max_num_imgs=3000,
|
119 |
+
num_test_imgs=5,
|
120 |
+
photometric=True,
|
121 |
+
geometric=True,
|
122 |
+
reload_step=4000
|
123 |
+
)
|
124 |
+
##################### COCO END #######################
|
125 |
+
|
126 |
+
|
127 |
+
##################### MEGADEPTH INIT ##########################
|
128 |
+
self.use_megadepth=use_megadepth
|
129 |
+
self.megadepth_batch_size=megadepth_batch_size
|
130 |
+
if self.use_megadepth:
|
131 |
+
TRAIN_BASE_PATH = f"{megadepth_root_path}/train_data/megadepth_indices"
|
132 |
+
TRAINVAL_DATA_SOURCE = f"{megadepth_root_path}/MegaDepth_v1"
|
133 |
+
|
134 |
+
TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7"
|
135 |
+
|
136 |
+
npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:]
|
137 |
+
megadepth_dataset = torch.utils.data.ConcatDataset( [MegaDepthDataset(root_dir = TRAINVAL_DATA_SOURCE,
|
138 |
+
npz_path = path) for path in tqdm.tqdm(npz_paths, desc="[MegaDepth] Loading metadata")] )
|
139 |
+
|
140 |
+
self.megadepth_dataloader = DataLoader(megadepth_dataset, batch_size=megadepth_batch_size, shuffle=True)
|
141 |
+
self.megadepth_data_iter = iter(self.megadepth_dataloader)
|
142 |
+
##################### MEGADEPTH INIT END #######################
|
143 |
+
|
144 |
+
os.makedirs(ckpt_save_path, exist_ok=True)
|
145 |
+
os.makedirs(ckpt_save_path + '/logdir', exist_ok=True)
|
146 |
+
|
147 |
+
self.dry_run = dry_run
|
148 |
+
self.save_ckpt_every = save_ckpt_every
|
149 |
+
self.ckpt_save_path = ckpt_save_path
|
150 |
+
self.writer = SummaryWriter(ckpt_save_path + f'/logdir/{model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S"))
|
151 |
+
self.model_name = model_name
|
152 |
+
|
153 |
+
|
154 |
+
def generate_train_data(self):
|
155 |
+
imgs1_t,imgs2_t=[],[]
|
156 |
+
imgs1_np,imgs2_np=[],[]
|
157 |
+
# norms0,norms1=[],[]
|
158 |
+
positives_coarse=[]
|
159 |
+
|
160 |
+
if self.use_coco:
|
161 |
+
coco_imgs1, coco_imgs2, H1, H2 = coco_wrapper.make_batch(self.augmentor, 0.1)
|
162 |
+
h_coarse, w_coarse = coco_imgs1[0].shape[-2] // 8, coco_imgs1[0].shape[-1] // 8
|
163 |
+
_ , positives_coco_coarse = coco_wrapper.get_corresponding_pts(coco_imgs1, coco_imgs2, H1, H2, self.augmentor, h_coarse, w_coarse)
|
164 |
+
coco_imgs1=coco_imgs1.mean(1,keepdim=True);coco_imgs2=coco_imgs2.mean(1,keepdim=True)
|
165 |
+
imgs1_t.append(coco_imgs1);imgs2_t.append(coco_imgs2)
|
166 |
+
positives_coarse += positives_coco_coarse
|
167 |
+
|
168 |
+
if self.use_megadepth:
|
169 |
+
try:
|
170 |
+
megadepth_data=next(self.megadepth_data_iter)
|
171 |
+
except StopIteration:
|
172 |
+
print('End of MD DATASET')
|
173 |
+
self.megadepth_data_iter=iter(self.megadepth_dataloader)
|
174 |
+
megadepth_data=next(self.megadepth_data_iter)
|
175 |
+
if megadepth_data is not None:
|
176 |
+
for k in megadepth_data.keys():
|
177 |
+
if isinstance(megadepth_data[k],torch.Tensor):
|
178 |
+
megadepth_data[k]=megadepth_data[k].to(self.dev)
|
179 |
+
megadepth_imgs1_t,megadepth_imgs2_t=megadepth_data['image0'],megadepth_data['image1']
|
180 |
+
megadepth_imgs1_t=megadepth_imgs1_t.mean(1,keepdim=True);megadepth_imgs2_t=megadepth_imgs2_t.mean(1,keepdim=True)
|
181 |
+
imgs1_t.append(megadepth_imgs1_t);imgs2_t.append(megadepth_imgs2_t)
|
182 |
+
megadepth_imgs1_np,megadepth_imgs2_np=megadepth_data['image0_np'],megadepth_data['image1_np']
|
183 |
+
for np_idx in range(megadepth_imgs1_np.shape[0]):
|
184 |
+
img1_np,img2_np=megadepth_imgs1_np[np_idx].squeeze(0).cpu().numpy(),megadepth_imgs2_np[np_idx].squeeze(0).cpu().numpy()
|
185 |
+
imgs1_np.append(img1_np);imgs2_np.append(img2_np)
|
186 |
+
positives_megadepth_coarse=megadepth_wrapper.spvs_coarse(megadepth_data,8)
|
187 |
+
positives_coarse += positives_megadepth_coarse
|
188 |
+
|
189 |
+
with torch.no_grad():
|
190 |
+
imgs1_t=torch.cat(imgs1_t,dim=0)
|
191 |
+
imgs2_t=torch.cat(imgs2_t,dim=0)
|
192 |
+
|
193 |
+
return imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse
|
194 |
+
|
195 |
+
|
196 |
+
def train(self):
|
197 |
+
self.net.train()
|
198 |
+
|
199 |
+
with tqdm.tqdm(total=self.steps) as pbar:
|
200 |
+
for i in range(self.steps):
|
201 |
+
# import pdb;pdb.set_trace()
|
202 |
+
imgs1_t,imgs2_t,imgs1_np,imgs2_np,positives_coarse=self.generate_train_data()
|
203 |
+
|
204 |
+
#Check if batch is corrupted with too few correspondences
|
205 |
+
is_corrupted = False
|
206 |
+
for p in positives_coarse:
|
207 |
+
if len(p) < 30:
|
208 |
+
is_corrupted = True
|
209 |
+
|
210 |
+
if is_corrupted:
|
211 |
+
continue
|
212 |
+
|
213 |
+
# import pdb;pdb.set_trace()
|
214 |
+
#Forward pass
|
215 |
+
# start=time.perf_counter()
|
216 |
+
feats1,kpts1,normals1 = self.net.forward1(imgs1_t)
|
217 |
+
feats2,kpts2,normals2 = self.net.forward1(imgs2_t)
|
218 |
+
|
219 |
+
coordinates,fb_coordinates=[],[]
|
220 |
+
alike_kpts1,alike_kpts2=[],[]
|
221 |
+
DA_normals1,DA_normals2=[],[]
|
222 |
+
|
223 |
+
# import pdb;pdb.set_trace()
|
224 |
+
|
225 |
+
fb_feats1,fb_feats2=[],[]
|
226 |
+
for b in range(feats1.shape[0]):
|
227 |
+
feat1=feats1[b].permute(1,2,0).reshape(-1,feats1.shape[1])
|
228 |
+
feat2=feats2[b].permute(1,2,0).reshape(-1,feats2.shape[1])
|
229 |
+
|
230 |
+
coordinate=self.net.fine_matcher(torch.cat([feat1,feat2],dim=-1))
|
231 |
+
coordinates.append(coordinate)
|
232 |
+
|
233 |
+
fb_feat1=self.net.forward2(feats1[b].unsqueeze(0),kpts1[b].unsqueeze(0),normals1[b].unsqueeze(0))
|
234 |
+
fb_feat2=self.net.forward2(feats2[b].unsqueeze(0),kpts2[b].unsqueeze(0),normals2[b].unsqueeze(0))
|
235 |
+
|
236 |
+
fb_coordinate=self.net.fine_matcher(torch.cat([fb_feat1,fb_feat2],dim=-1))
|
237 |
+
fb_coordinates.append(fb_coordinate)
|
238 |
+
|
239 |
+
fb_feats1.append(fb_feat1.unsqueeze(0));fb_feats2.append(fb_feat2.unsqueeze(0))
|
240 |
+
|
241 |
+
img1,img2=imgs1_t[b],imgs2_t[b]
|
242 |
+
img1=img1.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
|
243 |
+
img2=img2.permute(1,2,0).expand(-1,-1,3).cpu().numpy() * 255
|
244 |
+
alike_kpt1=torch.tensor(self.alike_net.extract_alike_kpts(img1),device=self.dev)
|
245 |
+
alike_kpt2=torch.tensor(self.alike_net.extract_alike_kpts(img2),device=self.dev)
|
246 |
+
alike_kpts1.append(alike_kpt1);alike_kpts2.append(alike_kpt2)
|
247 |
+
|
248 |
+
# import pdb;pdb.set_trace()
|
249 |
+
for b in range(len(imgs1_np)):
|
250 |
+
megadepth_depth1,megadepth_norm1=self.depth_net.extract(imgs1_np[b])
|
251 |
+
megadepth_depth2,megadepth_norm2=self.depth_net.extract(imgs2_np[b])
|
252 |
+
DA_normals1.append(megadepth_norm1);DA_normals2.append(megadepth_norm2)
|
253 |
+
|
254 |
+
# import pdb;pdb.set_trace()
|
255 |
+
fb_feats1=torch.cat(fb_feats1,dim=0)
|
256 |
+
fb_feats2=torch.cat(fb_feats2,dim=0)
|
257 |
+
fb_feats1=fb_feats1.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
|
258 |
+
fb_feats2=fb_feats2.reshape(feats2.shape[0],feats2.shape[2],feats2.shape[3],-1).permute(0,3,1,2)
|
259 |
+
|
260 |
+
coordinates=torch.cat(coordinates,dim=0)
|
261 |
+
coordinates=coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
|
262 |
+
|
263 |
+
fb_coordinates=torch.cat(fb_coordinates,dim=0)
|
264 |
+
fb_coordinates=fb_coordinates.reshape(feats1.shape[0],feats1.shape[2],feats1.shape[3],-1).permute(0,3,1,2)
|
265 |
+
|
266 |
+
# end=time.perf_counter()
|
267 |
+
# print(f"forward1 cost {end-start} seconds")
|
268 |
+
|
269 |
+
loss_items = []
|
270 |
+
|
271 |
+
# import pdb;pdb.set_trace()
|
272 |
+
loss_info=self.loss_fn(
|
273 |
+
feats1,fb_feats1,kpts1,normals1,
|
274 |
+
feats2,fb_feats2,kpts2,normals2,
|
275 |
+
positives_coarse,
|
276 |
+
coordinates,fb_coordinates,
|
277 |
+
alike_kpts1,alike_kpts2,
|
278 |
+
DA_normals1,DA_normals2,
|
279 |
+
self.megadepth_batch_size,self.coco_batch_size)
|
280 |
+
|
281 |
+
loss_descs,acc_coarse=loss_info['loss_descs'],loss_info['acc_coarse']
|
282 |
+
loss_coordinates,acc_coordinates=loss_info['loss_coordinates'],loss_info['acc_coordinates']
|
283 |
+
loss_fb_descs,acc_fb_coarse=loss_info['loss_fb_descs'],loss_info['acc_fb_coarse']
|
284 |
+
loss_fb_coordinates,acc_fb_coordinates=loss_info['loss_fb_coordinates'],loss_info['acc_fb_coordinates']
|
285 |
+
loss_kpts,acc_kpt=loss_info['loss_kpts'],loss_info['acc_kpt']
|
286 |
+
loss_normals=loss_info['loss_normals']
|
287 |
+
|
288 |
+
# loss_items.append(loss_descs.unsqueeze(0))
|
289 |
+
# loss_items.append(loss_coordinates.unsqueeze(0))
|
290 |
+
loss_items.append(loss_fb_descs.unsqueeze(0))
|
291 |
+
loss_items.append(loss_fb_coordinates.unsqueeze(0))
|
292 |
+
loss_items.append(loss_kpts.unsqueeze(0))
|
293 |
+
loss_items.append(loss_normals.unsqueeze(0))
|
294 |
+
|
295 |
+
# nb_coarse = len(m1)
|
296 |
+
# nb_coarse = len(fb_m1)
|
297 |
+
loss = torch.cat(loss_items, -1).mean()
|
298 |
+
|
299 |
+
# Compute Backward Pass
|
300 |
+
loss.backward()
|
301 |
+
torch.nn.utils.clip_grad_norm_(self.net.parameters(), 1.)
|
302 |
+
self.opt.step()
|
303 |
+
self.opt.zero_grad()
|
304 |
+
self.scheduler.step()
|
305 |
+
|
306 |
+
# import pdb;pdb.set_trace()
|
307 |
+
if (i+1) % self.save_ckpt_every == 0:
|
308 |
+
print('saving iter ', i+1)
|
309 |
+
torch.save(self.net.state_dict(), self.ckpt_save_path + f'/{self.model_name}_{i+1}.pth')
|
310 |
+
|
311 |
+
pbar.set_description(
|
312 |
+
'Loss: {:.4f} \
|
313 |
+
loss_descs: {:.3f} acc_coarse: {:.3f} \
|
314 |
+
loss_coordinates: {:.3f} acc_coordinates: {:.3f} \
|
315 |
+
loss_fb_descs: {:.3f} acc_fb_coarse: {:.3f} \
|
316 |
+
loss_fb_coordinates: {:.3f} acc_fb_coordinates: {:.3f} \
|
317 |
+
loss_kpts: {:.3f} acc_kpts: {:.3f} \
|
318 |
+
loss_normals: {:.3f}'.format( \
|
319 |
+
loss.item(), \
|
320 |
+
loss_descs.item(), acc_coarse, \
|
321 |
+
loss_coordinates.item(), acc_coordinates, \
|
322 |
+
loss_fb_descs.item(), acc_fb_coarse, \
|
323 |
+
loss_fb_coordinates.item(), acc_fb_coordinates, \
|
324 |
+
loss_kpts.item(), acc_kpt, \
|
325 |
+
loss_normals.item()) )
|
326 |
+
pbar.update(1)
|
327 |
+
|
328 |
+
# Log metrics
|
329 |
+
self.writer.add_scalar('Loss/total', loss.item(), i)
|
330 |
+
self.writer.add_scalar('Accuracy/acc_coarse', acc_coarse, i)
|
331 |
+
self.writer.add_scalar('Accuracy/acc_coordinates', acc_coordinates, i)
|
332 |
+
self.writer.add_scalar('Accuracy/acc_fb_coarse', acc_fb_coarse, i)
|
333 |
+
self.writer.add_scalar('Accuracy/acc_fb_coordinates', acc_fb_coordinates, i)
|
334 |
+
self.writer.add_scalar('Loss/descs', loss_descs.item(), i)
|
335 |
+
self.writer.add_scalar('Loss/coordinates', loss_coordinates.item(), i)
|
336 |
+
self.writer.add_scalar('Loss/fb_descs', loss_fb_descs.item(), i)
|
337 |
+
self.writer.add_scalar('Loss/fb_coordinates', loss_fb_coordinates.item(), i)
|
338 |
+
self.writer.add_scalar('Loss/kpts', loss_kpts.item(), i)
|
339 |
+
self.writer.add_scalar('Loss/normals', loss_normals.item(), i)
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
if __name__ == '__main__':
|
344 |
+
|
345 |
+
setproctitle.setproctitle(args.name)
|
346 |
+
|
347 |
+
trainer = Trainer(
|
348 |
+
megadepth_root_path=args.megadepth_root_path,
|
349 |
+
use_megadepth=args.use_megadepth,
|
350 |
+
megadepth_batch_size=args.megadepth_batch_size,
|
351 |
+
coco_root_path=args.coco_root_path,
|
352 |
+
use_coco=args.use_coco,
|
353 |
+
coco_batch_size=args.coco_batch_size,
|
354 |
+
ckpt_save_path=args.ckpt_save_path,
|
355 |
+
n_steps=args.n_steps,
|
356 |
+
lr=args.lr,
|
357 |
+
gamma_steplr=args.gamma_steplr,
|
358 |
+
training_res=args.training_res,
|
359 |
+
device_num=args.device_num,
|
360 |
+
dry_run=args.dry_run,
|
361 |
+
save_ckpt_every=args.save_ckpt_every
|
362 |
+
)
|
363 |
+
|
364 |
+
#The most fun part
|
365 |
+
trainer.train()
|
imcui/third_party/LiftFeat/train.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# default training
|
2 |
+
nohup python /home/yepeng_liu/code_python/LiftFeat/train.py \
|
3 |
+
--name LiftFeat_test \
|
4 |
+
--ckpt_save_path /home/yepeng_liu/code_python/LiftFeat/trained_weights/test \
|
5 |
+
--device_num 1 \
|
6 |
+
--use_megadepth \
|
7 |
+
--megadepth_batch_size 8 \
|
8 |
+
--use_coco \
|
9 |
+
--coco_batch_size 4 \
|
10 |
+
--save_ckpt_every 1000 \
|
11 |
+
> /home/yepeng_liu/code_python/LiftFeat/trained_weights/test/training.log 2>&1 &
|
imcui/third_party/LiftFeat/utils/__init__.py
ADDED
File without changes
|
imcui/third_party/LiftFeat/utils/alike_wrapper.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
"LiftFeat: 3D Geometry-Aware Local Feature Matching"
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import sys
|
7 |
+
import os
|
8 |
+
|
9 |
+
ALIKE_PATH = '/home/yepeng_liu/code_python/multimodal_remote/ALIKE'
|
10 |
+
sys.path.append(ALIKE_PATH)
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from alike import ALike
|
15 |
+
import cv2
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
import pdb
|
19 |
+
|
20 |
+
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
21 |
+
|
22 |
+
configs = {
|
23 |
+
'alike-t': {'c1': 8, 'c2': 16, 'c3': 32, 'c4': 64, 'dim': 64, 'single_head': True, 'radius': 2,
|
24 |
+
'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-t.pth')},
|
25 |
+
'alike-s': {'c1': 8, 'c2': 16, 'c3': 48, 'c4': 96, 'dim': 96, 'single_head': True, 'radius': 2,
|
26 |
+
'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-s.pth')},
|
27 |
+
'alike-n': {'c1': 16, 'c2': 32, 'c3': 64, 'c4': 128, 'dim': 128, 'single_head': True, 'radius': 2,
|
28 |
+
'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-n.pth')},
|
29 |
+
'alike-l': {'c1': 32, 'c2': 64, 'c3': 128, 'c4': 128, 'dim': 128, 'single_head': False, 'radius': 2,
|
30 |
+
'model_path': os.path.join(ALIKE_PATH, 'models', 'alike-l.pth')},
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
class ALikeExtractor(nn.Module):
|
35 |
+
def __init__(self,model_type,device) -> None:
|
36 |
+
super().__init__()
|
37 |
+
self.net=ALike(**configs[model_type],device=device,top_k=4096,scores_th=0.1,n_limit=8000)
|
38 |
+
|
39 |
+
|
40 |
+
@torch.inference_mode()
|
41 |
+
def extract_alike_kpts(self,img):
|
42 |
+
pred0=self.net(img,sub_pixel=True)
|
43 |
+
return pred0['keypoints']
|
44 |
+
|
45 |
+
|
imcui/third_party/LiftFeat/utils/config.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
featureboost_config = {
|
6 |
+
"keypoint_dim": 65,
|
7 |
+
"keypoint_encoder": [128, 64, 64],
|
8 |
+
"normal_dim": 192,
|
9 |
+
"normal_encoder": [128, 64, 64],
|
10 |
+
"descriptor_encoder": [64, 64],
|
11 |
+
"descriptor_dim": 64,
|
12 |
+
"Attentional_layers": 3,
|
13 |
+
"last_activation": None,
|
14 |
+
"l2_normalization": None,
|
15 |
+
"output_dim": 64,
|
16 |
+
}
|
imcui/third_party/LiftFeat/utils/depth_anything_wrapper.py
ADDED
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import cv2
|
3 |
+
import glob
|
4 |
+
import matplotlib
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torchvision.transforms import Compose
|
11 |
+
import sys
|
12 |
+
|
13 |
+
sys.path.append("/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2")
|
14 |
+
from depth_anything_v2.dpt_opt import DepthAnythingV2
|
15 |
+
from depth_anything_v2.util.transform import Resize, NormalizeImage, PrepareForNet
|
16 |
+
|
17 |
+
import time
|
18 |
+
|
19 |
+
VITS_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vits.pth"
|
20 |
+
VITB_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitb.pth"
|
21 |
+
VITL_MODEL_PATH = "/home/yepeng_liu/code_python/third_repos/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth"
|
22 |
+
|
23 |
+
model_configs = {
|
24 |
+
"vits": {"encoder": "vits", "features": 64, "out_channels": [48, 96, 192, 384]},
|
25 |
+
"vitb": {
|
26 |
+
"encoder": "vitb",
|
27 |
+
"features": 128,
|
28 |
+
"out_channels": [96, 192, 384, 768],
|
29 |
+
},
|
30 |
+
"vitl": {
|
31 |
+
"encoder": "vitl",
|
32 |
+
"features": 256,
|
33 |
+
"out_channels": [256, 512, 1024, 1024],
|
34 |
+
},
|
35 |
+
"vitg": {
|
36 |
+
"encoder": "vitg",
|
37 |
+
"features": 384,
|
38 |
+
"out_channels": [1536, 1536, 1536, 1536],
|
39 |
+
},
|
40 |
+
}
|
41 |
+
|
42 |
+
class DepthAnythingExtractor(nn.Module):
|
43 |
+
def __init__(self, encoder_type, device, input_size, process_size=(608,800)):
|
44 |
+
super().__init__()
|
45 |
+
self.net = DepthAnythingV2(**model_configs[encoder_type])
|
46 |
+
self.device = device
|
47 |
+
if encoder_type == "vits":
|
48 |
+
print(f"loading {VITS_MODEL_PATH}")
|
49 |
+
self.net.load_state_dict(torch.load(VITS_MODEL_PATH, map_location="cpu"))
|
50 |
+
elif encoder_type == "vitb":
|
51 |
+
print(f"loading {VITB_MODEL_PATH}")
|
52 |
+
self.net.load_state_dict(torch.load(VITB_MODEL_PATH, map_location="cpu"))
|
53 |
+
elif encoder_type == "vitl":
|
54 |
+
print(f"loading {VITL_MODEL_PATH}")
|
55 |
+
self.net.load_state_dict(torch.load(VITL_MODEL_PATH, map_location="cpu"))
|
56 |
+
else:
|
57 |
+
raise RuntimeError("unsupport encoder type")
|
58 |
+
self.net.to(self.device).eval()
|
59 |
+
self.tranform = Compose([
|
60 |
+
Resize(
|
61 |
+
width=input_size,
|
62 |
+
height=input_size,
|
63 |
+
resize_target=False,
|
64 |
+
keep_aspect_ratio=True,
|
65 |
+
ensure_multiple_of=14,
|
66 |
+
resize_method='lower_bound',
|
67 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
68 |
+
),
|
69 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
70 |
+
PrepareForNet(),
|
71 |
+
])
|
72 |
+
self.process_size=process_size
|
73 |
+
self.input_size=input_size
|
74 |
+
|
75 |
+
@torch.inference_mode()
|
76 |
+
def infer_image(self,img):
|
77 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0
|
78 |
+
|
79 |
+
img = self.tranform({'image': img})['image']
|
80 |
+
|
81 |
+
img = torch.from_numpy(img).unsqueeze(0)
|
82 |
+
|
83 |
+
img = img.to(self.device)
|
84 |
+
|
85 |
+
with torch.no_grad():
|
86 |
+
depth = self.net.forward(img)
|
87 |
+
|
88 |
+
depth = F.interpolate(depth[:, None], self.process_size, mode="bilinear", align_corners=True)[0, 0]
|
89 |
+
|
90 |
+
return depth.cpu().numpy()
|
91 |
+
|
92 |
+
@torch.inference_mode()
|
93 |
+
def compute_normal_map_torch(self, depth_map, scale=1.0):
|
94 |
+
"""
|
95 |
+
通过深度图计算法向量 (PyTorch 实现)
|
96 |
+
|
97 |
+
参数:
|
98 |
+
depth_map (torch.Tensor): 深度图,形状为 (H, W)
|
99 |
+
scale (float): 深度值的比例因子,用于调整深度图中的梯度计算
|
100 |
+
|
101 |
+
返回:
|
102 |
+
torch.Tensor: 法向量图,形状为 (H, W, 3)
|
103 |
+
"""
|
104 |
+
if depth_map.ndim != 2:
|
105 |
+
raise ValueError("输入 depth_map 必须是二维张量。")
|
106 |
+
|
107 |
+
# 计算深度图的梯度
|
108 |
+
dzdx = torch.diff(depth_map, dim=1, append=depth_map[:, -1:]) * scale
|
109 |
+
dzdy = torch.diff(depth_map, dim=0, append=depth_map[-1:, :]) * scale
|
110 |
+
|
111 |
+
# 初始化法向量图
|
112 |
+
H, W = depth_map.shape
|
113 |
+
normal_map = torch.zeros((H, W, 3), dtype=depth_map.dtype, device=depth_map.device)
|
114 |
+
normal_map[:, :, 0] = -dzdx # x 分量
|
115 |
+
normal_map[:, :, 1] = -dzdy # y 分量
|
116 |
+
normal_map[:, :, 2] = 1.0 # z 分量
|
117 |
+
|
118 |
+
# 归一化法向量
|
119 |
+
norm = torch.linalg.norm(normal_map, dim=2, keepdim=True)
|
120 |
+
norm = torch.where(norm == 0, torch.tensor(1.0, device=depth_map.device), norm) # 避免除以零
|
121 |
+
normal_map /= norm
|
122 |
+
|
123 |
+
return normal_map
|
124 |
+
|
125 |
+
@torch.inference_mode()
|
126 |
+
def extract(self, img):
|
127 |
+
depth = self.infer_image(img)
|
128 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
129 |
+
depth_t=torch.from_numpy(depth).float().to(self.device)
|
130 |
+
normal_map = self.compute_normal_map_torch(depth_t,1.0)
|
131 |
+
return depth_t,normal_map
|
132 |
+
|
133 |
+
|
134 |
+
if __name__=="__main__":
|
135 |
+
img_path=os.path.join(os.path.dirname(__file__),'../assert/ref.jpg')
|
136 |
+
img=cv2.imread(img_path)
|
137 |
+
img=cv2.resize(img,(800,608))
|
138 |
+
import pdb;pdb.set_trace()
|
139 |
+
DAExtractor=DepthAnythingExtractor('vitb',torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu'),256)
|
140 |
+
depth_t,norm=DAExtractor.extract(img)
|
141 |
+
norm=norm.cpu().numpy()
|
142 |
+
norm=(norm+1)/2*255
|
143 |
+
norm=norm.astype(np.uint8)
|
144 |
+
cv2.imwrite(os.path.join(os.path.dirname(__file__),"norm.png"),norm)
|
145 |
+
start=time.perf_counter()
|
146 |
+
for i in range(20):
|
147 |
+
depth_t,norm=DAExtractor.extract(img)
|
148 |
+
end=time.perf_counter()
|
149 |
+
print(f"cost {end-start} seconds")
|
150 |
+
|
imcui/third_party/LiftFeat/utils/featurebooster.py
ADDED
@@ -0,0 +1,247 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def MLP(channels: List[int], do_bn: bool = False) -> nn.Module:
|
9 |
+
""" Multi-layer perceptron """
|
10 |
+
n = len(channels)
|
11 |
+
layers = []
|
12 |
+
for i in range(1, n):
|
13 |
+
layers.append(nn.Linear(channels[i - 1], channels[i]))
|
14 |
+
if i < (n-1):
|
15 |
+
if do_bn:
|
16 |
+
layers.append(nn.BatchNorm1d(channels[i]))
|
17 |
+
layers.append(nn.ReLU())
|
18 |
+
return nn.Sequential(*layers)
|
19 |
+
|
20 |
+
def MLP_no_ReLU(channels: List[int], do_bn: bool = False) -> nn.Module:
|
21 |
+
""" Multi-layer perceptron """
|
22 |
+
n = len(channels)
|
23 |
+
layers = []
|
24 |
+
for i in range(1, n):
|
25 |
+
layers.append(nn.Linear(channels[i - 1], channels[i]))
|
26 |
+
if i < (n-1):
|
27 |
+
if do_bn:
|
28 |
+
layers.append(nn.BatchNorm1d(channels[i]))
|
29 |
+
return nn.Sequential(*layers)
|
30 |
+
|
31 |
+
|
32 |
+
class KeypointEncoder(nn.Module):
|
33 |
+
""" Encoding of geometric properties using MLP """
|
34 |
+
def __init__(self, keypoint_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
|
35 |
+
super().__init__()
|
36 |
+
self.encoder = MLP([keypoint_dim] + layers + [feature_dim])
|
37 |
+
self.use_dropout = dropout
|
38 |
+
self.dropout = nn.Dropout(p=p)
|
39 |
+
|
40 |
+
def forward(self, kpts):
|
41 |
+
if self.use_dropout:
|
42 |
+
return self.dropout(self.encoder(kpts))
|
43 |
+
return self.encoder(kpts)
|
44 |
+
|
45 |
+
class NormalEncoder(nn.Module):
|
46 |
+
""" Encoding of geometric properties using MLP """
|
47 |
+
def __init__(self, normal_dim: int, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
|
48 |
+
super().__init__()
|
49 |
+
self.encoder = MLP_no_ReLU([normal_dim] + layers + [feature_dim])
|
50 |
+
self.use_dropout = dropout
|
51 |
+
self.dropout = nn.Dropout(p=p)
|
52 |
+
|
53 |
+
def forward(self, kpts):
|
54 |
+
if self.use_dropout:
|
55 |
+
return self.dropout(self.encoder(kpts))
|
56 |
+
return self.encoder(kpts)
|
57 |
+
|
58 |
+
|
59 |
+
class DescriptorEncoder(nn.Module):
|
60 |
+
""" Encoding of visual descriptor using MLP """
|
61 |
+
def __init__(self, feature_dim: int, layers: List[int], dropout: bool = False, p: float = 0.1) -> None:
|
62 |
+
super().__init__()
|
63 |
+
self.encoder = MLP([feature_dim] + layers + [feature_dim])
|
64 |
+
self.use_dropout = dropout
|
65 |
+
self.dropout = nn.Dropout(p=p)
|
66 |
+
|
67 |
+
def forward(self, descs):
|
68 |
+
residual = descs
|
69 |
+
if self.use_dropout:
|
70 |
+
return residual + self.dropout(self.encoder(descs))
|
71 |
+
return residual + self.encoder(descs)
|
72 |
+
|
73 |
+
|
74 |
+
class AFTAttention(nn.Module):
|
75 |
+
""" Attention-free attention """
|
76 |
+
def __init__(self, d_model: int, dropout: bool = False, p: float = 0.1) -> None:
|
77 |
+
super().__init__()
|
78 |
+
self.dim = d_model
|
79 |
+
self.query = nn.Linear(d_model, d_model)
|
80 |
+
self.key = nn.Linear(d_model, d_model)
|
81 |
+
self.value = nn.Linear(d_model, d_model)
|
82 |
+
self.proj = nn.Linear(d_model, d_model)
|
83 |
+
# self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
|
84 |
+
self.use_dropout = dropout
|
85 |
+
self.dropout = nn.Dropout(p=p)
|
86 |
+
|
87 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
88 |
+
residual = x
|
89 |
+
q = self.query(x)
|
90 |
+
k = self.key(x)
|
91 |
+
v = self.value(x)
|
92 |
+
# q = torch.sigmoid(q)
|
93 |
+
k = k.T
|
94 |
+
k = torch.softmax(k, dim=-1)
|
95 |
+
k = k.T
|
96 |
+
kv = (k * v).sum(dim=-2, keepdim=True)
|
97 |
+
x = q * kv
|
98 |
+
x = self.proj(x)
|
99 |
+
if self.use_dropout:
|
100 |
+
x = self.dropout(x)
|
101 |
+
x += residual
|
102 |
+
# x = self.layer_norm(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class PositionwiseFeedForward(nn.Module):
|
107 |
+
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1) -> None:
|
108 |
+
super().__init__()
|
109 |
+
self.mlp = MLP([feature_dim, feature_dim*2, feature_dim])
|
110 |
+
# self.layer_norm = nn.LayerNorm(feature_dim, eps=1e-6)
|
111 |
+
self.use_dropout = dropout
|
112 |
+
self.dropout = nn.Dropout(p=p)
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
115 |
+
residual = x
|
116 |
+
x = self.mlp(x)
|
117 |
+
if self.use_dropout:
|
118 |
+
x = self.dropout(x)
|
119 |
+
x += residual
|
120 |
+
# x = self.layer_norm(x)
|
121 |
+
return x
|
122 |
+
|
123 |
+
|
124 |
+
class AttentionalLayer(nn.Module):
|
125 |
+
def __init__(self, feature_dim: int, dropout: bool = False, p: float = 0.1):
|
126 |
+
super().__init__()
|
127 |
+
self.attn = AFTAttention(feature_dim, dropout=dropout, p=p)
|
128 |
+
self.ffn = PositionwiseFeedForward(feature_dim, dropout=dropout, p=p)
|
129 |
+
|
130 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131 |
+
# import pdb;pdb.set_trace()
|
132 |
+
x = self.attn(x)
|
133 |
+
x = self.ffn(x)
|
134 |
+
return x
|
135 |
+
|
136 |
+
|
137 |
+
class AttentionalNN(nn.Module):
|
138 |
+
def __init__(self, feature_dim: int, layer_num: int, dropout: bool = False, p: float = 0.1) -> None:
|
139 |
+
super().__init__()
|
140 |
+
self.layers = nn.ModuleList([
|
141 |
+
AttentionalLayer(feature_dim, dropout=dropout, p=p)
|
142 |
+
for _ in range(layer_num)])
|
143 |
+
|
144 |
+
def forward(self, desc: torch.Tensor) -> torch.Tensor:
|
145 |
+
for layer in self.layers:
|
146 |
+
desc = layer(desc)
|
147 |
+
return desc
|
148 |
+
|
149 |
+
|
150 |
+
class FeatureBooster(nn.Module):
|
151 |
+
default_config = {
|
152 |
+
'descriptor_dim': 128,
|
153 |
+
'keypoint_encoder': [32, 64, 128],
|
154 |
+
'Attentional_layers': 3,
|
155 |
+
'last_activation': 'relu',
|
156 |
+
'l2_normalization': True,
|
157 |
+
'output_dim': 128
|
158 |
+
}
|
159 |
+
|
160 |
+
def __init__(self, config, dropout=False, p=0.1, use_kenc=True, use_normal=True, use_cross=True):
|
161 |
+
super().__init__()
|
162 |
+
self.config = {**self.default_config, **config}
|
163 |
+
self.use_kenc = use_kenc
|
164 |
+
self.use_cross = use_cross
|
165 |
+
self.use_normal = use_normal
|
166 |
+
|
167 |
+
if use_kenc:
|
168 |
+
self.kenc = KeypointEncoder(self.config['keypoint_dim'], self.config['descriptor_dim'], self.config['keypoint_encoder'], dropout=dropout)
|
169 |
+
|
170 |
+
if use_normal:
|
171 |
+
self.nenc = NormalEncoder(self.config['normal_dim'], self.config['descriptor_dim'], self.config['normal_encoder'], dropout=dropout)
|
172 |
+
|
173 |
+
if self.config.get('descriptor_encoder', False):
|
174 |
+
self.denc = DescriptorEncoder(self.config['descriptor_dim'], self.config['descriptor_encoder'], dropout=dropout)
|
175 |
+
else:
|
176 |
+
self.denc = None
|
177 |
+
|
178 |
+
if self.use_cross:
|
179 |
+
self.attn_proj = AttentionalNN(feature_dim=self.config['descriptor_dim'], layer_num=self.config['Attentional_layers'], dropout=dropout)
|
180 |
+
|
181 |
+
# self.final_proj = nn.Linear(self.config['descriptor_dim'], self.config['output_dim'])
|
182 |
+
|
183 |
+
self.use_dropout = dropout
|
184 |
+
self.dropout = nn.Dropout(p=p)
|
185 |
+
|
186 |
+
# self.layer_norm = nn.LayerNorm(self.config['descriptor_dim'], eps=1e-6)
|
187 |
+
|
188 |
+
if self.config.get('last_activation', False):
|
189 |
+
if self.config['last_activation'].lower() == 'relu':
|
190 |
+
self.last_activation = nn.ReLU()
|
191 |
+
elif self.config['last_activation'].lower() == 'sigmoid':
|
192 |
+
self.last_activation = nn.Sigmoid()
|
193 |
+
elif self.config['last_activation'].lower() == 'tanh':
|
194 |
+
self.last_activation = nn.Tanh()
|
195 |
+
else:
|
196 |
+
raise Exception('Not supported activation "%s".' % self.config['last_activation'])
|
197 |
+
else:
|
198 |
+
self.last_activation = None
|
199 |
+
|
200 |
+
def forward(self, desc, kpts, normals):
|
201 |
+
# import pdb;pdb.set_trace()
|
202 |
+
## Self boosting
|
203 |
+
# Descriptor MLP encoder
|
204 |
+
if self.denc is not None:
|
205 |
+
desc = self.denc(desc)
|
206 |
+
# Geometric MLP encoder
|
207 |
+
if self.use_kenc:
|
208 |
+
desc = desc + self.kenc(kpts)
|
209 |
+
if self.use_dropout:
|
210 |
+
desc = self.dropout(desc)
|
211 |
+
|
212 |
+
# 法向量特征 encoder
|
213 |
+
if self.use_normal:
|
214 |
+
desc = desc + self.nenc(normals)
|
215 |
+
if self.use_dropout:
|
216 |
+
desc = self.dropout(desc)
|
217 |
+
|
218 |
+
## Cross boosting
|
219 |
+
# Multi-layer Transformer network.
|
220 |
+
if self.use_cross:
|
221 |
+
# desc = self.attn_proj(self.layer_norm(desc))
|
222 |
+
desc = self.attn_proj(desc)
|
223 |
+
|
224 |
+
## Post processing
|
225 |
+
# Final MLP projection
|
226 |
+
# desc = self.final_proj(desc)
|
227 |
+
if self.last_activation is not None:
|
228 |
+
desc = self.last_activation(desc)
|
229 |
+
# L2 normalization
|
230 |
+
if self.config['l2_normalization']:
|
231 |
+
desc = F.normalize(desc, dim=-1)
|
232 |
+
|
233 |
+
return desc
|
234 |
+
|
235 |
+
if __name__ == "__main__":
|
236 |
+
from config import t1_featureboost_config
|
237 |
+
fb_net = FeatureBooster(t1_featureboost_config)
|
238 |
+
|
239 |
+
descs=torch.randn([1900,64])
|
240 |
+
kpts=torch.randn([1900,65])
|
241 |
+
normals=torch.randn([1900,3])
|
242 |
+
|
243 |
+
import pdb;pdb.set_trace()
|
244 |
+
|
245 |
+
descs_refine=fb_net(descs,kpts,normals)
|
246 |
+
|
247 |
+
print(descs_refine.shape)
|
imcui/third_party/LiftFeat/weights/LiftFeat.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0da33b2216bde964989f3d13e9b9c9cbdd65c98fb05fb4d4771b7d2f3a807c8b
|
3 |
+
size 8086947
|