Realcat commited on
Commit
e6ac593
·
1 Parent(s): 1b369eb
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +2 -1
  2. config/config.yaml +13 -2
  3. imcui/hloc/extract_features.py +11 -0
  4. imcui/hloc/extractors/ripe.py +46 -0
  5. imcui/third_party/RIPE/.gitignore +179 -0
  6. imcui/third_party/RIPE/LICENSE +35 -0
  7. imcui/third_party/RIPE/LICENSE_DALF_DISK +201 -0
  8. imcui/third_party/RIPE/README.md +367 -0
  9. imcui/third_party/RIPE/app.py +272 -0
  10. imcui/third_party/RIPE/assets/all_souls_000013.jpg +3 -0
  11. imcui/third_party/RIPE/assets/all_souls_000055.jpg +3 -0
  12. imcui/third_party/RIPE/assets/teaser_image.png +3 -0
  13. imcui/third_party/RIPE/conda_env.yml +26 -0
  14. imcui/third_party/RIPE/conf/backbones/resnet.yaml +6 -0
  15. imcui/third_party/RIPE/conf/backbones/vgg.yaml +5 -0
  16. imcui/third_party/RIPE/conf/data/disk_megadepth.yaml +12 -0
  17. imcui/third_party/RIPE/conf/data/megadepth+acdc.yaml +33 -0
  18. imcui/third_party/RIPE/conf/data/megadepth+tokyo.yaml +29 -0
  19. imcui/third_party/RIPE/conf/descriptor_loss/contrastive_loss.yaml +3 -0
  20. imcui/third_party/RIPE/conf/inl_th/constant.yaml +2 -0
  21. imcui/third_party/RIPE/conf/inl_th/exp_decay.yaml +4 -0
  22. imcui/third_party/RIPE/conf/matcher/concurrent_mnn_poselib.yaml +8 -0
  23. imcui/third_party/RIPE/conf/train.yaml +89 -0
  24. imcui/third_party/RIPE/conf/upsampler/hypercolumn_features.yaml +2 -0
  25. imcui/third_party/RIPE/conf/upsampler/interpolate_sparse2D.yaml +1 -0
  26. imcui/third_party/RIPE/data/download_disk_data.sh +43 -0
  27. imcui/third_party/RIPE/demo.py +51 -0
  28. imcui/third_party/RIPE/ripe/__init__.py +1 -0
  29. imcui/third_party/RIPE/ripe/benchmarks/imw_2020.py +320 -0
  30. imcui/third_party/RIPE/ripe/data/__init__.py +0 -0
  31. imcui/third_party/RIPE/ripe/data/data_transforms.py +204 -0
  32. imcui/third_party/RIPE/ripe/data/datasets/__init__.py +0 -0
  33. imcui/third_party/RIPE/ripe/data/datasets/acdc.py +154 -0
  34. imcui/third_party/RIPE/ripe/data/datasets/dataset_combinator.py +88 -0
  35. imcui/third_party/RIPE/ripe/data/datasets/disk_imw.py +160 -0
  36. imcui/third_party/RIPE/ripe/data/datasets/disk_megadepth.py +157 -0
  37. imcui/third_party/RIPE/ripe/data/datasets/tokyo247.py +134 -0
  38. imcui/third_party/RIPE/ripe/losses/__init__.py +0 -0
  39. imcui/third_party/RIPE/ripe/losses/contrastive_loss.py +88 -0
  40. imcui/third_party/RIPE/ripe/matcher/__init__.py +0 -0
  41. imcui/third_party/RIPE/ripe/matcher/concurrent_matcher.py +97 -0
  42. imcui/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py +31 -0
  43. imcui/third_party/RIPE/ripe/model_zoo/__init__.py +1 -0
  44. imcui/third_party/RIPE/ripe/model_zoo/vgg_hyper.py +39 -0
  45. imcui/third_party/RIPE/ripe/models/__init__.py +0 -0
  46. imcui/third_party/RIPE/ripe/models/backbones/__init__.py +0 -0
  47. imcui/third_party/RIPE/ripe/models/backbones/backbone_base.py +61 -0
  48. imcui/third_party/RIPE/ripe/models/backbones/vgg.py +99 -0
  49. imcui/third_party/RIPE/ripe/models/backbones/vgg_utils.py +143 -0
  50. imcui/third_party/RIPE/ripe/models/ripe.py +303 -0
README.md CHANGED
@@ -44,8 +44,9 @@ The tool currently supports various popular image matching algorithms, namely:
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
- | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
48
  | RDD | ✅ | CVPR | 2025 | [Link](https://github.com/xtcpete/rdd) |
 
49
  | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
50
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
51
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
 
44
 
45
  | Algorithm | Supported | Conference/Journal | Year | GitHub Link |
46
  |------------------|-----------|--------------------|------|-------------|
47
+ | RIPE | ✅ | ICCV | 2025 | [Link](https://github.com/fraunhoferhhi/RIPE) |
48
  | RDD | ✅ | CVPR | 2025 | [Link](https://github.com/xtcpete/rdd) |
49
+ | LiftFeat | ✅ | ICRA | 2025 | [Link](https://github.com/lyp-deeplearning/LiftFeat) |
50
  | DaD | ✅ | ARXIV | 2025 | [Link](https://github.com/Parskatt/dad) |
51
  | MINIMA | ✅ | ARXIV | 2024 | [Link](https://github.com/LSXI7/MINIMA) |
52
  | XoFTR | ✅ | CVPR | 2024 | [Link](https://github.com/OnderT/XoFTR) |
config/config.yaml CHANGED
@@ -267,6 +267,17 @@ matcher_zoo:
267
  paper: https://arxiv.org/abs/2505.0342
268
  project: null
269
  display: true
 
 
 
 
 
 
 
 
 
 
 
270
  rdd(sparse):
271
  matcher: NN-mutual
272
  feature: rdd
@@ -274,7 +285,7 @@ matcher_zoo:
274
  info:
275
  name: RDD(sparse) #dispaly name
276
  source: "CVPR 2025"
277
- github: hhttps://github.com/xtcpete/rdd
278
  paper: https://arxiv.org/abs/2505.08013
279
  project: https://xtcpete.github.io/rdd
280
  display: true
@@ -284,7 +295,7 @@ matcher_zoo:
284
  info:
285
  name: RDD(dense) #dispaly name
286
  source: "CVPR 2025"
287
- github: hhttps://github.com/xtcpete/rdd
288
  paper: https://arxiv.org/abs/2505.08013
289
  project: https://xtcpete.github.io/rdd
290
  display: true
 
267
  paper: https://arxiv.org/abs/2505.0342
268
  project: null
269
  display: true
270
+ ripe(+mnn):
271
+ matcher: NN-mutual
272
+ feature: ripe
273
+ dense: false
274
+ info:
275
+ name: RIPE #dispaly name
276
+ source: "ICCV 2025"
277
+ github: https://github.com/fraunhoferhhi/RIPE
278
+ paper: https://arxiv.org/abs/2507.04839
279
+ project: https://fraunhoferhhi.github.io/RIPE
280
+ display: true
281
  rdd(sparse):
282
  matcher: NN-mutual
283
  feature: rdd
 
285
  info:
286
  name: RDD(sparse) #dispaly name
287
  source: "CVPR 2025"
288
+ github: https://github.com/xtcpete/rdd
289
  paper: https://arxiv.org/abs/2505.08013
290
  project: https://xtcpete.github.io/rdd
291
  display: true
 
295
  info:
296
  name: RDD(dense) #dispaly name
297
  source: "CVPR 2025"
298
+ github: https://github.com/xtcpete/rdd
299
  paper: https://arxiv.org/abs/2505.08013
300
  project: https://xtcpete.github.io/rdd
301
  display: true
imcui/hloc/extract_features.py CHANGED
@@ -236,6 +236,17 @@ confs = {
236
  "resize_max": 1600,
237
  },
238
  },
 
 
 
 
 
 
 
 
 
 
 
239
  "aliked-n16-rot": {
240
  "output": "feats-aliked-n16-rot",
241
  "model": {
 
236
  "resize_max": 1600,
237
  },
238
  },
239
+ "ripe": {
240
+ "output": "feats-ripe-n2048-r1600",
241
+ "model": {
242
+ "name": "ripe",
243
+ "max_keypoints": 2048,
244
+ },
245
+ "preprocessing": {
246
+ "grayscale": False,
247
+ "resize_max": 1600,
248
+ },
249
+ },
250
  "aliked-n16-rot": {
251
  "output": "feats-aliked-n16-rot",
252
  "model": {
imcui/hloc/extractors/ripe.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from pathlib import Path
3
+ from ..utils.base_model import BaseModel
4
+ from .. import logger, MODEL_REPO_ID
5
+
6
+ ripe_path = Path(__file__).parent / "../../third_party/RIPE"
7
+ sys.path.append(str(ripe_path))
8
+
9
+ from ripe import vgg_hyper
10
+
11
+
12
+ class RIPE(BaseModel):
13
+ default_conf = {
14
+ "keypoint_threshold": 0.05,
15
+ "max_keypoints": 5000,
16
+ "model_name": "weights_ripe.pth",
17
+ }
18
+
19
+ required_inputs = ["image"]
20
+
21
+ def _init(self, conf):
22
+ logger.info("Loading RIPE model...")
23
+ model_path = self._download_model(
24
+ repo_id=MODEL_REPO_ID,
25
+ filename="{}/{}".format(Path(__file__).stem, self.conf["model_name"]),
26
+ )
27
+ self.net = vgg_hyper(Path(model_path))
28
+ logger.info("Loading RIPE model done!")
29
+
30
+ def _forward(self, data):
31
+ keypoints, descriptors, scores = self.net.detectAndCompute(
32
+ data["image"], threshold=0.5, top_k=2048
33
+ )
34
+
35
+ if self.conf["max_keypoints"] < len(keypoints):
36
+ idxs = scores.argsort()[-self.conf["max_keypoints"] or None :]
37
+ keypoints = keypoints[idxs, :2]
38
+ descriptors = descriptors[idxs]
39
+ scores = scores[idxs]
40
+
41
+ pred = {
42
+ "keypoints": keypoints[None],
43
+ "descriptors": descriptors[None].permute(0, 2, 1),
44
+ "scores": scores[None],
45
+ }
46
+ return pred
imcui/third_party/RIPE/.gitignore ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ pip-wheel-metadata/
24
+ share/python-wheels/
25
+ *.egg-info/
26
+ .installed.cfg
27
+ *.egg
28
+ MANIFEST
29
+
30
+ # PyInstaller
31
+ # Usually these files are written by a python script from a template
32
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
33
+ *.manifest
34
+ *.spec
35
+
36
+ # Installer logs
37
+ pip-log.txt
38
+ pip-delete-this-directory.txt
39
+
40
+ # Unit test / coverage reports
41
+ htmlcov/
42
+ .tox/
43
+ .nox/
44
+ .coverage
45
+ .coverage.*
46
+ .cache
47
+ nosetests.xml
48
+ coverage.xml
49
+ *.cover
50
+ *.py,cover
51
+ .hypothesis/
52
+ .pytest_cache/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ target/
76
+
77
+ # Jupyter Notebook
78
+ .ipynb_checkpoints
79
+
80
+ # IPython
81
+ profile_default/
82
+ ipython_config.py
83
+
84
+ # pyenv
85
+ .python-version
86
+
87
+ # pipenv
88
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
89
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
90
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
91
+ # install all needed dependencies.
92
+ #Pipfile.lock
93
+
94
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
95
+ __pypackages__/
96
+
97
+ # Celery stuff
98
+ celerybeat-schedule
99
+ celerybeat.pid
100
+
101
+ # SageMath parsed files
102
+ *.sage.py
103
+
104
+ # Environments
105
+ .venv
106
+ env/
107
+ venv/
108
+ ENV/
109
+ env.bak/
110
+ venv.bak/
111
+
112
+ # Spyder project settings
113
+ .spyderproject
114
+ .spyproject
115
+
116
+ # Rope project settings
117
+ .ropeproject
118
+
119
+ # mkdocs documentation
120
+ /site
121
+
122
+ # mypy
123
+ .mypy_cache/
124
+ .dmypy.json
125
+ dmypy.json
126
+
127
+ # Pyre type checker
128
+ .pyre/
129
+
130
+ ### VisualStudioCode
131
+ .vscode/*
132
+ !.vscode/settings.json
133
+ !.vscode/tasks.json
134
+ !.vscode/launch.json
135
+ !.vscode/extensions.json
136
+ *.code-workspace
137
+ **/.vscode
138
+
139
+ # JetBrains
140
+ .idea/
141
+
142
+ # ignore outputs
143
+ /outputs/
144
+
145
+ # ignore logs
146
+ /logs/
147
+ tmp.py
148
+ .env
149
+
150
+ # ignore pretrained pytorch models
151
+ *.pth
152
+
153
+ # ignore lightning_logs
154
+ /lightning_logs/*
155
+
156
+ # ignore built apptainer images
157
+ *.sif
158
+
159
+ # ignore the outputs server on the cluster
160
+ /output/*
161
+ # ignore .out files generated from the cluster
162
+ *.out
163
+ # ignore hparams_search folder
164
+ /hparams_search_configs/*
165
+
166
+ *.o
167
+ *.pkl
168
+ *.ninja_deps
169
+ *.ninja_log
170
+ *.ninja
171
+
172
+ /misc/*
173
+ /tmp/*
174
+ /apptainer_env.box/*
175
+ /scripts/tmp_build/*
176
+ /checkpoints
177
+ /pretrained_weights
178
+ /results_supple_cvpr
179
+ /ext_files
imcui/third_party/RIPE/LICENSE ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Software Copyright License for Academic Use of RIPE, Version 2.0
2
+
3
+ © Copyright (2025) Fraunhofer-Gesellschaft zur Förderung der angewandten Forschung e.V.
4
+
5
+ 1. INTRODUCTION
6
+
7
+ RIPE which means any source code, object code or binary files provided by Fraunhofer excluding third party software and materials, is made available under this Software Copyright License.
8
+
9
+ 2. COPYRIGHT LICENSE
10
+
11
+ Internal use of RIPE, in source and binary forms, with or without modification, is permitted without payment of copyright license fees for non-commercial purposes of evaluation, testing and academic research.
12
+
13
+ No right or license, express or implied, is granted to any part of RIPE except and solely to the extent as expressly set forth herein. Any commercial use or exploitation of RIPE and/or any modifications thereto under this license are prohibited.
14
+
15
+ For any other use of RIPE than permitted by this software copyright license You need another license from Fraunhofer. In such case please contact Fraunhofer under the CONTACT INFORMATION below.
16
+
17
+ 3. LIMITED PATENT LICENSE
18
+
19
+ If Fraunhofer patents are implemented by RIPE their use is permitted for internal non-commercial purposes of evaluation, testing and academic research. No patent grant is provided for any other use, including but not limited to commercial use or exploitation.
20
+
21
+ Fraunhofer provides no warranty of patent non-infringement with respect to RIPE.
22
+
23
+ 4. PLACE OF JURISDICTION
24
+
25
+ German law shall apply to all disputes arising from the use of the licensed software. A court in Munich shall have local jurisdiction.
26
+
27
+ 5. DISCLAIMER
28
+
29
+ RIPE is provided by Fraunhofer "AS IS" and WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES, including but not limited to the implied warranties of fitness for a particular purpose. IN NO EVENT SHALL FRAUNHOFER BE LIABLE for any direct, indirect, incidental, special, exemplary, or consequential damages, including but not limited to procurement of substitute goods or services; loss of use, data, or profits, or business interruption, however caused and on any theory of liability, whether in contract, strict liability, or tort (including negligence), arising in any way out of the use of the Fraunhofer Software, even if advised of the possibility of such damage.
30
+
31
+ 6. CONTACT INFORMATION
32
+
33
+ Fraunhofer-Institut für Nachrichtentechnik, Heinrich-Hertz-Institut, HHI
34
+ Einsteinufer 37, 10587 Berlin, Germany
35
imcui/third_party/RIPE/LICENSE_DALF_DISK ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
imcui/third_party/RIPE/README.md ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ <p align="center">
3
+ <h1 align="center"> <ins>RIPE</ins>:<br> Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction <br><br>🌊🌺 ICCV 2025 🌺🌊</h1>
4
+ <p align="center">
5
+ <a href="https://scholar.google.com/citations?user=ybMR38kAAAAJ">Johannes Künzel</a>
6
+ ·
7
+ <a href="https://scholar.google.com/citations?user=5yTuyGIAAAAJ">Anna Hilsmann</a>
8
+ ·
9
+ <a href="https://scholar.google.com/citations?user=BCElyCkAAAAJ">Peter Eisert</a>
10
+ </p>
11
+ <h2 align="center"><p>
12
+ <a href="https://arxiv.org/abs/2507.04839" align="center">Arxiv</a> |
13
+ <a href="https://fraunhoferhhi.github.io/RIPE/" align="center">Project Page</a> |
14
+ <a href="https://huggingface.co/spaces/JohannesK14/RIPE" align="center">🤗Demo🤗</a>
15
+ </p></h2>
16
+ <div align="center"></div>
17
+ </p>
18
+ <br/>
19
+ <p align="center">
20
+ <img src="assets/teaser_image.png" alt="example" width=80%>
21
+ <br>
22
+ <em>RIPE demonstrates that keypoint detection and description can be learned from image pairs only - no depth, no pose, no artificial augmentation required.</em>
23
+ </p>
24
+
25
+ ## Setup
26
+
27
+ 💡**Alternative**💡 Install nothing locally and try our Hugging Face demo: [🤗Demo🤗](https://huggingface.co/spaces/JohannesK14/RIPE)
28
+
29
+ 1. Install mamba by following the instructions given here: [Mamba Installation](https://mamba.readthedocs.io/en/latest/installation/mamba-installation.html)
30
+
31
+ 2. Create a new environment with:
32
+ ```bash
33
+ mamba create -f conda_env.yml
34
+ mamba activate ripe-env
35
+ ```
36
+
37
+ ## How to use
38
+
39
+ Or just check [demo.py](demo.py)
40
+
41
+ ```python
42
+ import cv2
43
+ import kornia.feature as KF
44
+ import kornia.geometry as KG
45
+ import matplotlib.pyplot as plt
46
+ import numpy as np
47
+ import torch
48
+ from torchvision.io import decode_image
49
+
50
+ from ripe import vgg_hyper
51
+ from ripe.utils.utils import cv2_matches_from_kornia, resize_image, to_cv_kpts
52
+
53
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
54
+
55
+ model = vgg_hyper().to(dev)
56
+ model.eval()
57
+
58
+ image1 = resize_image(decode_image("assets/all_souls_000013.jpg").float().to(dev) / 255.0)
59
+ image2 = resize_image(decode_image("assets/all_souls_000055.jpg").float().to(dev) / 255.0)
60
+
61
+ kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
62
+ kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
63
+
64
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
65
+ match_dists, match_idxs = matcher(desc_1, desc_2)
66
+
67
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
68
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
69
+
70
+ H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=1.0)(matched_pts_1, matched_pts_2)
71
+ matchesMask = mask.int().ravel().tolist()
72
+
73
+ result_ransac = cv2.drawMatches(
74
+ (image1.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
75
+ to_cv_kpts(kpts_1, score_1),
76
+ (image2.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
77
+ to_cv_kpts(kpts_2, score_2),
78
+ cv2_matches_from_kornia(match_dists, match_idxs),
79
+ None,
80
+ matchColor=(0, 255, 0),
81
+ matchesMask=matchesMask,
82
+ # matchesMask=None, # without RANSAC filtering
83
+ singlePointColor=(0, 0, 255),
84
+ flags=cv2.DrawMatchesFlags_DEFAULT,
85
+ )
86
+
87
+ plt.imshow(result_ransac)
88
+ plt.axis("off")
89
+ plt.tight_layout()
90
+
91
+ plt.show()
92
+ # plt.savefig("result_ransac.png")
93
+ ```
94
+
95
+ ## Reproduce the results
96
+
97
+ ### MegaDepth 1500 & HPatches
98
+
99
+ 1. Download and install [Glue Factory](https://github.com/cvg/glue-factory)
100
+ 2. Add this repo as a submodule to Glue Factory:
101
+ ```bash
102
+ cd glue-factory
103
+ git submodule add https://github.com/fraunhoferhhi/RIPE.git thirdparty/ripe
104
+ ```
105
+ 3. Create the new file ripe.py under gluefactory/models/extractors/ with the following content:
106
+
107
+ <details>
108
+ <summary>ripe.py</summary>
109
+
110
+ ```python
111
+ import sys
112
+ from pathlib import Path
113
+
114
+ import torch
115
+ import torchvision.transforms as transforms
116
+
117
+ from ..base_model import BaseModel
118
+
119
+ ripe_path = Path(__file__).parent / "../../../thirdparty/ripe"
120
+
121
+ print(f"RIPE Path: {ripe_path.resolve()}")
122
+ # check if the path exists
123
+ if not ripe_path.exists():
124
+ raise RuntimeError(f"RIPE path not found: {ripe_path}")
125
+
126
+ sys.path.append(str(ripe_path))
127
+
128
+ from ripe import vgg_hyper
129
+
130
+
131
+ class RIPE(BaseModel):
132
+ default_conf = {
133
+ "name": "RIPE",
134
+ "model_path": None,
135
+ "chunk": 4,
136
+ "dense_outputs": False,
137
+ "threshold": 1.0,
138
+ "top_k": 2048,
139
+ }
140
+
141
+ required_data_keys = ["image"]
142
+
143
+ # Initialize the line matcher
144
+ def _init(self, conf):
145
+ self.normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
146
+ self.model = vgg_hyper(model_path=conf.model_path)
147
+ self.model.eval()
148
+
149
+ self.set_initialized()
150
+
151
+ def _forward(self, data):
152
+ image = data["image"]
153
+
154
+ keypoints, scores, descriptors = [], [], []
155
+
156
+ chunk = self.conf.chunk
157
+
158
+ for i in range(0, image.shape[0], chunk):
159
+ if self.conf.dense_outputs:
160
+ raise NotImplementedError("Dense outputs are not supported")
161
+ else:
162
+ im = image[: min(image.shape[0], i + chunk)]
163
+ im = self.normalizer(im)
164
+
165
+ H, W = im.shape[-2:]
166
+
167
+ kpt, desc, score = self.model.detectAndCompute(
168
+ im,
169
+ threshold=self.conf.threshold,
170
+ top_k=self.conf.top_k,
171
+ )
172
+ keypoints += [kpt.squeeze(0)]
173
+ scores += [score.squeeze(0)]
174
+ descriptors += [desc.squeeze(0)]
175
+
176
+ del kpt
177
+ del desc
178
+ del score
179
+
180
+ keypoints = torch.stack(keypoints, 0)
181
+ scores = torch.stack(scores, 0)
182
+ descriptors = torch.stack(descriptors, 0)
183
+
184
+ pred = {
185
+ # "keypoints": keypoints.to(image) + 0.5,
186
+ "keypoints": keypoints.to(image),
187
+ "keypoint_scores": scores.to(image),
188
+ "descriptors": descriptors.to(image),
189
+ }
190
+
191
+ return pred
192
+
193
+ def loss(self, pred, data):
194
+ raise NotImplementedError
195
+ ```
196
+
197
+ </details>
198
+
199
+ 4. Create ripe+NN.yaml in gluefactory/configs with the following content:
200
+
201
+ <details>
202
+ <summary>ripe+NN.yaml</summary>
203
+
204
+ ```yaml
205
+ model:
206
+ name: two_view_pipeline
207
+ extractor:
208
+ name: extractors.ripe
209
+ threshold: 1.0
210
+ top_k: 2048
211
+ matcher:
212
+ name: matchers.nearest_neighbor_matcher
213
+ benchmarks:
214
+ megadepth1500:
215
+ data:
216
+ preprocessing:
217
+ side: long
218
+ resize: 1600
219
+ eval:
220
+ estimator: poselib
221
+ ransac_th: 0.5
222
+ hpatches:
223
+ eval:
224
+ estimator: poselib
225
+ ransac_th: 0.5
226
+ model:
227
+ extractor:
228
+ top_k: 1024 # overwrite config above
229
+ ```
230
+
231
+ 5. Run the MegaDepth 1500 evaluation script:
232
+
233
+ ```bash
234
+ python -m gluefactory.eval.megadepth1500 --conf ripe+NN # for MegaDepth 1500
235
+ ```
236
+
237
+ Should result in:
238
+
239
+ ```bash
240
+ 'rel_pose_error@10°': 0.6834,
241
+ 'rel_pose_error@20°': 0.7803,
242
+ 'rel_pose_error@5°': 0.5511,
243
+ ```
244
+
245
+ 6. Run the HPatches evaluation script:
246
+
247
+ ```bash
248
+ python -m gluefactory.eval.hpatches --conf ripe+NN # for HPatches
249
+ ```
250
+
251
+ Should result in:
252
+
253
+ ```bash
254
+ 'H_error_ransac@1px': 0.3793,
255
+ 'H_error_ransac@3px': 0.5893,
256
+ 'H_error_ransac@5px': 0.692,
257
+ ```
258
+
259
+
260
+
261
+ ## Training
262
+
263
+ 1. Create a .env file with the following content:
264
+ ```bash
265
+ OUTPUT_DIR="/output"
266
+ DATA_DIR="/data"
267
+ ```
268
+
269
+ 2. Download the required datasets:
270
+
271
+ <details>
272
+ <summary>DISK Megadepth subset</summary>
273
+
274
+ To download the dataset used by [DISK](https://github.com/cvlab-epfl/disk) execute the following commands:
275
+
276
+ ```bash
277
+ cd data
278
+ bash download_disk_data.sh
279
+ ```
280
+
281
+ </details>
282
+
283
+ <details>
284
+ <summary>Tokyo 24/7</summary>
285
+
286
+ - ⚠️**Optional**⚠️: Only if you are interest in the model used in Section 4.6 of the paper!
287
+ - Download the Tokyo 24/7 query images from here: [Tokyo 24/7 Query Images V3](http://www.ok.ctrl.titech.ac.jp/~torii/project/247/download/247query_v3.zip) from the official [website](http://www.ok.ctrl.titech.ac.jp/~torii/project/247/_).
288
+ - extract them into data/Tolyo_Query_V3
289
+
290
+ ```bash
291
+ Tokyo_Query_V3/
292
+ ├── 00001.csv
293
+ ├── 00001.jpg
294
+ ├── 00002.csv
295
+ ├── 00002.jpg
296
+ ├── ...
297
+ ├── 01125.csv
298
+ ├── 01125.jpg
299
+ ├── Readme.txt
300
+ └── Readme.txt~
301
+ ```
302
+
303
+ </details>
304
+
305
+ <details>
306
+ <summary>ACDC</summary>
307
+
308
+ - ⚠️**Optional**⚠️: Only if you are interest in the model used in Section 6.1 (supplementary) of the paper!
309
+ - Download the RGB images from here: [ACDC RGB Images](https://acdc.vision.ee.ethz.ch/rgb_anon_trainvaltest.zip)
310
+ - extract them into data/ACDC
311
+
312
+ ```bash
313
+ ACDC/
314
+ rgb_anon
315
+ ├── fog
316
+ │   ├── test
317
+ │   │   ├── GOPR0475
318
+ │   │   ├── GOPR0477
319
+ │   ├── test_ref
320
+ │   │   ├── GOPR0475
321
+ │   │   ├── GOPR0477
322
+ │   ├── train
323
+ │   │   ├── GOPR0475
324
+ │   │   ├── GOPR0476
325
+ ├── night
326
+ ```
327
+
328
+ </details>
329
+
330
+ 3. Run the training script:
331
+
332
+ ```bash
333
+ python ripe/train.py --config-name train project_name=train name=reproduce wandb_mode=offline
334
+ ```
335
+
336
+ You can also easily switch setting from the command line, e.g. to addionally train on the Tokyo 24/7 dataset:
337
+ ```bash
338
+ python ripe/train.py --config-name train project_name=train name=reproduce wandb_mode=offline data=megadepth+tokyo
339
+ ```
340
+
341
+ ## Acknowledgements
342
+
343
+ Our code is partly based on the following repositories:
344
+ - [DALF](https://github.com/verlab/DALF_CVPR_2023) Apache License 2.0
345
+ - [DeDoDe](https://github.com/Parskatt/DeDoDe) MIT License
346
+ - [DISK](https://github.com/cvlab-epfl/disk) Apache License 2.0
347
+
348
+ Our evaluation was based on the following repositories:
349
+ - [Glue Factory](https://github.com/cvg/glue-factory)
350
+ - [hloc](https://github.com/cvg/Hierarchical-Localization)
351
+
352
+ We would like to thank the authors of these repositories for their great work and for making their code available.
353
+
354
+ Our project webpage is based on the [Acadamic Project Page Template](https://github.com/eliahuhorwitz/Academic-project-page-template) by Eliahu Horwitz.
355
+
356
+ ## BibTex Citation
357
+
358
+ ```
359
+
360
+ @article{ripe2025,
361
+ year = {2025},
362
+ title = {{RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction}},
363
+ author = {Künzel, Johannes and Hilsmann, Anna and Eisert, Peter},
364
+ journal = {arXiv},
365
+ eprint = {2507.04839},
366
+ }
367
+ ```
imcui/third_party/RIPE/app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is a small gradio interface to access our RIPE keypoint extractor.
2
+ # You can either upload two images or use one of the example image pairs.
3
+
4
+ import os
5
+
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ from ripe import vgg_hyper
10
+
11
+ SEED = 32000
12
+ os.environ["PYTHONHASHSEED"] = str(SEED)
13
+
14
+ import random
15
+ from pathlib import Path
16
+
17
+ import numpy as np
18
+ import torch
19
+
20
+ torch.manual_seed(SEED)
21
+ np.random.seed(SEED)
22
+ random.seed(SEED)
23
+ import cv2
24
+ import kornia.feature as KF
25
+ import kornia.geometry as KG
26
+
27
+ from ripe.utils.utils import cv2_matches_from_kornia, to_cv_kpts
28
+
29
+ MIN_SIZE = 512
30
+ MAX_SIZE = 768
31
+
32
+ description_text = """
33
+ <p align='center'>
34
+ <h1 align='center'>🌊🌺 ICCV 2025 🌺🌊</h1>
35
+ <p align='center'>
36
+ <a href='https://scholar.google.com/citations?user=ybMR38kAAAAJ'>Johannes Künzel</a> ·
37
+ <a href='https://scholar.google.com/citations?user=5yTuyGIAAAAJ'>Anna Hilsmann</a> ·
38
+ <a href='https://scholar.google.com/citations?user=BCElyCkAAAAJ'>Peter Eisert</a>
39
+ </p>
40
+ <h2 align='center'>
41
+ <a href='???'>Arxiv</a> |
42
+ <a href='???'>Project Page</a> |
43
+ <a href='???'>Code</a>
44
+ </h2>
45
+ </p>
46
+
47
+ <br/>
48
+ <div align='center'>
49
+
50
+ ### This demo showcases our new keypoint extractor model, RIPE (Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction).
51
+
52
+ ### RIPE is trained without requiring pose or depth supervision or artificial augmentations. By leveraging reinforcement learning, it learns to extract keypoints solely based on whether an image pair depicts the same scene or not.
53
+
54
+ ### For more detailed information, please refer to our [paper](link to be added).
55
+
56
+ The demo code extracts the 2048-top keypoints from the two input images. It uses the mutual nearest neighbor (MNN) descriptor matcher from kornia to find matches between the two images.
57
+ If the number of matches is greater than 8, it applies RANSAC to filter out outliers based on the inlier threshold provided by the user.
58
+ Images are resized to fit within a maximum size of 2048x2048 pixels with maintained aspect ratio.
59
+
60
+ </div>
61
+ """
62
+
63
+ path_weights = Path(
64
+ "/media/jwkuenzel/work/projects/CVG_Reinforced_Keypoints/output/train/ablation_iccv/inlier_threshold/1571243/2025-02-19/14-00-10_789013/model_inlier_threshold_best.pth"
65
+ )
66
+
67
+ model = vgg_hyper(path_weights)
68
+
69
+
70
+ def get_new_image_size(image, min_size=1600, max_size=2048):
71
+ """
72
+ Get a new size for the image that is scaled to fit between min_size and max_size while maintaining the aspect ratio.
73
+
74
+ Args:
75
+ image (PIL.Image): Input image.
76
+ min_size (int): Minimum allowed size for width and height.
77
+ max_size (int): Maximum allowed size for width and height.
78
+
79
+ Returns:
80
+ tuple: New size (width, height) for the image.
81
+ """
82
+ width, height = image.size
83
+
84
+ aspect_ratio = width / height
85
+ if width > height:
86
+ new_width = max(min_size, min(max_size, width))
87
+ new_height = int(new_width / aspect_ratio)
88
+ else:
89
+ new_height = max(min_size, min(max_size, height))
90
+ new_width = int(new_height * aspect_ratio)
91
+
92
+ new_size = (new_width, new_height)
93
+
94
+ return new_size
95
+
96
+
97
+ def extract_keypoints(image1, image2, inl_th):
98
+ """
99
+ Extract keypoints from two input images using the RIPE model.
100
+
101
+ Args:
102
+ image1 (PIL.Image): First input image.
103
+ image2 (PIL.Image): Second input image.
104
+ inl_th (float): RANSAC inlier threshold.
105
+
106
+ Returns:
107
+ dict: A dictionary containing keypoints and matches.
108
+ """
109
+ log_text = "Extracting keypoints and matches with RIPE\n"
110
+
111
+ log_text += f"Image 1 size: {image1.size}\n"
112
+ log_text += f"Image 2 size: {image2.size}\n"
113
+
114
+ # check not larger than 2048x2048
115
+ new_size = get_new_image_size(image1, min_size=MIN_SIZE, max_size=MAX_SIZE)
116
+ image1 = image1.resize(new_size)
117
+
118
+ new_size = get_new_image_size(image2, min_size=MIN_SIZE, max_size=MAX_SIZE)
119
+ image2 = image2.resize(new_size)
120
+
121
+ log_text += f"Resized Image 1 size: {image1.size}\n"
122
+ log_text += f"Resized Image 2 size: {image2.size}\n"
123
+
124
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
125
+ model.to(dev)
126
+
127
+ image1 = image1.convert("RGB")
128
+ image2 = image2.convert("RGB")
129
+
130
+ image1_original = image1.copy()
131
+ image2_original = image2.copy()
132
+
133
+ # convert PIL images to numpy arrays
134
+ image1_original = np.array(image1_original)
135
+ image2_original = np.array(image2_original)
136
+
137
+ # convert PIL images to tensors
138
+ image1 = torch.tensor(np.array(image1)).permute(2, 0, 1).float() / 255.0
139
+ image2 = torch.tensor(np.array(image2)).permute(2, 0, 1).float() / 255.0
140
+
141
+ image1 = image1.to(dev).unsqueeze(0) # Add batch dimension
142
+ image2 = image2.to(dev).unsqueeze(0) # Add batch dimension
143
+
144
+ kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
145
+ kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
146
+
147
+ log_text += f"Number of keypoints in image 1: {kpts_1.shape[0]}\n"
148
+ log_text += f"Number of keypoints in image 2: {kpts_2.shape[0]}\n"
149
+
150
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
151
+ match_dists, match_idxs = matcher(desc_1, desc_2)
152
+
153
+ log_text += f"Number of MNN matches: {match_idxs.shape[0]}\n"
154
+
155
+ cv2_matches = cv2_matches_from_kornia(match_dists, match_idxs)
156
+
157
+ do_ransac = match_idxs.shape[0] > 8
158
+
159
+ if do_ransac:
160
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
161
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
162
+
163
+ H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=inl_th)(matched_pts_1, matched_pts_2)
164
+ matchesMask = mask.int().ravel().tolist()
165
+
166
+ log_text += f"RANSAC found {mask.sum().item()} inliers out of {mask.shape[0]} matches with an inlier threshold of {inl_th}.\n"
167
+ else:
168
+ log_text += "Not enough matches for RANSAC, skipping RANSAC step.\n"
169
+
170
+ kpts_1 = to_cv_kpts(kpts_1, score_1)
171
+ kpts_2 = to_cv_kpts(kpts_2, score_2)
172
+
173
+ keypoints_raw_1 = cv2.drawKeypoints(image1_original, kpts_1, image1_original, color=(0, 255, 0))
174
+ keypoints_raw_2 = cv2.drawKeypoints(image2_original, kpts_2, image2_original, color=(0, 255, 0))
175
+
176
+ # pad height smaller image to match the height of the larger image
177
+ if keypoints_raw_1.shape[0] < keypoints_raw_2.shape[0]:
178
+ pad_height = keypoints_raw_2.shape[0] - keypoints_raw_1.shape[0]
179
+ keypoints_raw_1 = np.pad(
180
+ keypoints_raw_1, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
181
+ )
182
+ elif keypoints_raw_1.shape[0] > keypoints_raw_2.shape[0]:
183
+ pad_height = keypoints_raw_1.shape[0] - keypoints_raw_2.shape[0]
184
+ keypoints_raw_2 = np.pad(
185
+ keypoints_raw_2, ((0, pad_height), (0, 0), (0, 0)), mode="constant", constant_values=255
186
+ )
187
+
188
+ # concatenate keypoints images horizontally
189
+ keypoints_raw = np.concatenate((keypoints_raw_1, keypoints_raw_2), axis=1)
190
+ keypoints_raw_pil = Image.fromarray(keypoints_raw)
191
+
192
+ result_raw = cv2.drawMatches(
193
+ image1_original,
194
+ kpts_1,
195
+ image2_original,
196
+ kpts_2,
197
+ cv2_matches,
198
+ None,
199
+ matchColor=(0, 255, 0),
200
+ matchesMask=None,
201
+ # matchesMask=None,
202
+ flags=cv2.DrawMatchesFlags_NOT_DRAW_SINGLE_POINTS,
203
+ )
204
+
205
+ if not do_ransac:
206
+ result_ransac = None
207
+ else:
208
+ result_ransac = cv2.drawMatches(
209
+ image1_original,
210
+ kpts_1,
211
+ image2_original,
212
+ kpts_2,
213
+ cv2_matches,
214
+ None,
215
+ matchColor=(0, 255, 0),
216
+ matchesMask=matchesMask,
217
+ singlePointColor=(0, 0, 255),
218
+ flags=cv2.DrawMatchesFlags_DEFAULT,
219
+ )
220
+
221
+ # result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB) # Convert BGR to RGB for display
222
+
223
+ # convert to PIL Image
224
+ result_raw_pil = Image.fromarray(result_raw)
225
+ if result_ransac is not None:
226
+ result_ransac_pil = Image.fromarray(result_ransac)
227
+ else:
228
+ result_ransac_pil = None
229
+
230
+ return log_text, result_ransac_pil, result_raw_pil, keypoints_raw_pil
231
+
232
+
233
+ demo = gr.Interface(
234
+ fn=extract_keypoints,
235
+ inputs=[
236
+ gr.Image(type="pil", label="Image 1"),
237
+ gr.Image(type="pil", label="Image 2"),
238
+ gr.Slider(
239
+ minimum=0.1,
240
+ maximum=3.0,
241
+ step=0.1,
242
+ value=0.5,
243
+ label="RANSAC inlier threshold",
244
+ info="Threshold for RANSAC inlier detection. Lower values may yield fewer inliers but more robust matches.",
245
+ ),
246
+ ],
247
+ outputs=[
248
+ gr.Textbox(type="text", label="Log"),
249
+ gr.Image(type="pil", label="Keypoints and Matches (RANSAC)"),
250
+ gr.Image(type="pil", label="Keypoints and Matches"),
251
+ gr.Image(type="pil", label="Keypoint Detection Results"),
252
+ ],
253
+ title="RIPE: Reinforcement Learning on Unlabeled Image Pairs for Robust Keypoint Extraction",
254
+ description=description_text,
255
+ examples=[
256
+ [
257
+ "assets_gradio/all_souls_000013.jpg",
258
+ "assets_gradio/all_souls_000055.jpg",
259
+ ],
260
+ [
261
+ "assets_gradio/167170681_0e5c42fd21_o.jpg",
262
+ "assets_gradio/170804731_6bf4fbecd4_o.jpg",
263
+ ],
264
+ [
265
+ "assets_gradio/4171014767_0fe879b783_o.jpg",
266
+ "assets_gradio/4174108353_20422632d6_o.jpg",
267
+ ],
268
+ ],
269
+ flagging_mode="never",
270
+ theme="default",
271
+ )
272
+ demo.launch()
imcui/third_party/RIPE/assets/all_souls_000013.jpg ADDED

Git LFS Details

  • SHA256: 60fd73963102f86baf08325631f8912db34acba7fb46cc9a41b818099276187e
  • Pointer size: 131 Bytes
  • Size of remote file: 440 kB
imcui/third_party/RIPE/assets/all_souls_000055.jpg ADDED

Git LFS Details

  • SHA256: e11c06ae78103c2dbb90737e2bab6aa47f2000948ece5bfe9a1e7eb1cacac53a
  • Pointer size: 131 Bytes
  • Size of remote file: 368 kB
imcui/third_party/RIPE/assets/teaser_image.png ADDED

Git LFS Details

  • SHA256: bd636ae0eb42927792cba0f04243c2ec65226a6f5e1287ab4ee015353b01c208
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
imcui/third_party/RIPE/conda_env.yml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: ripe-env
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python
6
+ - cmake
7
+ - eigen # for poselib
8
+ - pytorch=2.6=*cuda*
9
+ - torchvision
10
+ - pip
11
+ # others
12
+ - pudb # debugger
13
+ - pip:
14
+ - lightning>=2.0.0
15
+ - setuptools
16
+ - poselib @ git+https://github.com/PoseLib/PoseLib.git@56d158f744d3561b0b70174e6d8ca9a7fc9bd9c1
17
+ - hydra-core
18
+ - opencv-python
19
+ - torchmetrics
20
+ - pyrootutils # standardizing the project root setup
21
+ - rich
22
+ - matplotlib
23
+ - kornia
24
+ - numpy
25
+ - wandb
26
+ - h5py
imcui/third_party/RIPE/conf/backbones/resnet.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ _target_: ripe.models.backbones.resnet.ResNet
2
+ nchannels: 3
3
+ pretrained: True
4
+ use_instance_norm: False
5
+ mode: dect
6
+ num_layers: 4
imcui/third_party/RIPE/conf/backbones/vgg.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ _target_: ripe.models.backbones.vgg.VGG
2
+ nchannels: 3
3
+ pretrained: True
4
+ use_instance_norm: False
5
+ mode: dect
imcui/third_party/RIPE/conf/data/disk_megadepth.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
2
+ root: ${oc.env:DATA_DIR}/disk-data
3
+ stage: train
4
+ max_scene_size: 10000
5
+ transforms:
6
+ _target_: ripe.data.data_transforms.Compose
7
+ transforms:
8
+ - _target_: ripe.data.data_transforms.Normalize
9
+ mean: [0.485, 0.456, 0.406]
10
+ std: [0.229, 0.224, 0.225]
11
+ - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
12
+ target_size_longer_side: 560
imcui/third_party/RIPE/conf/data/megadepth+acdc.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: ripe.data.datasets.dataset_combinator.DatasetCombinator
2
+ mode: custom
3
+ weights:
4
+ - 0.2
5
+ - 0.8
6
+ datasets:
7
+ - _target_: ripe.data.datasets.acdc.ACDC
8
+ root: ${oc.env:DATA_DIR}/ACDC
9
+ stage: train
10
+ condition: all
11
+ transforms:
12
+ _target_: ripe.data.data_transforms.Compose
13
+ transforms:
14
+ - _target_: ripe.data.data_transforms.Normalize
15
+ mean: [0.485, 0.456, 0.406]
16
+ std: [0.229, 0.224, 0.225]
17
+ - _target_: ripe.data.data_transforms.Crop # to remove the car hood from some images
18
+ crop_height: 896
19
+ crop_width: 1920
20
+ - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
21
+ target_size_longer_side: 560
22
+ - _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
23
+ root: ${oc.env:DATA_DIR}/disk-data
24
+ stage: train
25
+ max_scene_size: 10000
26
+ transforms:
27
+ _target_: ripe.data.data_transforms.Compose
28
+ transforms:
29
+ - _target_: ripe.data.data_transforms.Normalize
30
+ mean: [0.485, 0.456, 0.406]
31
+ std: [0.229, 0.224, 0.225]
32
+ - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
33
+ target_size_longer_side: 560
imcui/third_party/RIPE/conf/data/megadepth+tokyo.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: ripe.data.datasets.dataset_combinator.DatasetCombinator
2
+ mode: custom
3
+ weights:
4
+ - 0.2
5
+ - 0.8
6
+ datasets:
7
+ - _target_: ripe.data.datasets.tokyo_query_v3.TokyoQueryV3
8
+ root: ${oc.env:DATA_DIR}/Tokyo_Query_V3
9
+ stage: train
10
+ transforms:
11
+ _target_: ripe.data.data_transforms.Compose
12
+ transforms:
13
+ - _target_: ripe.data.data_transforms.Normalize
14
+ mean: [0.485, 0.456, 0.406]
15
+ std: [0.229, 0.224, 0.225]
16
+ - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
17
+ target_size_longer_side: 560 # like DeDoDe
18
+ - _target_: ripe.data.datasets.disk_megadepth.DISK_Megadepth
19
+ root: ${oc.env:DATA_DIR}/disk-data
20
+ stage: train
21
+ max_scene_size: 10000
22
+ transforms:
23
+ _target_: ripe.data.data_transforms.Compose
24
+ transforms:
25
+ - _target_: ripe.data.data_transforms.Normalize
26
+ mean: [0.485, 0.456, 0.406]
27
+ std: [0.229, 0.224, 0.225]
28
+ - _target_: ripe.data.data_transforms.ResizeAndPadWithHomography
29
+ target_size_longer_side: 560
imcui/third_party/RIPE/conf/descriptor_loss/contrastive_loss.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ _target_: ripe.losses.contrastive_loss.ContrastiveLoss
2
+ pos_margin: 0.2
3
+ neg_margin: 0.2
imcui/third_party/RIPE/conf/inl_th/constant.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: ripe.scheduler.constant.ConstantScheduler
2
+ value: 1.0
imcui/third_party/RIPE/conf/inl_th/exp_decay.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ _target_: ripe.scheduler.expDecay.ExpDecay
2
+ a: 2.5
3
+ b: 0.0005
4
+ c: 0.5
imcui/third_party/RIPE/conf/matcher/concurrent_mnn_poselib.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ _target_: ripe.matcher.concurrent_matcher.ConcurrentMatcher
2
+ min_num_matches: 8
3
+ matcher:
4
+ _target_: kornia.feature.DescriptorMatcher
5
+ match_mode: "mnn"
6
+ th: 0.8
7
+ robust_estimator:
8
+ _target_: ripe.matcher.pose_estimator_poselib.PoseLibRelativePoseEstimator
imcui/third_party/RIPE/conf/train.yaml ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - data: disk_megadepth # megadepth+acdc or megadepth+tokyo
3
+ - backbones: vgg
4
+ - upsampler: hypercolumn_features # interpolate_sparse2D
5
+ - matcher: concurrent_mnn_poselib
6
+ - descriptor_loss: contrastive_loss # none to deactivate
7
+ - inl_th: constant # exp_decay
8
+ - _self_
9
+
10
+ project_name: ???
11
+ name: ???
12
+
13
+ hydra:
14
+ run:
15
+ dir: ${oc.env:OUTPUT_DIR}/${project_name}/${name}/${oc.env:SLURM_JOB_ID}/${now:%Y-%m-%d}/${now:%H-%M-%S}
16
+ output_dir: ${hydra:runtime.output_dir}
17
+
18
+ num_gpus: 1
19
+ # precision: "32-true"
20
+ precision: "bf16-mixed" # numerically more stable
21
+ # precision: "16-mixed"
22
+
23
+ log_interval: 50 # log every N steps/ batches
24
+ wandb_mode: online
25
+ val_interval: 2000
26
+ conf_inference:
27
+ threshold: 0.5
28
+ top_k: 2048
29
+
30
+ desc_loss_weight: 5.0 # 0.0 to deactivate, also deactivates 1x1 conv
31
+
32
+ num_workers: 8
33
+ batch_size: 6
34
+
35
+ transformation_model: fundamental
36
+
37
+ network:
38
+ _target_: ripe.models.ripe.RIPE
39
+ _partial_: true
40
+ window_size: 8
41
+ non_linearity_dect:
42
+ _target_: torch.nn.Identity
43
+ # _target_: torch.nn.ReLU
44
+ desc_shares:
45
+ null
46
+ # - 64
47
+ # - 64
48
+ # - 64
49
+ # - 64
50
+
51
+ lr: 0.001 # 0.001 makes it somewhat unstable
52
+ fp_penalty: -1e-7 # -1e-7
53
+ kp_penalty: -7e-7 # -7e-7
54
+ num_grad_accs: 4
55
+ reward_type: inlier # inlier_ratio , inlier+inlier_ratio
56
+ no_filtering_negatives: False
57
+ descriptor_dim: 256
58
+
59
+ lr_scheduler:
60
+ _partial_: true
61
+ _target_: ripe.scheduler.linearLR.StepLinearLR
62
+ num_steps: ${num_steps}
63
+ initial_lr: ${lr}
64
+ final_lr: 1e-6
65
+
66
+ use_whitening: false
67
+
68
+ selected_only: False
69
+
70
+ padding_filter_mode: ignore
71
+ # padding_filter_mode: punish
72
+
73
+ num_steps: 80000
74
+
75
+ alpha_scheduler: # 1.0 after 1/3 of the steps
76
+ _target_: ripe.scheduler.linear_with_plateaus.LinearWithPlateaus
77
+ start_val: 0.0
78
+ end_val: 1.0
79
+ steps_total: ${num_steps}
80
+ rel_length_start_plateau: 0.0
81
+ rel_length_end_plateu: 0.6666666
82
+
83
+ beta_scheduler: # linear increase over all steps
84
+ _target_: ripe.scheduler.linear_with_plateaus.LinearWithPlateaus
85
+ start_val: 0.0
86
+ end_val: 1.0
87
+ steps_total: ${num_steps}
88
+ rel_length_start_plateau: 0.0
89
+ rel_length_end_plateu: 0.0
imcui/third_party/RIPE/conf/upsampler/hypercolumn_features.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ _target_: ripe.models.upsampler.hypercolumn_features.HyperColumnFeatures
2
+ mode: bilinear
imcui/third_party/RIPE/conf/upsampler/interpolate_sparse2D.yaml ADDED
@@ -0,0 +1 @@
 
 
1
+ _target_: ripe.models.upsampler.interpolate_sparse2d.InterpolateSparse2d
imcui/third_party/RIPE/data/download_disk_data.sh ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #/usr/bin/env bash
2
+
3
+ # get the data (zipped)
4
+ # wget -r https://datasets.epfl.ch/disk-data/index.html
5
+
6
+ cd datasets.epfl.ch/disk-data;
7
+
8
+ # check for MD5 match
9
+ # md5sum -c md5sum.txt;
10
+ # if [ $? ]; then
11
+ # echo "MD5 mismatch (corrupt download)";
12
+ # return 1;
13
+ # fi
14
+
15
+ # create a crude progress counter
16
+ ITER=1;
17
+ TOTAL=138;
18
+ # unzip test scenes
19
+ cd imw2020-val/scenes;
20
+ for SCENE_TAR in *.tar.gz; do
21
+ echo "Unzipping $SCENE_TAR ($ITER / $TOTAL)";
22
+ tar -xz --strip-components=3 -f $SCENE_TAR;
23
+ rm $SCENE_TAR;
24
+ ITER=$(($ITER+1));
25
+ done
26
+
27
+ # unzip megadepth scenes
28
+ cd ../../megadepth/scenes;
29
+ for SCENE_TAR in *.tar; do
30
+ echo "Unzipping $SCENE_TAR ($ITER / $TOTAL)";
31
+ tar -x --strip-components=3 -f $SCENE_TAR;
32
+ rm $SCENE_TAR;
33
+ ITER=$(($ITER+1));
34
+ done
35
+
36
+ cd ../../../../
37
+
38
+ mv datasets.epfl.ch/disk-data ./
39
+ rm -rf datasets.epfl.ch
40
+
41
+
42
+
43
+
imcui/third_party/RIPE/demo.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import kornia.feature as KF
3
+ import kornia.geometry as KG
4
+ import matplotlib.pyplot as plt
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.io import decode_image
8
+
9
+ from ripe import vgg_hyper
10
+ from ripe.utils.utils import cv2_matches_from_kornia, resize_image, to_cv_kpts
11
+
12
+ dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ model = vgg_hyper().to(dev)
15
+ model.eval()
16
+
17
+ image1 = resize_image(decode_image("assets/all_souls_000013.jpg").float().to(dev) / 255.0)
18
+ image2 = resize_image(decode_image("assets/all_souls_000055.jpg").float().to(dev) / 255.0)
19
+
20
+ kpts_1, desc_1, score_1 = model.detectAndCompute(image1, threshold=0.5, top_k=2048)
21
+ kpts_2, desc_2, score_2 = model.detectAndCompute(image2, threshold=0.5, top_k=2048)
22
+
23
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
24
+ match_dists, match_idxs = matcher(desc_1, desc_2)
25
+
26
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
27
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
28
+
29
+ H, mask = KG.ransac.RANSAC(model_type="fundamental", inl_th=1.0)(matched_pts_1, matched_pts_2)
30
+ matchesMask = mask.int().ravel().tolist()
31
+
32
+ result_ransac = cv2.drawMatches(
33
+ (image1.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
34
+ to_cv_kpts(kpts_1, score_1),
35
+ (image2.cpu().permute(1, 2, 0).numpy() * 255.0).astype(np.uint8),
36
+ to_cv_kpts(kpts_2, score_2),
37
+ cv2_matches_from_kornia(match_dists, match_idxs),
38
+ None,
39
+ matchColor=(0, 255, 0),
40
+ matchesMask=matchesMask,
41
+ # matchesMask=None, # without RANSAC filtering
42
+ singlePointColor=(0, 0, 255),
43
+ flags=cv2.DrawMatchesFlags_DEFAULT,
44
+ )
45
+
46
+ plt.imshow(result_ransac)
47
+ plt.axis("off")
48
+ plt.tight_layout()
49
+
50
+ # plt.show()
51
+ plt.savefig("result_ransac.png")
imcui/third_party/RIPE/ripe/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model_zoo import vgg_hyper # noqa: F401
imcui/third_party/RIPE/ripe/benchmarks/imw_2020.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+
4
+ import cv2
5
+ import kornia.feature as KF
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import poselib
9
+ import torch
10
+ from tqdm import tqdm
11
+
12
+ from ripe import utils
13
+ from ripe.data.data_transforms import Compose, Normalize, Resize
14
+ from ripe.data.datasets.disk_imw import DISK_IMW
15
+ from ripe.utils.pose_error import AUCMetric, relative_pose_error
16
+ from ripe.utils.utils import (
17
+ cv2_matches_from_kornia,
18
+ cv_resize_and_pad_to_shape,
19
+ to_cv_kpts,
20
+ )
21
+
22
+ log = utils.get_pylogger(__name__)
23
+
24
+
25
+ class IMW_2020_Benchmark:
26
+ def __init__(
27
+ self,
28
+ use_predefined_subset: bool = True,
29
+ conf_inference=None,
30
+ edge_input_divisible_by=None,
31
+ ):
32
+ data_dir = os.getenv("DATA_DIR")
33
+ if data_dir is None:
34
+ raise ValueError("Environment variable DATA_DIR is not set.")
35
+ root_path = Path(data_dir) / "disk-data"
36
+
37
+ self.data = DISK_IMW(
38
+ str(
39
+ root_path
40
+ ), # Resize only to ensure that the input size is divisible the value of edge_input_divisible_by
41
+ transforms=Compose(
42
+ [
43
+ Resize(None, edge_input_divisible_by),
44
+ Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
45
+ ]
46
+ ),
47
+ )
48
+ self.ids_subset = None
49
+ self.results = []
50
+ self.conf_inference = conf_inference
51
+
52
+ # fmt: off
53
+ if use_predefined_subset:
54
+ self.ids_subset = [4921, 3561, 3143, 6040, 802, 6828, 5338, 9275, 10764, 10085, 5124, 11355, 7, 10027, 2161, 4433, 6887, 3311, 10766,
55
+ 11451, 11433, 8539, 2581, 10300, 10562, 1723, 8803, 6275, 10140, 11487, 6238, 638, 8092, 9979, 201, 10394, 3414,
56
+ 9002, 7456, 2431, 632, 6589, 9265, 9889, 3139, 7890, 10619, 4899, 675, 176, 4309, 4814, 3833, 3519, 148, 4560, 10705,
57
+ 3744, 1441, 4049, 1791, 5106, 575, 1540, 1105, 6791, 1383, 9344, 501, 2504, 4335, 8992, 10970, 10786, 10405, 9317,
58
+ 5279, 1396, 5044, 9408, 11125, 10417, 7627, 7480, 1358, 7738, 5461, 10178, 9226, 8106, 2766, 6216, 4032, 7298, 259,
59
+ 3021, 2645, 8756, 7513, 3163, 2510, 6701, 6684, 3159, 9689, 7425, 6066, 1904, 6382, 3052, 777, 6277, 7409, 5997, 2987,
60
+ 11316, 2894, 4528, 1927, 10366, 8605, 2726, 1886, 2416, 2164, 3352, 2997, 6636, 6765, 5609, 3679, 76, 10956, 3612, 6699,
61
+ 1741, 8811, 3755, 1285, 9520, 2476, 3977, 370, 9823, 1834, 7551, 6227, 7303, 6399, 4758, 10713, 5050, 380, 11056, 7620,
62
+ 4826, 6090, 9011, 7523, 7355, 8021, 9801, 1801, 6522, 7138, 10017, 8732, 6402, 3116, 4031, 6088, 3975, 9841, 9082, 9412,
63
+ 5406, 217, 2385, 8791, 8361, 494, 4319, 5275, 3274, 335, 6731, 207, 10095, 3068, 5996, 3951, 2808, 5877, 6134, 7772, 10042,
64
+ 8574, 5501, 10885, 7871]
65
+ # self.ids_subset = self.ids_subset[:10]
66
+ # fmt: on
67
+
68
+ def evaluate_sample(self, model, sample, dev):
69
+ img_1 = sample["src_image"].unsqueeze(0).to(dev)
70
+ img_2 = sample["trg_image"].unsqueeze(0).to(dev)
71
+
72
+ scale_h_1, scale_w_1 = (
73
+ sample["orig_size_src"][0] / img_1.shape[2],
74
+ sample["orig_size_src"][1] / img_1.shape[3],
75
+ )
76
+ scale_h_2, scale_w_2 = (
77
+ sample["orig_size_trg"][0] / img_2.shape[2],
78
+ sample["orig_size_trg"][1] / img_2.shape[3],
79
+ )
80
+
81
+ M = None
82
+ info = {}
83
+ kpts_1, desc_1, score_1 = None, None, None
84
+ kpts_2, desc_2, score_2 = None, None, None
85
+ match_dists, match_idxs = None, None
86
+
87
+ try:
88
+ kpts_1, desc_1, score_1 = model.detectAndCompute(img_1, **self.conf_inference)
89
+ kpts_2, desc_2, score_2 = model.detectAndCompute(img_2, **self.conf_inference)
90
+
91
+ if kpts_1.dim() == 3:
92
+ assert kpts_1.shape[0] == 1 and kpts_2.shape[0] == 1, "Batch size must be 1"
93
+
94
+ kpts_1, desc_1, score_1 = (
95
+ kpts_1.squeeze(0),
96
+ desc_1[0].squeeze(0),
97
+ score_1[0].squeeze(0),
98
+ )
99
+ kpts_2, desc_2, score_2 = (
100
+ kpts_2.squeeze(0),
101
+ desc_2[0].squeeze(0),
102
+ score_2[0].squeeze(0),
103
+ )
104
+
105
+ scale_1 = torch.tensor([scale_w_1, scale_h_1], dtype=torch.float).to(dev)
106
+ scale_2 = torch.tensor([scale_w_2, scale_h_2], dtype=torch.float).to(dev)
107
+
108
+ kpts_1 = kpts_1 * scale_1
109
+ kpts_2 = kpts_2 * scale_2
110
+
111
+ matcher = KF.DescriptorMatcher("mnn") # threshold is not used with mnn
112
+ match_dists, match_idxs = matcher(desc_1, desc_2)
113
+
114
+ matched_pts_1 = kpts_1[match_idxs[:, 0]]
115
+ matched_pts_2 = kpts_2[match_idxs[:, 1]]
116
+
117
+ camera_1 = sample["src_camera"]
118
+ camera_2 = sample["trg_camera"]
119
+
120
+ M, info = poselib.estimate_relative_pose(
121
+ matched_pts_1.cpu().numpy(),
122
+ matched_pts_2.cpu().numpy(),
123
+ camera_1.to_cameradict(),
124
+ camera_2.to_cameradict(),
125
+ {
126
+ "max_epipolar_error": 0.5,
127
+ },
128
+ {},
129
+ )
130
+ except RuntimeError as e:
131
+ if "No keypoints detected" in str(e):
132
+ pass
133
+ else:
134
+ raise e
135
+
136
+ success = M is not None
137
+ if success:
138
+ M = {
139
+ "R": torch.tensor(M.R, dtype=torch.float),
140
+ "t": torch.tensor(M.t, dtype=torch.float),
141
+ }
142
+ inl = info["inliers"]
143
+ else:
144
+ M = {
145
+ "R": torch.eye(3, dtype=torch.float),
146
+ "t": torch.zeros((3), dtype=torch.float),
147
+ }
148
+ inl = np.zeros((0,)).astype(bool)
149
+
150
+ t_err, r_err = relative_pose_error(sample["s2t_R"].cpu(), sample["s2t_T"].cpu(), M["R"], M["t"])
151
+
152
+ rel_pose_error = max(t_err.item(), r_err.item()) if success else np.inf
153
+ ransac_inl = np.sum(inl)
154
+ ransac_inl_ratio = np.mean(inl)
155
+
156
+ if success:
157
+ assert match_dists is not None and match_idxs is not None, "Matches must be computed"
158
+ cv_keypoints_src = to_cv_kpts(kpts_1, score_1)
159
+ cv_keypoints_trg = to_cv_kpts(kpts_2, score_2)
160
+ cv_matches = cv2_matches_from_kornia(match_dists, match_idxs)
161
+ cv_mask = [int(m) for m in inl]
162
+ else:
163
+ cv_keypoints_src, cv_keypoints_trg = [], []
164
+ cv_matches, cv_mask = [], []
165
+
166
+ estimation = {
167
+ "success": success,
168
+ "M_0to1": M,
169
+ "inliers": torch.tensor(inl).to(img_1),
170
+ "rel_pose_error": rel_pose_error,
171
+ "ransac_inl": ransac_inl,
172
+ "ransac_inl_ratio": ransac_inl_ratio,
173
+ "path_src_image": sample["src_path"],
174
+ "path_trg_image": sample["trg_path"],
175
+ "cv_keypoints_src": cv_keypoints_src,
176
+ "cv_keypoints_trg": cv_keypoints_trg,
177
+ "cv_matches": cv_matches,
178
+ "cv_mask": cv_mask,
179
+ }
180
+
181
+ return estimation
182
+
183
+ def evaluate(self, model, dev, progress_bar=False):
184
+ model.eval()
185
+
186
+ # reset results
187
+ self.results = []
188
+
189
+ for idx in tqdm(
190
+ self.ids_subset if self.ids_subset is not None else range(len(self.data)),
191
+ disable=not progress_bar,
192
+ ):
193
+ sample = self.data[idx]
194
+ self.results.append(self.evaluate_sample(model, sample, dev))
195
+
196
+ def get_auc(self, threshold=5, downsampled=False):
197
+ if len(self.results) == 0:
198
+ raise ValueError("No results to log. Run evaluate first.")
199
+
200
+ summary_results = self.calc_auc(downsampled=downsampled)
201
+
202
+ return summary_results[f"rel_pose_error@{threshold}°{'__original' if not downsampled else '__downsampled'}"]
203
+
204
+ def plot_results(self, num_samples=10, logger=None, step=None, downsampled=False):
205
+ if len(self.results) == 0:
206
+ raise ValueError("No results to plot. Run evaluate first.")
207
+
208
+ plot_data = []
209
+
210
+ for result in self.results[:num_samples]:
211
+ img1 = cv2.imread(result["path_src_image"])
212
+ img2 = cv2.imread(result["path_trg_image"])
213
+
214
+ # from BGR to RGB
215
+ img1 = cv2.cvtColor(img1, cv2.COLOR_BGR2RGB)
216
+ img2 = cv2.cvtColor(img2, cv2.COLOR_BGR2RGB)
217
+
218
+ plt_matches = cv2.drawMatches(
219
+ img1,
220
+ result["cv_keypoints_src"],
221
+ img2,
222
+ result["cv_keypoints_trg"],
223
+ result["cv_matches"],
224
+ None,
225
+ matchColor=None,
226
+ matchesMask=result["cv_mask"],
227
+ flags=cv2.DrawMatchesFlags_DEFAULT,
228
+ )
229
+ file_name = (
230
+ Path(result["path_src_image"]).parent.parent.name
231
+ + "_"
232
+ + Path(result["path_src_image"]).stem
233
+ + Path(result["path_trg_image"]).stem
234
+ + ("_downsampled" if downsampled else "")
235
+ + ".png"
236
+ )
237
+ # print rel_pose_error on image
238
+ plt_matches = cv2.putText(
239
+ plt_matches,
240
+ f"rel_pose_error: {result['rel_pose_error']:.2f} num_inliers: {result['ransac_inl']} inl_ratio: {result['ransac_inl_ratio']:.2f} num_matches: {len(result['cv_matches'])} num_keypoints: {len(result['cv_keypoints_src'])}/{len(result['cv_keypoints_trg'])}",
241
+ (10, 30),
242
+ cv2.FONT_HERSHEY_SIMPLEX,
243
+ 1,
244
+ (0, 0, 0),
245
+ 2,
246
+ cv2.LINE_8,
247
+ )
248
+
249
+ plot_data.append({"file_name": file_name, "image": plt_matches})
250
+
251
+ if logger is None:
252
+ log.info("No logger provided. Using plt to plot results.")
253
+ for image in plot_data:
254
+ plt.imsave(
255
+ image["file_name"],
256
+ cv_resize_and_pad_to_shape(image["image"], (1024, 2048)),
257
+ )
258
+ plt.close()
259
+ else:
260
+ import wandb
261
+
262
+ log.info(f"Logging images to wandb with step={step}")
263
+ if not downsampled:
264
+ logger.log(
265
+ {
266
+ "examples": [
267
+ wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
268
+ ]
269
+ },
270
+ step=step,
271
+ )
272
+ else:
273
+ logger.log(
274
+ {
275
+ "examples_downsampled": [
276
+ wandb.Image(cv_resize_and_pad_to_shape(image["image"], (1024, 2048))) for image in plot_data
277
+ ]
278
+ },
279
+ step=step,
280
+ )
281
+
282
+ def log_results(self, logger=None, step=None, downsampled=False):
283
+ if len(self.results) == 0:
284
+ raise ValueError("No results to log. Run evaluate first.")
285
+
286
+ summary_results = self.calc_auc(downsampled=downsampled)
287
+
288
+ if logger is not None:
289
+ logger.log(summary_results, step=step)
290
+ else:
291
+ log.warning("No logger provided. Printing results instead.")
292
+ print(self.calc_auc())
293
+
294
+ def print_results(self):
295
+ if len(self.results) == 0:
296
+ raise ValueError("No results to print. Run evaluate first.")
297
+
298
+ print(self.calc_auc())
299
+
300
+ def calc_auc(self, auc_thresholds=None, downsampled=False):
301
+ if auc_thresholds is None:
302
+ auc_thresholds = [5, 10, 20]
303
+ if not isinstance(auc_thresholds, list):
304
+ auc_thresholds = [auc_thresholds]
305
+
306
+ if len(self.results) == 0:
307
+ raise ValueError("No results to calculate auc. Run evaluate first.")
308
+
309
+ rel_pose_errors = [r["rel_pose_error"] for r in self.results]
310
+
311
+ pose_aucs = AUCMetric(auc_thresholds, rel_pose_errors).compute()
312
+ assert isinstance(pose_aucs, list) and len(pose_aucs) == len(auc_thresholds)
313
+
314
+ ext = "_downsampled" if downsampled else "_original"
315
+
316
+ summary = {}
317
+ for i, ath in enumerate(auc_thresholds):
318
+ summary[f"rel_pose_error@{ath}°_{ext}"] = pose_aucs[i]
319
+
320
+ return summary
imcui/third_party/RIPE/ripe/data/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/data/data_transforms.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import collections.abc
3
+
4
+ import kornia.geometry as KG
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.transforms import functional as TF
8
+
9
+
10
+ class Compose:
11
+ """Composes several transforms together. The transforms are applied in the order they are passed in.
12
+ Args: transforms (list): A list of transforms to be applied.
13
+ """
14
+
15
+ def __init__(self, transforms):
16
+ self.transforms = transforms
17
+
18
+ def __call__(self, src, trg, src_mask, trg_mask, h):
19
+ for t in self.transforms:
20
+ src, trg, src_mask, trg_mask, h = t(src, trg, src_mask, trg_mask, h)
21
+
22
+ return src, trg, src_mask, trg_mask, h
23
+
24
+
25
+ class Transform:
26
+ """Base class for all transforms. It provides a method to apply a transformation function to the input images and masks.
27
+ Args:
28
+ src (torch.Tensor): The source image tensor.
29
+ trg (torch.Tensor): The target image tensor.
30
+ src_mask (torch.Tensor): The source image mask tensor.
31
+ trg_mask (torch.Tensor): The target image mask tensor.
32
+ h (torch.Tensor): The homography matrix tensor.
33
+ Returns:
34
+ tuple: A tuple containing the transformed source image, the transformed target image, the transformed source mask,
35
+ the transformed target mask and the updated homography matrix.
36
+ """
37
+
38
+ def __init__(self):
39
+ pass
40
+
41
+ def apply_transform(self, src, trg, src_mask, trg_mask, h, transfrom_function):
42
+ src, trg, src_mask, trg_mask, h = transfrom_function(src, trg, src_mask, trg_mask, h)
43
+ return src, trg, src_mask, trg_mask, h
44
+
45
+
46
+ class Normalize(Transform):
47
+ def __init__(self, mean, std):
48
+ self.mean = mean
49
+ self.std = std
50
+
51
+ def __call__(self, src, trg, src_mask, trg_mask, h):
52
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
53
+
54
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
55
+ src = TF.normalize(src, mean=self.mean, std=self.std)
56
+ trg = TF.normalize(trg, mean=self.mean, std=self.std)
57
+ return src, trg, src_mask, trg_mask, h
58
+
59
+
60
+ class ResizeAndPadWithHomography(Transform):
61
+ def __init__(self, target_size_longer_side=768):
62
+ self.target_size = target_size_longer_side
63
+
64
+ def __call__(self, src, trg, src_mask, trg_mask, h):
65
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
66
+
67
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
68
+ src_w, src_h = src.shape[-1], src.shape[-2]
69
+ trg_w, trg_h = trg.shape[-1], trg.shape[-2]
70
+
71
+ # Resizing logic for both images
72
+ scale_src, new_src_w, new_src_h = self.compute_resize(src_w, src_h)
73
+ scale_trg, new_trg_w, new_trg_h = self.compute_resize(trg_w, trg_h)
74
+
75
+ # Resize both images
76
+ src_resized = TF.resize(src, [new_src_h, new_src_w])
77
+ trg_resized = TF.resize(trg, [new_trg_h, new_trg_w])
78
+
79
+ src_mask_resized = TF.resize(src_mask, [new_src_h, new_src_w])
80
+ trg_mask_resized = TF.resize(trg_mask, [new_trg_h, new_trg_w])
81
+
82
+ # Pad the resized images to be square (768x768)
83
+ src_padded, src_padding = self.apply_padding(src_resized, new_src_w, new_src_h)
84
+ trg_padded, trg_padding = self.apply_padding(trg_resized, new_trg_w, new_trg_h)
85
+
86
+ src_mask_padded, _ = self.apply_padding(src_mask_resized, new_src_w, new_src_h)
87
+ trg_mask_padded, _ = self.apply_padding(trg_mask_resized, new_trg_w, new_trg_h)
88
+
89
+ # Update the homography matrix
90
+ h = self.update_homography(h, scale_src, src_padding, scale_trg, trg_padding)
91
+
92
+ return src_padded, trg_padded, src_mask_padded, trg_mask_padded, h
93
+
94
+ def compute_resize(self, w, h):
95
+ if w > h:
96
+ scale = self.target_size / w
97
+ new_w = self.target_size
98
+ new_h = int(h * scale)
99
+ else:
100
+ scale = self.target_size / h
101
+ new_h = self.target_size
102
+ new_w = int(w * scale)
103
+ return scale, new_w, new_h
104
+
105
+ def apply_padding(self, img, new_w, new_h):
106
+ pad_w = (self.target_size - new_w) // 2
107
+ pad_h = (self.target_size - new_h) // 2
108
+ padding = [
109
+ pad_w,
110
+ pad_h,
111
+ self.target_size - new_w - pad_w,
112
+ self.target_size - new_h - pad_h,
113
+ ]
114
+ img_padded = TF.pad(img, padding, fill=0) # Zero-pad
115
+ return img_padded, padding
116
+
117
+ def update_homography(self, h, scale_src, padding_src, scale_trg, padding_trg):
118
+ # Create the scaling matrices
119
+ scale_matrix_src = np.array([[scale_src, 0, 0], [0, scale_src, 0], [0, 0, 1]])
120
+ scale_matrix_trg = np.array([[scale_trg, 0, 0], [0, scale_trg, 0], [0, 0, 1]])
121
+
122
+ # Create the padding translation matrices
123
+ pad_matrix_src = np.array([[1, 0, padding_src[0]], [0, 1, padding_src[1]], [0, 0, 1]])
124
+ pad_matrix_trg = np.array([[1, 0, -padding_trg[0]], [0, 1, -padding_trg[1]], [0, 0, 1]])
125
+
126
+ # Update the homography: apply scaling and translation
127
+ h_updated = (
128
+ pad_matrix_trg
129
+ @ scale_matrix_trg
130
+ @ h.numpy()
131
+ @ np.linalg.inv(scale_matrix_src)
132
+ @ np.linalg.inv(pad_matrix_src)
133
+ )
134
+
135
+ return torch.from_numpy(h_updated).float()
136
+
137
+
138
+ class Resize(Transform):
139
+ def __init__(self, output_size, edge_divisible_by=None, side="long", antialias=True):
140
+ self.output_size = output_size
141
+ self.edge_divisible_by = edge_divisible_by
142
+ self.side = side
143
+ self.antialias = antialias
144
+
145
+ def __call__(self, src, trg, src_mask, trg_mask, h):
146
+ return self.apply_transform(src, trg, src_mask, trg_mask, h, self.transform_function)
147
+
148
+ def transform_function(self, src, trg, src_mask, trg_mask, h):
149
+ new_size_src = self.get_new_image_size(src)
150
+ new_size_trg = self.get_new_image_size(trg)
151
+
152
+ src, T_src = self.resize(src, new_size_src)
153
+ trg, T_trg = self.resize(trg, new_size_trg)
154
+
155
+ src_mask, _ = self.resize(src_mask, new_size_src)
156
+ trg_mask, _ = self.resize(trg_mask, new_size_trg)
157
+
158
+ h = torch.from_numpy(T_trg @ h.numpy() @ T_src).float()
159
+
160
+ return src, trg, src_mask, trg_mask, h
161
+
162
+ def resize(self, img, size):
163
+ h, w = img.shape[-2:]
164
+
165
+ img = KG.transform.resize(
166
+ img,
167
+ size,
168
+ side=self.side,
169
+ antialias=self.antialias,
170
+ align_corners=None,
171
+ interpolation="bilinear",
172
+ )
173
+
174
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
175
+ T = np.diag([scale[0].item(), scale[1].item(), 1])
176
+
177
+ return img, T
178
+
179
+ def get_new_image_size(self, img):
180
+ h, w = img.shape[-2:]
181
+
182
+ if isinstance(self.output_size, collections.abc.Iterable):
183
+ assert len(self.output_size) == 2
184
+ return tuple(self.output_size)
185
+ if self.output_size is None: # keep the original size, but possibly make it divisible by edge_divisible_by
186
+ size = (h, w)
187
+ else:
188
+ side_size = self.output_size
189
+ aspect_ratio = w / h
190
+ if self.side not in ("short", "long", "vert", "horz"):
191
+ raise ValueError(f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{self.side}'")
192
+ if self.side == "vert":
193
+ size = side_size, int(side_size * aspect_ratio)
194
+ elif self.side == "horz":
195
+ size = int(side_size / aspect_ratio), side_size
196
+ elif (self.side == "short") ^ (aspect_ratio < 1.0):
197
+ size = side_size, int(side_size * aspect_ratio)
198
+ else:
199
+ size = int(side_size / aspect_ratio), side_size
200
+
201
+ if self.edge_divisible_by is not None:
202
+ df = self.edge_divisible_by
203
+ size = list(map(lambda x: int(x // df * df), size))
204
+ return size
imcui/third_party/RIPE/ripe/data/datasets/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/data/datasets/acdc.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any, Callable, Dict, Optional
3
+
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision.io import read_image
7
+
8
+ from ripe import utils
9
+ from ripe.data.data_transforms import Compose
10
+ from ripe.utils.utils import get_other_random_id
11
+
12
+ log = utils.get_pylogger(__name__)
13
+
14
+
15
+ class ACDC(Dataset):
16
+ def __init__(
17
+ self,
18
+ root: Path,
19
+ stage: str = "train",
20
+ condition: str = "rain",
21
+ transforms: Optional[Callable] = None,
22
+ positive_only: bool = False,
23
+ ) -> None:
24
+ self.root = root
25
+ self.stage = stage
26
+ self.condition = condition
27
+ self.transforms = transforms
28
+ self.positive_only = positive_only
29
+
30
+ if isinstance(self.root, str):
31
+ self.root = Path(self.root)
32
+
33
+ if not self.root.exists():
34
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
35
+
36
+ if transforms is None:
37
+ self.transforms = Compose([])
38
+ else:
39
+ self.transforms = transforms
40
+
41
+ if self.stage not in ["train", "val", "test", "pred"]:
42
+ raise RuntimeError(
43
+ "Unknown option "
44
+ + self.stage
45
+ + " as training stage variable. Valid options: 'train', 'val', 'test' and 'pred'"
46
+ )
47
+
48
+ if self.stage == "pred": # prediction uses the test set
49
+ self.stage = "test"
50
+
51
+ if self.stage in ["val", "test", "pred"]:
52
+ self.positive_only = True
53
+ log.info(f"{self.stage} stage: Using only positive pairs!")
54
+
55
+ weather_conditions = ["fog", "night", "rain", "snow"]
56
+
57
+ if self.condition not in weather_conditions + ["all"]:
58
+ raise RuntimeError(
59
+ "Unknown option "
60
+ + self.condition
61
+ + " as weather condition variable. Valid options: 'fog', 'night', 'rain', 'snow' and 'all'"
62
+ )
63
+
64
+ self.weather_condition_query = weather_conditions if self.condition == "all" else [self.condition]
65
+
66
+ self._read_sample_files()
67
+
68
+ if positive_only:
69
+ log.warning("Using only positive pairs!")
70
+ log.info(f"Found {len(self.src_images)} source images and {len(self.trg_images)} target images.")
71
+
72
+ def _read_sample_files(self):
73
+ file_name_pattern_ref = "_ref_anon.png"
74
+ file_name_pattern = "_rgb_anon.png"
75
+
76
+ self.trg_images = []
77
+ self.src_images = []
78
+
79
+ for weather_condition in self.weather_condition_query:
80
+ rgb_files = sorted(
81
+ list(self.root.glob("rgb_anon/" + weather_condition + "/" + self.stage + "/**/*" + file_name_pattern)),
82
+ key=lambda i: i.stem[:21],
83
+ )
84
+
85
+ src_images = sorted(
86
+ list(
87
+ self.root.glob(
88
+ "rgb_anon/" + weather_condition + "/" + self.stage + "_ref" + "/**/*" + file_name_pattern_ref
89
+ )
90
+ ),
91
+ key=lambda i: i.stem[:21],
92
+ )
93
+
94
+ self.trg_images += rgb_files
95
+ self.src_images += src_images
96
+
97
+ def __len__(self) -> int:
98
+ if self.positive_only:
99
+ return len(self.trg_images)
100
+ return 2 * len(self.trg_images)
101
+
102
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
103
+ sample: Any = {}
104
+
105
+ positive_sample = (idx % 2 == 0) or (self.positive_only)
106
+ if not self.positive_only:
107
+ idx = idx // 2
108
+
109
+ sample["label"] = positive_sample
110
+
111
+ if positive_sample:
112
+ sample["src_path"] = str(self.src_images[idx])
113
+ sample["trg_path"] = str(self.trg_images[idx])
114
+
115
+ assert self.src_images[idx].stem[:21] == self.trg_images[idx].stem[:21], (
116
+ f"Source and target image mismatch: {self.src_images[idx]} vs {self.trg_images[idx]}"
117
+ )
118
+
119
+ src_img = read_image(sample["src_path"])
120
+ trg_img = read_image(sample["trg_path"])
121
+
122
+ homography = torch.eye(3, dtype=torch.float32)
123
+ else:
124
+ sample["src_path"] = str(self.src_images[idx])
125
+ idx_other = get_other_random_id(idx, len(self) // 2)
126
+ sample["trg_path"] = str(self.trg_images[idx_other])
127
+
128
+ assert self.src_images[idx].stem[:21] != self.trg_images[idx_other].stem[:21], (
129
+ f"Source and target image match for negative sample: {self.src_images[idx]} vs {self.trg_images[idx_other]}"
130
+ )
131
+
132
+ src_img = read_image(sample["src_path"])
133
+ trg_img = read_image(sample["trg_path"])
134
+
135
+ homography = torch.zeros((3, 3), dtype=torch.float32)
136
+
137
+ src_img = src_img / 255.0
138
+ trg_img = trg_img / 255.0
139
+
140
+ _, H, W = src_img.shape
141
+
142
+ src_mask = torch.ones((1, H, W), dtype=torch.uint8)
143
+ trg_mask = torch.ones((1, H, W), dtype=torch.uint8)
144
+
145
+ if self.transforms:
146
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
147
+
148
+ sample["src_image"] = src_img
149
+ sample["trg_image"] = trg_img
150
+ sample["src_mask"] = src_mask.to(torch.bool)
151
+ sample["trg_mask"] = trg_mask.to(torch.bool)
152
+ sample["homography"] = homography
153
+
154
+ return sample
imcui/third_party/RIPE/ripe/data/datasets/dataset_combinator.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ripe import utils
4
+
5
+ log = utils.get_pylogger(__name__)
6
+
7
+
8
+ class DatasetCombinator:
9
+ """Combines multiple datasets into one. Length of the combined dataset is the length of the
10
+ longest dataset. Shorter datasets are looped over.
11
+
12
+ Args:
13
+ datasets: List of datasets to combine.
14
+ mode: How to sample from the datasets. Can be either "uniform" or "weighted".
15
+ In "uniform" mode, each dataset is sampled with equal probability.
16
+ In "weighted" mode, each dataset is sampled with probability proportional to its length.
17
+ """
18
+
19
+ def __init__(self, datasets, mode="uniform", weights=None):
20
+ self.datasets = datasets
21
+
22
+ names_datasets = [type(ds).__name__ for ds in self.datasets]
23
+ self.lengths = [len(ds) for ds in datasets]
24
+
25
+ if mode == "weighted":
26
+ self.probs_datasets = [length / sum(self.lengths) for length in self.lengths]
27
+ elif mode == "uniform":
28
+ self.probs_datasets = [1 / len(self.datasets) for _ in self.datasets]
29
+ elif mode == "custom":
30
+ assert weights is not None, "Weights must be provided in custom mode"
31
+ assert len(weights) == len(datasets), "Number of weights must match number of datasets"
32
+ assert sum(weights) == 1.0, "Weights must sum to 1"
33
+ self.probs_datasets = weights
34
+ else:
35
+ raise ValueError(f"Unknown mode {mode}")
36
+
37
+ log.info("Got the following datasets: ")
38
+
39
+ for name, length, prob in zip(names_datasets, self.lengths, self.probs_datasets):
40
+ log.info(f"{name} with {length} samples and probability {prob}")
41
+ log.info(f"Total number of samples: {sum(self.lengths)}")
42
+
43
+ self.num_samples = max(self.lengths)
44
+
45
+ self.dataset_dist = torch.distributions.Categorical(probs=torch.tensor(self.probs_datasets))
46
+
47
+ def __len__(self):
48
+ return self.num_samples
49
+
50
+ def __getitem__(self, idx: int):
51
+ positive_sample = idx % 2 == 0
52
+
53
+ if positive_sample:
54
+ dataset_idx = self.dataset_dist.sample().item()
55
+
56
+ idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
57
+ while idx % 2 == 1:
58
+ idx = torch.randint(0, self.lengths[dataset_idx], (1,)).item()
59
+
60
+ return self.datasets[dataset_idx][idx]
61
+ else:
62
+ dataset_idx_1 = self.dataset_dist.sample().item()
63
+ dataset_idx_2 = self.dataset_dist.sample().item()
64
+
65
+ if dataset_idx_1 == dataset_idx_2:
66
+ idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
67
+ while idx % 2 == 0:
68
+ idx = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
69
+ return self.datasets[dataset_idx_1][idx]
70
+
71
+ else:
72
+ idx_1 = torch.randint(0, self.lengths[dataset_idx_1], (1,)).item()
73
+ idx_2 = torch.randint(0, self.lengths[dataset_idx_2], (1,)).item()
74
+
75
+ sample_1 = self.datasets[dataset_idx_1][idx_1]
76
+ sample_2 = self.datasets[dataset_idx_2][idx_2]
77
+
78
+ sample = {
79
+ "label": False,
80
+ "src_path": sample_1["src_path"],
81
+ "trg_path": sample_2["trg_path"],
82
+ "src_image": sample_1["src_image"],
83
+ "trg_image": sample_2["trg_image"],
84
+ "src_mask": sample_1["src_mask"],
85
+ "trg_mask": sample_2["trg_mask"],
86
+ "homography": sample_2["homography"],
87
+ }
88
+ return sample
imcui/third_party/RIPE/ripe/data/datasets/disk_imw.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from itertools import accumulate
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision.io import read_image
10
+
11
+ from ripe import utils
12
+ from ripe.data.data_transforms import Compose
13
+ from ripe.utils.image_utils import Camera, cameras2F
14
+
15
+ log = utils.get_pylogger(__name__)
16
+
17
+
18
+ class DISK_IMW(Dataset):
19
+ def __init__(
20
+ self,
21
+ root: str,
22
+ stage: str = "val",
23
+ # condition: str = "rain",
24
+ transforms: Optional[Callable] = None,
25
+ ) -> None:
26
+ self.root = root
27
+ self.stage = stage
28
+ self.transforms = transforms
29
+
30
+ if isinstance(self.root, str):
31
+ self.root = Path(self.root)
32
+
33
+ if not self.root.exists():
34
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
35
+
36
+ if transforms is None:
37
+ self.transforms = Compose([])
38
+ else:
39
+ self.transforms = transforms
40
+
41
+ if self.stage not in ["val"]:
42
+ raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
43
+
44
+ json_path = self.root / "imw2020-val" / "dataset.json"
45
+ with open(json_path) as json_file:
46
+ json_data = json.load(json_file)
47
+
48
+ self.scenes = []
49
+
50
+ for scene in json_data:
51
+ self.scenes.append(Scene(self.root / "imw2020-val", json_data[scene]))
52
+
53
+ self.tuples_per_scene = [len(scene) for scene in self.scenes]
54
+
55
+ def __len__(self) -> int:
56
+ return sum(self.tuples_per_scene)
57
+
58
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
59
+ sample: Any = {}
60
+
61
+ i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
62
+
63
+ sample["src_path"], sample["trg_path"], path_calib_src, path_calib_trg = self.scenes[i_scene][i_image]
64
+
65
+ cam_src = Camera.from_calibration_file(path_calib_src)
66
+ cam_trg = Camera.from_calibration_file(path_calib_trg)
67
+
68
+ F = self.get_F(cam_src, cam_trg)
69
+ s2t_R, s2t_T = self.get_relative_pose(cam_src, cam_trg)
70
+
71
+ src_img = read_image(sample["src_path"]) / 255.0
72
+ trg_img = read_image(sample["trg_path"]) / 255.0
73
+
74
+ _, H_src, W_src = src_img.shape
75
+ _, H_trg, W_trg = trg_img.shape
76
+
77
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
78
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
79
+
80
+ H = torch.eye(3)
81
+ if self.transforms:
82
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, H)
83
+
84
+ # check if transformations in self.transforms. Only Normalize is allowed
85
+ for t in self.transforms.transforms:
86
+ if t.__class__.__name__ not in ["Normalize", "Resize"]:
87
+ raise ValueError(f"Transform {t.__class__.__name__} not allowed in DISK_IMW dataset")
88
+
89
+ sample["src_image"] = src_img
90
+ sample["trg_image"] = trg_img
91
+ sample["orig_size_src"] = (H_src, W_src)
92
+ sample["orig_size_trg"] = (H_trg, W_trg)
93
+ sample["src_mask"] = src_mask.to(torch.bool)
94
+ sample["trg_mask"] = trg_mask.to(torch.bool)
95
+ sample["F"] = F
96
+ sample["s2t_R"] = s2t_R
97
+ sample["s2t_T"] = s2t_T
98
+ sample["src_camera"] = cam_src
99
+ sample["trg_camera"] = cam_trg
100
+
101
+ return sample
102
+
103
+ def get_relative_pose(self, cam_src: Camera, cam_trg: Camera) -> Tuple[torch.Tensor, torch.Tensor]:
104
+ R = cam_trg.R @ cam_src.R.T
105
+ T = cam_trg.t - R @ cam_src.t
106
+
107
+ return R, T
108
+
109
+ def get_F(self, cam_src: Camera, cam_trg: Camera) -> torch.Tensor:
110
+ F = cameras2F(cam_src, cam_trg)
111
+
112
+ return F
113
+
114
+ def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
115
+ accumulated_tuples = accumulate(self.tuples_per_scene)
116
+
117
+ if idx >= sum(self.tuples_per_scene):
118
+ raise IndexError(f"Index {idx} out of bounds")
119
+
120
+ idx_scene = None
121
+ for i, accumulated_tuple in enumerate(accumulated_tuples):
122
+ idx_scene = i
123
+ if idx < accumulated_tuple:
124
+ break
125
+
126
+ idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
127
+
128
+ return idx_scene, idx_image
129
+
130
+ def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
131
+ possible_scene_ids = list(range(len(self.scenes)))
132
+ possible_scene_ids.remove(scene_id_to_exclude)
133
+
134
+ idx_scene = random.choice(possible_scene_ids)
135
+ idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
136
+
137
+ return idx_scene, idx_image
138
+
139
+
140
+ class Scene:
141
+ def __init__(self, root_path, scene_data: Dict[str, Any]) -> None:
142
+ self.root_path = root_path
143
+ self.image_path = Path(scene_data["image_path"])
144
+ self.calib_path = Path(scene_data["calib_path"])
145
+ self.image_names = scene_data["images"]
146
+ self.tuples = scene_data["tuples"]
147
+
148
+ def __len__(self) -> int:
149
+ return len(self.tuples)
150
+
151
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
152
+ idx_1 = self.tuples[idx][0]
153
+ idx_2 = self.tuples[idx][1]
154
+
155
+ path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1]) + ".jpg"
156
+ path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2]) + ".jpg"
157
+ path_calib_1 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_1])) + ".h5"
158
+ path_calib_2 = str(self.root_path / self.calib_path / ("calibration_" + self.image_names[idx_2])) + ".h5"
159
+
160
+ return path_image_1, path_image_2, path_calib_1, path_calib_2
imcui/third_party/RIPE/ripe/data/datasets/disk_megadepth.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ from itertools import accumulate
4
+ from pathlib import Path
5
+ from typing import Any, Callable, Dict, Optional, Tuple
6
+
7
+ import torch
8
+ from torch.utils.data import Dataset
9
+ from torchvision.io import read_image
10
+
11
+ from ripe import utils
12
+ from ripe.data.data_transforms import Compose
13
+
14
+ log = utils.get_pylogger(__name__)
15
+
16
+
17
+ class DISK_Megadepth(Dataset):
18
+ def __init__(
19
+ self,
20
+ root: str,
21
+ max_scene_size: int,
22
+ stage: str = "train",
23
+ # condition: str = "rain",
24
+ transforms: Optional[Callable] = None,
25
+ positive_only: bool = False,
26
+ ) -> None:
27
+ self.root = root
28
+ self.stage = stage
29
+ self.transforms = transforms
30
+ self.positive_only = positive_only
31
+
32
+ if isinstance(self.root, str):
33
+ self.root = Path(self.root)
34
+
35
+ if not self.root.exists():
36
+ raise FileNotFoundError(f"Dataset not found at {self.root}")
37
+
38
+ if transforms is None:
39
+ self.transforms = Compose([])
40
+ else:
41
+ self.transforms = transforms
42
+
43
+ if self.stage not in ["train"]:
44
+ raise RuntimeError("Unknown option " + self.stage + " as training stage variable. Valid options: 'train'")
45
+
46
+ json_path = self.root / "megadepth" / "dataset.json"
47
+ with open(json_path) as json_file:
48
+ json_data = json.load(json_file)
49
+
50
+ self.scenes = []
51
+
52
+ for scene in json_data:
53
+ self.scenes.append(Scene(self.root / "megadepth", json_data[scene], max_scene_size))
54
+
55
+ self.tuples_per_scene = [len(scene) for scene in self.scenes]
56
+
57
+ if positive_only:
58
+ log.warning("Using only positive pairs!")
59
+
60
+ def __len__(self) -> int:
61
+ if self.positive_only:
62
+ return sum(self.tuples_per_scene)
63
+ return 2 * sum(self.tuples_per_scene)
64
+
65
+ def __getitem__(self, idx: int) -> Dict[str, Any]:
66
+ sample: Any = {}
67
+
68
+ positive_sample = idx % 2 == 0 or self.positive_only
69
+ if not self.positive_only:
70
+ idx = idx // 2
71
+
72
+ sample["label"] = positive_sample
73
+
74
+ i_scene, i_image = self._get_scene_and_image_id_from_idx(idx)
75
+
76
+ if positive_sample:
77
+ sample["src_path"], sample["trg_path"] = self.scenes[i_scene][i_image]
78
+
79
+ homography = torch.eye(3, dtype=torch.float32)
80
+ else:
81
+ sample["src_path"], _ = self.scenes[i_scene][i_image]
82
+
83
+ i_scene_other, i_image_other = self._get_other_random_scene_and_image_id(i_scene)
84
+
85
+ sample["trg_path"], _ = self.scenes[i_scene_other][i_image_other]
86
+
87
+ homography = torch.zeros((3, 3), dtype=torch.float32)
88
+
89
+ src_img = read_image(sample["src_path"]) / 255.0
90
+ trg_img = read_image(sample["trg_path"]) / 255.0
91
+
92
+ _, H_src, W_src = src_img.shape
93
+ _, H_trg, W_trg = trg_img.shape
94
+
95
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
96
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
97
+
98
+ if self.transforms:
99
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
100
+
101
+ sample["src_image"] = src_img
102
+ sample["trg_image"] = trg_img
103
+ sample["src_mask"] = src_mask.to(torch.bool)
104
+ sample["trg_mask"] = trg_mask.to(torch.bool)
105
+ sample["homography"] = homography
106
+
107
+ return sample
108
+
109
+ def _get_scene_and_image_id_from_idx(self, idx: int) -> Tuple[int, int]:
110
+ accumulated_tuples = accumulate(self.tuples_per_scene)
111
+
112
+ if idx >= sum(self.tuples_per_scene):
113
+ raise IndexError(f"Index {idx} out of bounds")
114
+
115
+ idx_scene = None
116
+ for i, accumulated_tuple in enumerate(accumulated_tuples):
117
+ idx_scene = i
118
+ if idx < accumulated_tuple:
119
+ break
120
+
121
+ idx_image = idx - sum(self.tuples_per_scene[:idx_scene])
122
+
123
+ return idx_scene, idx_image
124
+
125
+ def _get_other_random_scene_and_image_id(self, scene_id_to_exclude: int) -> Tuple[int, int]:
126
+ possible_scene_ids = list(range(len(self.scenes)))
127
+ possible_scene_ids.remove(scene_id_to_exclude)
128
+
129
+ idx_scene = random.choice(possible_scene_ids)
130
+ idx_image = random.randint(0, len(self.scenes[idx_scene]) - 1)
131
+
132
+ return idx_scene, idx_image
133
+
134
+
135
+ class Scene:
136
+ def __init__(self, root_path, scene_data: Dict[str, Any], max_size_scene) -> None:
137
+ self.root_path = root_path
138
+ self.image_path = Path(scene_data["image_path"])
139
+ self.image_names = scene_data["images"]
140
+
141
+ # randomly sample tuples
142
+ if max_size_scene > 0:
143
+ self.tuples = random.sample(scene_data["tuples"], min(max_size_scene, len(scene_data["tuples"])))
144
+
145
+ def __len__(self) -> int:
146
+ return len(self.tuples)
147
+
148
+ def __getitem__(self, idx: int) -> Tuple[str, str]:
149
+ idx_1, idx_2 = random.sample([0, 1, 2], 2)
150
+
151
+ idx_1 = self.tuples[idx][idx_1]
152
+ idx_2 = self.tuples[idx][idx_2]
153
+
154
+ path_image_1 = str(self.root_path / self.image_path / self.image_names[idx_1])
155
+ path_image_2 = str(self.root_path / self.image_path / self.image_names[idx_2])
156
+
157
+ return path_image_1, path_image_2
imcui/third_party/RIPE/ripe/data/datasets/tokyo247.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from glob import glob
4
+ from typing import Any, Callable, Optional
5
+
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from torchvision.io import read_image
9
+
10
+ from ripe import utils
11
+ from ripe.data.data_transforms import Compose
12
+
13
+ log = utils.get_pylogger(__name__)
14
+
15
+
16
+ class Tokyo247(Dataset):
17
+ def __init__(
18
+ self,
19
+ root: str,
20
+ stage: str = "train",
21
+ transforms: Optional[Callable] = None,
22
+ positive_only: bool = False,
23
+ ):
24
+ if stage != "train":
25
+ raise ValueError("Tokyo247Dataset only supports the 'train' stage.")
26
+
27
+ # check if the root directory exists
28
+ if not os.path.isdir(root):
29
+ raise FileNotFoundError(f"Directory {root} does not exist.")
30
+
31
+ self.root_dir = root
32
+ self.transforms = transforms if transforms is not None else Compose([])
33
+ self.positive_only = positive_only
34
+
35
+ self.image_paths = []
36
+ self.positive_pairs = []
37
+
38
+ # Collect images grouped by location folder
39
+ self.locations = {}
40
+ for location_rough in sorted(os.listdir(self.root_dir)):
41
+ location_rough_path = os.path.join(self.root_dir, location_rough)
42
+
43
+ # check if the location_rough_path is a directory
44
+ if not os.path.isdir(location_rough_path):
45
+ continue
46
+
47
+ for location_fine in sorted(os.listdir(location_rough_path)):
48
+ location_fine_path = os.path.join(self.root_dir, location_rough, location_fine)
49
+
50
+ if os.path.isdir(location_fine_path):
51
+ images = sorted(
52
+ glob(os.path.join(location_fine_path, "*.png")),
53
+ key=lambda i: int(i[-7:-4]),
54
+ )
55
+ if len(images) >= 12:
56
+ self.locations[location_fine] = images
57
+ self.image_paths.extend(images)
58
+
59
+ # Generate positive pairs
60
+ for _, images in self.locations.items():
61
+ for i in range(len(images) - 1):
62
+ self.positive_pairs.append((images[i], images[i + 1]))
63
+ self.positive_pairs.append((images[-1], images[0]))
64
+
65
+ if positive_only:
66
+ log.warning("Using only positive pairs!")
67
+
68
+ log.info(f"Found {len(self.positive_pairs)} image pairs.")
69
+
70
+ def __len__(self):
71
+ if self.positive_only:
72
+ return len(self.positive_pairs)
73
+ return 2 * len(self.positive_pairs)
74
+
75
+ def __getitem__(self, idx):
76
+ sample: Any = {}
77
+
78
+ positive_sample = (idx % 2 == 0) or (self.positive_only)
79
+ if not self.positive_only:
80
+ idx = idx // 2
81
+
82
+ sample["label"] = positive_sample
83
+
84
+ if positive_sample: # Positive pair
85
+ img1_path, img2_path = self.positive_pairs[idx]
86
+
87
+ assert os.path.dirname(img1_path) == os.path.dirname(img2_path), (
88
+ f"Source and target image mismatch: {img1_path} vs {img2_path}"
89
+ )
90
+
91
+ homography = torch.eye(3, dtype=torch.float32)
92
+ else: # Negative pair
93
+ img1_path = random.choice(self.image_paths)
94
+ img2_path = random.choice(self.image_paths)
95
+
96
+ # Ensure images are from different folders
97
+ esc = 0
98
+ while os.path.dirname(img1_path) == os.path.dirname(img2_path):
99
+ img2_path = random.choice(self.image_paths)
100
+
101
+ esc += 1
102
+ if esc > 100:
103
+ raise RuntimeError("Could not find a negative pair.")
104
+
105
+ assert os.path.dirname(img1_path) != os.path.dirname(img2_path), (
106
+ f"Source and target image match for negative pair: {img1_path} vs {img2_path}"
107
+ )
108
+
109
+ homography = torch.zeros((3, 3), dtype=torch.float32)
110
+
111
+ sample["src_path"] = img1_path
112
+ sample["trg_path"] = img2_path
113
+
114
+ # Load images
115
+ src_img = read_image(sample["src_path"]) / 255.0
116
+ trg_img = read_image(sample["trg_path"]) / 255.0
117
+
118
+ _, H_src, W_src = src_img.shape
119
+ _, H_trg, W_trg = src_img.shape
120
+
121
+ src_mask = torch.ones((1, H_src, W_src), dtype=torch.uint8)
122
+ trg_mask = torch.ones((1, H_trg, W_trg), dtype=torch.uint8)
123
+
124
+ # Apply transformations
125
+ if self.transforms:
126
+ src_img, trg_img, src_mask, trg_mask, _ = self.transforms(src_img, trg_img, src_mask, trg_mask, homography)
127
+
128
+ sample["src_image"] = src_img
129
+ sample["trg_image"] = trg_img
130
+ sample["src_mask"] = src_mask.to(torch.bool)
131
+ sample["trg_mask"] = trg_mask.to(torch.bool)
132
+ sample["homography"] = homography
133
+
134
+ return sample
imcui/third_party/RIPE/ripe/losses/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/losses/contrastive_loss.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ def second_nearest_neighbor(desc1, desc2):
7
+ if desc2.shape[0] < 2: # We cannot perform snn check, so output empty matches
8
+ raise ValueError("desc2 should have at least 2 descriptors")
9
+
10
+ dist = torch.cdist(desc1, desc2, p=2)
11
+
12
+ vals, idxs = torch.topk(dist, 2, dim=1, largest=False)
13
+ idxs_in_2 = idxs[:, 1]
14
+ idxs_in_1 = torch.arange(0, idxs_in_2.size(0), device=dist.device)
15
+
16
+ matches_idxs = torch.cat([idxs_in_1.view(-1, 1), idxs_in_2.view(-1, 1)], 1)
17
+
18
+ return vals[:, 1].view(-1, 1), matches_idxs
19
+
20
+
21
+ def contrastive_loss(
22
+ desc1,
23
+ desc2,
24
+ matches,
25
+ inliers,
26
+ label,
27
+ logits_1,
28
+ logits_2,
29
+ pos_margin=1.0,
30
+ neg_margin=1.0,
31
+ ):
32
+ if inliers.sum() < 8: # if there are too few inliers, calculate loss on all matches
33
+ inliers = torch.ones_like(inliers)
34
+
35
+ matched_inliers_descs1 = desc1[matches[:, 0][inliers]]
36
+ matched_inliers_descs2 = desc2[matches[:, 1][inliers]]
37
+
38
+ if logits_1 is not None and logits_2 is not None:
39
+ matched_inliers_logits1 = logits_1[matches[:, 0][inliers]]
40
+ matched_inliers_logits2 = logits_2[matches[:, 1][inliers]]
41
+ logits = torch.minimum(matched_inliers_logits1, matched_inliers_logits2)
42
+ else:
43
+ logits = torch.ones_like(matches[:, 0][inliers])
44
+
45
+ if label:
46
+ snn_match_dists_1, idx1 = second_nearest_neighbor(matched_inliers_descs1, desc2)
47
+ snn_match_dists_2, idx2 = second_nearest_neighbor(matched_inliers_descs2, desc1)
48
+
49
+ dists = torch.hstack((snn_match_dists_1, snn_match_dists_2))
50
+ min_dists_idx = torch.min(dists, dim=1).indices.unsqueeze(1)
51
+
52
+ dists_hard = torch.gather(dists, 1, min_dists_idx).squeeze(-1)
53
+ dists_pos = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
54
+
55
+ contrastive_loss = torch.clamp(pos_margin + dists_pos - dists_hard, min=0.0)
56
+
57
+ contrastive_loss = contrastive_loss * logits
58
+
59
+ contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
60
+ else:
61
+ dists = F.pairwise_distance(matched_inliers_descs1, matched_inliers_descs2)
62
+ contrastive_loss = torch.clamp(neg_margin - dists, min=0.0)
63
+
64
+ contrastive_loss = contrastive_loss * logits
65
+
66
+ contrastive_loss = contrastive_loss.sum() / (logits.sum() + 1e-8) # small epsilon to avoid division by zero
67
+
68
+ return contrastive_loss
69
+
70
+
71
+ class ContrastiveLoss(nn.Module):
72
+ def __init__(self, pos_margin=1.0, neg_margin=1.0):
73
+ super().__init__()
74
+ self.pos_margin = pos_margin
75
+ self.neg_margin = neg_margin
76
+
77
+ def forward(self, desc1, desc2, matches, inliers, label, logits_1=None, logits_2=None):
78
+ return contrastive_loss(
79
+ desc1,
80
+ desc2,
81
+ matches,
82
+ inliers,
83
+ label,
84
+ logits_1,
85
+ logits_2,
86
+ self.pos_margin,
87
+ self.neg_margin,
88
+ )
imcui/third_party/RIPE/ripe/matcher/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/matcher/concurrent_matcher.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import concurrent.futures
2
+
3
+ import torch
4
+
5
+
6
+ class ConcurrentMatcher:
7
+ """A class that performs matching and geometric filtering in parallel using a thread pool executor.
8
+ It matches keypoints from two sets of descriptors and applies a robust estimator to filter the matches based on geometric constraints.
9
+
10
+ Args:
11
+ matcher (callable): A callable that takes two sets of descriptors and returns distances and indices of matches.
12
+ robust_estimator (callable): A callable that estimates a geometric transformation and returns inliers.
13
+ min_num_matches (int, optional): Minimum number of matches required to perform geometric filtering. Defaults to 8.
14
+ max_workers (int, optional): Maximum number of threads in the thread pool executor. Defaults to 12.
15
+ """
16
+
17
+ def __init__(self, matcher, robust_estimator, min_num_matches=8, max_workers=12):
18
+ self.matcher = matcher
19
+ self.robust_estimator = robust_estimator
20
+ self.min_num_matches = min_num_matches
21
+
22
+ self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
23
+
24
+ @torch.no_grad()
25
+ def __call__(
26
+ self,
27
+ kpts1,
28
+ kpts2,
29
+ pdesc1,
30
+ pdesc2,
31
+ selected_mask1,
32
+ selected_mask2,
33
+ inl_th,
34
+ label=None,
35
+ ):
36
+ dev = pdesc1.device
37
+ B = pdesc1.shape[0]
38
+
39
+ batch_rel_idx_matches = [None] * B
40
+ batch_idx_matches = [None] * B
41
+ future_results = [None] * B
42
+
43
+ for b in range(B):
44
+ if selected_mask1[b].sum() < 16 or selected_mask2[b].sum() < 16:
45
+ continue
46
+
47
+ dists, idx_matches = self.matcher(pdesc1[b][selected_mask1[b]], pdesc2[b][selected_mask2[b]])
48
+
49
+ batch_rel_idx_matches[b] = idx_matches.clone()
50
+
51
+ # calculate ABSOLUTE indexes
52
+ idx_matches[:, 0] = torch.nonzero(selected_mask1[b], as_tuple=False)[idx_matches[:, 0]].squeeze()
53
+ idx_matches[:, 1] = torch.nonzero(selected_mask2[b], as_tuple=False)[idx_matches[:, 1]].squeeze()
54
+
55
+ batch_idx_matches[b] = idx_matches
56
+
57
+ # if not enough matches
58
+ if idx_matches.shape[0] < self.min_num_matches:
59
+ ransac_inliers = torch.zeros((idx_matches.shape[0]), device=dev).bool()
60
+ future_results[b] = (None, ransac_inliers)
61
+ continue
62
+
63
+ # use label information to exclude negative pairs from geometric filtering process -> enforces more descriminative descriptors
64
+ if label is not None and label[b] == 0:
65
+ ransac_inliers = torch.ones((idx_matches.shape[0]), device=dev).bool()
66
+ future_results[b] = (None, ransac_inliers)
67
+ continue
68
+
69
+ mkpts1 = kpts1[b][idx_matches[:, 0]]
70
+ mkpts2 = kpts2[b][idx_matches[:, 1]]
71
+
72
+ future_results[b] = self.executor.submit(self.robust_estimator, mkpts1, mkpts2, inl_th)
73
+
74
+ batch_ransac_inliers = [None] * B
75
+ batch_Fm = [None] * B
76
+
77
+ for b in range(B):
78
+ future_result = future_results[b]
79
+ if future_result is None:
80
+ ransac_inliers = None
81
+ Fm = None
82
+ elif isinstance(future_result, tuple):
83
+ Fm, ransac_inliers = future_result
84
+ else:
85
+ Fm, ransac_inliers = future_result.result()
86
+
87
+ # if no inliers
88
+ if ransac_inliers.sum() == 0:
89
+ ransac_inliers = ransac_inliers.squeeze(
90
+ -1
91
+ ) # kornia.geometry.ransac.RANSAC returns (N, 1) tensor if no inliers and (N,) tensor if inliers
92
+ Fm = None
93
+
94
+ batch_ransac_inliers[b] = ransac_inliers
95
+ batch_Fm[b] = Fm
96
+
97
+ return batch_rel_idx_matches, batch_idx_matches, batch_ransac_inliers, batch_Fm
imcui/third_party/RIPE/ripe/matcher/pose_estimator_poselib.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import poselib
2
+ import torch
3
+
4
+
5
+ class PoseLibRelativePoseEstimator:
6
+ """PoseLibRelativePoseEstimator estimates the fundamental matrix using poselib library.
7
+ It uses the poselib's estimate_fundamental function to compute the fundamental matrix and inliers based on the provided points.
8
+ Args:
9
+ None
10
+ """
11
+
12
+ def __init__(self):
13
+ pass
14
+
15
+ def __call__(self, pts0, pts1, inl_th):
16
+ F, info = poselib.estimate_fundamental(
17
+ pts0.cpu().numpy(),
18
+ pts1.cpu().numpy(),
19
+ {
20
+ "max_epipolar_error": inl_th,
21
+ },
22
+ )
23
+
24
+ success = F is not None
25
+ if success:
26
+ inliers = info.pop("inliers")
27
+ inliers = torch.tensor(inliers, dtype=torch.bool, device=pts0.device)
28
+ else:
29
+ inliers = torch.zeros(pts0.shape[0], dtype=torch.bool, device=pts0.device)
30
+
31
+ return F, inliers
imcui/third_party/RIPE/ripe/model_zoo/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .vgg_hyper import vgg_hyper # noqa: F401
imcui/third_party/RIPE/ripe/model_zoo/vgg_hyper.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import torch
4
+
5
+ from ripe.models.backbones.vgg import VGG
6
+ from ripe.models.ripe import RIPE
7
+ from ripe.models.upsampler.hypercolumn_features import HyperColumnFeatures
8
+
9
+
10
+ def vgg_hyper(model_path: Path = None, desc_shares=None):
11
+ if model_path is None:
12
+ # check if the weights file exists in the current directory
13
+ model_path = Path("/tmp/ripe_weights.pth")
14
+
15
+ if model_path.exists():
16
+ print(f"Using existing weights from {model_path}")
17
+ else:
18
+ print("Weights file not found. Downloading ...")
19
+ torch.hub.download_url_to_file(
20
+ "https://cvg.hhi.fraunhofer.de/RIPE/ripe_weights.pth",
21
+ "/tmp/ripe_weights.pth",
22
+ )
23
+ else:
24
+ if not model_path.exists():
25
+ print(f"Error: {model_path} does not exist.")
26
+ raise FileNotFoundError(f"Error: {model_path} does not exist.")
27
+
28
+ backbone = VGG(pretrained=False)
29
+ upsampler = HyperColumnFeatures()
30
+
31
+ extractor = RIPE(
32
+ net=backbone,
33
+ upsampler=upsampler,
34
+ desc_shares=desc_shares,
35
+ )
36
+
37
+ extractor.load_state_dict(torch.load(model_path, map_location="cpu"))
38
+
39
+ return extractor
imcui/third_party/RIPE/ripe/models/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/models/backbones/__init__.py ADDED
File without changes
imcui/third_party/RIPE/ripe/models/backbones/backbone_base.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class BackboneBase(nn.Module):
6
+ """Base class for backbone networks. Provides a standard interface for preprocessing inputs and
7
+ defining encoder dimensions.
8
+
9
+ Args:
10
+ nchannels (int): Number of input channels.
11
+ use_instance_norm (bool): Whether to apply instance normalization.
12
+ """
13
+
14
+ def __init__(self, nchannels=3, use_instance_norm=False):
15
+ super().__init__()
16
+ assert nchannels > 0, "Number of channels must be positive."
17
+ self.nchannels = nchannels
18
+ self.use_instance_norm = use_instance_norm
19
+ self.norm = nn.InstanceNorm2d(nchannels) if use_instance_norm else None
20
+
21
+ def get_dim_layers_encoder(self):
22
+ """Get dimensions of encoder layers."""
23
+ raise NotImplementedError("Subclasses must implement this method.")
24
+
25
+ def _forward(self, x):
26
+ """Define the forward pass for the backbone."""
27
+ raise NotImplementedError("Subclasses must implement this method.")
28
+
29
+ def forward(self, x: torch.Tensor, preprocess=True):
30
+ """Forward pass with optional preprocessing.
31
+
32
+ Args:
33
+ x (Tensor): Input tensor.
34
+ preprocess (bool): Whether to apply channel reduction.
35
+ """
36
+ if preprocess:
37
+ if x.dim() != 4:
38
+ if x.dim() == 2 and x.shape[0] > 3 and x.shape[1] > 3:
39
+ x = x.unsqueeze(0).unsqueeze(0)
40
+ elif x.dim() == 3:
41
+ x = x.unsqueeze(0)
42
+ else:
43
+ raise ValueError(f"Unexpected input shape: {x.shape}")
44
+
45
+ if self.nchannels == 1 and x.shape[1] != 1:
46
+ if len(x.shape) == 4: # Assumes (batch, channel, height, width)
47
+ x = torch.mean(x, axis=1, keepdim=True)
48
+ else:
49
+ raise ValueError(f"Unexpected input shape: {x.shape}")
50
+
51
+ #
52
+ if self.nchannels == 3 and x.shape[1] == 1:
53
+ if len(x.shape) == 4:
54
+ x = x.repeat(1, 3, 1, 1)
55
+ else:
56
+ raise ValueError(f"Unexpected input shape: {x.shape}")
57
+
58
+ if self.use_instance_norm:
59
+ x = self.norm(x)
60
+
61
+ return self._forward(x)
imcui/third_party/RIPE/ripe/models/backbones/vgg.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
2
+
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .backbone_base import BackboneBase
7
+ from .vgg_utils import VGG19, ConvRefiner, Decoder
8
+
9
+
10
+ class VGG(BackboneBase):
11
+ def __init__(self, nchannels=3, pretrained=True, use_instance_norm=True, mode="dect"):
12
+ super().__init__(nchannels=nchannels, use_instance_norm=use_instance_norm)
13
+
14
+ self.nchannels = nchannels
15
+ self.mode = mode
16
+
17
+ if self.mode not in ["dect", "desc", "dect+desc"]:
18
+ raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
19
+
20
+ NUM_OUTPUT_CHANNELS, hidden_blocks = self._get_mode_params(mode)
21
+ conv_refiner = self._create_conv_refiner(NUM_OUTPUT_CHANNELS, hidden_blocks)
22
+
23
+ self.encoder = VGG19(pretrained=pretrained, num_input_channels=nchannels)
24
+ self.decoder = Decoder(conv_refiner, num_prototypes=NUM_OUTPUT_CHANNELS)
25
+
26
+ def _get_mode_params(self, mode):
27
+ """Get the number of output channels and the number of hidden blocks for the ConvRefiner.
28
+
29
+ Depending on the mode, the ConvRefiner will have a different number of output channels.
30
+ """
31
+
32
+ if mode == "dect":
33
+ return 1, 8
34
+ elif mode == "desc":
35
+ return 256, 5
36
+ elif mode == "dect+desc":
37
+ return 256 + 1, 8
38
+
39
+ def _create_conv_refiner(self, num_output_channels, hidden_blocks):
40
+ return nn.ModuleDict(
41
+ {
42
+ "8": ConvRefiner(
43
+ 512,
44
+ 512,
45
+ 256 + num_output_channels,
46
+ hidden_blocks=hidden_blocks,
47
+ residual=True,
48
+ ),
49
+ "4": ConvRefiner(
50
+ 256 + 256,
51
+ 256,
52
+ 128 + num_output_channels,
53
+ hidden_blocks=hidden_blocks,
54
+ residual=True,
55
+ ),
56
+ "2": ConvRefiner(
57
+ 128 + 128,
58
+ 128,
59
+ 64 + num_output_channels,
60
+ hidden_blocks=hidden_blocks,
61
+ residual=True,
62
+ ),
63
+ "1": ConvRefiner(
64
+ 64 + 64,
65
+ 64,
66
+ 1 + num_output_channels,
67
+ hidden_blocks=hidden_blocks,
68
+ residual=True,
69
+ ),
70
+ }
71
+ )
72
+
73
+ def get_dim_layers_encoder(self):
74
+ return self.encoder.get_dim_layers()
75
+
76
+ def _forward(self, x):
77
+ features, sizes = self.encoder(x)
78
+ output = 0
79
+ context = None
80
+ scales = self.decoder.scales
81
+ for idx, (feature_map, scale) in enumerate(zip(reversed(features), scales)):
82
+ delta_descriptor, context = self.decoder(feature_map, scale=scale, context=context)
83
+ output = output + delta_descriptor
84
+ if idx < len(scales) - 1:
85
+ size = sizes[-(idx + 2)]
86
+ output = F.interpolate(output, size=size, mode="bilinear", align_corners=False)
87
+ context = F.interpolate(context, size=size, mode="bilinear", align_corners=False)
88
+
89
+ if self.mode == "dect":
90
+ return {"heatmap": output, "coarse_descs": features}
91
+ elif self.mode == "desc":
92
+ return {"fine_descs": output, "coarse_descs": features}
93
+ elif self.mode == "dect+desc":
94
+ logits = output[:, :1].contiguous()
95
+ descs = output[:, 1:].contiguous()
96
+
97
+ return {"heatmap": logits, "fine_descs": descs, "coarse_descs": features}
98
+ else:
99
+ raise ValueError("mode should be 'dect', 'desc' or 'dect+desc'")
imcui/third_party/RIPE/ripe/models/backbones/vgg_utils.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adapted from: https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/encoder.py and https://github.com/Parskatt/DeDoDe/blob/main/DeDoDe/decoder.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.models as tvm
6
+
7
+ from ripe import utils
8
+
9
+ log = utils.get_pylogger(__name__)
10
+
11
+
12
+ class Decoder(nn.Module):
13
+ def __init__(self, layers, *args, super_resolution=False, num_prototypes=1, **kwargs) -> None:
14
+ super().__init__(*args, **kwargs)
15
+ self.layers = layers
16
+ self.scales = self.layers.keys()
17
+ self.super_resolution = super_resolution
18
+ self.num_prototypes = num_prototypes
19
+
20
+ def forward(self, features, context=None, scale=None):
21
+ if context is not None:
22
+ features = torch.cat((features, context), dim=1)
23
+ stuff = self.layers[scale](features)
24
+ logits, context = (
25
+ stuff[:, : self.num_prototypes],
26
+ stuff[:, self.num_prototypes :],
27
+ )
28
+ return logits, context
29
+
30
+
31
+ class ConvRefiner(nn.Module):
32
+ def __init__(
33
+ self,
34
+ in_dim=6,
35
+ hidden_dim=16,
36
+ out_dim=2,
37
+ dw=True,
38
+ kernel_size=5,
39
+ hidden_blocks=5,
40
+ residual=False,
41
+ ):
42
+ super().__init__()
43
+ self.block1 = self.create_block(
44
+ in_dim,
45
+ hidden_dim,
46
+ dw=False,
47
+ kernel_size=1,
48
+ )
49
+ self.hidden_blocks = nn.Sequential(
50
+ *[
51
+ self.create_block(
52
+ hidden_dim,
53
+ hidden_dim,
54
+ dw=dw,
55
+ kernel_size=kernel_size,
56
+ )
57
+ for hb in range(hidden_blocks)
58
+ ]
59
+ )
60
+ self.hidden_blocks = self.hidden_blocks
61
+ self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
62
+ self.residual = residual
63
+
64
+ def create_block(
65
+ self,
66
+ in_dim,
67
+ out_dim,
68
+ dw=True,
69
+ kernel_size=5,
70
+ bias=True,
71
+ norm_type=nn.BatchNorm2d,
72
+ ):
73
+ num_groups = 1 if not dw else in_dim
74
+ if dw:
75
+ assert out_dim % in_dim == 0, "outdim must be divisible by indim for depthwise"
76
+ conv1 = nn.Conv2d(
77
+ in_dim,
78
+ out_dim,
79
+ kernel_size=kernel_size,
80
+ stride=1,
81
+ padding=kernel_size // 2,
82
+ groups=num_groups,
83
+ bias=bias,
84
+ )
85
+ norm = norm_type(out_dim) if norm_type is nn.BatchNorm2d else norm_type(num_channels=out_dim)
86
+ relu = nn.ReLU(inplace=True)
87
+ conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
88
+ return nn.Sequential(conv1, norm, relu, conv2)
89
+
90
+ def forward(self, feats):
91
+ b, c, hs, ws = feats.shape
92
+ x0 = self.block1(feats)
93
+ x = self.hidden_blocks(x0)
94
+ if self.residual:
95
+ x = (x + x0) / 1.4
96
+ x = self.out_conv(x)
97
+ return x
98
+
99
+
100
+ class VGG19(nn.Module):
101
+ def __init__(self, pretrained=False, num_input_channels=3) -> None:
102
+ super().__init__()
103
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
104
+ # Maxpool layers: 6, 13, 26, 39
105
+
106
+ if num_input_channels != 3:
107
+ log.info(f"Changing input channels from 3 to {num_input_channels}")
108
+ self.layers[0] = nn.Conv2d(num_input_channels, 64, 3, 1, 1)
109
+
110
+ def get_dim_layers(self):
111
+ return [64, 128, 256, 512]
112
+
113
+ def forward(self, x, **kwargs):
114
+ feats = []
115
+ sizes = []
116
+ for layer in self.layers:
117
+ if isinstance(layer, nn.MaxPool2d):
118
+ feats.append(x)
119
+ sizes.append(x.shape[-2:])
120
+ x = layer(x)
121
+ return feats, sizes
122
+
123
+
124
+ class VGG(nn.Module):
125
+ def __init__(self, size="19", pretrained=False) -> None:
126
+ super().__init__()
127
+ if size == "11":
128
+ self.layers = nn.ModuleList(tvm.vgg11_bn(pretrained=pretrained).features[:22])
129
+ elif size == "13":
130
+ self.layers = nn.ModuleList(tvm.vgg13_bn(pretrained=pretrained).features[:28])
131
+ elif size == "19":
132
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
133
+ # Maxpool layers: 6, 13, 26, 39
134
+
135
+ def forward(self, x, **kwargs):
136
+ feats = []
137
+ sizes = []
138
+ for layer in self.layers:
139
+ if isinstance(layer, nn.MaxPool2d):
140
+ feats.append(x)
141
+ sizes.append(x.shape[-2:])
142
+ x = layer(x)
143
+ return feats, sizes
imcui/third_party/RIPE/ripe/models/ripe.py ADDED
@@ -0,0 +1,303 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from ripe import utils
9
+ from ripe.utils.utils import gridify
10
+
11
+ log = utils.get_pylogger(__name__)
12
+
13
+
14
+ class KeypointSampler(nn.Module):
15
+ """
16
+ Sample keypoints according to a Heatmap
17
+ Adapted from: https://github.com/verlab/DALF_CVPR_2023/blob/main/modules/models/DALF.py
18
+ """
19
+
20
+ def __init__(self, window_size=8):
21
+ super().__init__()
22
+ self.window_size = window_size
23
+ self.idx_cells = None # Cache for meshgrid indices
24
+
25
+ def sample(self, grid):
26
+ """
27
+ Sample keypoints given a grid where each cell has logits stacked in last dimension
28
+ Input
29
+ grid: [B, C, H//w, W//w, w*w]
30
+
31
+ Returns
32
+ log_probs: [B, C, H//w, W//w ] - logprobs of selected samples
33
+ choices: [B, C, H//w, W//w] indices of choices
34
+ accept_mask: [B, C, H//w, W//w] mask of accepted keypoints
35
+
36
+ """
37
+ chooser = torch.distributions.Categorical(logits=grid)
38
+ choices = chooser.sample()
39
+ logits_selected = torch.gather(grid, -1, choices.unsqueeze(-1)).squeeze(-1)
40
+
41
+ flipper = torch.distributions.Bernoulli(logits=logits_selected)
42
+ accepted_choices = flipper.sample()
43
+
44
+ # Sum log-probabilities is equivalent to multiplying the probabilities
45
+ log_probs = chooser.log_prob(choices) + flipper.log_prob(accepted_choices)
46
+
47
+ accept_mask = accepted_choices.gt(0)
48
+
49
+ return (
50
+ log_probs.squeeze(1),
51
+ choices,
52
+ accept_mask.squeeze(1),
53
+ logits_selected.squeeze(1),
54
+ )
55
+
56
+ def precompute_idx_cells(self, H, W, device):
57
+ idx_cells = gridify(
58
+ torch.dstack(
59
+ torch.meshgrid(
60
+ torch.arange(H, dtype=torch.float32, device=device),
61
+ torch.arange(W, dtype=torch.float32, device=device),
62
+ )
63
+ )
64
+ .permute(2, 0, 1)
65
+ .unsqueeze(0)
66
+ .expand(1, -1, -1, -1),
67
+ window_size=self.window_size,
68
+ )
69
+
70
+ return idx_cells
71
+
72
+ def forward(self, x, mask_padding=None):
73
+ """
74
+ Sample keypoints from a heatmap
75
+ Input
76
+ x: [B, C, H, W] Heatmap
77
+ mask_padding: [B, 1, H, W] Mask for padding (optional)
78
+ Returns
79
+ keypoints: [B, H//w, W//w, 2] Keypoints in (x, y) format
80
+ log_probs: [B, H//w, W//w] Log probabilities of selected keypoints
81
+ mask: [B, H//w, W//w] Mask of accepted keypoints
82
+ mask_padding: [B, 1, H//w, W//w] Mask of padding (optional)
83
+ logits_selected: [B, H//w, W//w] Logits of selected keypoints
84
+ """
85
+
86
+ B, C, H, W = x.shape
87
+
88
+ keypoint_cells = gridify(x, self.window_size)
89
+
90
+ mask_padding = (
91
+ (torch.min(gridify(mask_padding, self.window_size), dim=4).values) if mask_padding is not None else None
92
+ )
93
+
94
+ if self.idx_cells is None or self.idx_cells.shape[2:4] != (
95
+ H // self.window_size,
96
+ W // self.window_size,
97
+ ):
98
+ self.idx_cells = self.precompute_idx_cells(H, W, x.device)
99
+
100
+ log_probs, idx, mask, logits_selected = self.sample(keypoint_cells)
101
+
102
+ keypoints = (
103
+ torch.gather(
104
+ self.idx_cells.expand(B, -1, -1, -1, -1),
105
+ -1,
106
+ idx.repeat(1, 2, 1, 1).unsqueeze(-1),
107
+ )
108
+ .squeeze(-1)
109
+ .permute(0, 2, 3, 1)
110
+ )
111
+
112
+ # flip keypoints to (x, y) format
113
+ return keypoints.flip(-1), log_probs, mask, mask_padding, logits_selected
114
+
115
+
116
+ class RIPE(nn.Module):
117
+ """
118
+ Base class for extracting keypoints and descriptors
119
+ Input
120
+ x: [B, C, H, W] Images
121
+
122
+ Returns
123
+ kpts:
124
+ list of size [B] with detected keypoints
125
+ descs:
126
+ list of size [B] with descriptors
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ net,
132
+ upsampler,
133
+ window_size: int = 8,
134
+ non_linearity_dect=None,
135
+ desc_shares: Optional[List[int]] = None,
136
+ descriptor_dim: int = 256,
137
+ device=None,
138
+ ):
139
+ super().__init__()
140
+ self.net = net
141
+
142
+ self.detector = KeypointSampler(window_size)
143
+ self.upsampler = upsampler
144
+ self.sampler = None
145
+ self.window_size = window_size
146
+ self.non_linearity_dect = non_linearity_dect if non_linearity_dect is not None else nn.Identity()
147
+
148
+ log.info(f"Training with window size {window_size}.")
149
+ log.info(f"Use {non_linearity_dect} as final non-linearity before the detection heatmap.")
150
+
151
+ dim_coarse_desc = self.get_dim_raw_desc()
152
+
153
+ if desc_shares is not None:
154
+ assert upsampler.name == "HyperColumnFeatures", (
155
+ "Individual descriptor convolutions are only supported with HyperColumnFeatures"
156
+ )
157
+ assert len(desc_shares) == 4, "desc_shares should have 4 elements"
158
+ assert sum(desc_shares) == descriptor_dim, f"sum of desc_shares should be {descriptor_dim}"
159
+
160
+ self.conv_dim_reduction_coarse_desc = nn.ModuleList()
161
+
162
+ for dim_in, dim_out in zip(dim_coarse_desc, desc_shares):
163
+ log.info(f"Training dim reduction descriptor with {dim_in} -> {dim_out} 1x1 conv")
164
+ self.conv_dim_reduction_coarse_desc.append(
165
+ nn.Conv1d(dim_in, dim_out, kernel_size=1, stride=1, padding=0)
166
+ )
167
+ else:
168
+ if descriptor_dim is not None:
169
+ log.info(f"Training dim reduction descriptor with {sum(dim_coarse_desc)} -> {descriptor_dim} 1x1 conv")
170
+ self.conv_dim_reduction_coarse_desc = nn.Conv1d(
171
+ sum(dim_coarse_desc),
172
+ descriptor_dim,
173
+ kernel_size=1,
174
+ stride=1,
175
+ padding=0,
176
+ )
177
+ else:
178
+ log.warning(
179
+ f"No descriptor dimension specified, no 1x1 conv will be applied! Direct usage of {sum(dim_coarse_desc)}-dimensional raw descriptor"
180
+ )
181
+ self.conv_dim_reduction_coarse_desc = nn.Identity()
182
+
183
+ def get_dim_raw_desc(self):
184
+ layers_dims_encoder = self.net.get_dim_layers_encoder()
185
+
186
+ if self.upsampler.name == "InterpolateSparse2d":
187
+ return [layers_dims_encoder[-1]]
188
+ elif self.upsampler.name == "HyperColumnFeatures":
189
+ return layers_dims_encoder
190
+ else:
191
+ raise ValueError(f"Unknown interpolator {self.upsampler.name}")
192
+
193
+ @torch.inference_mode()
194
+ def detectAndCompute(self, img, threshold=0.5, top_k=2048, output_aux=False):
195
+ self.train(False)
196
+
197
+ if img.dim() == 3:
198
+ img = img.unsqueeze(0)
199
+
200
+ out = self(img, training=False)
201
+ B, K, H, W = out["heatmap"].shape
202
+
203
+ assert B == 1, "Batch size should be 1"
204
+
205
+ kpts = [{"xy": self.NMS(out["heatmap"][b], threshold)} for b in range(B)]
206
+
207
+ if top_k is not None:
208
+ for b in range(B):
209
+ scores = out["heatmap"][b].squeeze(0)[kpts[b]["xy"][:, 1].long(), kpts[b]["xy"][:, 0].long()]
210
+ sorted_idx = torch.argsort(-scores)
211
+ kpts[b]["xy"] = kpts[b]["xy"][sorted_idx[:top_k]]
212
+ if "logprobs" in kpts[b]:
213
+ kpts[b]["logprobs"] = kpts[b]["xy"][sorted_idx[:top_k]]
214
+
215
+ if kpts[0]["xy"].shape[0] == 0:
216
+ raise RuntimeError("No keypoints detected")
217
+
218
+ # the following works for batch size 1 only
219
+
220
+ descs = self.get_descs(out["coarse_descs"], img, kpts[0]["xy"].unsqueeze(0), H, W)
221
+ descs = descs.squeeze(0)
222
+
223
+ score_map = out["heatmap"][0].squeeze(0)
224
+
225
+ kpts = kpts[0]["xy"]
226
+
227
+ scores = score_map[kpts[:, 1], kpts[:, 0]]
228
+ scores /= score_map.max()
229
+
230
+ sort_idx = torch.argsort(-scores)
231
+ kpts, descs, scores = kpts[sort_idx], descs[sort_idx], scores[sort_idx]
232
+
233
+ if output_aux:
234
+ return (
235
+ kpts.float(),
236
+ descs,
237
+ scores,
238
+ {
239
+ "heatmap": out["heatmap"],
240
+ "descs": out["coarse_descs"],
241
+ "conv": self.conv_dim_reduction_coarse_desc,
242
+ },
243
+ )
244
+
245
+ return kpts.float(), descs, scores
246
+
247
+ def NMS(self, x, threshold=3.0, kernel_size=3):
248
+ pad = kernel_size // 2
249
+ local_max = nn.MaxPool2d(kernel_size=kernel_size, stride=1, padding=pad)(x)
250
+
251
+ pos = (x == local_max) & (x > threshold)
252
+ return pos.nonzero()[..., 1:].flip(-1)
253
+
254
+ def get_descs(self, feature_map, guidance, kpts, H, W):
255
+ descs = self.upsampler(feature_map, kpts, H, W)
256
+
257
+ if isinstance(self.conv_dim_reduction_coarse_desc, nn.ModuleList):
258
+ # individual descriptor convolutions for each layer
259
+ desc_conv = []
260
+ for desc, conv in zip(descs, self.conv_dim_reduction_coarse_desc):
261
+ desc_conv.append(conv(desc.permute(0, 2, 1)).permute(0, 2, 1))
262
+ desc = torch.cat(desc_conv, dim=-1)
263
+ else:
264
+ desc = torch.cat(descs, dim=-1)
265
+ desc = self.conv_dim_reduction_coarse_desc(desc.permute(0, 2, 1)).permute(0, 2, 1)
266
+
267
+ desc = F.normalize(desc, dim=2)
268
+
269
+ return desc
270
+
271
+ def forward(self, x, mask_padding=None, training=False):
272
+ B, C, H, W = x.shape
273
+ out = self.net(x)
274
+ out["heatmap"] = self.non_linearity_dect(out["heatmap"])
275
+ # print(out['map'].shape, out['descr'].shape)
276
+ if training:
277
+ kpts, log_probs, mask, mask_padding, logits_selected = self.detector(out["heatmap"], mask_padding)
278
+
279
+ filter_A = kpts[:, :, :, 0] >= 16
280
+ filter_B = kpts[:, :, :, 1] >= 16
281
+ filter_C = kpts[:, :, :, 0] < W - 16
282
+ filter_D = kpts[:, :, :, 1] < H - 16
283
+ filter_all = filter_A * filter_B * filter_C * filter_D
284
+
285
+ mask = mask * filter_all
286
+
287
+ return (
288
+ kpts.view(B, -1, 2),
289
+ log_probs.view(B, -1),
290
+ mask.view(B, -1),
291
+ mask_padding.view(B, -1),
292
+ logits_selected.view(B, -1),
293
+ out,
294
+ )
295
+ else:
296
+ return out
297
+
298
+
299
+ def output_number_trainable_params(model):
300
+ model_parameters = filter(lambda p: p.requires_grad, model.parameters())
301
+ nb_params = sum([np.prod(p.size()) for p in model_parameters])
302
+
303
+ print(f"Number of trainable parameters: {nb_params:d}")