diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..41d770a277b3c57c15453c4d17d11c015a5de85a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+assets/fisheye-skyline.jpg filter=lfs diff=lfs merge=lfs -text
+assets/teaser.gif filter=lfs diff=lfs merge=lfs -text
+demo.ipynb filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..7a404431fe7e7e9e10bb8fb485f39331c86553aa
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,138 @@
+# folders
+data/**
+outputs/**
+weights/**
+**.DS_Store
+.vscode/**
+wandb/**
+third_party/**
+
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+*.log
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+env/
+venv/
+ENV/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d369a9455e7874673bb022ccd0023c3739dc9e5b
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,35 @@
+default_stages: [commit]
+default_language_version:
+ python: python3.10
+repos:
+ - repo: https://github.com/psf/black
+ rev: 23.9.1
+ hooks:
+ - id: black
+ args: [--line-length=100]
+ exclude: ^(venv/|docs/)
+ types: [python]
+ - repo: https://github.com/PyCQA/flake8
+ rev: 6.1.0
+ hooks:
+ - id: flake8
+ additional_dependencies: [flake8-docstrings]
+ args:
+ [
+ --max-line-length=100,
+ --docstring-convention=google,
+ --ignore=E203 W503 E402 E731,
+ ]
+ exclude: ^(venv/|docs/|.*__init__.py)
+ types: [python]
+
+ - repo: https://github.com/pycqa/isort
+ rev: 5.12.0
+ hooks:
+ - id: isort
+ args: [--line-length=100, --profile=black, --atomic]
+
+ - repo: https://github.com/pre-commit/mirrors-mypy
+ rev: v1.1.1
+ hooks:
+ - id: mypy
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..797795d81ba8a5d06fefa772bc5b4d0b4bb94dc4
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,190 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2024 ETH Zurich
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 631be781f03494b75818fa6ea30ebc99ba6e3cca..b68d36ad523c16e239d8ca94735a8821f8f721e2 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,587 @@
---
title: GeoCalib
-emoji: 👀
-colorFrom: pink
-colorTo: yellow
+app_file: gradio_app.py
sdk: gradio
-sdk_version: 4.43.0
-app_file: app.py
-pinned: false
+sdk_version: 4.38.1
---
+
+
GeoCalib 📸
Single-image Calibration with Geometric Optimization
+
+ Alexander Veicht
+ ·
+ Paul-Edouard Sarlin
+ ·
+ Philipp Lindenberger
+ ·
+ Marc Pollefeys
+
+
+
ECCV 2024
+ Paper |
+ Colab |
+ Demo 🤗
+
+
+
+
+
+
+
+ GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image
+
+ by combining geometric optimization with deep learning.
+
+
+
+##
+
+GeoCalib is a an algoritm for single-image calibration: it estimates the camera intrinsics and gravity direction from a single image only. By combining geometric optimization with deep learning, GeoCalib provides a more flexible and accurate calibration compared to previous approaches. This repository hosts the [inference](#setup-and-demo), [evaluation](#evaluation), and [training](#training) code for GeoCalib and instructions to download our training set [OpenPano](#openpano-dataset).
+
+
+## Setup and demo
+
+[](https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K)
+[](https://huggingface.co/spaces/veichta/GeoCalib)
+
+We provide a small inference package [`geocalib`](geocalib) that requires only minimal dependencies and Python >= 3.9. First clone the repository and install the dependencies:
+
+```bash
+git clone https://github.com/cvg/GeoCalib.git && cd GeoCalib
+python -m pip install -e .
+# OR
+python -m pip install -e "git+https://github.com/cvg/GeoCalib#egg=geocalib"
+```
+
+Here is a minimal usage example:
+
+```python
+from geocalib import GeoCalib
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = GeoCalib().to(device)
+
+# load image as tensor in range [0, 1] with shape [C, H, W]
+img = model.load_image("path/to/image.jpg").to(device)
+result = model.calibrate(img)
+
+print("camera:", result["camera"])
+print("gravity:", result["gravity"])
+```
+
+When either the intrinsics or the gravity are already know, they can be provided:
+
+```python
+# known intrinsics:
+result = model.calibrate(img, priors={"focal": focal_length_tensor})
+
+# known gravity:
+result = model.calibrate(img, priors={"gravity": gravity_direction_tensor})
+```
+
+The default model is optimized for pinhole images. To handle lens distortion, use the following:
+
+```python
+model = GeoCalib(weights="distorted") # default is "pinhole"
+result = model.calibrate(img, camera_model="simple_radial") # or pinhole, simple_divisional
+```
+
+Check out our [demo notebook](demo.ipynb) for a full working example.
+
+
+[Interactive demo for your webcam - click to expand]
+Run the following command:
+
+```bash
+python -m geocalib.interactive_demo --camera_id 0
+```
+
+The demo will open a window showing the camera feed and the calibration results. If `--camera_id` is not provided, the demo will ask for the IP address of a [droidcam](https://droidcam.app) camera.
+
+Controls:
+
+>Toggle the different features using the following keys:
+>
+>- ```h```: Show the estimated horizon line
+>- ```u```: Show the estimated up-vectors
+>- ```l```: Show the estimated latitude heatmap
+>- ```c```: Show the confidence heatmap for the up-vectors and latitudes
+>- ```d```: Show undistorted image, will overwrite the other features
+>- ```g```: Shows a virtual grid of points
+>- ```b```: Shows a virtual box object
+>
+>Change the camera model using the following keys:
+>
+>- ```1```: Pinhole -> Simple and fast
+>- ```2```: Simple Radial -> For small distortions
+>- ```3```: Simple Divisional -> For large distortions
+>
+>Press ```q``` to quit the demo.
+
+
+
+
+
+[Load GeoCalib with torch hub - click to expand]
+
+```python
+model = torch.hub.load("cvg/GeoCalib", "GeoCalib", trust_repo=True)
+```
+
+
+
+## Evaluation
+
+The full evaluation and training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed as:
+```bash
+python -m pip install -e siclib
+```
+
+Running the evaluation commands will write the results to `outputs/results/`.
+
+### LaMAR
+
+Running the evaluation commands will download the dataset to ```data/lamar2k``` which will take around 400 MB of disk space.
+
+
+[Evaluate GeoCalib]
+
+To evaluate GeoCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.lamar2k --conf geocalib-pinhole --tag geocalib --overwrite
+```
+
+
+
+
+[Evaluate DeepCalib]
+
+To evaluate DeepCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.lamar2k --conf deepcalib --tag deepcalib --overwrite
+```
+
+
+
+
+[Evaluate Perspective Fields]
+
+Coming soon!
+
+
+
+
+[Evaluate UVP]
+
+To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
+
+```bash
+python -m siclib.eval.lamar2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
+```
+
+
+
+
+[Evaluate your own model]
+
+If you have trained your own model, you can evaluate it by running:
+
+```bash
+python -m siclib.eval.lamar2k --checkpoint --tag --overwrite
+```
+
+
+
+
+
+[Results]
+
+Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
+
+| Approach | Roll | Pitch | FoV |
+| --------- | ------------------ | ------------------ | ------------------ |
+| DeepCalib | 44.1 / 73.9 / 84.8 | 10.8 / 28.3 / 49.8 | 0.7 / 13.0 / 24.0 |
+| ParamNet | 51.7 / 77.0 / 86.0 | 27.0 / 52.7 / 70.2 | 02.8 / 06.8 / 14.3 |
+| UVP | 72.7 / 81.8 / 85.7 | 42.3 / 59.9 / 69.4 | 15.6 / 30.6 / 43.5 |
+| GeoCalib | 86.4 / 92.5 / 95.0 | 55.0 / 76.9 / 86.2 | 19.1 / 41.5 / 60.0 |
+
+
+### MegaDepth
+
+Running the evaluation commands will download the dataset to ```data/megadepth2k``` or ```data/memegadepth2k-radial``` which will take around 2.1 GB and 1.47 GB of disk space respectively.
+
+
+[Evaluate GeoCalib]
+
+To evaluate GeoCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.megadepth2k --conf geocalib-pinhole --tag geocalib --overwrite
+```
+
+To run the eval on the radial distorted images, run:
+
+```bash
+python -m siclib.eval.megadepth2k_radial --conf geocalib-pinhole --tag geocalib --overwrite model.camera_model=simple_radial
+```
+
+
+
+
+[Evaluate DeepCalib]
+
+To evaluate DeepCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.megadepth2k --conf deepcalib --tag deepcalib --overwrite
+```
+
+
+
+
+[Evaluate Perspective Fields]
+
+Coming soon!
+
+
+
+
+[Evaluate UVP]
+
+To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
+
+```bash
+python -m siclib.eval.megadepth2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
+```
+
+
+
+
+[Evaluate your own model]
+
+If you have trained your own model, you can evaluate it by running:
+
+```bash
+python -m siclib.eval.megadepth2k --checkpoint --tag --overwrite
+```
+
+
+
+
+[Results]
+
+Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
+
+| Approach | Roll | Pitch | FoV |
+| --------- | ------------------ | ------------------ | ------------------ |
+| DeepCalib | 34.6 / 65.4 / 79.4 | 11.9 / 27.8 / 44.8 | 5.6 / 12.1 / 22.9 |
+| ParamNet | 43.4 / 70.7 / 82.2 | 15.4 / 34.5 / 53.3 | 3.2 / 10.1 / 21.3 |
+| UVP | 69.2 / 81.6 / 86.9 | 21.6 / 36.2 / 47.4 | 8.2 / 18.7 / 29.8 |
+| GeoCalib | 82.6 / 90.6 / 94.0 | 32.4 / 53.3 / 67.5 | 13.6 / 31.7 / 48.2 |
+
+
+### TartanAir
+
+Running the evaluation commands will download the dataset to ```data/tartanair``` which will take around 1.85 GB of disk space.
+
+
+[Evaluate GeoCalib]
+
+To evaluate GeoCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.tartanair --conf geocalib-pinhole --tag geocalib --overwrite
+```
+
+
+
+
+[Evaluate DeepCalib]
+
+To evaluate DeepCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.tartanair --conf deepcalib --tag deepcalib --overwrite
+```
+
+
+
+
+[Evaluate Perspective Fields]
+
+Coming soon!
+
+
+
+
+[Evaluate UVP]
+
+To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
+
+```bash
+python -m siclib.eval.tartanair --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
+```
+
+
+
+
+[Evaluate your own model]
+
+If you have trained your own model, you can evaluate it by running:
+
+```bash
+python -m siclib.eval.tartanair --checkpoint --tag --overwrite
+```
+
+
+
+
+[Results]
+
+Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
+
+| Approach | Roll | Pitch | FoV |
+| --------- | ------------------ | ------------------ | ------------------ |
+| DeepCalib | 24.7 / 55.4 / 71.5 | 16.3 / 38.8 / 58.5 | 1.5 / 8.8 / 27.2 |
+| ParamNet | 34.5 / 59.2 / 73.9 | 19.4 / 42.0 / 60.3 | 6.0 / 16.8 / 31.6 |
+| UVP | 52.1 / 64.8 / 71.9 | 36.2 / 48.8 / 58.6 | 15.8 / 25.8 / 35.7 |
+| GeoCalib | 71.3 / 83.8 / 89.8 | 38.2 / 62.9 / 76.6 | 14.1 / 30.4 / 47.6 |
+
+
+### Stanford2D3D
+
+Before downloading and running the evaluation, you will need to agree to the [terms of use](https://docs.google.com/forms/d/e/1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1) for the Stanford2D3D dataset.
+Running the evaluation commands will download the dataset to ```data/stanford2d3d``` which will take around 885 MB of disk space.
+
+
+[Evaluate GeoCalib]
+
+To evaluate GeoCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.stanford2d3d --conf geocalib-pinhole --tag geocalib --overwrite
+```
+
+
+
+
+[Evaluate DeepCalib]
+
+To evaluate DeepCalib trained on the OpenPano dataset, run:
+
+```bash
+python -m siclib.eval.stanford2d3d --conf deepcalib --tag deepcalib --overwrite
+```
+
+
+
+
+[Evaluate Perspective Fields]
+
+Coming soon!
+
+
+
+
+[Evaluate UVP]
+
+To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run:
+
+```bash
+python -m siclib.eval.stanford2d3d --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null
+```
+
+
+
+
+[Evaluate your own model]
+
+If you have trained your own model, you can evaluate it by running:
+
+```bash
+python -m siclib.eval.stanford2d3d --checkpoint --tag --overwrite
+```
+
+
+
+
+[Results]
+
+Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods:
+
+| Approach | Roll | Pitch | FoV |
+| --------- | ------------------ | ------------------ | ------------------ |
+| DeepCalib | 33.8 / 63.9 / 79.2 | 21.6 / 46.9 / 65.7 | 8.1 / 20.6 / 37.6 |
+| ParamNet | 44.6 / 73.9 / 84.8 | 29.2 / 56.7 / 73.1 | 5.8 / 14.3 / 27.8 |
+| UVP | 65.3 / 74.6 / 79.1 | 51.2 / 63.0 / 69.2 | 22.2 / 39.5 / 51.3 |
+| GeoCalib | 83.1 / 91.8 / 94.8 | 52.3 / 74.8 / 84.6 | 17.4 / 40.0 / 59.4 |
+
+
+
+### Evaluation options
+
+If you want to provide priors during the evaluation, you can add one or multiple of the following flags:
+
+```bash
+python -m siclib.eval. --conf \
+ --tag \
+ data.use_prior_focal=true \
+ data.use_prior_gravity=true \
+ data.use_prior_k1=true
+```
+
+
+[Visual inspection]
+
+To visually inspect the results of the evaluation, you can run the following command:
+
+```bash
+python -m siclib.eval.inspect
+
+```
+For example, to inspect the results of the evaluation of the GeoCalib model on the LaMAR dataset, you can run:
+```bash
+python -m siclib.eval.inspect lamar2k geocalib
+```
+
+
+## OpenPano Dataset
+
+The OpenPano dataset is a new dataset for single-image calibration which contains about 2.8k panoramas from various sources, namely [HDRMAPS](https://hdrmaps.com/hdris/), [PolyHaven](https://polyhaven.com/hdris), and the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation). While this dataset is smaller than previous ones, it is publicly available and it provides a better balance between indoor and outdoor scenes.
+
+
+[Downloading and preparing the dataset]
+
+In order to assemble the training set, first download the Laval dataset following the instructions on [the corresponding project page](http://hdrdb.com/indoor/#presentation) and place the panoramas in ```data/indoorDatasetCalibrated```. Then, tonemap the HDR images using the following command:
+
+```bash
+python -m siclib.datasets.utils.tonemapping --hdr_dir data/indoorDatasetCalibrated --out_dir data/laval-tonemap
+```
+
+We provide a script to download the PolyHaven and HDRMAPS panos. The script will create folders ```data/openpano/panoramas/{split}``` containing the panoramas specified by the ```{split}_panos.txt``` files. To run the script, execute the following commands:
+
+```bash
+python -m siclib.datasets.utils.download_openpano --name openpano --laval_dir data/laval-tonemap
+```
+Alternatively, you can download the PolyHaven and HDRMAPS panos from [here](https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/).
+
+
+After downloading the panoramas, you can create the training set by running the following command:
+
+```bash
+python -m siclib.datasets.create_dataset_from_pano --config-name openpano
+```
+
+The dataset creation can be sped up by using multiple workers and a GPU. To do so, add the following arguments to the command:
+
+```bash
+python -m siclib.datasets.create_dataset_from_pano --config-name openpano n_workers=10 device=cuda
+```
+
+This will create the training set in ```data/openpano/openpano``` with about 37k images for training, 2.1k for validation, and 2.1k for testing.
+
+
+[Distorted OpenPano]
+
+To create the OpenPano dataset with radial distortion, run the following command:
+
+```bash
+python -m siclib.datasets.create_dataset_from_pano --config-name openpano_radial
+```
+
+
+
+
+
+## Training
+
+As for the evaluation, the training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed by:
+
+```bash
+python -m pip install -e siclib
+```
+
+Once the [OpenPano Dataset](#openpano-dataset) has been downloaded and prepared, we can train GeoCalib with it:
+
+First download the pre-trained weights for the [MSCAN-B](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) backbone:
+
+```bash
+mkdir weights
+wget "https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/files/?p=%2Fmscan_b.pth&dl=1" -O weights/mscan_b.pth
+```
+
+Then, start the training with the following command:
+
+```bash
+python -m siclib.train geocalib-pinhole-openpano --conf geocalib --distributed
+```
+
+Feel free to use any other experiment name. By default, the checkpoints will be written to ```outputs/training/```. The default batch size is 24 which requires 2x 4090 GPUs with 24GB of VRAM each. Configurations are managed by [Hydra](https://hydra.cc/) and can be overwritten from the command line.
+For example, to train GeoCalib on a single GPU with a batch size of 5, run:
+
+```bash
+python -m siclib.train geocalib-pinhole-openpano \
+ --conf geocalib \
+ data.train_batch_size=5 # for 1x 2080 GPU
+```
+
+Be aware that this can impact the overall performance. You might need to adjust the learning rate and number of training steps accordingly.
+
+If you want to log the training progress to [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://wandb.ai/), you can set the ```train.writer``` option:
+
+```bash
+python -m siclib.train geocalib-pinhole-openpano \
+ --conf geocalib \
+ --distributed \
+ train.writer=tensorboard
+```
+
+The model can then be evaluated using its experiment name:
+
+```bash
+python -m siclib.eval. --checkpoint geocalib-pinhole-openpano \
+ --tag geocalib-retrained
+```
+
+
+[Training DeepCalib]
+
+To train DeepCalib on the OpenPano dataset, run:
+
+```bash
+python -m siclib.train deepcalib-openpano --conf deepcalib --distributed
+```
+
+Make sure that you have generated the [OpenPano Dataset](#openpano-dataset) with radial distortion or add
+the flag ```data=openpano``` to the command to train on the pinhole images.
+
+
+
+
+[Training Perspective Fields]
+
+Coming soon!
+
+
+
+## BibTeX citation
+
+If you use any ideas from the paper or code from this repo, please consider citing:
+
+```bibtex
+@inproceedings{veicht2024geocalib,
+ author = {Alexander Veicht and
+ Paul-Edouard Sarlin and
+ Philipp Lindenberger and
+ Marc Pollefeys},
+ title = {{GeoCalib: Single-image Calibration with Geometric Optimization}},
+ booktitle = {ECCV},
+ year = {2024}
+}
+```
+
+## License
+
+The code is provided under the [Apache-2.0 License](LICENSE) while the weights of the trained model are provided under the [Creative Commons Attribution 4.0 International Public License](https://creativecommons.org/licenses/by/4.0/legalcode). Thanks to the authors of the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation) for allowing this.
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/assets/fisheye-dog-pool.jpg b/assets/fisheye-dog-pool.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..02278a47628312cbb922e5c60d17e1f0ea6f1876
Binary files /dev/null and b/assets/fisheye-dog-pool.jpg differ
diff --git a/assets/fisheye-skyline.jpg b/assets/fisheye-skyline.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..00b83edbf3168fd1a46e615323c558489e870b71
--- /dev/null
+++ b/assets/fisheye-skyline.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:1f32ee048e0e9f931f9d42b19269419201834322f40a172367ebbca4752826a2
+size 1353673
diff --git a/assets/pinhole-church.jpg b/assets/pinhole-church.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ae9ad103798098af7f7e26ea13096ab8aabd4d5d
Binary files /dev/null and b/assets/pinhole-church.jpg differ
diff --git a/assets/pinhole-garden.jpg b/assets/pinhole-garden.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ecdcd0eee7c5ec1486bf93485c20d382f5c056f5
Binary files /dev/null and b/assets/pinhole-garden.jpg differ
diff --git a/assets/teaser.gif b/assets/teaser.gif
new file mode 100644
index 0000000000000000000000000000000000000000..a3e6e19e9c64f17138d09004c56fd9c353d857c3
--- /dev/null
+++ b/assets/teaser.gif
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:2944bad746c92415438f65adb4362bfd4b150db50729f118266d86f638ac3d21
+size 11997917
diff --git a/demo.ipynb b/demo.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..980f3eed9b70945fe68758485cd19ef621ba39e1
--- /dev/null
+++ b/demo.ipynb
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:115886d06182eca375251eab9e301180e86465fd0ed152e917d43d7eb4cbd722
+size 13275966
diff --git a/geocalib/__init__.py b/geocalib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a034045d7e3c9933383b6d5fcb1a7ce2dd89037
--- /dev/null
+++ b/geocalib/__init__.py
@@ -0,0 +1,17 @@
+import logging
+
+from geocalib.extractor import GeoCalib # noqa
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
+)
+handler = logging.StreamHandler()
+handler.setFormatter(formatter)
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
+
+__module_name__ = __name__
diff --git a/geocalib/camera.py b/geocalib/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..1988ebf017d6a6d048cc0c9ca688b5dae35f50c4
--- /dev/null
+++ b/geocalib/camera.py
@@ -0,0 +1,774 @@
+"""Implementation of the pinhole, simple radial, and simple divisional camera models."""
+
+from abc import abstractmethod
+from typing import Dict, Optional, Tuple, Union
+
+import torch
+from torch.func import jacfwd, vmap
+from torch.nn import functional as F
+
+from geocalib.gravity import Gravity
+from geocalib.misc import TensorWrapper, autocast
+from geocalib.utils import deg2rad, focal2fov, fov2focal, rad2rotmat
+
+# flake8: noqa: E741
+# mypy: ignore-errors
+
+
+class BaseCamera(TensorWrapper):
+ """Camera tensor class."""
+
+ eps = 1e-3
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
+
+ Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
+ - w, h: image size in pixels
+ - fx, fy: focal lengths in pixels
+ - cx, cy: principal points in normalized image coordinates
+ - dist: distortion parameters
+
+ Args:
+ data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
+ """
+ # w, h, fx, fy, cx, cy, dist
+ assert data.shape[-1] in {6, 7, 8}, data.shape
+
+ pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
+ data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
+ super().__init__(data)
+
+ @classmethod
+ def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
+ """Create a Camera object from a dictionary of parameters.
+
+ Args:
+ param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
+
+ Returns:
+ Camera: Camera object.
+ """
+ for key, value in param_dict.items():
+ if not isinstance(value, torch.Tensor):
+ param_dict[key] = torch.tensor(value)
+
+ h, w = param_dict["height"], param_dict["width"]
+ cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
+
+ if "f" in param_dict:
+ f = param_dict["f"]
+ elif "vfov" in param_dict:
+ vfov = param_dict["vfov"]
+ f = fov2focal(vfov, h)
+ else:
+ raise ValueError("Focal length or vertical field of view must be provided.")
+
+ if "dist" in param_dict:
+ k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
+ elif "k1_hat" in param_dict:
+ k1 = param_dict["k1_hat"] * (f / h) ** 2
+
+ k2 = param_dict.get("k2", torch.zeros_like(k1))
+ else:
+ k1 = param_dict.get("k1", torch.zeros_like(f))
+ k2 = param_dict.get("k2", torch.zeros_like(f))
+
+ fx, fy = f, f
+ if "scales" in param_dict:
+ fx = fx * param_dict["scales"][..., 0] / param_dict["scales"][..., 1]
+
+ params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
+ return cls(params)
+
+ def pinhole(self):
+ """Return the pinhole camera model."""
+ return self.__class__(self._data[..., :6])
+
+ @property
+ def size(self) -> torch.Tensor:
+ """Size (width height) of the images, with shape (..., 2)."""
+ return self._data[..., :2]
+
+ @property
+ def f(self) -> torch.Tensor:
+ """Focal lengths (fx, fy) with shape (..., 2)."""
+ return self._data[..., 2:4]
+
+ @property
+ def vfov(self) -> torch.Tensor:
+ """Vertical field of view in radians."""
+ return focal2fov(self.f[..., 1], self.size[..., 1])
+
+ @property
+ def hfov(self) -> torch.Tensor:
+ """Horizontal field of view in radians."""
+ return focal2fov(self.f[..., 0], self.size[..., 0])
+
+ @property
+ def c(self) -> torch.Tensor:
+ """Principal points (cx, cy) with shape (..., 2)."""
+ return self._data[..., 4:6]
+
+ @property
+ def K(self) -> torch.Tensor:
+ """Returns the self intrinsic matrix with shape (..., 3, 3)."""
+ shape = self.shape + (3, 3)
+ K = self._data.new_zeros(shape)
+ K[..., 0, 0] = self.f[..., 0]
+ K[..., 1, 1] = self.f[..., 1]
+ K[..., 0, 2] = self.c[..., 0]
+ K[..., 1, 2] = self.c[..., 1]
+ K[..., 2, 2] = 1
+ return K
+
+ def update_focal(self, delta: torch.Tensor, as_log: bool = False):
+ """Update the self parameters after changing the focal length."""
+ f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
+
+ # clamp focal length to a reasonable range for stability during training
+ min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
+ max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
+ min_f = min_f.unsqueeze(-1).expand(-1, 2)
+ max_f = max_f.unsqueeze(-1).expand(-1, 2)
+ f = f.clamp(min=min_f, max=max_f)
+
+ # make sure focal ration stays the same (avoid inplace operations)
+ fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
+ f = torch.stack([fx, f[..., 1]], -1)
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
+
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
+ """Update the self parameters after resizing an image."""
+ scales = (scales, scales) if isinstance(scales, (int, float)) else scales
+ s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
+
+ def crop(self, pad: Tuple[float]):
+ """Update the self parameters after cropping an image."""
+ pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
+ size = self.size + pad.to(self.size)
+ c = self.c + pad.to(self.c) / 2
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([size, self.f, c, dist], -1))
+
+ @autocast
+ def in_image(self, p2d: torch.Tensor):
+ """Check if 2D points are within the image boundaries."""
+ assert p2d.shape[-1] == 2
+ size = self.size.unsqueeze(-2)
+ return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
+
+ @autocast
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Project 3D points into the self plane and check for visibility."""
+ z = p3d[..., -1]
+ valid = z > self.eps
+ z = z.clamp(min=self.eps)
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
+ return p2d, valid
+
+ def J_project(self, p3d: torch.Tensor):
+ """Jacobian of the projection function."""
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
+ zero = torch.zeros_like(z)
+ z = z.clamp(min=self.eps)
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
+ return J # N x 2 x 3
+
+ def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
+ """Undo transforms done during scaling and cropping."""
+ camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
+ return camera.scale(1.0 / data["scales"])
+
+ @abstractmethod
+ def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ raise NotImplementedError("distort() must be implemented.")
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the distortion function."""
+ if wrt == "scale2pts": # (..., 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0).squeeze(-3, -2)
+
+ elif wrt == "scale2dist": # (..., 1)
+ J = []
+ for idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[idx, :6], x[None]], -1)
+ return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
+
+ J.append(vmap(jacfwd(func))(self[idx].dist))
+
+ return torch.cat(J, dim=0)
+
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @abstractmethod
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ raise NotImplementedError("undistort() must be implemented.")
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ if wrt == "pts": # (..., 2, 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0).squeeze(-3)
+
+ elif wrt == "dist": # (..., 1)
+ J = []
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
+ return self.__class__(params).undistort(p2d[batch_idx])[0]
+
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
+
+ return torch.cat(J, dim=0)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @autocast
+ def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Compute the offset for the up-projection."""
+ return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the distortion offset for up-projection."""
+ if wrt == "uv": # (B, N, 2, 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0)
+
+ elif wrt == "dist": # (B, N, 2)
+ J = []
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
+ return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
+
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
+
+ return torch.cat(J, dim=0).squeeze(1)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @autocast
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
+
+ def J_denormalize(self):
+ """Jacobian of the denormalization function."""
+ return torch.diag_embed(self.f) # ..., 2 x 2
+
+ @autocast
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert pixel coordinates into normalized 2D coordinates."""
+ return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
+
+ def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
+ """Jacobian of the normalization function."""
+ # ... x N x 2 x 2
+ if wrt == "f":
+ J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
+ return torch.diag_embed(J_f)
+ elif wrt == "pts":
+ J_pts = 1 / self.f
+ return torch.diag_embed(J_pts)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ def pixel_coordinates(self) -> torch.Tensor:
+ """Pixel coordinates in self frame.
+
+ Returns:
+ torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
+ """
+ w, h = self.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # create grid
+ x = torch.arange(0, w, dtype=self.dtype, device=self.device)
+ y = torch.arange(0, h, dtype=self.dtype, device=self.device)
+ x, y = torch.meshgrid(x, y, indexing="xy")
+ xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
+
+ # add batch dimension (normalize() would broadcast but we make it explicit)
+ B = self.shape[0]
+ xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
+
+ return xy.to(self.device).to(self.dtype)
+
+ @autocast
+ def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
+ """Get the bearing vectors of pixel coordinates by normalizing them."""
+ return F.normalize(p3d, dim=-1)
+
+ @autocast
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Transform 3D points into 2D pixel coordinates."""
+ p2d, visible = self.project(p3d)
+ p2d, mask = self.distort(p2d)
+ p2d = self.denormalize(p2d)
+ valid = visible & mask & self.in_image(p2d)
+ return p2d, valid
+
+ @autocast
+ def J_world2image(self, p3d: torch.Tensor):
+ """Jacobian of the world2image function."""
+ p2d_proj, valid = self.project(p3d)
+
+ J_dnorm = self.J_denormalize()
+ J_dist = self.J_distort(p2d_proj)
+ J_proj = self.J_project(p3d)
+
+ J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
+ return J, valid
+
+ @autocast
+ def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Transform point in the image plane to 3D world coordinates."""
+ p2d = self.normalize(p2d)
+ p2d, valid = self.undistort(p2d)
+ ones = p2d.new_ones(p2d.shape[:-1] + (1,))
+ p3d = torch.cat([p2d, ones], -1)
+ return p3d, valid
+
+ @autocast
+ def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
+ """Jacobian of the image2world function."""
+ if wrt == "dist":
+ p2d_norm = self.normalize(p2d)
+ return self.J_undistort(p2d_norm, wrt)
+ elif wrt == "f":
+ J_norm2f = self.J_normalize(p2d, wrt)
+ p2d_norm = self.normalize(p2d)
+ J_dist2norm = self.J_undistort(p2d_norm, "pts")
+
+ return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+ @autocast
+ def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
+ """Undistort an image using the distortion model."""
+ assert self.shape[0] == 1, "Batch size must be 1."
+ W, H = self.size.unbind(-1)
+ H, W = H.int().item(), W.int().item()
+
+ x, y = torch.meshgrid(torch.arange(0, W), torch.arange(0, H), indexing="xy")
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
+
+ p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
+ p2d, _ = self.world2image(p3d)
+
+ mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
+ grid = torch.stack((mapx, mapy), dim=-1)
+ grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
+ return F.grid_sample(img, grid, align_corners=True)
+
+ def get_img_from_pano(
+ self,
+ pano_img: torch.Tensor,
+ gravity: Gravity,
+ yaws: torch.Tensor = 0.0,
+ resize_factor: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Render an image from a panorama.
+
+ Args:
+ pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
+ gravity (Gravity): Gravity direction of the camera.
+ yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
+ resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
+ field of view. Defaults to 1.
+
+ Returns:
+ torch.Tensor: Image rendered from the panorama.
+ """
+ B = self.shape[0]
+ if B > 0:
+ assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
+ assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
+
+ w, h = self.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if isinstance(yaws, (int, float)):
+ yaws = [yaws]
+ if isinstance(resize_factor, (int, float)):
+ resize_factor = [resize_factor]
+
+ yaws = (
+ yaws.to(self.dtype).to(self.device)
+ if isinstance(yaws, torch.Tensor)
+ else self.new_tensor(yaws)
+ )
+
+ if isinstance(resize_factor, torch.Tensor):
+ resize_factor = resize_factor.to(self.dtype).to(self.device)
+ elif resize_factor is not None:
+ resize_factor = self.new_tensor(resize_factor)
+
+ assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
+ pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3
+
+ pano_imgs = []
+ for i, yaw in enumerate(yaws):
+ if resize_factor is not None:
+ # resize the panorama such that the fov of the panorama has the same height as the
+ # image
+ vfov = self.vfov[i] if B != 0 else self.vfov
+ scale = torch.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i]
+ pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale))
+
+ mode = "bicubic" if scale >= 1 else "area"
+ resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
+ else:
+ # make sure to copy: resized_pano = pano_img
+ resized_pano = pano_img
+ pano_shape = pano_img.shape[-2:][::-1]
+
+ pano_imgs.append((resized_pano, pano_shape))
+
+ xy = self.pixel_coordinates()
+ uv1, _ = self.image2world(xy)
+ bearings = self.pixel_bearing_many(uv1)
+
+ # rotate bearings
+ R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
+ rotated_bearings = bearings @ gravity.R @ R_yaw
+
+ # spherical coordinates
+ lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
+ lat = torch.atan2(
+ rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
+ )
+
+ images = []
+ for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
+ min_lon, max_lon = -torch.pi, torch.pi
+ min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
+ min_x, max_x = 0, pano_shape[0] - 1.0
+ min_y, max_y = 0, pano_shape[1] - 1.0
+
+ # map Spherical Coordinates to Panoramic Coordinates
+ nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
+ ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
+
+ # reshape and cast to numpy for remap
+ mapx, mapy = nx.reshape((1, h, w)), ny.reshape((1, h, w))
+
+ grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
+ # Normalize to [-1, 1]
+ grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
+ # Apply grid sample
+ image = F.grid_sample(resized_pano, grid, align_corners=True)
+ images.append(image)
+
+ return torch.concatenate(images, 0) if B > 0 else images[0]
+
+ def __repr__(self):
+ """Print the Camera object."""
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
+
+
+class Pinhole(BaseCamera):
+ """Implementation of the pinhole camera model.
+
+ Use this model for undistorted images.
+ """
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates."""
+ if return_scale:
+ return p2d.new_ones(p2d.shape[:-1] + (1,))
+
+ return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the distortion function."""
+ if wrt == "pts":
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
+
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates."""
+ return pts, pts.new_ones((pts.shape[0], 1)).bool()
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ if wrt == "pts":
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
+
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the up-projection offset."""
+ if wrt == "uv":
+ return torch.zeros(p2d.shape[:-1] + (2, 2), device=p2d.device, dtype=p2d.dtype)
+
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+
+class SimpleRadial(BaseCamera):
+ """Implementation of the simple radial camera model.
+
+ Use this model for weakly distorted images.
+
+ The distortion model is 1 + k1 * r^2 where r^2 = x^2 + y^2.
+ The undistortion model is 1 - k1 * r^2 estimated as in
+ "An Exact Formula for Calculating Inverse Radial Lens Distortions" by Pierre Drap.
+ """
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., 1)."""
+ return self._data[..., 6:]
+
+ @property
+ def k1(self) -> torch.Tensor:
+ """Distortion parameters, with shape (...)."""
+ return self._data[..., 6]
+
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
+ """Update the self parameters after changing the k1 distortion parameter."""
+ delta_dist = self.new_ones(self.dist.shape) * delta
+ dist = (self.dist + delta_dist).clamp(*dist_range)
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Check if the distorted points are valid."""
+ return p2d.new_ones(p2d.shape[:-1]).bool()
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 + self.k1[..., None, None] * r2
+
+ if return_scale:
+ return radial, None
+
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
+ """Jacobian of the distortion function."""
+ if wrt == "scale2dist": # (..., 1)
+ return torch.sum(p2d**2, -1, keepdim=True)
+ elif wrt == "scale2pts": # (..., 2)
+ return 2 * self.k1[..., None, None] * p2d
+ else:
+ return super().J_distort(p2d, wrt)
+
+ @autocast
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ b1 = -self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 + b1 * r2
+ return p2d * radial, self.check_valid(p2d)
+
+ @autocast
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ b1 = -self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ if wrt == "dist":
+ return -r2 * p2d
+ elif wrt == "pts":
+ radial = 1 + b1 * r2
+ radial_diag = torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+ return (2 * b1[..., None] * ppT) + radial_diag
+ else:
+ return super().J_undistort(p2d, wrt)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the up-projection offset."""
+ if wrt == "uv": # (..., 2, 2)
+ return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
+ elif wrt == "dist":
+ return 2 * p2d # (..., 2)
+ else:
+ return super().J_up_projection_offset(p2d, wrt)
+
+
+class SimpleDivisional(BaseCamera):
+ """Implementation of the simple divisional camera model.
+
+ Use this model for strongly distorted images.
+
+ The distortion model is (1 - sqrt(1 - 4 * k1 * r^2)) / (2 * k1 * r^2) where r^2 = x^2 + y^2.
+ The undistortion model is 1 / (1 + k1 * r^2).
+ """
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., 1)."""
+ return self._data[..., 6:]
+
+ @property
+ def k1(self) -> torch.Tensor:
+ """Distortion parameters, with shape (...)."""
+ return self._data[..., 6]
+
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
+ """Update the self parameters after changing the k1 distortion parameter."""
+ delta_dist = self.new_ones(self.dist.shape) * delta
+ dist = (self.dist + delta_dist).clamp(*dist_range)
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Check if the distorted points are valid."""
+ return p2d.new_ones(p2d.shape[:-1]).bool()
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
+ denom = 2 * self.k1[..., None, None] * r2
+
+ ones = radial.new_ones(radial.shape)
+ radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
+
+ if return_scale:
+ return radial, None
+
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
+ """Jacobian of the distortion function."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
+ if wrt == "scale2pts": # (B, N, 2)
+ d1 = t0 * 2 * r2
+ d2 = self.k1[..., None, None] * r2**2
+ denom = d1 * d2
+ return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
+
+ elif wrt == "scale2dist":
+ d1 = 2 * self.k1[..., None, None] * t0
+ d2 = 2 * r2 * self.k1[..., None, None] ** 2
+ denom = d1 * d2
+ return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
+
+ else:
+ return super().J_distort(p2d, wrt)
+
+ @autocast
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ denom = 1 + self.k1[..., None, None] * r2
+ radial = 1 / denom.masked_fill(denom == 0, 1e6)
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ # return super().J_undistort(p2d, wrt)
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ k1 = self.k1[..., None, None]
+ if wrt == "dist":
+ denom = (1 + k1 * r2) ** 2
+ return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
+ elif wrt == "pts":
+ t0 = 1 + k1 * r2
+ t0 = t0.masked_fill(t0 == 0, 1e6)
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+ J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
+ return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
+
+ else:
+ return super().J_undistort(p2d, wrt)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the up-projection offset.
+
+ func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
+ - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
+ """
+ k1 = self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
+ t1 = torch.sqrt(t0)
+ if wrt == "dist":
+ denom = 4 * t0 ** (3 / 2)
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = 16 / denom
+
+ denom = r2 * t1 * k1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 2 / denom
+
+ denom = (r2 * k1) ** 2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (1 - t1) / denom
+
+ return J * p2d
+ elif wrt == "uv":
+ # ! unstable (gradient checker might fail), rewrite to use single division (by denom)
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+
+ denom = 2 * r2 * t1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
+
+ denom = 4 * t1 * r2**2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 16 / denom[..., None] * ppT
+
+ denom = 4 * r2 * t0 ** (3 / 2)
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (32 * k1[..., None]) / denom[..., None] * ppT
+
+ denom = r2**2 * t1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 4 / denom[..., None] * ppT
+
+ denom = k1 * r2**3
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (4 * (1 - t1) / denom)[..., None] * ppT
+
+ denom = k1 * r2**2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
+
+ return J
+ else:
+ return super().J_up_projection_offset(p2d, wrt)
+
+
+camera_models = {
+ "pinhole": Pinhole,
+ "simple_radial": SimpleRadial,
+ "simple_divisional": SimpleDivisional,
+}
diff --git a/geocalib/extractor.py b/geocalib/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fdda6ad30136385bcfe5fd0cf606cb1f860c4f9
--- /dev/null
+++ b/geocalib/extractor.py
@@ -0,0 +1,126 @@
+"""Simple interface for GeoCalib model."""
+
+from pathlib import Path
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.functional import interpolate
+
+from geocalib.camera import BaseCamera
+from geocalib.geocalib import GeoCalib as Model
+from geocalib.utils import ImagePreprocessor, load_image
+
+
+class GeoCalib(nn.Module):
+ """Simple interface for GeoCalib model."""
+
+ def __init__(self, weights: str = "pinhole"):
+ """Initialize the model with optional config overrides.
+
+ Args:
+ weights (str): trained variant, "pinhole" (default) or "distorted".
+ """
+ super().__init__()
+ if weights == "pinhole":
+ url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
+ elif weights == "distorted":
+ url = (
+ "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
+ )
+ else:
+ raise ValueError(f"Unknown weights: {weights}")
+
+ # load checkpoint
+ model_dir = f"{torch.hub.get_dir()}/geocalib"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
+ )
+
+ self.model = Model()
+ self.model.flexible_load(state_dict["model"])
+ self.model.eval()
+
+ self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32})
+
+ def load_image(self, path: Path) -> torch.Tensor:
+ """Load image from path."""
+ return load_image(path)
+
+ def _post_process(
+ self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor]
+ ) -> tuple[BaseCamera, dict[str, torch.Tensor]]:
+ """Post-process model output by undoing scaling and cropping."""
+ camera = camera.undo_scale_crop(img_data)
+
+ w, h = camera.size.unbind(-1)
+ h = h[0].round().int().item()
+ w = w[0].round().int().item()
+
+ for k in ["latitude_field", "up_field"]:
+ out[k] = interpolate(out[k], size=(h, w), mode="bilinear")
+ for k in ["up_confidence", "latitude_confidence"]:
+ out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0]
+
+ inverse_scales = 1.0 / img_data["scales"]
+ zero = camera.new_zeros(camera.f.shape[0])
+ out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1]
+ return camera, out
+
+ @torch.no_grad()
+ def calibrate(
+ self,
+ img: torch.Tensor,
+ camera_model: str = "pinhole",
+ priors: Optional[Dict[str, torch.Tensor]] = None,
+ shared_intrinsics: bool = False,
+ ) -> Dict[str, torch.Tensor]:
+ """Perform calibration with online resizing.
+
+ Assumes input image is in range [0, 1] and in RGB format.
+
+ Args:
+ img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W)
+ camera_model (str, optional): Camera model. Defaults to "pinhole".
+ priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}.
+ shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False.
+
+ Returns:
+ Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties.
+ """
+ if len(img.shape) == 3:
+ img = img[None] # add batch dim
+ if not shared_intrinsics:
+ assert len(img.shape) == 4 and img.shape[0] == 1
+
+ img_data = self.image_processor(img)
+
+ if priors is None:
+ priors = {}
+
+ prior_values = {}
+ if prior_focal := priors.get("focal"):
+ prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal
+ prior_values["prior_focal"] = prior_focal * img_data["scales"][1]
+
+ if "gravity" in priors:
+ prior_gravity = priors["gravity"]
+ prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity
+ prior_values["prior_gravity"] = prior_gravity
+
+ self.model.optimizer.set_camera_model(camera_model)
+ self.model.optimizer.shared_intrinsics = shared_intrinsics
+
+ out = self.model(img_data | prior_values)
+
+ camera, gravity = out["camera"], out["gravity"]
+ camera, out = self._post_process(camera, img_data, out)
+
+ return {
+ "camera": camera,
+ "gravity": gravity,
+ "covariance": out["covariance"],
+ **{k: out[k] for k in out.keys() if "field" in k},
+ **{k: out[k] for k in out.keys() if "confidence" in k},
+ **{k: out[k] for k in out.keys() if "uncertainty" in k},
+ }
diff --git a/geocalib/geocalib.py b/geocalib/geocalib.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2dc70fdb0ed56774e920fbeeb64cb150e23e0f1
--- /dev/null
+++ b/geocalib/geocalib.py
@@ -0,0 +1,150 @@
+"""GeoCalib model definition."""
+
+import logging
+from typing import Dict
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from geocalib.lm_optimizer import LMOptimizer
+from geocalib.modules import MSCAN, ConvModule, LightHamHead
+
+# mypy: ignore-errors
+
+logger = logging.getLogger(__name__)
+
+
+class LowLevelEncoder(nn.Module):
+ """Very simple low-level encoder."""
+
+ def __init__(self):
+ """Simple low-level encoder."""
+ super().__init__()
+ self.in_channel = 3
+ self.feat_dim = 64
+
+ self.conv1 = ConvModule(self.in_channel, self.feat_dim, kernel_size=3, padding=1)
+ self.conv2 = ConvModule(self.feat_dim, self.feat_dim, kernel_size=3, padding=1)
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Forward pass."""
+ x = data["image"]
+
+ assert (
+ x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0
+ ), "Image size must be multiple of 32 if not using single image input."
+
+ c1 = self.conv1(x)
+ c2 = self.conv2(c1)
+
+ return {"features": c2}
+
+
+class UpDecoder(nn.Module):
+ """Minimal implementation of UpDecoder."""
+
+ def __init__(self):
+ """Up decoder."""
+ super().__init__()
+ self.decoder = LightHamHead()
+ self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, 2, kernel_size=1)
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Forward pass."""
+ x, log_confidence = self.decoder(data["features"])
+ up = self.linear_pred_up(x)
+ return {"up_field": F.normalize(up, dim=1), "up_confidence": torch.sigmoid(log_confidence)}
+
+
+class LatitudeDecoder(nn.Module):
+ """Minimal implementation of LatitudeDecoder."""
+
+ def __init__(self):
+ """Latitude decoder."""
+ super().__init__()
+ self.decoder = LightHamHead()
+ self.linear_pred_latitude = nn.Conv2d(self.decoder.out_channels, 1, kernel_size=1)
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Forward pass."""
+ x, log_confidence = self.decoder(data["features"])
+ eps = 1e-5 # avoid nan in backward of asin
+ lat = torch.tanh(self.linear_pred_latitude(x))
+ lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps))
+ return {"latitude_field": lat, "latitude_confidence": torch.sigmoid(log_confidence)}
+
+
+class PerspectiveDecoder(nn.Module):
+ """Minimal implementation of PerspectiveDecoder."""
+
+ def __init__(self):
+ """Perspective decoder wrapping up and latitude decoders."""
+ super().__init__()
+ self.up_head = UpDecoder()
+ self.latitude_head = LatitudeDecoder()
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Forward pass."""
+ return self.up_head(data) | self.latitude_head(data)
+
+
+class GeoCalib(nn.Module):
+ """GeoCalib inference model."""
+
+ def __init__(self, **optimizer_options):
+ """Initialize the GeoCalib inference model.
+
+ Args:
+ optimizer_options: Options for the lm optimizer.
+ """
+ super().__init__()
+ self.backbone = MSCAN()
+ self.ll_enc = LowLevelEncoder()
+ self.perspective_decoder = PerspectiveDecoder()
+
+ self.optimizer = LMOptimizer({**optimizer_options})
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Forward pass."""
+ features = {"hl": self.backbone(data)["features"], "ll": self.ll_enc(data)["features"]}
+ out = self.perspective_decoder({"features": features})
+
+ out |= {
+ k: data[k]
+ for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"]
+ if k in data
+ }
+
+ out |= self.optimizer(out)
+
+ return out
+
+ def flexible_load(self, state_dict: Dict[str, torch.Tensor]) -> None:
+ """Load a checkpoint with flexible key names."""
+ dict_params = set(state_dict.keys())
+ model_params = set(map(lambda n: n[0], self.named_parameters()))
+
+ if dict_params == model_params: # perfect fit
+ logger.info("Loading all parameters of the checkpoint.")
+ self.load_state_dict(state_dict, strict=True)
+ return
+ elif len(dict_params & model_params) == 0: # perfect mismatch
+ strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:])
+ state_dict = {strip_prefix(n): p for n, p in state_dict.items()}
+ dict_params = set(state_dict.keys())
+ if len(dict_params & model_params) == 0:
+ raise ValueError(
+ "Could not manage to load the checkpoint with"
+ "parameters:" + "\n\t".join(sorted(dict_params))
+ )
+ common_params = dict_params & model_params
+ left_params = dict_params - model_params
+ left_params = [
+ p for p in left_params if "running" not in p and "num_batches_tracked" not in p
+ ]
+ logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params)))
+ if left_params:
+ # ignore running stats of batchnorm
+ logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params)))
+ self.load_state_dict(state_dict, strict=False)
diff --git a/geocalib/gravity.py b/geocalib/gravity.py
new file mode 100644
index 0000000000000000000000000000000000000000..31f0e91e720ac5cf0a23e8db538e1f76a18f7fee
--- /dev/null
+++ b/geocalib/gravity.py
@@ -0,0 +1,131 @@
+"""Tensor class for gravity vector in camera frame."""
+
+import torch
+from torch.nn import functional as F
+
+from geocalib.misc import EuclideanManifold, SphericalManifold, TensorWrapper, autocast
+from geocalib.utils import rad2rotmat
+
+# mypy: ignore-errors
+
+
+class Gravity(TensorWrapper):
+ """Gravity vector in camera frame."""
+
+ eps = 1e-4
+
+ @autocast
+ def __init__(self, data: torch.Tensor) -> None:
+ """Create gravity vector from data.
+
+ Args:
+ data (torch.Tensor): gravity vector as 3D vector in camera frame.
+ """
+ assert data.shape[-1] == 3, data.shape
+
+ data = F.normalize(data, dim=-1)
+
+ super().__init__(data)
+
+ @classmethod
+ def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
+ """Create gravity vector from roll and pitch angles."""
+ if not isinstance(roll, torch.Tensor):
+ roll = torch.tensor(roll)
+ if not isinstance(pitch, torch.Tensor):
+ pitch = torch.tensor(pitch)
+
+ sr, cr = torch.sin(roll), torch.cos(roll)
+ sp, cp = torch.sin(pitch), torch.cos(pitch)
+ return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
+
+ @property
+ def vec3d(self) -> torch.Tensor:
+ """Return the gravity vector in the representation."""
+ return self._data
+
+ @property
+ def x(self) -> torch.Tensor:
+ """Return first component of the gravity vector."""
+ return self._data[..., 0]
+
+ @property
+ def y(self) -> torch.Tensor:
+ """Return second component of the gravity vector."""
+ return self._data[..., 1]
+
+ @property
+ def z(self) -> torch.Tensor:
+ """Return third component of the gravity vector."""
+ return self._data[..., 2]
+
+ @property
+ def roll(self) -> torch.Tensor:
+ """Return the roll angle of the gravity vector."""
+ roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
+ offset = -torch.pi * torch.sign(self.x)
+ return torch.where(self.y < 0, roll, -roll + offset)
+
+ def J_roll(self) -> torch.Tensor:
+ """Return the Jacobian of the roll angle of the gravity vector."""
+ cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
+ Jr = self.new_zeros(self.shape + (3,))
+ Jr[..., 0] = -cr * cp
+ Jr[..., 1] = sr * cp
+ return Jr
+
+ @property
+ def pitch(self) -> torch.Tensor:
+ """Return the pitch angle of the gravity vector."""
+ return torch.asin(self.z)
+
+ def J_pitch(self) -> torch.Tensor:
+ """Return the Jacobian of the pitch angle of the gravity vector."""
+ cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
+
+ Jp = self.new_zeros(self.shape + (3,))
+ Jp[..., 0] = sr * sp
+ Jp[..., 1] = cr * sp
+ Jp[..., 2] = cp
+ return Jp
+
+ @property
+ def rp(self) -> torch.Tensor:
+ """Return the roll and pitch angles of the gravity vector."""
+ return torch.stack([self.roll, self.pitch], dim=-1)
+
+ def J_rp(self) -> torch.Tensor:
+ """Return the Jacobian of the roll and pitch angles of the gravity vector."""
+ return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
+
+ @property
+ def R(self) -> torch.Tensor:
+ """Return the rotation matrix from the gravity vector."""
+ return rad2rotmat(roll=self.roll, pitch=self.pitch)
+
+ def J_R(self) -> torch.Tensor:
+ """Return the Jacobian of the rotation matrix from the gravity vector."""
+ raise NotImplementedError
+
+ def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
+ """Update the gravity vector by adding a delta."""
+ if spherical:
+ data = SphericalManifold.plus(self.vec3d, delta)
+ return self.__class__(data)
+
+ data = EuclideanManifold.plus(self.rp, delta)
+ return self.from_rp(data[..., 0], data[..., 1])
+
+ def J_update(self, spherical: bool = False) -> torch.Tensor:
+ """Return the Jacobian of the update."""
+ return (
+ SphericalManifold.J_plus(self.vec3d)
+ if spherical
+ else EuclideanManifold.J_plus(self.vec3d)
+ )
+
+ def __repr__(self):
+ """Print the Camera object."""
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
diff --git a/geocalib/interactive_demo.py b/geocalib/interactive_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2f0c4f9a0c0fb0982f700f9c1d34a523fc85c66
--- /dev/null
+++ b/geocalib/interactive_demo.py
@@ -0,0 +1,450 @@
+import argparse
+import logging
+import queue
+import threading
+import time
+from time import time
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from geocalib.extractor import GeoCalib
+from geocalib.perspective_fields import get_perspective_field
+from geocalib.utils import get_device, rad2deg
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+description = """
+-------------------------
+GeoCalib Interactive Demo
+-------------------------
+
+This script is an interactive demo for GeoCalib. It will open a window showing the camera feed and
+the calibration results.
+
+Arguments:
+- '--camera_id': Camera ID to use. If none, will ask for ip of droidcam (https://droidcam.app)
+
+You can toggle different features using the following keys:
+
+- 'h': Toggle horizon line
+- 'u': Toggle up vector field
+- 'l': Toggle latitude heatmap
+- 'c': Toggle confidence heatmap
+- 'd': Toggle undistorted image
+- 'g': Toggle grid of points
+- 'b': Toggle box object
+
+You can also change the camera model using the following keys:
+
+- '1': Pinhole
+- '2': Simple Radial
+- '3': Simple Divisional
+
+Press 'q' to quit the demo.
+"""
+
+
+# Custom VideoCapture class to get the most recent frame instead FIFO
+class VideoCapture:
+ def __init__(self, name):
+ self.cap = cv2.VideoCapture(name)
+ self.q = queue.Queue()
+ t = threading.Thread(target=self._reader)
+ t.daemon = True
+ t.start()
+
+ # read frames as soon as they are available, keeping only most recent one
+ def _reader(self):
+ while True:
+ ret, frame = self.cap.read()
+ if not ret:
+ break
+ if not self.q.empty():
+ try:
+ self.q.get_nowait() # discard previous (unprocessed) frame
+ except queue.Empty:
+ pass
+ self.q.put(frame)
+
+ def read(self):
+ return 1, self.q.get()
+
+ def isOpened(self):
+ return self.cap.isOpened()
+
+
+def add_text(frame, text, align_left=True, align_top=True):
+ """Add text to a plot."""
+ h, w = frame.shape[:2]
+ sc = min(h / 640.0, 2.0)
+ Ht = int(40 * sc) # text height
+
+ for i, l in enumerate(text.split("\n")):
+ max_line = len(max([l for l in text.split("\n")], key=len))
+ x = int(8 * sc if align_left else w - (max_line) * sc * 18)
+ y = Ht * (i + 1) if align_top else h - Ht * (len(text.split("\n")) - i - 1) - int(8 * sc)
+
+ c_back, c_front = (0, 0, 0), (255, 255, 255)
+ font, style = cv2.FONT_HERSHEY_DUPLEX, cv2.LINE_AA
+ cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_back, int(6 * sc), style)
+ cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_front, int(1 * sc), style)
+ return frame
+
+
+def is_corner(p, h, w):
+ """Check if a point is a corner."""
+ return p in [(0, 0), (0, h - 1), (w - 1, 0), (w - 1, h - 1)]
+
+
+def plot_latitude(frame, latitude):
+ """Plot latitude heatmap."""
+ if not isinstance(latitude, np.ndarray):
+ latitude = latitude.cpu().numpy()
+
+ cmap = plt.get_cmap("seismic")
+ h, w = frame.shape[0], frame.shape[1]
+ sc = min(h / 640.0, 2.0)
+
+ vmin, vmax = -90, 90
+ latitude = (latitude - vmin) / (vmax - vmin)
+
+ colors = (cmap(latitude)[..., :3] * 255).astype(np.uint8)[..., ::-1]
+ frame = cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
+
+ for contour_line in np.linspace(vmin, vmax, 15):
+ contour_line = (contour_line - vmin) / (vmax - vmin)
+
+ mask = (latitude > contour_line).astype(np.uint8)
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
+
+ for contour in contours:
+ color = (np.array(cmap(contour_line))[:3] * 255).astype(np.uint8)[::-1]
+
+ # remove corners
+ contour = [p for p in contour if not is_corner(tuple(p[0]), h, w)]
+ for index, item in enumerate(contour[:-1]):
+ cv2.line(frame, item[0], contour[index + 1][0], color.tolist(), int(5 * sc))
+
+ return frame
+
+
+def draw_horizon_line(frame, heatmap):
+ """Draw a horizon line."""
+ if not isinstance(heatmap, np.ndarray):
+ heatmap = heatmap.cpu().numpy()
+
+ h, w = frame.shape[0], frame.shape[1]
+ sc = min(h / 640.0, 2.0)
+
+ color = (0, 255, 255)
+ vmin, vmax = -90, 90
+ heatmap = (heatmap - vmin) / (vmax - vmin)
+
+ contours, _ = cv2.findContours(
+ (heatmap > 0.5).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
+ )
+ if contours:
+ contour = [p for p in contours[0] if not is_corner(tuple(p[0]), h, w)]
+ for index, item in enumerate(contour[:-1]):
+ cv2.line(frame, item[0], contour[index + 1][0], color, int(5 * sc))
+ return frame
+
+
+def plot_confidence(frame, confidence):
+ """Plot confidence heatmap."""
+ if not isinstance(confidence, np.ndarray):
+ confidence = confidence.cpu().numpy()
+
+ confidence = np.log10(confidence.clip(1e-6)).clip(-4)
+ confidence = (confidence - confidence.min()) / (confidence.max() - confidence.min())
+
+ cmap = plt.get_cmap("turbo")
+ colors = (cmap(confidence)[..., :3] * 255).astype(np.uint8)[..., ::-1]
+ return cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0)
+
+
+def plot_vector_field(frame, vector_field):
+ """Plot a vector field."""
+ if not isinstance(vector_field, np.ndarray):
+ vector_field = vector_field.cpu().numpy()
+
+ H, W = frame.shape[:2]
+ sc = min(H / 640.0, 2.0)
+
+ subsample = min(W, H) // 10
+ offset_x = ((W % subsample) + subsample) // 2
+ samples_x = np.arange(offset_x, W, subsample)
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
+
+ vec_len = 40 * sc
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
+ x, y = vector_field[:, samples_y][:, :, samples_x]
+ for xi, yi, xi_dir, yi_dir in zip(x_grid.ravel(), y_grid.ravel(), x.ravel(), y.ravel()):
+ start = (xi, yi)
+ end = (int(xi + xi_dir * vec_len), int(yi + yi_dir * vec_len))
+ cv2.arrowedLine(
+ frame, start, end, (0, 255, 0), int(5 * sc), line_type=cv2.LINE_AA, tipLength=0.3
+ )
+
+ return frame
+
+
+def plot_box(frame, gravity, camera):
+ """Plot a box object."""
+ pts = np.array(
+ [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]]
+ )
+ pts = pts - np.array([0.5, 1, 0.5])
+ rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
+ t = np.array([0, 0, 1], dtype=float)
+ K = camera.K[0].cpu().numpy().astype(float)
+ dist = np.zeros(4, dtype=float)
+ axis_points, _ = cv2.projectPoints(
+ 0.1 * pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist
+ )
+
+ h = frame.shape[0]
+ sc = min(h / 640.0, 2.0)
+
+ color = (85, 108, 228)
+ for p in axis_points:
+ center = tuple((int(p[0][0]), int(p[0][1])))
+ frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
+
+ for i in range(0, 4):
+ p1 = axis_points[i].astype(int)
+ p2 = axis_points[i + 4].astype(int)
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
+
+ p1 = axis_points[i].astype(int)
+ p2 = axis_points[(i + 1) % 4].astype(int)
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
+
+ p1 = axis_points[i + 4].astype(int)
+ p2 = axis_points[(i + 1) % 4 + 4].astype(int)
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
+
+ return frame
+
+
+def plot_grid(frame, gravity, camera, grid_size=0.2, num_points=5):
+ """Plot a grid of points."""
+ h = frame.shape[0]
+ sc = min(h / 640.0, 2.0)
+
+ samples = np.linspace(-grid_size, grid_size, num_points)
+ xz = np.meshgrid(samples, samples)
+ pts = np.stack((xz[0].ravel(), np.zeros_like(xz[0].ravel()), xz[1].ravel()), axis=-1)
+
+ # project points
+ rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0]
+ t = np.array([0, 0, 1], dtype=float)
+ K = camera.K[0].cpu().numpy().astype(float)
+ dist = np.zeros(4, dtype=float)
+ axis_points, _ = cv2.projectPoints(pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist)
+
+ color = (192, 77, 58)
+ # draw points
+ for p in axis_points:
+ center = tuple((int(p[0][0]), int(p[0][1])))
+ frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA)
+
+ # draw lines
+ for i in range(num_points):
+ for j in range(num_points - 1):
+ p1 = axis_points[i * num_points + j].astype(int)
+ p2 = axis_points[i * num_points + j + 1].astype(int)
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
+
+ p1 = axis_points[j * num_points + i].astype(int)
+ p2 = axis_points[(j + 1) * num_points + i].astype(int)
+ frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA)
+
+ return frame
+
+
+def undistort_image(img, camera, padding=0.3):
+ """Undistort an image."""
+ W, H = camera.size.unbind(-1)
+ H, W = H.int().item(), W.int().item()
+
+ pad_h, pad_w = int(H * padding), int(W * padding)
+ x, y = torch.meshgrid(torch.arange(0, W + pad_w), torch.arange(0, H + pad_h), indexing="xy")
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2) - torch.tensor([pad_w / 2, pad_h / 2])
+
+ p3d, _ = camera.pinhole().image2world(coords.to(camera.device).to(camera.dtype))
+ p2d, _ = camera.world2image(p3d)
+
+ p2d = p2d.float().numpy().reshape(H + pad_h, W + pad_w, 2)
+ img = cv2.remap(img, p2d[..., 0], p2d[..., 1], cv2.INTER_LINEAR, borderValue=(254, 254, 254))
+ return cv2.resize(img, (W, H))
+
+
+class InteractiveDemo:
+ def __init__(self, capture: VideoCapture, device: str) -> None:
+ self.cap = capture
+
+ self.device = torch.device(device)
+ self.model = GeoCalib().to(device)
+
+ self.up_toggle = False
+ self.lat_toggle = False
+ self.conf_toggle = False
+
+ self.hl_toggle = False
+ self.grid_toggle = False
+ self.box_toggle = False
+
+ self.undist_toggle = False
+
+ self.camera_model = "pinhole"
+
+ def render_frame(self, frame, calibration):
+ """Render the frame with the calibration results."""
+ camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
+
+ if self.undist_toggle:
+ return undistort_image(frame, camera)
+
+ up, lat = get_perspective_field(camera, gravity)
+
+ if gravity.pitch[0] > 0:
+ frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
+ frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
+ else:
+ frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame
+ frame = plot_box(frame, gravity, camera) if self.box_toggle else frame
+
+ frame = draw_horizon_line(frame, lat[0, 0]) if self.hl_toggle else frame
+
+ if self.conf_toggle and self.up_toggle:
+ frame = plot_confidence(frame, calibration["up_confidence"][0])
+ frame = plot_vector_field(frame, up[0]) if self.up_toggle else frame
+
+ if self.conf_toggle and self.lat_toggle:
+ frame = plot_confidence(frame, calibration["latitude_confidence"][0])
+ frame = plot_latitude(frame, rad2deg(lat)[0, 0]) if self.lat_toggle else frame
+
+ return frame
+
+ def format_results(self, calibration):
+ """Format the calibration results."""
+ camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu()
+
+ vfov, focal = camera.vfov[0].item(), camera.f[0, 0].item()
+ fov_unc = rad2deg(calibration["vfov_uncertainty"].item())
+ f_unc = calibration["focal_uncertainty"].item()
+
+ roll, pitch = gravity.rp[0].unbind(-1)
+ roll, pitch, vfov = rad2deg(roll), rad2deg(pitch), rad2deg(vfov)
+ roll_unc = rad2deg(calibration["roll_uncertainty"].item())
+ pitch_unc = rad2deg(calibration["pitch_uncertainty"].item())
+
+ text = f"{self.camera_model.replace('_', ' ').title()}\n"
+ text += f"Roll: {roll:.2f} (+- {roll_unc:.2f})\n"
+ text += f"Pitch: {pitch:.2f} (+- {pitch_unc:.2f})\n"
+ text += f"vFoV: {vfov:.2f} (+- {fov_unc:.2f})\n"
+ text += f"Focal: {focal:.2f} (+- {f_unc:.2f})"
+
+ if hasattr(camera, "k1"):
+ text += f"\nK1: {camera.k1[0].item():.2f}"
+
+ return text
+
+ def update_toggles(self):
+ """Update the toggles."""
+ key = cv2.waitKey(100) & 0xFF
+ if key == ord("h"):
+ self.hl_toggle = not self.hl_toggle
+ elif key == ord("u"):
+ self.up_toggle = not self.up_toggle
+ elif key == ord("l"):
+ self.lat_toggle = not self.lat_toggle
+ elif key == ord("c"):
+ self.conf_toggle = not self.conf_toggle
+ elif key == ord("d"):
+ self.undist_toggle = not self.undist_toggle
+ elif key == ord("g"):
+ self.grid_toggle = not self.grid_toggle
+ elif key == ord("b"):
+ self.box_toggle = not self.box_toggle
+
+ elif key == ord("1"):
+ self.camera_model = "pinhole"
+ elif key == ord("2"):
+ self.camera_model = "simple_radial"
+ elif key == ord("3"):
+ self.camera_model = "simple_divisional"
+
+ elif key == ord("q"):
+ return True
+
+ return False
+
+ def run(self):
+ """Run the interactive demo."""
+ while True:
+ start = time()
+ ret, frame = self.cap.read()
+
+ if not ret:
+ print("Error: Failed to retrieve frame.")
+ break
+
+ # create tensor from frame
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+ img = torch.tensor(img).permute(2, 0, 1) / 255.0
+
+ calibration = self.model.calibrate(img.to(self.device), camera_model=self.camera_model)
+
+ # render results to the frame
+ frame = self.render_frame(frame, calibration)
+ frame = add_text(frame, self.format_results(calibration))
+
+ end = time()
+ frame = add_text(
+ frame, f"FPS: {1 / (end - start):04.1f}", align_left=False, align_top=False
+ )
+
+ cv2.imshow("GeoCalib Demo", frame)
+
+ if self.update_toggles():
+ break
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--camera_id",
+ type=int,
+ default=None,
+ help="Camera ID to use. If none, will ask for ip of droidcam.",
+ )
+ args = parser.parse_args()
+
+ print(description)
+
+ device = get_device()
+ print(f"Running on: {device}")
+
+ # setup video capture
+ if args.camera_id is not None:
+ cap = VideoCapture(args.camera_id)
+ else:
+ ip = input("Enter the IP address of the camera: ")
+ cap = VideoCapture(f"http://{ip}:4747/video/force/1920x1080")
+
+ if not cap.isOpened():
+ raise ValueError("Error: Could not open camera.")
+
+ demo = InteractiveDemo(cap, device)
+ demo.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/geocalib/lm_optimizer.py b/geocalib/lm_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..45c28154aed238b35b5c48f37d04700882621450
--- /dev/null
+++ b/geocalib/lm_optimizer.py
@@ -0,0 +1,642 @@
+"""Implementation of the Levenberg-Marquardt optimizer for camera calibration."""
+
+import logging
+import time
+from types import SimpleNamespace
+from typing import Any, Callable, Dict, Tuple
+
+import torch
+import torch.nn as nn
+
+from geocalib.camera import BaseCamera, camera_models
+from geocalib.gravity import Gravity
+from geocalib.misc import J_focal2fov
+from geocalib.perspective_fields import J_perspective_field, get_perspective_field
+from geocalib.utils import focal2fov, rad2deg
+
+logger = logging.getLogger(__name__)
+
+
+def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
+
+ Args:
+ data (Dict[str, torch.Tensor]): Input data dictionary.
+ camera_model (BaseCamera): Camera model to use.
+
+ Returns:
+ BaseCamera: Initial camera for optimization.
+ """
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
+ ref = data.get("up_field", data["latitude_field"])
+ ref = ref.detach()
+
+ h, w = ref.shape[-2:]
+ batch_h, batch_w = (
+ ref.new_ones((ref.shape[0],)) * h,
+ ref.new_ones((ref.shape[0],)) * w,
+ )
+
+ init_r = ref.new_zeros((ref.shape[0],))
+ init_p = ref.new_zeros((ref.shape[0],))
+
+ focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
+ init_vfov = focal2fov(focal, h)
+
+ params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
+ params |= {"scales": data["scales"]} if "scales" in data else {}
+ params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
+ camera = camera_model.from_dict(params)
+ camera = camera.float().to(ref.device)
+
+ gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
+
+ if "prior_gravity" in data:
+ gravity = data["prior_gravity"].float().to(ref.device)
+ gravity = Gravity(gravity) if isinstance(gravity, torch.Tensor) else gravity
+
+ return camera, gravity
+
+
+def scaled_loss(
+ x: torch.Tensor, fn: Callable, a: float
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Apply a loss function to a tensor and pre- and post-scale it.
+
+ Args:
+ x: the data tensor, should already be squared: `x = y**2`.
+ fn: the loss function, with signature `fn(x) -> y`.
+ a: the scale parameter.
+
+ Returns:
+ The value of the loss, and its first and second derivatives.
+ """
+ a2 = a**2
+ loss, loss_d1, loss_d2 = fn(x / a2)
+ return loss * a2, loss_d1, loss_d2 / a2
+
+
+def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """The classical robust Huber loss, with first and second derivatives."""
+ mask = x <= 1
+ sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass
+ isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
+ loss = torch.where(mask, x, 2 * sx - 1)
+ loss_d1 = torch.where(mask, torch.ones_like(x), isx)
+ loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
+ return loss, loss_d1, loss_d2
+
+
+def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
+ """Early stopping criterion based on cost convergence."""
+ return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
+
+
+def update_lambda(
+ lamb: torch.Tensor,
+ prev_cost: torch.Tensor,
+ new_cost: torch.Tensor,
+ lambda_min: float = 1e-6,
+ lambda_max: float = 1e2,
+) -> torch.Tensor:
+ """Update damping factor for Levenberg-Marquardt optimization."""
+ new_lamb = lamb.new_zeros(lamb.shape)
+ new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
+ lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
+ return lamb
+
+
+def optimizer_step(
+ G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
+) -> torch.Tensor:
+ """One optimization step with Gauss-Newton or Levenberg-Marquardt.
+
+ Args:
+ G (torch.Tensor): Batched gradient tensor of size (..., N).
+ H (torch.Tensor): Batched hessian tensor of size (..., N, N).
+ lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
+ eps (float, optional): Epsilon for damping. Defaults to 1e-6.
+
+ Returns:
+ torch.Tensor: Batched update tensor of size (..., N).
+ """
+ diag = H.diagonal(dim1=-2, dim2=-1)
+ diag = diag * lambda_.unsqueeze(-1) # (B, 3)
+
+ H = H + diag.clamp(min=eps).diag_embed()
+
+ H_, G_ = H.cpu(), G.cpu()
+ try:
+ U = torch.linalg.cholesky(H_)
+ except RuntimeError:
+ logger.warning("Cholesky decomposition failed. Stopping.")
+ delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
+ else:
+ delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
+
+ return delta.to(H.device)
+
+
+# mypy: ignore-errors
+class LMOptimizer(nn.Module):
+ """Levenberg-Marquardt optimizer for camera calibration."""
+
+ default_conf = {
+ # Camera model parameters
+ "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
+ "shared_intrinsics": False, # share focal length across all images in batch
+ # LM optimizer parameters
+ "num_steps": 30,
+ "lambda_": 0.1,
+ "fix_lambda": False,
+ "early_stop": True,
+ "atol": 1e-8,
+ "rtol": 1e-8,
+ "use_spherical_manifold": True, # use spherical manifold for gravity optimization
+ "use_log_focal": True, # use log focal length for optimization
+ # Loss function parameters
+ "up_loss_fn_scale": 1e-2,
+ "lat_loss_fn_scale": 1e-2,
+ # Misc
+ "verbose": False,
+ }
+
+ def __init__(self, conf: Dict[str, Any]):
+ """Initialize the LM optimizer."""
+ super().__init__()
+ self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf})
+ self.num_steps = conf.num_steps
+
+ self.set_camera_model(conf.camera_model)
+ self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)
+
+ def set_camera_model(self, camera_model: str) -> None:
+ """Set the camera model to use for the optimization.
+
+ Args:
+ camera_model (str): Camera model to use.
+ """
+ assert (
+ camera_model in camera_models.keys()
+ ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
+ self.camera_model = camera_models[camera_model]
+ self.camera_has_distortion = hasattr(self.camera_model, "dist")
+
+ logger.debug(
+ f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
+ )
+
+ def setup_optimization_and_priors(
+ self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
+ ) -> None:
+ """Setup the optimization and priors for the LM optimizer.
+
+ Args:
+ data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults
+ to None.
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
+ Defaults to False.
+ """
+ if data is None:
+ data = {}
+ self.shared_intrinsics = shared_intrinsics
+
+ if shared_intrinsics: # si => must use pinhole
+ assert (
+ self.camera_model == camera_models["pinhole"]
+ ), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}"
+
+ self.estimate_gravity = True
+ if "prior_gravity" in data:
+ self.estimate_gravity = False
+ logger.debug("Using provided gravity as prior.")
+
+ self.estimate_focal = True
+ if "prior_focal" in data:
+ self.estimate_focal = False
+ logger.debug("Using provided focal as prior.")
+
+ self.estimate_k1 = True
+ if "prior_k1" in data:
+ self.estimate_k1 = False
+ logger.debug("Using provided k1 as prior.")
+
+ self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
+ self.focal_delta_dims = (
+ (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
+ )
+ self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,)
+
+ logger.debug(f"Camera Model: {self.camera_model}")
+ logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})")
+ logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})")
+ logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})")
+
+ logger.debug(f"Shared intrinsics: {self.shared_intrinsics}")
+
+ def calculate_residuals(
+ self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
+ ) -> Dict[str, torch.Tensor]:
+ """Calculate the residuals for the optimization.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields.
+
+ Returns:
+ Dict[str, torch.Tensor]: Residuals for the optimization.
+ """
+ perspective_up, perspective_lat = get_perspective_field(camera, gravity)
+ perspective_lat = torch.sin(perspective_lat)
+
+ residuals = {}
+ if "up_field" in data:
+ up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
+ residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)
+
+ if "latitude_field" in data:
+ target_lat = torch.sin(data["latitude_field"])
+ lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
+ residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)
+
+ return residuals
+
+ def calculate_costs(
+ self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """Calculate the costs and weights for the optimization.
+
+ Args:
+ residuals (torch.Tensor): Residuals for the optimization.
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence.
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the
+ optimization.
+ """
+ costs, weights = {}, {}
+
+ if "up_residual" in residuals:
+ up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
+ up_cost, up_weight, _ = scaled_loss(up_cost, huber_loss, self.conf.up_loss_fn_scale)
+
+ if "up_confidence" in data:
+ up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
+ up_weight = up_weight * up_conf
+ up_cost = up_cost * up_conf
+
+ costs["up_cost"] = up_cost
+ weights["up_weights"] = up_weight
+
+ if "latitude_residual" in residuals:
+ lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
+ lat_cost, lat_weight, _ = scaled_loss(lat_cost, huber_loss, self.conf.lat_loss_fn_scale)
+
+ if "latitude_confidence" in data:
+ lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
+ lat_weight = lat_weight * lat_conf
+ lat_cost = lat_cost * lat_conf
+
+ costs["latitude_cost"] = lat_cost
+ weights["latitude_weights"] = lat_weight
+
+ return costs, weights
+
+ def calculate_gradient_and_hessian(
+ self,
+ J: torch.Tensor,
+ residuals: torch.Tensor,
+ weights: torch.Tensor,
+ shared_intrinsics: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate the gradient and Hessian for given the Jacobian, residuals, and weights.
+
+ Args:
+ J (torch.Tensor): Jacobian.
+ residuals (torch.Tensor): Residuals.
+ weights (torch.Tensor): Weights.
+ shared_intrinsics (bool): Whether to share the intrinsics across the batch.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian.
+ """
+ dims = ()
+ if self.estimate_gravity:
+ dims = (0, 1)
+ if self.estimate_focal:
+ dims += (2,)
+ if self.camera_has_distortion and self.estimate_k1:
+ dims += (3,)
+ assert dims, "No parameters to optimize"
+
+ J = J[..., dims]
+
+ Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
+ Grad = weights[..., None] * Grad
+ Grad = Grad.sum(-2) # (B, N_params)
+
+ if shared_intrinsics:
+ # reshape to (1, B * (N_params-1) + 1)
+ Grad_g = Grad[..., :2].reshape(1, -1)
+ Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
+ Grad = torch.cat([Grad_g, Grad_f], dim=-1)
+
+ Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
+ Hess = weights[..., None, None] * Hess
+ Hess = Hess.sum(-3)
+
+ if shared_intrinsics:
+ H_g = torch.block_diag(*list(Hess[..., :2, :2]))
+ J_fg = Hess[..., :2, 2].flatten()
+ J_gf = Hess[..., 2, :2].flatten()
+ J_f = Hess[..., 2, 2].sum()
+ dims = H_g.shape[-1] + 1
+ Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
+ Hess[:-1, :-1] = H_g
+ Hess[-1, :-1] = J_gf
+ Hess[:-1, -1] = J_fg
+ Hess[-1, -1] = J_f
+ Hess = Hess.unsqueeze(0)
+
+ return Grad, Hess
+
+ def setup_system(
+ self,
+ camera: BaseCamera,
+ gravity: Gravity,
+ residuals: Dict[str, torch.Tensor],
+ weights: Dict[str, torch.Tensor],
+ as_rpf: bool = False,
+ shared_intrinsics: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate the gradient and Hessian for the optimization.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ residuals (Dict[str, torch.Tensor]): Residuals for the optimization.
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
+ as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to
+ roll, pitch, and focal length. Defaults to False.
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
+ Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization.
+ """
+ J_up, J_lat = J_perspective_field(
+ camera,
+ gravity,
+ spherical=self.conf.use_spherical_manifold and not as_rpf,
+ log_focal=self.conf.use_log_focal and not as_rpf,
+ )
+
+ J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3)
+ J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3)
+
+ n_params = (
+ 2 * self.estimate_gravity
+ + self.estimate_focal
+ + (self.camera_has_distortion and self.estimate_k1)
+ )
+ Grad = J_up.new_zeros(J_up.shape[0], n_params)
+ Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)
+
+ if shared_intrinsics:
+ N_params = Grad.shape[0] * (n_params - 1) + 1
+ Grad = Grad.new_zeros(1, N_params)
+ Hess = Hess.new_zeros(1, N_params, N_params)
+
+ if "up_residual" in residuals:
+ Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
+ J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
+ )
+
+ if self.conf.verbose:
+ logger.info(f"Up J:\n{Up_Grad.mean(0)}")
+
+ Grad = Grad + Up_Grad
+ Hess = Hess + Up_Hess
+
+ if "latitude_residual" in residuals:
+ Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
+ J_lat,
+ residuals["latitude_residual"],
+ weights["latitude_weights"],
+ shared_intrinsics,
+ )
+
+ if self.conf.verbose:
+ logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")
+
+ Grad = Grad + Lat_Grad
+ Hess = Hess + Lat_Hess
+
+ return Grad, Hess
+
+ def estimate_uncertainty(
+ self,
+ camera_opt: BaseCamera,
+ gravity_opt: Gravity,
+ errors: Dict[str, torch.Tensor],
+ weights: Dict[str, torch.Tensor],
+ ) -> Dict[str, torch.Tensor]:
+ """Estimate the uncertainty of the optimized camera and gravity at the final step.
+
+ Args:
+ camera_opt (BaseCamera): Final optimized camera.
+ gravity_opt (Gravity): Final optimized gravity.
+ errors (Dict[str, torch.Tensor]): Costs for the optimization.
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
+
+ Returns:
+ Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity.
+ """
+ _, Hess = self.setup_system(
+ camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
+ )
+ Cov = torch.inverse(Hess)
+
+ roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ if self.estimate_gravity:
+ roll_uncertainty = Cov[..., 0, 0]
+ pitch_uncertainty = Cov[..., 1, 1]
+
+ try:
+ delta_uncertainty = Cov[..., :2, :2]
+ eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
+ gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
+ except RuntimeError:
+ logger.warning("Could not calculate gravity uncertainty")
+ gravity_uncertainty = Cov.new_zeros(Cov.shape[0])
+
+ focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ if self.estimate_focal:
+ focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
+ fov_uncertainty = (
+ J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
+ )
+
+ return {
+ "covariance": Cov,
+ "roll_uncertainty": torch.sqrt(roll_uncertainty),
+ "pitch_uncertainty": torch.sqrt(pitch_uncertainty),
+ "gravity_uncertainty": torch.sqrt(gravity_uncertainty),
+ "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
+ "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
+ }
+
+ def update_estimate(
+ self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
+ ) -> Tuple[BaseCamera, Gravity]:
+ """Update the camera and gravity estimates with the given delta.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ delta (torch.Tensor): Delta to update the camera and gravity estimates.
+
+ Returns:
+ Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates.
+ """
+ delta_gravity = (
+ delta[..., self.gravity_delta_dims]
+ if self.estimate_gravity
+ else delta.new_zeros(delta.shape[:-1] + (2,))
+ )
+ new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)
+
+ delta_f = (
+ delta[..., self.focal_delta_dims]
+ if self.estimate_focal
+ else delta.new_zeros(delta.shape[:-1] + (1,))
+ )
+ new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)
+
+ delta_dist = (
+ delta[..., self.k1_delta_dims]
+ if self.camera_has_distortion and self.estimate_k1
+ else delta.new_zeros(delta.shape[:-1] + (1,))
+ )
+ if self.camera_has_distortion:
+ new_camera = new_camera.update_dist(delta_dist)
+
+ return new_camera, new_gravity
+
+ def optimize(
+ self,
+ data: Dict[str, torch.Tensor],
+ camera_opt: BaseCamera,
+ gravity_opt: Gravity,
+ ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:
+ """Optimize the camera and gravity estimates.
+
+ Args:
+ data (Dict[str, torch.Tensor]): Input data.
+ camera_opt (BaseCamera): Optimized camera.
+ gravity_opt (Gravity): Optimized gravity.
+
+ Returns:
+ Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity
+ estimates and optimization information.
+ """
+ key = list(data.keys())[0]
+ B = data[key].shape[0]
+
+ lamb = data[key].new_ones(B) * self.conf.lambda_
+ if self.shared_intrinsics:
+ lamb = data[key].new_ones(1) * self.conf.lambda_
+
+ infos = {"stop_at": self.num_steps}
+ for i in range(self.num_steps):
+ if self.conf.verbose:
+ logger.info(f"Step {i+1}/{self.num_steps}")
+
+ errors = self.calculate_residuals(camera_opt, gravity_opt, data)
+ costs, weights = self.calculate_costs(errors, data)
+
+ if i == 0:
+ prev_cost = sum(c.mean(-1) for c in costs.values())
+ for k, c in costs.items():
+ infos[f"initial_{k}"] = c.mean(-1)
+
+ infos["initial_cost"] = prev_cost
+
+ Grad, Hess = self.setup_system(
+ camera_opt,
+ gravity_opt,
+ errors,
+ weights,
+ shared_intrinsics=self.shared_intrinsics,
+ )
+ delta = optimizer_step(Grad, Hess, lamb) # (B, N_params)
+
+ if self.shared_intrinsics:
+ delta_g = delta[..., :-1].reshape(B, 2)
+ delta_f = delta[..., -1].expand(B, 1)
+ delta = torch.cat([delta_g, delta_f], dim=-1)
+
+ # calculate new cost
+ camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
+ new_cost, _ = self.calculate_costs(
+ self.calculate_residuals(camera_opt, gravity_opt, data), data
+ )
+ new_cost = sum(c.mean(-1) for c in new_cost.values())
+
+ if not self.conf.fix_lambda and not self.shared_intrinsics:
+ lamb = update_lambda(lamb, prev_cost, new_cost)
+
+ if self.conf.verbose:
+ logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}")
+ logger.info(f"Camera:\n{camera_opt._data}")
+
+ if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
+ infos["stop_at"] = min(i + 1, infos["stop_at"])
+
+ if self.conf.early_stop:
+ if self.conf.verbose:
+ logger.info(f"Early stopping at step {i+1}")
+ break
+
+ prev_cost = new_cost
+
+ if i == self.num_steps - 1 and self.conf.early_stop:
+ logger.warning("Reached maximum number of steps without convergence.")
+
+ final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3)
+ final_cost, weights = self.calculate_costs(final_errors, data) # (B, N)
+
+ if not self.training:
+ infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)
+
+ infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
+ for k, c in final_cost.items():
+ infos[f"final_{k}"] = c.mean(-1)
+
+ infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())
+
+ return camera_opt, gravity_opt, infos
+
+ def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Run the LM optimization."""
+ camera_init, gravity_init = get_trivial_estimation(data, self.camera_model)
+
+ self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)
+
+ start = time.time()
+ camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)
+
+ if self.conf.verbose:
+ logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")
+
+ logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
+ logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")
+
+ logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
+ logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")
+
+ return {"camera": camera_opt, "gravity": gravity_opt, **infos}
diff --git a/geocalib/misc.py b/geocalib/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6c481fee5276a69470559d427497d53b14acf7d
--- /dev/null
+++ b/geocalib/misc.py
@@ -0,0 +1,318 @@
+"""Miscellaneous functions and classes for the geocalib_inference package."""
+
+import functools
+import inspect
+import logging
+from typing import Callable, List
+
+import numpy as np
+import torch
+
+logger = logging.getLogger(__name__)
+
+# mypy: ignore-errors
+
+
+def autocast(func: Callable) -> Callable:
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
+
+ Use the device and dtype of the wrapper.
+
+ Args:
+ func (Callable): Method of a TensorWrapper class.
+
+ Returns:
+ Callable: Wrapped method.
+ """
+
+ @functools.wraps(func)
+ def wrap(self, *args):
+ device = torch.device("cpu")
+ dtype = None
+ if isinstance(self, TensorWrapper):
+ if self._data is not None:
+ device = self.device
+ dtype = self.dtype
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
+ raise ValueError(self)
+
+ cast_args = []
+ for arg in args:
+ if isinstance(arg, np.ndarray):
+ arg = torch.from_numpy(arg)
+ arg = arg.to(device=device, dtype=dtype)
+ cast_args.append(arg)
+ return func(self, *cast_args)
+
+ return wrap
+
+
+class TensorWrapper:
+ """Wrapper for PyTorch tensors."""
+
+ _data = None
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ """Wrapper for PyTorch tensors."""
+ self._data = data
+
+ @property
+ def shape(self) -> torch.Size:
+ """Shape of the underlying tensor."""
+ return self._data.shape[:-1]
+
+ @property
+ def device(self) -> torch.device:
+ """Get the device of the underlying tensor."""
+ return self._data.device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """Get the dtype of the underlying tensor."""
+ return self._data.dtype
+
+ def __getitem__(self, index) -> torch.Tensor:
+ """Get the underlying tensor."""
+ return self.__class__(self._data[index])
+
+ def __setitem__(self, index, item):
+ """Set the underlying tensor."""
+ self._data[index] = item.data
+
+ def to(self, *args, **kwargs):
+ """Move the underlying tensor to a new device."""
+ return self.__class__(self._data.to(*args, **kwargs))
+
+ def cpu(self):
+ """Move the underlying tensor to the CPU."""
+ return self.__class__(self._data.cpu())
+
+ def cuda(self):
+ """Move the underlying tensor to the GPU."""
+ return self.__class__(self._data.cuda())
+
+ def pin_memory(self):
+ """Pin the underlying tensor to memory."""
+ return self.__class__(self._data.pin_memory())
+
+ def float(self):
+ """Cast the underlying tensor to float."""
+ return self.__class__(self._data.float())
+
+ def double(self):
+ """Cast the underlying tensor to double."""
+ return self.__class__(self._data.double())
+
+ def detach(self):
+ """Detach the underlying tensor."""
+ return self.__class__(self._data.detach())
+
+ def numpy(self):
+ """Convert the underlying tensor to a numpy array."""
+ return self._data.detach().cpu().numpy()
+
+ def new_tensor(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_tensor(*args, **kwargs)
+
+ def new_zeros(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_zeros(*args, **kwargs)
+
+ def new_ones(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_ones(*args, **kwargs)
+
+ def new_full(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_full(*args, **kwargs)
+
+ def new_empty(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_empty(*args, **kwargs)
+
+ def unsqueeze(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self.__class__(self._data.unsqueeze(*args, **kwargs))
+
+ def squeeze(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self.__class__(self._data.squeeze(*args, **kwargs))
+
+ @classmethod
+ def stack(cls, objects: List, dim=0, *, out=None):
+ """Stack a list of objects with the same type and shape."""
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
+ return cls(data)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ """Support torch functions."""
+ if kwargs is None:
+ kwargs = {}
+ return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
+
+
+class EuclideanManifold:
+ """Simple euclidean manifold."""
+
+ @staticmethod
+ def J_plus(x: torch.Tensor) -> torch.Tensor:
+ """Plus operator Jacobian."""
+ return torch.eye(x.shape[-1]).to(x)
+
+ @staticmethod
+ def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
+ """Plus operator."""
+ return x + delta
+
+
+class SphericalManifold:
+ """Implementation of the spherical manifold.
+
+ Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
+
+ Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
+ Golub et al.
+ """
+
+ @staticmethod
+ def householder_vector(x: torch.Tensor) -> torch.Tensor:
+ """Return the Householder vector and beta.
+
+ Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
+ in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
+ first.
+
+ This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
+ orthogonal and H * x = ||x||_2 * e_n.
+
+ Args:
+ x (torch.Tensor): [..., n] tensor.
+
+ Returns:
+ torch.Tensor: v of shape [..., n]
+ torch.Tensor: beta of shape [...]
+ """
+ sigma = torch.sum(x[..., :-1] ** 2, -1)
+ xpiv = x[..., -1]
+ norm = torch.norm(x, dim=-1)
+ if torch.any(sigma < 1e-7):
+ sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
+ logger.warning("sigma < 1e-7")
+
+ vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
+ beta = 2 * vpiv**2 / (sigma + vpiv**2)
+ v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
+ return v, beta
+
+ @staticmethod
+ def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
+ """Apply Householder transformation.
+
+ Args:
+ y (torch.Tensor): Vector to transform of shape [..., n].
+ v (torch.Tensor): Householder vector of shape [..., n].
+ beta (torch.Tensor): Householder beta of shape [...].
+
+ Returns:
+ torch.Tensor: Transformed vector of shape [..., n].
+ """
+ return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
+
+ @classmethod
+ def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
+ """Plus operator Jacobian."""
+ v, beta = cls.householder_vector(x)
+ H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
+ H = H + torch.eye(H.shape[-1]).to(H)
+ return H[..., :-1] # J
+
+ @classmethod
+ def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
+ """Plus operator.
+
+ Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
+ element of the input vector as pivot instead of first.
+
+ Args:
+ x: point on the manifold
+ delta: tangent vector
+ """
+ eps = 1e-7
+ # keep norm is not equal to 1
+ nx = torch.norm(x, dim=-1, keepdim=True)
+ nd = torch.norm(delta, dim=-1, keepdim=True)
+
+ # make sure we don't divide by zero in backward as torch.where computes grad for both
+ # branches
+ nd_ = torch.where(nd < eps, nd + eps, nd)
+ sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
+
+ # cos is applied to last dim instead of first
+ exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
+
+ v, beta = cls.householder_vector(x)
+ return nx * cls.apply_householder(exp_delta, v, beta)
+
+
+@torch.jit.script
+def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
+ """Compute the jacobian of vec / norm2(vec).
+
+ Args:
+ vec (torch.Tensor): [..., D] tensor.
+
+ Returns:
+ torch.Tensor: [..., D, D] Jacobian.
+ """
+ D = vec.shape[-1]
+ norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
+
+ if (norm_x == 0).any():
+ norm_x = norm_x + 1e-6
+
+ xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
+ identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
+
+ return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
+
+
+@torch.jit.script
+def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the jacobian of the focal2fov function."""
+ return -4 * h / (4 * focal**2 + h**2)
+
+
+@torch.jit.script
+def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Compute the jacobian of the up-vector projection.
+
+ Args:
+ uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
+ abc (torch.Tensor): Gravity vector of shape (..., 3).
+ wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
+
+ Raises:
+ ValueError: If the wrt parameter is unknown.
+
+ Returns:
+ torch.Tensor: Jacobian with respect to the parameter.
+ """
+ if wrt == "uv":
+ c = abc[..., 2][..., None, None, None]
+ return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
+
+ elif wrt == "abc":
+ J = uv.new_zeros(uv.shape[:-1] + (2, 3))
+ J[..., 0, 0] = 1
+ J[..., 1, 1] = 1
+ J[..., 0, 2] = -uv[..., 0]
+ J[..., 1, 2] = -uv[..., 1]
+ return J
+
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
diff --git a/geocalib/modules.py b/geocalib/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..f635e84bc9e577711bd3e508800011a3821ce026
--- /dev/null
+++ b/geocalib/modules.py
@@ -0,0 +1,575 @@
+"""Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic
+Segmentation (NeurIPS 2022) adapted from
+
+https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/backbones/mscan.py
+
+
+Light Hamburger Decoder adapted from:
+
+https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py
+"""
+
+from typing import Dict, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.modules.utils import _pair as to_2tuple
+
+# flake8: noqa: E266
+# mypy: ignore-errors
+
+
+class ConvModule(nn.Module):
+ """Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ padding: int = 0,
+ use_norm: bool = False,
+ bias: bool = True,
+ ):
+ """Simple convolution block.
+
+ Args:
+ in_channels (int): Input channels.
+ out_channels (int): Output channels.
+ kernel_size (int): Kernel size.
+ padding (int, optional): Padding. Defaults to 0.
+ use_norm (bool, optional): Whether to use normalization. Defaults to False.
+ bias (bool, optional): Whether to use bias. Defaults to True.
+ """
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
+ self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
+ self.activate = nn.ReLU(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ x = self.conv(x)
+ x = self.bn(x)
+ return self.activate(x)
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features):
+ """Simple residual convolution block.
+
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
+
+ self.relu = torch.nn.ReLU(inplace=True)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(self, features: int, unit2only=False, upsample=True):
+ """Feature fusion block.
+
+ Args:
+ features (int): Number of features.
+ unit2only (bool, optional): Whether to use only the second unit. Defaults to False.
+ upsample (bool, optional): Whether to upsample. Defaults to True.
+ """
+ super().__init__()
+ self.upsample = upsample
+
+ if not unit2only:
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ output = xs[0]
+
+ if len(xs) == 2:
+ output = output + self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ if self.upsample:
+ output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False)
+
+ return output
+
+
+###################################################
+########### Light Hamburger Decoder ###############
+###################################################
+
+
+class NMF2D(nn.Module):
+ """Non-negative Matrix Factorization (NMF) for 2D data."""
+
+ def __init__(self):
+ """Non-negative Matrix Factorization (NMF) for 2D data."""
+ super().__init__()
+ self.S, self.D, self.R = 1, 512, 64
+ self.train_steps = 6
+ self.eval_steps = 7
+ self.inv_t = 1
+
+ def _build_bases(self, B: int, S: int, D: int, R: int, device: str = "cpu") -> torch.Tensor:
+ bases = torch.rand((B * S, D, R)).to(device)
+ return F.normalize(bases, dim=1)
+
+ def local_step(
+ self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Update bases and coefficient."""
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # Multiplicative Update
+ coef = coef * numerator / (denominator + 1e-6)
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
+ numerator = torch.bmm(x, coef)
+ # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
+ denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
+ # Multiplicative Update
+ bases = bases * numerator / (denominator + 1e-6)
+ return bases, coef
+
+ def compute_coef(
+ self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor
+ ) -> torch.Tensor:
+ """Compute coefficient."""
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # multiplication update
+ return coef * numerator / (denominator + 1e-6)
+
+ def local_inference(
+ self, x: torch.Tensor, bases: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Local inference."""
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ coef = torch.bmm(x.transpose(1, 2), bases)
+ coef = F.softmax(self.inv_t * coef, dim=-1)
+
+ steps = self.train_steps if self.training else self.eval_steps
+ for _ in range(steps):
+ bases, coef = self.local_step(x, bases, coef)
+
+ return bases, coef
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ B, C, H, W = x.shape
+
+ # (B, C, H, W) -> (B * S, D, N)
+ D = C // self.S
+ N = H * W
+ x = x.view(B * self.S, D, N)
+
+ # (S, D, R) -> (B * S, D, R)
+ bases = self._build_bases(B, self.S, D, self.R, device=x.device)
+ bases, coef = self.local_inference(x, bases)
+ # (B * S, N, R)
+ coef = self.compute_coef(x, bases, coef)
+ # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
+ x = torch.bmm(bases, coef.transpose(1, 2))
+ # (B * S, D, N) -> (B, C, H, W)
+ x = x.view(B, C, H, W)
+ # (B * H, D, R) -> (B, H, N, D)
+ bases = bases.view(B, self.S, D, self.R)
+
+ return x
+
+
+class Hamburger(nn.Module):
+ """Hamburger Module."""
+
+ def __init__(self, ham_channels: int = 512):
+ """Hambuger Module.
+
+ Args:
+ ham_channels (int, optional): Number of channels in the hamburger module. Defaults to
+ 512.
+ """
+ super().__init__()
+ self.ham_in = ConvModule(ham_channels, ham_channels, 1)
+ self.ham = NMF2D()
+ self.ham_out = ConvModule(ham_channels, ham_channels, 1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ enjoy = self.ham_in(x)
+ enjoy = F.relu(enjoy, inplace=False)
+ enjoy = self.ham(enjoy)
+ enjoy = self.ham_out(enjoy)
+ ham = F.relu(x + enjoy, inplace=False)
+ return ham
+
+
+class LightHamHead(nn.Module):
+ """Is Attention Better Than Matrix Decomposition?
+
+ This head is the implementation of `HamNet `.
+ """
+
+ def __init__(self):
+ """Light hamburger decoder head."""
+ super().__init__()
+ self.in_index = [0, 1, 2, 3]
+ self.in_channels = [64, 128, 320, 512]
+ self.out_channels = 64
+ self.ham_channels = 512
+ self.align_corners = False
+
+ self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1)
+
+ self.hamburger = Hamburger(self.ham_channels)
+
+ self.align = ConvModule(self.ham_channels, self.out_channels, 1)
+
+ self.linear_pred_uncertainty = nn.Sequential(
+ ConvModule(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1),
+ )
+
+ self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1, bias=False)
+ self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False)
+
+ def forward(self, features: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Forward pass."""
+ inputs = [features["hl"][i] for i in self.in_index]
+
+ inputs = [
+ F.interpolate(
+ level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ for level in inputs
+ ]
+
+ inputs = torch.cat(inputs, dim=1)
+ x = self.squeeze(inputs)
+
+ x = self.hamburger(x)
+
+ feats = self.align(x)
+
+ assert "ll" in features, "Low-level features are required for this model"
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
+ feats = self.out_conv(feats)
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
+ feats = self.ll_fusion(feats, features["ll"].clone())
+
+ uncertainty = self.linear_pred_uncertainty(feats).squeeze(1)
+
+ return feats, uncertainty
+
+
+###################################################
+########### MSCAN ################
+###################################################
+
+
+class DWConv(nn.Module):
+ """Depthwise convolution."""
+
+ def __init__(self, dim: int = 768):
+ """Depthwise convolution.
+
+ Args:
+ dim (int, optional): Number of features. Defaults to 768.
+ """
+ super().__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ return self.dwconv(x)
+
+
+class Mlp(nn.Module):
+ """MLP module."""
+
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+ ):
+ """Initialize the MLP."""
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """Forward pass."""
+ x = self.fc1(x)
+
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+
+ return x
+
+
+class StemConv(nn.Module):
+ """Simple stem convolution module."""
+
+ def __init__(self, in_channels: int, out_channels: int):
+ """Simple stem convolution module.
+
+ Args:
+ in_channels (int): Input channels.
+ out_channels (int): Output channels.
+ """
+ super().__init__()
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
+ ),
+ nn.BatchNorm2d(out_channels // 2),
+ nn.GELU(),
+ nn.Conv2d(
+ out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
+ ),
+ nn.BatchNorm2d(out_channels),
+ )
+
+ def forward(self, x):
+ """Forward pass."""
+ x = self.proj(x)
+ _, _, H, W = x.size()
+ x = x.flatten(2).transpose(1, 2)
+ return x, H, W
+
+
+class AttentionModule(nn.Module):
+ """Attention module."""
+
+ def __init__(self, dim: int):
+ """Attention module.
+
+ Args:
+ dim (int): Number of features.
+ """
+ super().__init__()
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+ self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
+ self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
+
+ self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
+ self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
+
+ self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
+ self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
+ self.conv3 = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ u = x.clone()
+ attn = self.conv0(x)
+
+ attn_0 = self.conv0_1(attn)
+ attn_0 = self.conv0_2(attn_0)
+
+ attn_1 = self.conv1_1(attn)
+ attn_1 = self.conv1_2(attn_1)
+
+ attn_2 = self.conv2_1(attn)
+ attn_2 = self.conv2_2(attn_2)
+ attn = attn + attn_0 + attn_1 + attn_2
+
+ attn = self.conv3(attn)
+ return attn * u
+
+
+class SpatialAttention(nn.Module):
+ """Spatial attention module."""
+
+ def __init__(self, dim: int):
+ """Spatial attention module.
+
+ Args:
+ dim (int): Number of features.
+ """
+ super().__init__()
+ self.d_model = dim
+ self.proj_1 = nn.Conv2d(dim, dim, 1)
+ self.activation = nn.GELU()
+ self.spatial_gating_unit = AttentionModule(dim)
+ self.proj_2 = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """Forward pass."""
+ shorcut = x.clone()
+ x = self.proj_1(x)
+ x = self.activation(x)
+ x = self.spatial_gating_unit(x)
+ x = self.proj_2(x)
+ x = x + shorcut
+ return x
+
+
+class Block(nn.Module):
+ """MSCAN block."""
+
+ def __init__(
+ self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, act_layer: nn.Module = nn.GELU
+ ):
+ """MSCAN block.
+
+ Args:
+ dim (int): Number of features.
+ mlp_ratio (float, optional): Ratio of the hidden features in the MLP. Defaults to 4.0.
+ drop (float, optional): Dropout rate. Defaults to 0.0.
+ act_layer (nn.Module, optional): Activation layer. Defaults to nn.GELU.
+ """
+ super().__init__()
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.attn = SpatialAttention(dim)
+ self.drop_path = nn.Identity() # only used in training
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+ )
+ layer_scale_init_value = 1e-2
+ self.layer_scale_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
+ )
+ self.layer_scale_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
+ )
+
+ def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor:
+ """Forward pass."""
+ B, N, C = x.shape
+ x = x.permute(0, 2, 1).view(B, C, H, W)
+ x = x + self.drop_path(self.layer_scale_1[..., None, None] * self.attn(self.norm1(x)))
+ x = x + self.drop_path(self.layer_scale_2[..., None, None] * self.mlp(self.norm2(x)))
+ return x.view(B, C, N).permute(0, 2, 1)
+
+
+class OverlapPatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(
+ self, patch_size: int = 7, stride: int = 4, in_chans: int = 3, embed_dim: int = 768
+ ):
+ """Image to Patch Embedding.
+
+ Args:
+ patch_size (int, optional): Image patch size. Defaults to 7.
+ stride (int, optional): Stride. Defaults to 4.
+ in_chans (int, optional): Number of input channels. Defaults to 3.
+ embed_dim (int, optional): Embedding dimension. Defaults to 768.
+ """
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
+ )
+ self.norm = nn.BatchNorm2d(embed_dim)
+
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]:
+ """Forward pass."""
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = self.norm(x)
+ x = x.flatten(2).transpose(1, 2)
+ return x, H, W
+
+
+class MSCAN(nn.Module):
+ """Multi-scale convolutional attention network."""
+
+ def __init__(self):
+ """Multi-scale convolutional attention network."""
+ super().__init__()
+ self.in_channels = 3
+ self.embed_dims = [64, 128, 320, 512]
+ self.mlp_ratios = [8, 8, 4, 4]
+ self.drop_rate = 0.0
+ self.drop_path_rate = 0.1
+ self.depths = [3, 3, 12, 3]
+ self.num_stages = 4
+
+ for i in range(self.num_stages):
+ if i == 0:
+ patch_embed = StemConv(3, self.embed_dims[0])
+ else:
+ patch_embed = OverlapPatchEmbed(
+ patch_size=7 if i == 0 else 3,
+ stride=4 if i == 0 else 2,
+ in_chans=self.in_chans if i == 0 else self.embed_dims[i - 1],
+ embed_dim=self.embed_dims[i],
+ )
+
+ block = nn.ModuleList(
+ [
+ Block(
+ dim=self.embed_dims[i],
+ mlp_ratio=self.mlp_ratios[i],
+ drop=self.drop_rate,
+ )
+ for _ in range(self.depths[i])
+ ]
+ )
+ norm = nn.LayerNorm(self.embed_dims[i])
+
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
+ setattr(self, f"block{i + 1}", block)
+ setattr(self, f"norm{i + 1}", norm)
+
+ def forward(self, data):
+ """Forward pass."""
+ # rgb -> bgr and from [0, 1] to [0, 255]
+ x = data["image"][:, [2, 1, 0], :, :] * 255.0
+ B = x.shape[0]
+
+ outs = []
+ for i in range(self.num_stages):
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
+ block = getattr(self, f"block{i + 1}")
+ norm = getattr(self, f"norm{i + 1}")
+ x, H, W = patch_embed(x)
+ for blk in block:
+ x = blk(x, H, W)
+ x = norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ return {"features": outs}
diff --git a/geocalib/perspective_fields.py b/geocalib/perspective_fields.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6b099ce4e9da1a31b785969c95a4af94029cf76
--- /dev/null
+++ b/geocalib/perspective_fields.py
@@ -0,0 +1,366 @@
+"""Implementation of perspective fields.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
+"""
+
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+from geocalib.camera import BaseCamera
+from geocalib.gravity import Gravity
+from geocalib.misc import J_up_projection, J_vecnorm, SphericalManifold
+
+# flake8: noqa: E266
+
+
+def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
+ """Get the horizon line from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
+
+ Returns:
+ torch.Tensor: In image frame, fraction of image left/right border intersection with
+ respect to image height.
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ # project horizon midpoint to image plane
+ horizon_midpoint = camera.new_tensor([0, 0, 1])
+ horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
+ midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
+
+ # compute left and right offset to borders
+ left_offset = midpoint[0] * torch.tan(gravity.roll)
+ right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
+ left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
+
+ horizon = camera.new_tensor([left, right])
+ return horizon / camera.size[1] if relative else horizon
+
+
+def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
+ """Get the up vector field from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
+
+ Returns:
+ torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ uv = camera.normalize(camera.pixel_coordinates())
+
+ # projected up is (a, b) - c * (u, v)
+ abc = gravity.vec3d
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
+
+ if hasattr(camera, "dist"):
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
+ offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
+
+ # (..., N, 2)
+ projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
+
+ if normalize:
+ projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
+
+ return projected_up2d.reshape(camera.shape[0], h, w, 2)
+
+
+def J_up_field(
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
+) -> torch.Tensor:
+ """Get the jacobian of the up field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # Forward
+ xy = camera.pixel_coordinates()
+ uv = camera.normalize(xy)
+
+ projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
+
+ # Backward
+ J = []
+
+ # (..., N, 2, 2)
+ J_norm2proj = J_vecnorm(
+ get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
+ )
+
+ # distortion values
+ if hasattr(camera, "dist"):
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
+ offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
+
+ ######################
+ ## Gravity Jacobian ##
+ ######################
+
+ J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
+
+ if hasattr(camera, "dist"):
+ # (..., N, 2, 3)
+ J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
+
+ J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
+ J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
+ J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
+ J.append(J_up2delta)
+
+ ######################
+ ### Focal Jacobian ###
+ ######################
+
+ J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
+
+ if hasattr(camera, "dist"):
+ J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
+ J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
+
+ inner = (uv * projected_up2d).sum(-1)[..., None, None]
+ J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
+ J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
+ J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
+
+ J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
+
+ if log_focal:
+ J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
+
+ J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
+
+ J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
+ J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
+ J.append(J_up2f)
+
+ ######################
+ ##### K1 Jacobian ####
+ ######################
+
+ if hasattr(camera, "dist"):
+ J_duv = camera.J_distort(uv, wrt="scale2dist")
+ J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ J_offset = torch.einsum(
+ "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
+ )
+ J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
+ J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
+ J.append(J_k1)
+
+ n_params = sum(j.shape[-1] for j in J)
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
+
+
+def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
+ """Get the latitudes of the camera pixels in radians.
+
+ Latitudes are defined as the angle between the ray and the up vector.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+
+ Returns:
+ torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ uv1, _ = camera.image2world(camera.pixel_coordinates())
+ rays = camera.pixel_bearing_many(uv1)
+
+ lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
+
+ eps = 1e-6
+ lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
+
+ return lat_asin.reshape(camera.shape[0], h, w, 1)
+
+
+def J_latitude_field(
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
+) -> torch.Tensor:
+ """Get the jacobian of the latitude field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # Forward
+ xy = camera.pixel_coordinates()
+ uv1, _ = camera.image2world(xy)
+ uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
+
+ # Backward
+ J = []
+ J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
+
+ ######################
+ ## Gravity Jacobian ##
+ ######################
+
+ J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
+ J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
+ J.append(J_delta)
+
+ ######################
+ ### Focal Jacobian ###
+ ######################
+
+ J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
+ if log_focal:
+ J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
+ J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
+
+ J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
+ J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
+ J.append(J_f)
+
+ ######################
+ ##### K1 Jacobian ####
+ ######################
+
+ if hasattr(camera, "dist"):
+ J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
+ # (..., N, 2)
+ J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
+ # (..., N, 1)
+ J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
+ J.append(J_k1)
+
+ n_params = sum(j.shape[-1] for j in J)
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
+
+
+def get_perspective_field(
+ camera: BaseCamera,
+ gravity: Gravity,
+ use_up: bool = True,
+ use_latitude: bool = True,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the perspective field from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
+ (..., 2, h, w) and (..., 1, h, w).
+ """
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
+
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if use_up:
+ permute = (0, 3, 1, 2)
+ # (..., 2, h, w)
+ up = get_up_field(camera, gravity).permute(permute)
+ else:
+ shape = (camera.shape[0], 2, h, w)
+ up = camera.new_zeros(shape)
+
+ if use_latitude:
+ permute = (0, 3, 1, 2)
+ # (..., 1, h, w)
+ lat = get_latitude_field(camera, gravity).permute(permute)
+ else:
+ shape = (camera.shape[0], 1, h, w)
+ lat = camera.new_zeros(shape)
+
+ return up, lat
+
+
+def J_perspective_field(
+ camera: BaseCamera,
+ gravity: Gravity,
+ use_up: bool = True,
+ use_latitude: bool = True,
+ spherical: bool = False,
+ log_focal: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the jacobian of the perspective field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
+ (..., h, w, 2, 4) and (..., h, w, 1, 4).
+ """
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
+
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if use_up:
+ J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
+ else:
+ shape = (camera.shape[0], h, w, 2, 4)
+ J_up = camera.new_zeros(shape)
+
+ if use_latitude:
+ J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
+ else:
+ shape = (camera.shape[0], h, w, 1, 4)
+ J_lat = camera.new_zeros(shape)
+
+ return J_up, J_lat
diff --git a/geocalib/utils.py b/geocalib/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5372e157443502c6f4e5cacf46968839e14af341
--- /dev/null
+++ b/geocalib/utils.py
@@ -0,0 +1,325 @@
+"""Image loading and general conversion utilities."""
+
+import collections.abc as collections
+from pathlib import Path
+from types import SimpleNamespace
+from typing import Dict, Optional, Tuple
+
+import cv2
+import kornia
+import numpy as np
+import torch
+import torchvision
+
+# mypy: ignore-errors
+
+
+def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
+ """Get padding to make the image size a multiple of the given number.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ multiple (int, optional): Multiple to fit to.
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
+
+ Returns:
+ torch.Tensor: Padding.
+ """
+ h, w = x.shape[-2:]
+
+ if crop:
+ pad_w = (w // multiple) * multiple - w
+ pad_h = (h // multiple) * multiple - h
+ else:
+ pad_w = (multiple - w % multiple) % multiple
+ pad_h = (multiple - h % multiple) % multiple
+
+ if mode == "center":
+ pad_l = pad_w // 2
+ pad_r = pad_w - pad_l
+ pad_t = pad_h // 2
+ pad_b = pad_h - pad_t
+ elif mode == "left":
+ pad_l, pad_r = 0, pad_w
+ pad_t, pad_b = 0, pad_h
+ else:
+ raise ValueError(f"Unknown mode {mode}")
+
+ return (pad_l, pad_r, pad_t, pad_b)
+
+
+def fit_features_to_multiple(
+ features: torch.Tensor, multiple: int = 32, crop: bool = False
+) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """Pad or crop image to a multiple of the given number.
+
+ Args:
+ features (torch.Tensor): Input features.
+ multiple (int, optional): Multiple. Defaults to 32.
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
+ """
+ pad = fit_to_multiple(features, multiple, crop=crop)
+ return torch.nn.functional.pad(features, pad, mode="reflect"), pad
+
+
+class ImagePreprocessor:
+ """Preprocess images for calibration."""
+
+ default_conf = {
+ "resize": 320, # target edge length, None for no resizing
+ "edge_divisible_by": None,
+ "side": "short",
+ "interpolation": "bilinear",
+ "align_corners": None,
+ "antialias": True,
+ "square_crop": False,
+ "add_padding_mask": False,
+ "resize_backend": "kornia", # torchvision, kornia
+ }
+
+ def __init__(self, conf) -> None:
+ """Initialize the image preprocessor."""
+ self.conf = {**self.default_conf, **conf}
+ self.conf = SimpleNamespace(**self.conf)
+
+ def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
+ """Resize and preprocess an image, return image and resize scale."""
+ h, w = img.shape[-2:]
+ size = h, w
+
+ if self.conf.square_crop:
+ min_size = min(h, w)
+ offset = (h - min_size) // 2, (w - min_size) // 2
+ img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
+ size = img.shape[-2:]
+
+ if self.conf.resize is not None:
+ if interpolation is None:
+ interpolation = self.conf.interpolation
+ size = self.get_new_image_size(h, w)
+ img = self.resize(img, size, interpolation)
+
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
+ T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
+
+ data = {
+ "scales": scale,
+ "image_size": np.array(size[::-1]),
+ "transform": T,
+ "original_image_size": np.array([w, h]),
+ }
+
+ if self.conf.edge_divisible_by is not None:
+ # crop to make the edge divisible by a number
+ w_, h_ = img.shape[-1], img.shape[-2]
+ img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
+ crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
+ data["crop_pad"] = crop_pad
+ data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
+
+ data["image"] = img
+ return data
+
+ def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
+ """Resize an image using the specified backend."""
+ if self.conf.resize_backend == "kornia":
+ return kornia.geometry.transform.resize(
+ img,
+ size,
+ side=self.conf.side,
+ antialias=self.conf.antialias,
+ align_corners=self.conf.align_corners,
+ interpolation=interpolation,
+ )
+ elif self.conf.resize_backend == "torchvision":
+ return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
+ else:
+ raise ValueError(f"{self.conf.resize_backend} not implemented.")
+
+ def load_image(self, image_path: Path) -> dict:
+ """Load an image from a path and preprocess it."""
+ return self(load_image(image_path))
+
+ def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
+ """Get the new image size after resizing."""
+ side = self.conf.side
+ if isinstance(self.conf.resize, collections.Iterable):
+ assert len(self.conf.resize) == 2
+ return tuple(self.conf.resize)
+ side_size = self.conf.resize
+ aspect_ratio = w / h
+ if side not in ("short", "long", "vert", "horz"):
+ raise ValueError(
+ f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
+ )
+ return (
+ (side_size, int(side_size * aspect_ratio))
+ if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
+ else (int(side_size / aspect_ratio), side_size)
+ )
+
+
+def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
+ """Normalize the image tensor and reorder the dimensions."""
+ if image.ndim == 3:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f"Not an image: {image.shape}")
+ return torch.tensor(image / 255.0, dtype=torch.float)
+
+
+def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
+ """Normalize and reorder the dimensions of an image tensor."""
+ if image.ndim == 3:
+ image = image.permute((1, 2, 0)) # CxHxW to HxWxC
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f"Not an image: {image.shape}")
+ return (image.cpu().detach().numpy() * 255).astype(np.uint8)
+
+
+def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
+ """Read an image from path as RGB or grayscale."""
+ if not Path(path).exists():
+ raise FileNotFoundError(f"No image at path {path}.")
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise IOError(f"Could not read image at {path}.")
+ if not grayscale:
+ image = image[..., ::-1]
+ return image
+
+
+def write_image(img: torch.Tensor, path: Path):
+ """Write an image tensor to a file."""
+ img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
+ cv2.imwrite(str(path), img[..., ::-1])
+
+
+def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
+ """Load an image from a path and return as a tensor."""
+ image = read_image(path, grayscale=grayscale)
+ if return_tensor:
+ return numpy_image_to_torch(image)
+
+ assert image.ndim in [2, 3], f"Not an image: {image.shape}"
+ image = image[None] if image.ndim == 2 else image
+ return torch.tensor(image.copy(), dtype=torch.uint8)
+
+
+def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
+
+ Args:
+ (torch.Tensor): Vector of size (..., 3).
+
+ Returns:
+ (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
+ """
+ z = torch.zeros_like(v[..., 0])
+ return torch.stack(
+ [z, -v[..., 2], v[..., 1], v[..., 2], z, -v[..., 0], -v[..., 1], v[..., 0], z], dim=-1
+ ).reshape(v.shape[:-1] + (3, 3))
+
+
+def rad2rotmat(
+ roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
+
+ Args:
+ roll (torch.Tensor): Roll angle in radians.
+ pitch (torch.Tensor): Pitch angle in radians.
+ yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
+
+ Returns:
+ torch.Tensor: Rotation matrix of shape (..., 3, 3).
+ """
+ if yaw is None:
+ yaw = roll.new_zeros(roll.shape)
+
+ Rx = pitch.new_zeros(pitch.shape + (3, 3))
+ Rx[..., 0, 0] = 1
+ Rx[..., 1, 1] = torch.cos(pitch)
+ Rx[..., 1, 2] = torch.sin(pitch)
+ Rx[..., 2, 1] = -torch.sin(pitch)
+ Rx[..., 2, 2] = torch.cos(pitch)
+
+ Ry = yaw.new_zeros(yaw.shape + (3, 3))
+ Ry[..., 0, 0] = torch.cos(yaw)
+ Ry[..., 0, 2] = -torch.sin(yaw)
+ Ry[..., 1, 1] = 1
+ Ry[..., 2, 0] = torch.sin(yaw)
+ Ry[..., 2, 2] = torch.cos(yaw)
+
+ Rz = roll.new_zeros(roll.shape + (3, 3))
+ Rz[..., 0, 0] = torch.cos(roll)
+ Rz[..., 0, 1] = torch.sin(roll)
+ Rz[..., 1, 0] = -torch.sin(roll)
+ Rz[..., 1, 1] = torch.cos(roll)
+ Rz[..., 2, 2] = 1
+
+ return Rz @ Rx @ Ry
+
+
+def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
+ """Compute focal length from (vertical/horizontal) field of view."""
+ return size / 2 / torch.tan(fov / 2)
+
+
+def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
+ """Compute (vertical/horizontal) field of view from focal length."""
+ return 2 * torch.arctan(size / (2 * focal))
+
+
+def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the distance from principal point to the horizon."""
+ return torch.tan(pitch) * f / h
+
+
+def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the pitch angle from the distance to the horizon."""
+ return torch.atan(rho * h / f)
+
+
+def rad2deg(rad: torch.Tensor) -> torch.Tensor:
+ """Convert radians to degrees."""
+ return rad / torch.pi * 180
+
+
+def deg2rad(deg: torch.Tensor) -> torch.Tensor:
+ """Convert degrees to radians."""
+ return deg / 180 * torch.pi
+
+
+def get_device() -> str:
+ """Get the device (cpu, cuda, mps) available."""
+ device = "cpu"
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif torch.backends.mps.is_available():
+ device = "mps"
+ return device
+
+
+def print_calibration(results: Dict[str, torch.Tensor]) -> None:
+ """Print the calibration results."""
+ camera, gravity = results["camera"], results["gravity"]
+ vfov = rad2deg(camera.vfov)
+ roll, pitch = rad2deg(gravity.rp).unbind(-1)
+
+ print("\nEstimated parameters (Pred):")
+ print(f"Roll: {roll.item():.1f}° (± {rad2deg(results['roll_uncertainty']).item():.1f})°")
+ print(f"Pitch: {pitch.item():.1f}° (± {rad2deg(results['pitch_uncertainty']).item():.1f})°")
+ print(f"vFoV: {vfov.item():.1f}° (± {rad2deg(results['vfov_uncertainty']).item():.1f})°")
+ print(f"Focal: {camera.f[0, 1].item():.1f} px (± {results['focal_uncertainty'].item():.1f} px)")
+
+ if hasattr(camera, "k1"):
+ print(f"K1: {camera.k1.item():.1f}")
diff --git a/geocalib/viz2d.py b/geocalib/viz2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..2747684c826d0ab610d399d99cd4eb5b22fbfd27
--- /dev/null
+++ b/geocalib/viz2d.py
@@ -0,0 +1,502 @@
+"""2D visualization primitives based on Matplotlib.
+
+1) Plot images with `plot_images`.
+2) Call functions to plot heatmaps, vector fields, and horizon lines.
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib.patheffects as path_effects
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from geocalib.perspective_fields import get_perspective_field
+from geocalib.utils import rad2deg
+
+# mypy: ignore-errors
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
+ """Plot a list of images.
+
+ Args:
+ imgs (List[np.ndarray]): List of images to plot.
+ titles (List[str], optional): Titles. Defaults to None.
+ cmaps (str, optional): Colormaps. Defaults to "gray".
+ dpi (int, optional): Dots per inch. Defaults to 200.
+ pad (float, optional): Padding. Defaults to 0.5.
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
+
+ Returns:
+ plt.Figure: Figure of the images.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
+ if n == 1:
+ axs = [axs]
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
+ ax.set_axis_off()
+ if titles:
+ ax.set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+ return fig
+
+
+def plot_image_grid(
+ imgs,
+ titles=None,
+ cmaps="gray",
+ dpi=100,
+ pad=0.5,
+ fig=None,
+ adaptive=True,
+ figs=3.0,
+ return_fig=False,
+ set_lim=False,
+) -> plt.Figure:
+ """Plot a grid of images.
+
+ Args:
+ imgs (List[np.ndarray]): List of images to plot.
+ titles (List[str], optional): Titles. Defaults to None.
+ cmaps (str, optional): Colormaps. Defaults to "gray".
+ dpi (int, optional): Dots per inch. Defaults to 100.
+ pad (float, optional): Padding. Defaults to 0.5.
+ fig (_type_, optional): Figure to plot on. Defaults to None.
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
+ figs (float, optional): Figure size. Defaults to 3.0.
+ return_fig (bool, optional): Whether to return the figure. Defaults to False.
+ set_lim (bool, optional): Whether to set the limits. Defaults to False.
+
+ Returns:
+ plt.Figure: Figure and axes or just axes.
+ """
+ nr, n = len(imgs), len(imgs[0])
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
+ else:
+ ratios = [4 / 3] * n
+
+ figsize = [sum(ratios) * figs, nr * figs]
+ if fig is None:
+ fig, axs = plt.subplots(
+ nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ else:
+ axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
+ fig.figure.set_size_inches(figsize)
+
+ if nr == 1 and n == 1:
+ axs = [[axs]]
+ elif n == 1:
+ axs = axs[:, None]
+ elif nr == 1:
+ axs = [axs]
+
+ for j in range(nr):
+ for i in range(n):
+ ax = axs[j][i]
+ ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
+ ax.set_axis_off()
+ if set_lim:
+ ax.set_xlim([0, imgs[j][i].shape[1]])
+ ax.set_ylim([imgs[j][i].shape[0], 0])
+ if titles:
+ ax.set_title(titles[j][i])
+ if isinstance(fig, plt.Figure):
+ fig.tight_layout(pad=pad)
+ return (fig, axs) if return_fig else axs
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=4,
+ ha="left",
+ va="top",
+ axes=None,
+ **kwargs,
+):
+ """Add text to a plot.
+
+ Args:
+ idx (int): Index of the axes.
+ text (str): Text to add.
+ pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
+ fs (int, optional): Font size. Defaults to 15.
+ color (str, optional): Text color. Defaults to "w".
+ lcolor (str, optional): Line color. Defaults to "k".
+ lwidth (int, optional): Line width. Defaults to 4.
+ ha (str, optional): Horizontal alignment. Defaults to "left".
+ va (str, optional): Vertical alignment. Defaults to "top".
+ axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
+
+ Returns:
+ plt.Text: Text object.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ ax = axes[idx]
+
+ t = ax.text(
+ *pos,
+ text,
+ fontsize=fs,
+ ha=ha,
+ va=va,
+ color=color,
+ transform=ax.transAxes,
+ zorder=5,
+ **kwargs,
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+ return t
+
+
+def plot_heatmaps(
+ heatmaps,
+ vmin=-1e-6, # include negative zero
+ vmax=None,
+ cmap="Spectral",
+ a=0.5,
+ axes=None,
+ contours_every=None,
+ contour_style="solid",
+ colorbar=False,
+):
+ """Plot heatmaps with optional contours.
+
+ To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
+
+ Args:
+ heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
+ vmin (float, optional): Min Value. Defaults to -1e-6.
+ vmax (float, optional): Max Value. Defaults to None.
+ cmap (str, optional): Colormap. Defaults to "Spectral".
+ a (float, optional): Alpha value. Defaults to 0.5.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+ contours_every (int, optional): If not none, will draw contours. Defaults to None.
+ contour_style (str, optional): Style of the contours. Defaults to "solid".
+ colorbar (bool, optional): Whether to show colorbar. Defaults to False.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+ artists = []
+
+ for i in range(len(axes)):
+ a_ = a if isinstance(a, float) else a[i]
+
+ if isinstance(heatmaps[i], torch.Tensor):
+ heatmaps[i] = heatmaps[i].cpu().numpy()
+
+ alpha = a_
+ # Plot the heatmap
+ art = axes[i].imshow(
+ heatmaps[i],
+ alpha=alpha,
+ vmin=vmin,
+ vmax=vmax,
+ cmap=cmap,
+ )
+ if colorbar:
+ cmax = vmax or np.percentile(heatmaps[i], 99)
+ art.set_clim(vmin, cmax)
+ cbar = plt.colorbar(art, ax=axes[i])
+ artists.append(cbar)
+
+ artists.append(art)
+
+ if contours_every is not None:
+ # Add contour lines to the heatmap
+ contour_data = np.arange(vmin, vmax + contours_every, contours_every)
+
+ # Get the colormap colors for contour lines
+ contour_colors = [
+ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
+ for level in contour_data
+ ]
+ contours = axes[i].contour(
+ heatmaps[i],
+ levels=contour_data,
+ linewidths=2,
+ colors=contour_colors,
+ linestyles=contour_style,
+ )
+
+ contours.set_clim(vmin, vmax)
+
+ fmt = {
+ level: f"{label}°"
+ for level, label in zip(contour_data, contour_data.astype(int).astype(str))
+ }
+ t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
+
+ for label in t:
+ label.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=1, foreground="k"),
+ path_effects.Normal(),
+ ]
+ )
+ artists.append(contours)
+
+ return artists
+
+
+def plot_horizon_lines(
+ cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
+):
+ """Plot horizon lines on the perspective field.
+
+ Args:
+ cameras (List[Camera]): List of cameras.
+ gravities (List[Gravity]): Gravities.
+ line_colors (str, optional): Line Colors. Defaults to "orange".
+ lw (int, optional): Line width. Defaults to 2.
+ styles (str, optional): Line styles. Defaults to "solid".
+ alpha (float, optional): Alphas. Defaults to 1.0.
+ ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(cameras)
+
+ if not isinstance(styles, list):
+ styles = [styles] * len(cameras)
+
+ fig = plt.gcf()
+ ax = fig.gca() if ax is None else ax
+
+ if isinstance(ax, plt.Axes):
+ ax = [ax] * len(cameras)
+
+ assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
+
+ for i in range(len(cameras)):
+ _, lat = get_perspective_field(cameras[i], gravities[i])
+ # horizon line is zero level of the latitude field
+ lat = lat[0, 0].cpu().numpy()
+ contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
+ for contour_line in contours.collections:
+ contour_line.set_linestyle(styles[i])
+
+
+def plot_vector_fields(
+ vector_fields,
+ cmap="lime",
+ subsample=15,
+ scale=None,
+ lw=None,
+ alphas=0.8,
+ axes=None,
+):
+ """Plot vector fields.
+
+ Args:
+ vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
+ cmap (str, optional): Color of the vectors. Defaults to "lime".
+ subsample (int, optional): Subsample the vector field. Defaults to 15.
+ scale (float, optional): Scale of the vectors. Defaults to None.
+ lw (float, optional): Line width of the vectors. Defaults to None.
+ alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
+ axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
+
+ artists = []
+
+ H, W = vector_fields[0].shape[-2:]
+ if scale is None:
+ scale = subsample / min(H, W)
+
+ if lw is None:
+ lw = 0.1 / subsample
+
+ if alphas is None:
+ alphas = np.ones_like(vector_fields[0][0])
+ alphas = np.stack([alphas] * len(vector_fields), 0)
+ elif isinstance(alphas, float):
+ alphas = np.ones_like(vector_fields[0][0]) * alphas
+ alphas = np.stack([alphas] * len(vector_fields), 0)
+ else:
+ alphas = np.array(alphas)
+
+ subsample = min(W, H) // subsample
+ offset_x = ((W % subsample) + subsample) // 2
+
+ samples_x = np.arange(offset_x, W, subsample)
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
+
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
+
+ for i in range(len(axes)):
+ # vector field of shape (2, H, W) with vectors of norm == 1
+ vector_field = vector_fields[i]
+
+ a = alphas[i][samples_y][:, samples_x]
+ x, y = vector_field[:, samples_y][:, :, samples_x]
+
+ c = cmap
+ if not isinstance(cmap, str):
+ c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
+
+ s = scale * min(H, W)
+ arrows = axes[i].quiver(
+ x_grid,
+ y_grid,
+ x,
+ y,
+ scale=s,
+ scale_units="width" if H > W else "height",
+ units="width" if H > W else "height",
+ alpha=a,
+ color=c,
+ angles="xy",
+ antialiased=True,
+ width=lw,
+ headaxislength=3.5,
+ zorder=5,
+ )
+
+ artists.append(arrows)
+
+ return artists
+
+
+def plot_latitudes(
+ latitude,
+ is_radians=True,
+ vmin=-90,
+ vmax=90,
+ cmap="seismic",
+ contours_every=15,
+ alpha=0.4,
+ axes=None,
+ **kwargs,
+):
+ """Plot latitudes.
+
+ Args:
+ latitude (List[torch.Tensor]): List of latitudes.
+ is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
+ vmin (int, optional): Min value to clip to. Defaults to -90.
+ vmax (int, optional): Max value to clip to. Defaults to 90.
+ cmap (str, optional): Colormap. Defaults to "seismic".
+ contours_every (int, optional): Contours every. Defaults to 15.
+ alpha (float, optional): Alpha value. Defaults to 0.4.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
+ lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
+ return plot_heatmaps(
+ lat,
+ vmin=vmin,
+ vmax=vmax,
+ cmap=cmap,
+ a=alpha,
+ axes=axes,
+ contours_every=contours_every,
+ **kwargs,
+ )
+
+
+def plot_perspective_fields(cameras, gravities, axes=None, **kwargs):
+ """Plot perspective fields.
+
+ Args:
+ cameras (List[Camera]): List of cameras.
+ gravities (List[Gravity]): List of gravities.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}"
+
+ artists = []
+ for i in range(len(axes)):
+ up, lat = get_perspective_field(cameras[i], gravities[i])
+ artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs)
+ artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs)
+
+ return artists
+
+
+def plot_confidences(
+ confidence,
+ as_log=True,
+ vmin=-4,
+ vmax=0,
+ cmap="turbo",
+ alpha=0.4,
+ axes=None,
+ **kwargs,
+):
+ """Plot confidences.
+
+ Args:
+ confidence (List[torch.Tensor]): Confidence maps.
+ as_log (bool, optional): Whether to plot in log scale. Defaults to True.
+ vmin (int, optional): Min value to clip to. Defaults to -4.
+ vmax (int, optional): Max value to clip to. Defaults to 0.
+ cmap (str, optional): Colormap. Defaults to "turbo".
+ alpha (float, optional): Alpha value. Defaults to 0.4.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
+
+ if as_log:
+ confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
+
+ # normalize to [0, 1]
+ confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
+ return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
diff --git a/gradio_app.py b/gradio_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..11ba5653adf5bccc464e8334634b008e0ede1074
--- /dev/null
+++ b/gradio_app.py
@@ -0,0 +1,228 @@
+"""Gradio app for GeoCalib inference."""
+
+from copy import deepcopy
+from time import time
+
+import gradio as gr
+import numpy as np
+import spaces
+import torch
+
+from geocalib import viz2d
+from geocalib.camera import camera_models
+from geocalib.extractor import GeoCalib
+from geocalib.perspective_fields import get_perspective_field
+from geocalib.utils import rad2deg
+
+# flake8: noqa
+# mypy: ignore-errors
+
+description = """
+
+
GeoCalib 📸
Single-image Calibration with Geometric Optimization
+
+ Alexander Veicht
+ ·
+ Paul-Edouard Sarlin
+ ·
+ Philipp Lindenberger
+ ·
+ Marc Pollefeys
+
+
+
ECCV 2024
+ Paper |
+ Code |
+ Colab
+
+
+
+## Getting Started
+GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image by
+combining geometric optimization with deep learning.
+
+To get started, upload an image or select one of the examples below.
+You can choose between different camera models and visualize the calibration results.
+
+"""
+
+example_images = [
+ ["assets/pinhole-church.jpg"],
+ ["assets/pinhole-garden.jpg"],
+ ["assets/fisheye-skyline.jpg"],
+ ["assets/fisheye-dog-pool.jpg"],
+]
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+model = GeoCalib().to(device)
+
+
+def format_output(results):
+ camera, gravity = results["camera"], results["gravity"]
+ vfov = rad2deg(camera.vfov)
+ roll, pitch = rad2deg(gravity.rp).unbind(-1)
+
+ txt = "Estimated parameters:\n"
+ txt += f"Roll: {roll.item():.2f}° (± {rad2deg(results['roll_uncertainty']).item():.2f})°\n"
+ txt += f"Pitch: {pitch.item():.2f}° (± {rad2deg(results['pitch_uncertainty']).item():.2f})°\n"
+ txt += f"vFoV: {vfov.item():.2f}° (± {rad2deg(results['vfov_uncertainty']).item():.2f})°\n"
+ txt += (
+ f"Focal: {camera.f[0, 1].item():.2f} px (± {results['focal_uncertainty'].item():.2f} px)\n"
+ )
+ if hasattr(camera, "k1"):
+ txt += f"K1: {camera.k1[0].item():.2f}\n"
+ return txt
+
+
+@spaces.GPU(duration=10)
+def inference(img, camera_model):
+ out = model.calibrate(img.to(device), camera_model=camera_model)
+ save_keys = ["camera", "gravity"] + [f"{k}_uncertainty" for k in ["roll", "pitch", "vfov"]]
+ res = {k: v.cpu() for k, v in out.items() if k in save_keys}
+ # not converting to numpy results in gpu abort
+ res["up_confidence"] = out["up_confidence"].cpu().numpy()
+ res["latitude_confidence"] = out["latitude_confidence"].cpu().numpy()
+ return res
+
+
+def process_results(
+ image_path,
+ camera_model,
+ plot_up,
+ plot_up_confidence,
+ plot_latitude,
+ plot_latitude_confidence,
+ plot_undistort,
+):
+ """Process the image and return the calibration results."""
+
+ if image_path is None:
+ raise gr.Error("Please upload an image first.")
+
+ img = model.load_image(image_path)
+ print("Running inference...")
+ start = time()
+ inference_result = inference(img, camera_model)
+ print(f"Done ({time() - start:.2f}s)")
+ inference_result["image"] = img.cpu()
+
+ if inference_result is None:
+ return ("", np.ones((128, 256, 3)), None)
+
+ plot_img = update_plot(
+ inference_result,
+ plot_up,
+ plot_up_confidence,
+ plot_latitude,
+ plot_latitude_confidence,
+ plot_undistort,
+ )
+
+ return format_output(inference_result), plot_img, inference_result
+
+
+def update_plot(
+ inference_result,
+ plot_up,
+ plot_up_confidence,
+ plot_latitude,
+ plot_latitude_confidence,
+ plot_undistort,
+):
+ """Update the plot based on the selected options."""
+ if inference_result is None:
+ gr.Error("Please calibrate an image first.")
+ return np.ones((128, 256, 3))
+
+ camera, gravity = inference_result["camera"], inference_result["gravity"]
+ img = inference_result["image"].permute(1, 2, 0).numpy()
+
+ if plot_undistort:
+ if not hasattr(camera, "k1"):
+ return img
+
+ return camera.undistort_image(inference_result["image"][None])[0].permute(1, 2, 0).numpy()
+
+ up, lat = get_perspective_field(camera, gravity)
+
+ fig = viz2d.plot_images([img], pad=0)
+ ax = fig.get_axes()
+
+ if plot_up:
+ viz2d.plot_vector_fields([up[0]], axes=[ax[0]])
+
+ if plot_latitude:
+ viz2d.plot_latitudes([lat[0, 0]], axes=[ax[0]])
+
+ if plot_up_confidence:
+ viz2d.plot_confidences([inference_result["up_confidence"][0]], axes=[ax[0]])
+
+ if plot_latitude_confidence:
+ viz2d.plot_confidences([inference_result["latitude_confidence"][0]], axes=[ax[0]])
+
+ fig.canvas.draw()
+ img = np.array(fig.canvas.renderer.buffer_rgba())
+
+ return img
+
+
+# Create the Gradio interface
+with gr.Blocks() as demo:
+ gr.Markdown(description)
+ with gr.Row():
+ with gr.Column():
+ gr.Markdown("""## Input Image""")
+ image_path = gr.Image(label="Upload image to calibrate", type="filepath")
+ choice_input = gr.Dropdown(
+ choices=list(camera_models.keys()), label="Choose a camera model.", value="pinhole"
+ )
+ submit_btn = gr.Button("Calibrate 📸")
+ gr.Examples(examples=example_images, inputs=[image_path, choice_input])
+
+ with gr.Column():
+ gr.Markdown("""## Results""")
+ image_output = gr.Image(label="Calibration Results")
+ gr.Markdown("### Plot Options")
+ plot_undistort = gr.Checkbox(
+ label="undistort",
+ value=False,
+ info="Undistorted image "
+ + "(this is only available for models with distortion "
+ + "parameters and will overwrite other options).",
+ )
+
+ with gr.Row():
+ plot_up = gr.Checkbox(label="up-vectors", value=True)
+ plot_up_confidence = gr.Checkbox(label="up confidence", value=False)
+ plot_latitude = gr.Checkbox(label="latitude", value=True)
+ plot_latitude_confidence = gr.Checkbox(label="latitude confidence", value=False)
+
+ gr.Markdown("### Calibration Results")
+ text_output = gr.Textbox(label="Estimated parameters", type="text", lines=5)
+
+ # Define the action when the button is clicked
+ inference_state = gr.State()
+ plot_inputs = [
+ inference_state,
+ plot_up,
+ plot_up_confidence,
+ plot_latitude,
+ plot_latitude_confidence,
+ plot_undistort,
+ ]
+ submit_btn.click(
+ fn=process_results,
+ inputs=[image_path, choice_input] + plot_inputs[1:],
+ outputs=[text_output, image_output, inference_state],
+ )
+
+ # Define the action when the plot checkboxes are clicked
+ plot_up.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
+ plot_up_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
+ plot_latitude.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
+ plot_latitude_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
+ plot_undistort.change(fn=update_plot, inputs=plot_inputs, outputs=image_output)
+
+
+# Launch the app
+demo.launch()
diff --git a/hubconf.py b/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1ad63268c53133b227bdf5cfec046bd81f7ecea
--- /dev/null
+++ b/hubconf.py
@@ -0,0 +1,14 @@
+"""Entrypoint for torch hub."""
+
+dependencies = ["torch", "torchvision", "opencv-python", "kornia", "matplotlib"]
+
+from geocalib import GeoCalib
+
+
+def model(*args, **kwargs):
+ """Pre-trained Geocalib model.
+
+ Args:
+ weights (str): trained variant, "pinhole" (default) or "distorted".
+ """
+ return GeoCalib(*args, **kwargs)
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..10b2c1c5800f42c9067b194bc9a1228d832ac4db
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,49 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "geocalib"
+version = "1.0"
+description = "GeoCalib Inference Package"
+authors = [
+ { name = "Alexander Veicht" },
+ { name = "Paul-Edouard Sarlin" },
+ { name = "Philipp Lindenberger" },
+]
+readme = "README.md"
+requires-python = ">=3.9"
+license = { file = "LICENSE" }
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+]
+urls = { Repository = "https://github.com/cvg/GeoCalib" }
+
+dynamic = ["dependencies"]
+
+[project.optional-dependencies]
+dev = ["black==23.9.1", "flake8", "isort==5.12.0"]
+
+[tool.setuptools]
+packages = ["geocalib"]
+
+[tool.setuptools.dynamic]
+dependencies = { file = ["requirements.txt"] }
+
+
+[tool.black]
+line-length = 100
+exclude = "(venv/|docs/|third_party/)"
+
+[tool.isort]
+profile = "black"
+line_length = 100
+atomic = true
+
+[tool.flake8]
+max-line-length = 100
+docstring-convention = "google"
+ignore = ["E203", "W503", "E402"]
+exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..ea45682e0ba9634fb03cbca9df33eae28a2c6842
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,5 @@
+torch
+torchvision
+opencv-python
+kornia
+matplotlib
diff --git a/siclib/LICENSE b/siclib/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..797795d81ba8a5d06fefa772bc5b4d0b4bb94dc4
--- /dev/null
+++ b/siclib/LICENSE
@@ -0,0 +1,190 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ Copyright 2024 ETH Zurich
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/siclib/__init__.py b/siclib/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0944cfcf82a485c31ba77867ee3821ae59557ecb
--- /dev/null
+++ b/siclib/__init__.py
@@ -0,0 +1,15 @@
+import logging
+
+formatter = logging.Formatter(
+ fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
+)
+handler = logging.StreamHandler()
+handler.setFormatter(formatter)
+handler.setLevel(logging.INFO)
+
+logger = logging.getLogger(__name__)
+logger.setLevel(logging.INFO)
+logger.addHandler(handler)
+logger.propagate = False
+
+__module_name__ = __name__
diff --git a/siclib/configs/deepcalib.yaml b/siclib/configs/deepcalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..994f19c929d55dffb2cac9d9f972b480457a607a
--- /dev/null
+++ b/siclib/configs/deepcalib.yaml
@@ -0,0 +1,12 @@
+defaults:
+ - data: openpano-radial
+ - train: deepcalib
+ - model: deepcalib
+ - _self_
+
+data:
+ train_batch_size: 32
+ val_batch_size: 32
+ test_batch_size: 32
+ augmentations:
+ name: "deepcalib"
diff --git a/siclib/configs/geocalib-radial.yaml b/siclib/configs/geocalib-radial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4b492f9aac6f19cd1bb59c2cf38eeeaa007d1156
--- /dev/null
+++ b/siclib/configs/geocalib-radial.yaml
@@ -0,0 +1,38 @@
+defaults:
+ - data: openpano-radial
+ - train: geocalib
+ - model: geocalib
+ - _self_
+
+data:
+ # smaller batch size since lm takes more memory
+ train_batch_size: 18
+ val_batch_size: 18
+ test_batch_size: 18
+
+model:
+ optimizer:
+ camera_model: simple_radial
+
+ weights: weights/geocalib.tar
+
+train:
+ lr: 1e-5 # smaller lr since we are fine-tuning
+ num_steps: 200_000 # adapt to see same number of samples as previous training
+
+ lr_schedule:
+ type: SequentialLR
+ on_epoch: false
+ options:
+ # adapt to see same number of samples as previous training
+ milestones: [5_000]
+ schedulers:
+ - type: LinearLR
+ options:
+ start_factor: 1e-3
+ total_iters: 5_000
+ - type: MultiStepLR
+ options:
+ gamma: 0.1
+ # adapt to see same number of samples as previous training
+ milestones: [110_000, 170_000]
diff --git a/siclib/configs/geocalib.yaml b/siclib/configs/geocalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f965a269673bbe907ca60387e7458c7df3c5b4fc
--- /dev/null
+++ b/siclib/configs/geocalib.yaml
@@ -0,0 +1,10 @@
+defaults:
+ - data: openpano
+ - train: geocalib
+ - model: geocalib
+ - _self_
+
+data:
+ train_batch_size: 24
+ val_batch_size: 24
+ test_batch_size: 24
diff --git a/siclib/configs/model/deepcalib.yaml b/siclib/configs/model/deepcalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2d4d0670cfb5d1742029ba3dc32b81eb56b1b57e
--- /dev/null
+++ b/siclib/configs/model/deepcalib.yaml
@@ -0,0 +1,7 @@
+name: networks.deepcalib
+bounds:
+ roll: [-45, 45]
+ # rho = torch.tan(pitch) / torch.tan(vfov / 2) / 2 -> rho in [-1/0.3526, 1/0.0872]
+ rho: [-2.83607487, 2.83607487]
+ vfov: [20, 105]
+ k1_hat: [-0.7, 0.7]
diff --git a/siclib/configs/model/geocalib.yaml b/siclib/configs/model/geocalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..42c06f44b522c47ce3127a4a4bd4ac85d5686768
--- /dev/null
+++ b/siclib/configs/model/geocalib.yaml
@@ -0,0 +1,31 @@
+name: networks.geocalib
+
+ll_enc:
+ name: encoders.low_level_encoder
+
+backbone:
+ name: encoders.mscan
+ weights: weights/mscan_b.pth
+
+perspective_decoder:
+ name: decoders.perspective_decoder
+
+ up_decoder:
+ name: decoders.up_decoder
+ loss_type: l1
+ use_uncertainty_loss: true
+ decoder:
+ name: decoders.light_hamburger
+ predict_uncertainty: true
+
+ latitude_decoder:
+ name: decoders.latitude_decoder
+ loss_type: l1
+ use_uncertainty_loss: true
+ decoder:
+ name: decoders.light_hamburger
+ predict_uncertainty: true
+
+optimizer:
+ name: optimization.lm_optimizer
+ camera_model: pinhole
diff --git a/siclib/configs/train/deepcalib.yaml b/siclib/configs/train/deepcalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f4bd79685287ab0d926c11a709e98a466705b86c
--- /dev/null
+++ b/siclib/configs/train/deepcalib.yaml
@@ -0,0 +1,22 @@
+seed: 0
+num_steps: 20_000
+log_every_iter: 500
+eval_every_iter: 3000
+test_every_epoch: 1
+writer: null
+lr: 1.0e-4
+clip_grad: 1.0
+lr_schedule:
+ type: null
+optimizer: adam
+submodules: []
+median_metrics:
+ - roll_error
+ - pitch_error
+ - vfov_error
+recall_metrics:
+ roll_error: [1, 5, 10]
+ pitch_error: [1, 5, 10]
+ vfov_error: [1, 5, 10]
+
+plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
diff --git a/siclib/configs/train/geocalib.yaml b/siclib/configs/train/geocalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..966f227f685246f7c932f40046c423b0fc6c7bed
--- /dev/null
+++ b/siclib/configs/train/geocalib.yaml
@@ -0,0 +1,50 @@
+seed: 0
+num_steps: 150_000
+
+writer: null
+log_every_iter: 500
+eval_every_iter: 1000
+
+lr: 1e-4
+optimizer: adamw
+clip_grad: 1.0
+best_key: loss/param_total
+
+lr_schedule:
+ type: SequentialLR
+ on_epoch: false
+ options:
+ milestones: [4_000]
+ schedulers:
+ - type: LinearLR
+ options:
+ start_factor: 1e-3
+ total_iters: 4_000
+ - type: MultiStepLR
+ options:
+ gamma: 0.1
+ milestones: [80_000, 130_000]
+
+submodules: []
+
+median_metrics:
+ - roll_error
+ - pitch_error
+ - gravity_error
+ - vfov_error
+ - up_angle_error
+ - latitude_angle_error
+ - up_angle_recall@1
+ - up_angle_recall@5
+ - up_angle_recall@10
+ - latitude_angle_recall@1
+ - latitude_angle_recall@5
+ - latitude_angle_recall@10
+
+recall_metrics:
+ roll_error: [1, 3, 5, 10]
+ pitch_error: [1, 3, 5, 10]
+ gravity_error: [1, 3, 5, 10]
+ vfov_error: [1, 3, 5, 10]
+
+plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"]
diff --git a/siclib/datasets/__init__.py b/siclib/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f660a54e7a1174e83d999cd90c25ce85ac7294b
--- /dev/null
+++ b/siclib/datasets/__init__.py
@@ -0,0 +1,25 @@
+import importlib.util
+
+from siclib.datasets.base_dataset import BaseDataset
+from siclib.utils.tools import get_class
+
+
+def get_dataset(name):
+ import_paths = [name, f"{__name__}.{name}"]
+ for path in import_paths:
+ try:
+ spec = importlib.util.find_spec(path)
+ except ModuleNotFoundError:
+ spec = None
+ if spec is not None:
+ try:
+ return get_class(path, BaseDataset)
+ except AssertionError:
+ mod = __import__(path, fromlist=[""])
+ try:
+ return mod.__main_dataset__
+ except AttributeError as exc:
+ print(exc)
+ continue
+
+ raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]')
diff --git a/siclib/datasets/augmentations.py b/siclib/datasets/augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f203f301d5760763a397e9666c0af80ef39a776
--- /dev/null
+++ b/siclib/datasets/augmentations.py
@@ -0,0 +1,359 @@
+from typing import Union
+
+import albumentations as A
+import cv2
+import numpy as np
+import torch
+from albumentations.pytorch.transforms import ToTensorV2
+from omegaconf import OmegaConf
+
+
+# flake8: noqa
+# mypy: ignore-errors
+class IdentityTransform(A.ImageOnlyTransform):
+ def apply(self, img, **params):
+ return img
+
+ def get_transform_init_args_names(self):
+ return ()
+
+
+class RandomAdditiveShade(A.ImageOnlyTransform):
+ def __init__(
+ self,
+ nb_ellipses=10,
+ transparency_limit=[-0.5, 0.8],
+ kernel_size_limit=[150, 350],
+ always_apply=False,
+ p=0.5,
+ ):
+ super().__init__(always_apply, p)
+ self.nb_ellipses = nb_ellipses
+ self.transparency_limit = transparency_limit
+ self.kernel_size_limit = kernel_size_limit
+
+ def apply(self, img, **params):
+ if img.dtype == np.float32:
+ shaded = self._py_additive_shade(img * 255.0)
+ shaded /= 255.0
+ elif img.dtype == np.uint8:
+ shaded = self._py_additive_shade(img.astype(np.float32))
+ shaded = shaded.astype(np.uint8)
+ else:
+ raise NotImplementedError(f"Data augmentation not available for type: {img.dtype}")
+ return shaded
+
+ def _py_additive_shade(self, img):
+ grayscale = len(img.shape) == 2
+ if grayscale:
+ img = img[None]
+ min_dim = min(img.shape[:2]) / 4
+ mask = np.zeros(img.shape[:2], img.dtype)
+ for i in range(self.nb_ellipses):
+ ax = int(max(np.random.rand() * min_dim, min_dim / 5))
+ ay = int(max(np.random.rand() * min_dim, min_dim / 5))
+ max_rad = max(ax, ay)
+ x = np.random.randint(max_rad, img.shape[1] - max_rad) # center
+ y = np.random.randint(max_rad, img.shape[0] - max_rad)
+ angle = np.random.rand() * 90
+ cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1)
+
+ transparency = np.random.uniform(*self.transparency_limit)
+ ks = np.random.randint(*self.kernel_size_limit)
+ if (ks % 2) == 0: # kernel_size has to be odd
+ ks += 1
+ mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0)
+ shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0)
+ out = np.clip(shaded, 0, 255)
+ if grayscale:
+ out = out.squeeze(0)
+ return out
+
+ def get_transform_init_args_names(self):
+ return "transparency_limit", "kernel_size_limit", "nb_ellipses"
+
+
+def kw(entry: Union[float, dict], n=None, **default):
+ if not isinstance(entry, dict):
+ entry = {"p": entry}
+ entry = OmegaConf.create(entry)
+ if n is not None:
+ entry = default.get(n, entry)
+ return OmegaConf.merge(default, entry)
+
+
+def kwi(entry: Union[float, dict], n=None, **default):
+ conf = kw(entry, n=n, **default)
+ return {k: conf[k] for k in set(default.keys()).union(set(["p"]))}
+
+
+def replay_str(transforms, s="Replay:\n", log_inactive=True):
+ for t in transforms:
+ if "transforms" in t.keys():
+ s = replay_str(t["transforms"], s=s)
+ elif t["applied"] or log_inactive:
+ s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n"
+ return s
+
+
+class BaseAugmentation(object):
+ base_default_conf = {
+ "name": "???",
+ "shuffle": False,
+ "p": 1.0,
+ "verbose": False,
+ "dtype": "uint8", # (byte, float)
+ }
+
+ default_conf = {}
+
+ def __init__(self, conf={}):
+ """Perform some logic and call the _init method of the child model."""
+ default_conf = OmegaConf.merge(
+ OmegaConf.create(self.base_default_conf),
+ OmegaConf.create(self.default_conf),
+ )
+ OmegaConf.set_struct(default_conf, True)
+ if isinstance(conf, dict):
+ conf = OmegaConf.create(conf)
+ self.conf = OmegaConf.merge(default_conf, conf)
+ OmegaConf.set_readonly(self.conf, True)
+ self._init(self.conf)
+
+ self.conf = OmegaConf.merge(self.conf, conf)
+ if self.conf.verbose:
+ self.compose = A.ReplayCompose
+ else:
+ self.compose = A.Compose
+ if self.conf.dtype == "uint8":
+ self.dtype = np.uint8
+ self.preprocess = A.FromFloat(always_apply=True, dtype="uint8")
+ self.postprocess = A.ToFloat(always_apply=True)
+ elif self.conf.dtype == "float32":
+ self.dtype = np.float32
+ self.preprocess = A.ToFloat(always_apply=True)
+ self.postprocess = IdentityTransform()
+ else:
+ raise ValueError(f"Unsupported dtype {self.conf.dtype}")
+ self.to_tensor = ToTensorV2()
+
+ def _init(self, conf):
+ """Child class overwrites this, setting up a list of transforms"""
+ self.transforms = []
+
+ def __call__(self, image, return_tensor=False):
+ """image as HW or HWC"""
+ if isinstance(image, torch.Tensor):
+ image = image.cpu().numpy()
+ data = {"image": image}
+ if image.dtype != self.dtype:
+ data = self.preprocess(**data)
+ transforms = self.transforms
+ if self.conf.shuffle:
+ order = [i for i, _ in enumerate(transforms)]
+ np.random.shuffle(order)
+ transforms = [transforms[i] for i in order]
+ transformed = self.compose(transforms, p=self.conf.p)(**data)
+ if self.conf.verbose:
+ print(replay_str(transformed["replay"]["transforms"]))
+ transformed = self.postprocess(**transformed)
+ if return_tensor:
+ return self.to_tensor(**transformed)["image"]
+ else:
+ return transformed["image"]
+
+
+class IdentityAugmentation(BaseAugmentation):
+ default_conf = {}
+
+ def _init(self, conf):
+ self.transforms = [IdentityTransform(p=1.0)]
+
+
+class DarkAugmentation(BaseAugmentation):
+ default_conf = {"p": 0.75}
+
+ def _init(self, conf):
+ bright_contr = 0.5
+ blur = 0.1
+ random_gamma = 0.1
+ hue = 0.1
+ self.transforms = [
+ A.RandomRain(p=0.2),
+ A.RandomBrightnessContrast(
+ **kw(
+ bright_contr,
+ brightness_limit=(-0.4, 0.0),
+ contrast_limit=(-0.3, 0.0),
+ )
+ ),
+ A.OneOf(
+ [
+ A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")),
+ A.MotionBlur(**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")),
+ A.ISONoise(),
+ A.ImageCompression(),
+ ],
+ **kwi(blur, p=0.1),
+ ),
+ A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))),
+ A.OneOf(
+ [
+ A.Equalize(),
+ A.CLAHE(p=0.2),
+ A.ToGray(),
+ A.ToSepia(p=0.1),
+ A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))),
+ ],
+ p=0.5,
+ ),
+ ]
+
+
+class DefaultAugmentation(BaseAugmentation):
+ default_conf = {"p": 1.0}
+
+ def _init(self, conf):
+ self.transforms = [
+ A.RandomBrightnessContrast(p=0.2),
+ A.HueSaturationValue(p=0.2),
+ A.ToGray(p=0.2),
+ A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
+ A.OneOf(
+ [
+ A.MotionBlur(p=0.2),
+ A.MedianBlur(blur_limit=3, p=0.1),
+ A.Blur(blur_limit=3, p=0.1),
+ ],
+ p=0.2,
+ ),
+ ]
+
+
+class PerspectiveAugmentation(BaseAugmentation):
+ default_conf = {"p": 1.0}
+
+ def _init(self, conf):
+ self.transforms = [
+ A.RandomBrightnessContrast(p=0.2),
+ A.HueSaturationValue(p=0.2),
+ A.ToGray(p=0.2),
+ A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5),
+ A.OneOf(
+ [
+ A.MotionBlur(p=0.2),
+ A.MedianBlur(blur_limit=3, p=0.1),
+ A.Blur(blur_limit=3, p=0.1),
+ ],
+ p=0.2,
+ ),
+ ]
+
+
+class DeepCalibAugmentations(BaseAugmentation):
+ default_conf = {"p": 1.0}
+
+ def _init(self, conf):
+ self.transforms = [
+ A.RandomBrightnessContrast(p=0.5),
+ A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
+ A.Downscale(
+ scale_min=0.5,
+ scale_max=0.95,
+ interpolation=dict(downscale=cv2.INTER_AREA, upscale=cv2.INTER_LINEAR),
+ p=0.5,
+ ),
+ A.Downscale(scale_min=0.5, scale_max=0.95, interpolation=cv2.INTER_LINEAR, p=0.5),
+ A.ImageCompression(quality_lower=20, quality_upper=85, p=1, always_apply=True),
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
+ A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5),
+ A.ToGray(always_apply=False, p=0.2),
+ A.GaussianBlur(blur_limit=(3, 5), sigma_limit=0, p=0.25),
+ A.MotionBlur(blur_limit=5, allow_shifted=True, p=0.25),
+ A.MultiplicativeNoise(multiplier=[0.85, 1.15], elementwise=True, p=0.5),
+ ]
+
+
+class GeoCalibAugmentations(BaseAugmentation):
+ default_conf = {"p": 1.0}
+
+ def _init(self, conf):
+ self.color_transforms = [
+ A.RandomGamma(gamma_limit=(80, 180), p=0.8),
+ A.RandomToneCurve(scale=0.1, p=0.5),
+ A.RandomBrightnessContrast(p=0.5),
+ A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4),
+ A.OneOf([A.ToGray(p=0.1), A.ToSepia(p=0.1), IdentityTransform(p=0.8)], p=1),
+ ]
+
+ self.noise_transforms = [
+ A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75),
+ A.ImageCompression(quality_lower=20, quality_upper=100, p=1),
+ A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5),
+ A.OneOrOther(
+ first=A.Compose(
+ [
+ A.AdvancedBlur(
+ p=1,
+ blur_limit=(3, 7),
+ sigmaX_limit=(0.2, 1.0),
+ sigmaY_limit=(0.2, 1.0),
+ rotate_limit=(-90, 90),
+ beta_limit=(0.5, 8.0),
+ noise_limit=(0.9, 1.1),
+ ),
+ A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
+ ]
+ ),
+ second=A.Compose(
+ [
+ A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)),
+ A.AdvancedBlur(
+ p=1,
+ blur_limit=(3, 7),
+ sigmaX_limit=(0.2, 1.0),
+ sigmaY_limit=(0.2, 1.0),
+ rotate_limit=(-90, 90),
+ beta_limit=(0.5, 8.0),
+ noise_limit=(0.9, 1.1),
+ ),
+ ]
+ ),
+ ),
+ ]
+
+ self.image_transforms = [
+ A.OneOf(
+ [
+ A.Downscale(
+ scale_min=0.5,
+ scale_max=0.99,
+ interpolation=dict(downscale=down, upscale=up),
+ p=1,
+ )
+ for down, up in [
+ (cv2.INTER_AREA, cv2.INTER_LINEAR),
+ (cv2.INTER_LINEAR, cv2.INTER_CUBIC),
+ (cv2.INTER_CUBIC, cv2.INTER_LINEAR),
+ (cv2.INTER_LINEAR, cv2.INTER_AREA),
+ ]
+ ],
+ p=1,
+ )
+ ]
+
+ self.transforms = [
+ *self.color_transforms,
+ *self.noise_transforms,
+ *self.image_transforms,
+ ]
+
+
+augmentations = {
+ "default": DefaultAugmentation,
+ "dark": DarkAugmentation,
+ "perspective": PerspectiveAugmentation,
+ "deepcalib": DeepCalibAugmentations,
+ "geocalib": GeoCalibAugmentations,
+ "identity": IdentityAugmentation,
+}
diff --git a/siclib/datasets/base_dataset.py b/siclib/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..7491b20f4f77338e5800b00f2624c84bb1ca35fb
--- /dev/null
+++ b/siclib/datasets/base_dataset.py
@@ -0,0 +1,218 @@
+"""Base class for dataset.
+
+See mnist.py for an example of dataset.
+"""
+
+import collections
+import logging
+from abc import ABCMeta, abstractmethod
+
+import omegaconf
+import torch
+from omegaconf import OmegaConf
+from torch.utils.data import DataLoader, Sampler, get_worker_info
+from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern
+
+from siclib.utils.tensor import string_classes
+from siclib.utils.tools import set_num_threads, set_seed
+
+logger = logging.getLogger(__name__)
+
+# mypy: ignore-errors
+
+
+class LoopSampler(Sampler):
+ """Infinite sampler that loops over a given number of elements."""
+
+ def __init__(self, loop_size: int, total_size: int = None):
+ """Initialize the sampler.
+
+ Args:
+ loop_size (int): Number of elements to loop over.
+ total_size (int, optional): Total number of elements. Defaults to None.
+ """
+ self.loop_size = loop_size
+ self.total_size = total_size - (total_size % loop_size)
+
+ def __iter__(self):
+ """Return an iterator over the elements."""
+ return (i % self.loop_size for i in range(self.total_size))
+
+ def __len__(self):
+ """Return the number of elements."""
+ return self.total_size
+
+
+def worker_init_fn(i):
+ """Initialize the workers with a different seed."""
+ info = get_worker_info()
+ if hasattr(info.dataset, "conf"):
+ conf = info.dataset.conf
+ set_seed(info.id + conf.seed)
+ set_num_threads(conf.num_threads)
+ else:
+ set_num_threads(1)
+
+
+def collate(batch):
+ """Difference with PyTorch default_collate: it can stack of other objects."""
+ if not isinstance(batch, list): # no batching
+ return batch
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ # out = None
+ if torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, concatenate directly into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ try:
+ _ = elem.untyped_storage()._new_shared(numel)
+ except AttributeError:
+ _ = elem.storage()._new_shared(numel)
+ return torch.stack(batch, dim=0)
+ elif (
+ elem_type.__module__ == "numpy"
+ and elem_type.__name__ != "str_"
+ and elem_type.__name__ != "string_"
+ ):
+ if elem_type.__name__ in ["ndarray", "memmap"]:
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+ return collate([torch.as_tensor(b) for b in batch])
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float64)
+ elif isinstance(elem, int):
+ return torch.tensor(batch)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, collections.abc.Mapping):
+ return {key: collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple
+ return elem_type(*(collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, collections.abc.Sequence):
+ # check to make sure that the elements in batch have consistent size
+ it = iter(batch)
+ elem_size = len(next(it))
+ if any(len(elem) != elem_size for elem in it):
+ raise RuntimeError("each element in list of batch should be of equal size")
+ transposed = zip(*batch)
+ return [collate(samples) for samples in transposed]
+ elif elem is None:
+ return elem
+ else:
+ # try to stack anyway in case the object implements stacking.
+ return torch.stack(batch, 0)
+
+
+class BaseDataset(metaclass=ABCMeta):
+ """Base class for dataset.
+
+ What the dataset model is expect to declare:
+ default_conf: dictionary of the default configuration of the dataset.
+ It overwrites base_default_conf in BaseModel, and it is overwritten by
+ the user-provided configuration passed to __init__.
+ Configurations can be nested.
+
+ _init(self, conf): initialization method, where conf is the final
+ configuration object (also accessible with `self.conf`). Accessing
+ unknown configuration entries will raise an error.
+
+ get_dataset(self, split): method that returns an instance of
+ torch.utils.data.Dataset corresponding to the requested split string,
+ which can be `'train'`, `'val'`, or `'test'`.
+ """
+
+ base_default_conf = {
+ "name": "???",
+ "num_workers": "???",
+ "train_batch_size": "???",
+ "val_batch_size": "???",
+ "test_batch_size": "???",
+ "shuffle_training": True,
+ "batch_size": 1,
+ "num_threads": 1,
+ "seed": 0,
+ "prefetch_factor": 2,
+ }
+ default_conf = {}
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ default_conf = OmegaConf.merge(
+ OmegaConf.create(self.base_default_conf),
+ OmegaConf.create(self.default_conf),
+ )
+ OmegaConf.set_struct(default_conf, True)
+ if isinstance(conf, dict):
+ conf = OmegaConf.create(conf)
+ self.conf = OmegaConf.merge(default_conf, conf)
+ OmegaConf.set_readonly(self.conf, True)
+ logger.info(f"Creating dataset {self.__class__.__name__}")
+ self._init(self.conf)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_dataset(self, split):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False):
+ """Return a data loader for a given split."""
+ assert split in ["train", "val", "test"]
+ dataset = self.get_dataset(split)
+ try:
+ batch_size = self.conf[f"{split}_batch_size"]
+ except omegaconf.MissingMandatoryValue:
+ batch_size = self.conf.batch_size
+ num_workers = self.conf.get("num_workers", batch_size)
+ if distributed:
+ shuffle = False
+ sampler = torch.utils.data.distributed.DistributedSampler(dataset)
+ else:
+ sampler = None
+ if shuffle is None:
+ shuffle = split == "train" and self.conf.shuffle_training
+ return DataLoader(
+ dataset,
+ batch_size=batch_size,
+ shuffle=shuffle,
+ sampler=sampler,
+ pin_memory=pinned,
+ collate_fn=collate,
+ num_workers=num_workers,
+ worker_init_fn=worker_init_fn,
+ prefetch_factor=self.conf.prefetch_factor,
+ )
+
+ def get_overfit_loader(self, split: str):
+ """Return an overfit data loader.
+
+ The training set is composed of a single duplicated batch, while
+ the validation and test sets contain a single copy of this same batch.
+ This is useful to debug a model and make sure that losses and metrics
+ correlate well.
+ """
+ assert split in {"train", "val", "test"}
+ dataset = self.get_dataset("train")
+ sampler = LoopSampler(
+ self.conf.batch_size,
+ len(dataset) if split == "train" else self.conf.batch_size,
+ )
+ num_workers = self.conf.get("num_workers", self.conf.batch_size)
+ return DataLoader(
+ dataset,
+ batch_size=self.conf.batch_size,
+ pin_memory=True,
+ num_workers=num_workers,
+ sampler=sampler,
+ worker_init_fn=worker_init_fn,
+ collate_fn=collate,
+ )
diff --git a/siclib/datasets/configs/openpano-radial.yaml b/siclib/datasets/configs/openpano-radial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..717073904732c8bd9b2076b826f65664a0b9cd8f
--- /dev/null
+++ b/siclib/datasets/configs/openpano-radial.yaml
@@ -0,0 +1,41 @@
+name: openpano_radial
+base_dir: data/openpano
+pano_dir: "${.base_dir}/panoramas"
+images_per_pano: 16
+resize_factor: null
+n_workers: 1
+device: cpu
+overwrite: true
+parameter_dists:
+ roll:
+ type: uniform # uni[-45, 45]
+ options:
+ loc: -0.7853981633974483 # -45 degrees
+ scale: 1.5707963267948966 # 90 degrees
+ pitch:
+ type: uniform # uni[-45, 45]
+ options:
+ loc: -0.7853981633974483 # -45 degrees
+ scale: 1.5707963267948966 # 90 degrees
+ vfov:
+ type: uniform # uni[20, 105]
+ options:
+ loc: 0.3490658503988659 # 20 degrees
+ scale: 1.48352986419518 # 85 degrees
+ k1_hat:
+ type: truncnorm
+ options:
+ a: -4.285714285714286 # corresponds to -0.3
+ b: 4.285714285714286 # corresponds to 0.3
+ loc: 0
+ scale: 0.07
+ resize_factor:
+ type: uniform
+ options:
+ loc: 1.2
+ scale: 0.5
+ shape:
+ type: fix
+ value:
+ - 640
+ - 640
diff --git a/siclib/datasets/configs/openpano.yaml b/siclib/datasets/configs/openpano.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..07e55d33873eb8ab342f31664bb32ca895c4fce7
--- /dev/null
+++ b/siclib/datasets/configs/openpano.yaml
@@ -0,0 +1,34 @@
+name: openpano
+base_dir: data/openpano
+pano_dir: "${.base_dir}/panoramas"
+images_per_pano: 16
+resize_factor: null
+n_workers: 1
+device: cpu
+overwrite: true
+parameter_dists:
+ roll:
+ type: uniform # uni[-45, 45]
+ options:
+ loc: -0.7853981633974483 # -45 degrees
+ scale: 1.5707963267948966 # 90 degrees
+ pitch:
+ type: uniform # uni[-45, 45]
+ options:
+ loc: -0.7853981633974483 # -45 degrees
+ scale: 1.5707963267948966 # 90 degrees
+ vfov:
+ type: uniform # uni[20, 105]
+ options:
+ loc: 0.3490658503988659 # 20 degrees
+ scale: 1.48352986419518 # 85 degrees
+ resize_factor:
+ type: uniform
+ options:
+ loc: 1.2
+ scale: 0.5
+ shape:
+ type: fix
+ value:
+ - 640
+ - 640
diff --git a/siclib/datasets/create_dataset_from_pano.py b/siclib/datasets/create_dataset_from_pano.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7e672e92e005fced0bde624cc9da31abc94fb89
--- /dev/null
+++ b/siclib/datasets/create_dataset_from_pano.py
@@ -0,0 +1,350 @@
+"""Script to create a dataset from panorama images."""
+
+import hashlib
+import logging
+from concurrent import futures
+from pathlib import Path
+
+import hydra
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import scipy
+import torch
+from omegaconf import DictConfig, OmegaConf
+from tqdm import tqdm
+
+from siclib.geometry.camera import camera_models
+from siclib.geometry.gravity import Gravity
+from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg
+from siclib.utils.image import load_image, write_image
+
+logger = logging.getLogger(__name__)
+
+
+# mypy: ignore-errors
+
+
+def max_radius(a, b):
+ """Compute the maximum radius of a Brown distortion model."""
+ discrim = a * a - 4 * b
+ # if torch.isfinite(discrim) and discrim >= 0.0:
+ # discrim = np.sqrt(discrim) - a
+ # if discrim > 0.0:
+ # return 2.0 / discrim
+
+ valid = torch.isfinite(discrim) & (discrim >= 0.0)
+ discrim = torch.sqrt(discrim) - a
+ valid &= discrim > 0.0
+ return 2.0 / torch.where(valid, discrim, 0)
+
+
+def brown_max_radius(k1, k2):
+ """Compute the maximum radius of a Brown distortion model."""
+ # fold the constants from the derivative into a and b
+ a = k1 * 3
+ b = k2 * 5
+ return torch.sqrt(max_radius(a, b))
+
+
+class ParallelProcessor:
+ """Generic parallel processor class."""
+
+ def __init__(self, max_workers):
+ """Init processor and pbars."""
+ self.max_workers = max_workers
+ self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers)
+ self.pbars = {}
+
+ def update_pbar(self, pbar_key):
+ """Update progressbar."""
+ pbar = self.pbars.get(pbar_key)
+ pbar.update(1)
+
+ def submit_tasks(self, task_func, task_args, pbar_key):
+ """Submit tasks."""
+ pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80)
+ self.pbars[pbar_key] = pbar
+
+ def update_pbar(future):
+ self.update_pbar(pbar_key)
+
+ futures = []
+ for args in task_args:
+ future = self.executor.submit(task_func, *args)
+ future.add_done_callback(update_pbar)
+ futures.append(future)
+
+ return futures
+
+ def wait_for_completion(self, futures):
+ """Wait for completion and return results."""
+ results = []
+ for f in futures:
+ results += f.result()
+
+ for key in self.pbars.keys():
+ self.pbars[key].close()
+
+ return results
+
+ def shutdown(self):
+ """Close the executer."""
+ self.executor.shutdown()
+
+
+class DatasetGenerator:
+ """Dataset generator class to create perspective datasets from panoramas."""
+
+ default_conf = {
+ "name": "???",
+ # paths
+ "base_dir": "???",
+ "pano_dir": "${.base_dir}/panoramas",
+ "pano_train": "${.pano_dir}/train",
+ "pano_val": "${.pano_dir}/val",
+ "pano_test": "${.pano_dir}/test",
+ "perspective_dir": "${.base_dir}/${.name}",
+ "perspective_train": "${.perspective_dir}/train",
+ "perspective_val": "${.perspective_dir}/val",
+ "perspective_test": "${.perspective_dir}/test",
+ "train_csv": "${.perspective_dir}/train.csv",
+ "val_csv": "${.perspective_dir}/val.csv",
+ "test_csv": "${.perspective_dir}/test.csv",
+ # data options
+ "camera_model": "pinhole",
+ "parameter_dists": {
+ "roll": {
+ "type": "uniform",
+ "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
+ },
+ "pitch": {
+ "type": "uniform",
+ "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
+ },
+ "vfov": {
+ "type": "uniform",
+ "options": {"loc": deg2rad(20), "scale": deg2rad(85)}, # in [20, 105]
+ },
+ "resize_factor": {
+ "type": "uniform",
+ "options": {"loc": 1.0, "scale": 1.0}, # factor in [1.0, 2.0]
+ },
+ "shape": {"type": "fix", "value": (640, 640)},
+ },
+ "images_per_pano": 16,
+ "n_workers": 10,
+ "device": "cpu",
+ "overwrite": False,
+ }
+
+ def __init__(self, conf):
+ """Init the class by merging and storing the config."""
+ self.conf = OmegaConf.merge(
+ OmegaConf.create(self.default_conf),
+ OmegaConf.create(conf),
+ )
+ logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}")
+
+ self.infos = {}
+ self.device = self.conf.device
+
+ self.camera_model = camera_models[self.conf.camera_model]
+
+ def sample_value(self, parameter_name, seed=None):
+ """Sample a value from the specified distribution."""
+ param_conf = self.conf["parameter_dists"][parameter_name]
+
+ if param_conf.type == "fix":
+ return torch.tensor(param_conf.value)
+
+ # fix seed for reproducibility
+ generator = None
+ if seed:
+ if not isinstance(seed, (int, float)):
+ seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32)
+ generator = np.random.default_rng(seed)
+
+ sampler = getattr(scipy.stats, param_conf.type)
+ return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options))
+
+ def plot_distributions(self):
+ """Plot parameter distributions."""
+ fig, ax = plt.subplots(3, 3, figsize=(15, 10))
+ for i, split in enumerate(["train", "val", "test"]):
+ roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]]
+ ax[i, 0].hist(roll_vals, bins=100)
+ ax[i, 0].set_xlabel("Roll (°)")
+ ax[i, 0].set_ylabel(f"Count {split}")
+
+ pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]]
+ ax[i, 1].hist(pitch_vals, bins=100)
+ ax[i, 1].set_xlabel("Pitch (°)")
+ ax[i, 1].set_ylabel(f"Count {split}")
+
+ vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]]
+ ax[i, 2].hist(vfov_vals, bins=100)
+ ax[i, 2].set_xlabel("vFoV (°)")
+ ax[i, 2].set_ylabel(f"Count {split}")
+
+ plt.tight_layout()
+ plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf")
+
+ fig, ax = plt.subplots(3, 3, figsize=(15, 10))
+ for i, k1 in enumerate(["roll", "pitch", "vfov"]):
+ for j, k2 in enumerate(["roll", "pitch", "vfov"]):
+ ax[i, j].scatter(
+ [rad2deg(row[k1]) for row in self.infos["train"]],
+ [rad2deg(row[k2]) for row in self.infos["train"]],
+ s=1,
+ label="train",
+ )
+
+ ax[i, j].scatter(
+ [rad2deg(row[k1]) for row in self.infos["val"]],
+ [rad2deg(row[k2]) for row in self.infos["val"]],
+ s=1,
+ label="val",
+ )
+
+ ax[i, j].scatter(
+ [rad2deg(row[k1]) for row in self.infos["test"]],
+ [rad2deg(row[k2]) for row in self.infos["test"]],
+ s=1,
+ label="test",
+ )
+
+ ax[i, j].set_xlabel(k1)
+ ax[i, j].set_ylabel(k2)
+ ax[i, j].legend()
+
+ plt.tight_layout()
+ plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf")
+
+ def generate_images_from_pano(self, pano_path: Path, out_dir: Path):
+ """Generate perspective images from a single panorama."""
+ infos = []
+
+ pano = load_image(pano_path).to(self.device)
+
+ yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False)
+ params = {
+ k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws]
+ for k in self.conf.parameter_dists
+ if k != "shape"
+ }
+ shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws]
+ params |= {
+ "height": [shape[0] for shape in shapes],
+ "width": [shape[1] for shape in shapes],
+ }
+
+ if "k1_hat" in params:
+ height = torch.tensor(params["height"])
+ width = torch.tensor(params["width"])
+ k1_hat = torch.tensor(params["k1_hat"])
+ vfov = torch.tensor(params["vfov"])
+ focal = fov2focal(vfov, height)
+ focal = focal
+ rel_focal = focal / height
+ k1 = k1_hat * rel_focal
+
+ # distance to image corner
+ # r_max_im = f_px * r_max * (1 + k1*r_max**2)
+ # function of r_max_im: f_px = r_max_im / (r_max * (1 + k1*r_max**2))
+ min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2)
+ r_max = brown_max_radius(k1=k1, k2=0)
+ lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2))
+ valid = lowest_possible_f_px <= focal
+
+ f = torch.where(valid, focal, lowest_possible_f_px)
+ vfov = focal2fov(f, height)
+
+ params["vfov"] = vfov
+ params |= {"k1": k1}
+
+ cam = self.camera_model.from_dict(params).float().to(self.device)
+ gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device)
+
+ if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite:
+ for i in range(self.conf.images_per_pano):
+ perspective_name = f"{pano_path.stem}_{i}.jpg"
+ info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
+ infos.append(info)
+
+ logger.info(f"Perspectives for {pano_path.stem} already exist.")
+
+ return infos
+
+ perspective_images = cam.get_img_from_pano(
+ pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"]
+ )
+
+ for i, perspective_image in enumerate(perspective_images):
+ perspective_name = f"{pano_path.stem}_{i}.jpg"
+
+ n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1]
+ valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01
+ if not valid:
+ logger.debug(f"Perspective {perspective_name} has too many black pixels.")
+ continue
+
+ write_image(perspective_image, out_dir / perspective_name)
+
+ info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
+ infos.append(info)
+
+ return infos
+
+ def generate_split(self, split: str, parallel_processor: ParallelProcessor):
+ """Generate a single split of a dataset."""
+ self.infos[split] = []
+ panorama_paths = [
+ path
+ for path in Path(self.conf[f"pano_{split}"]).glob("*")
+ if not path.name.startswith(".")
+ ]
+
+ out_dir = Path(self.conf[f"perspective_{split}"])
+ logger.info(f"Writing perspective images to {str(out_dir)}")
+ if not out_dir.exists():
+ out_dir.mkdir(parents=True)
+
+ futures = parallel_processor.submit_tasks(
+ self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split
+ )
+ self.infos[split] = parallel_processor.wait_for_completion(futures)
+ # parallel_processor.shutdown()
+
+ metadata = pd.DataFrame(data=self.infos[split])
+ metadata.to_csv(self.conf[f"{split}_csv"])
+
+ def generate_dataset(self):
+ """Generate all splits of a dataset."""
+ out_dir = Path(self.conf.perspective_dir)
+ if not out_dir.exists():
+ out_dir.mkdir(parents=True)
+
+ OmegaConf.save(self.conf, out_dir / "config.yaml")
+
+ processor = ParallelProcessor(self.conf.n_workers)
+ for split in ["train", "val", "test"]:
+ self.generate_split(split=split, parallel_processor=processor)
+
+ processor.shutdown()
+
+ for split in ["train", "val", "test"]:
+ logger.info(f"Generated {len(self.infos[split])} {split} images.")
+
+ self.plot_distributions()
+
+
+@hydra.main(version_base=None, config_path="configs", config_name="SUN360")
+def main(cfg: DictConfig) -> None:
+ """Run dataset generation."""
+ generator = DatasetGenerator(conf=cfg)
+ generator.generate_dataset()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/siclib/datasets/simple_dataset.py b/siclib/datasets/simple_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb2ac7d4b1116f5967ae9ee0e62bab7ce83b7c92
--- /dev/null
+++ b/siclib/datasets/simple_dataset.py
@@ -0,0 +1,237 @@
+"""Dataset for images created with 'create_dataset_from_pano.py'."""
+
+import logging
+from pathlib import Path
+from typing import Any, Dict, List, Tuple
+
+import pandas as pd
+import torch
+from omegaconf import DictConfig
+
+from siclib.datasets.augmentations import IdentityAugmentation, augmentations
+from siclib.datasets.base_dataset import BaseDataset
+from siclib.geometry.camera import SimpleRadial
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.perspective_fields import get_perspective_field
+from siclib.utils.conversions import fov2focal
+from siclib.utils.image import ImagePreprocessor, load_image
+from siclib.utils.tools import fork_rng
+
+logger = logging.getLogger(__name__)
+
+# mypy: ignore-errors
+
+
+def load_csv(
+ csv_file: Path, img_root: Path
+) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]:
+ """Load a CSV file containing image information.
+
+ Args:
+ csv_file (str): Path to the CSV file.
+ img_root (str): Path to the root directory containing the images.
+
+ Returns:
+ list: List of dictionaries containing the image paths and camera parameters.
+ """
+ df = pd.read_csv(csv_file)
+
+ infos, params, gravity = [], [], []
+ for _, row in df.iterrows():
+ h = row["height"]
+ w = row["width"]
+ px = row.get("px", w / 2)
+ py = row.get("py", h / 2)
+ vfov = row["vfov"]
+ f = fov2focal(torch.tensor(vfov), h)
+ k1 = row.get("k1", 0)
+ k2 = row.get("k2", 0)
+ params.append(torch.tensor([w, h, f, f, px, py, k1, k2]))
+
+ roll = row["roll"]
+ pitch = row["pitch"]
+ gravity.append(torch.tensor([roll, pitch]))
+
+ infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])})
+
+ params = torch.stack(params).float()
+ gravity = torch.stack(gravity).float()
+ return infos, params, gravity
+
+
+class SimpleDataset(BaseDataset):
+ """Dataset for images created with 'create_dataset_from_pano.py'."""
+
+ default_conf = {
+ # paths
+ "dataset_dir": "???",
+ "train_img_dir": "${.dataset_dir}/train",
+ "val_img_dir": "${.dataset_dir}/val",
+ "test_img_dir": "${.dataset_dir}/test",
+ "train_csv": "${.dataset_dir}/train.csv",
+ "val_csv": "${.dataset_dir}/val.csv",
+ "test_csv": "${.dataset_dir}/test.csv",
+ # data options
+ "use_up": True,
+ "use_latitude": True,
+ "use_prior_focal": False,
+ "use_prior_gravity": False,
+ "use_prior_k1": False,
+ # image options
+ "grayscale": False,
+ "preprocessing": ImagePreprocessor.default_conf,
+ "augmentations": {"name": "geocalib", "verbose": False},
+ "p_rotate": 0.0, # probability to rotate image by +/- 90°
+ "reseed": False,
+ "seed": 0,
+ # data loader options
+ "num_workers": 8,
+ "prefetch_factor": 2,
+ "train_batch_size": 32,
+ "val_batch_size": 32,
+ "test_batch_size": 32,
+ }
+
+ def _init(self, conf):
+ pass
+
+ def get_dataset(self, split: str) -> torch.utils.data.Dataset:
+ """Return a dataset for a given split."""
+ return _SimpleDataset(self.conf, split)
+
+
+class _SimpleDataset(torch.utils.data.Dataset):
+ """Dataset for dataset for images created with 'create_dataset_from_pano.py'."""
+
+ def __init__(self, conf: DictConfig, split: str):
+ """Initialize the dataset."""
+ self.conf = conf
+ self.split = split
+ self.img_dir = Path(conf.get(f"{split}_img_dir"))
+
+ self.preprocessor = ImagePreprocessor(conf.preprocessing)
+
+ # load image information
+ assert f"{split}_csv" in conf, f"Missing {split}_csv in conf"
+ infos_path = self.conf.get(f"{split}_csv")
+ self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir)
+
+ # define augmentations
+ aug_name = conf.augmentations.name
+ assert (
+ aug_name in augmentations.keys()
+ ), f'{aug_name} not in {" ".join(augmentations.keys())}'
+
+ if self.split == "train":
+ self.augmentation = augmentations[aug_name](conf.augmentations)
+ else:
+ self.augmentation = IdentityAugmentation()
+
+ def __len__(self):
+ return len(self.infos)
+
+ def __getitem__(self, idx):
+ if not self.conf.reseed:
+ return self.getitem(idx)
+ with fork_rng(self.conf.seed + idx, False):
+ return self.getitem(idx)
+
+ def _read_image(
+ self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor
+ ) -> Dict[str, Any]:
+ path = Path(str(infos["file_name"]))
+
+ # load image as uint8 and HWC for augmentation
+ image = load_image(path, self.conf.grayscale, return_tensor=False)
+ image = self.augmentation(image, return_tensor=True)
+
+ # create radial camera -> same as pinhole if k1 = 0
+ camera = SimpleRadial(parameters[None]).float()
+
+ roll, pitch = gravity[None].unbind(-1)
+ gravity = Gravity.from_rp(roll, pitch)
+
+ # preprocess
+ data = self.preprocessor(image)
+ camera = camera.scale(data["scales"])
+ camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera
+
+ priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {}
+ priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {}
+ priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {}
+ return {
+ "name": infos["name"],
+ "path": str(path),
+ "camera": camera[0],
+ "gravity": gravity[0],
+ **priors,
+ **data,
+ }
+
+ def _get_perspective(self, data):
+ """Get perspective field."""
+ camera = data["camera"]
+ gravity = data["gravity"]
+
+ up_field, lat_field = get_perspective_field(
+ camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude
+ )
+
+ out = {}
+ if self.conf.use_up:
+ out["up_field"] = up_field[0]
+ if self.conf.use_latitude:
+ out["latitude_field"] = lat_field[0]
+
+ return out
+
+ def getitem(self, idx: int):
+ """Return a sample from the dataset."""
+ infos = self.infos[idx]
+ parameters = self.parameters[idx]
+ gravity = self.gravity[idx]
+ data = self._read_image(infos, parameters, gravity)
+
+ if self.conf.use_up or self.conf.use_latitude:
+ data |= self._get_perspective(data)
+
+ return data
+
+
+if __name__ == "__main__":
+ # Create a dump of the dataset
+ import argparse
+
+ import matplotlib.pyplot as plt
+
+ from siclib.visualization.visualize_batch import make_perspective_figures
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--name", type=str, required=True)
+ parser.add_argument("--data_dir", type=str)
+ parser.add_argument("--split", type=str, default="train")
+ parser.add_argument("--shuffle", action="store_true")
+ parser.add_argument("--n_rows", type=int, default=4)
+ parser.add_argument("--dpi", type=int, default=100)
+ args = parser.parse_intermixed_args()
+
+ dconf = SimpleDataset.default_conf
+ dconf["name"] = args.name
+ dconf["num_workers"] = 0
+ dconf["prefetch_factor"] = None
+
+ dconf["dataset_dir"] = args.data_dir
+ dconf[f"{args.split}_batch_size"] = args.n_rows
+
+ torch.set_grad_enabled(False)
+
+ dataset = SimpleDataset(dconf)
+ loader = dataset.get_data_loader(args.split, args.shuffle)
+
+ with fork_rng(seed=42):
+ for data in loader:
+ pred = data
+ break
+ fig = make_perspective_figures(pred, data, n_pairs=args.n_rows)
+
+ plt.show()
diff --git a/siclib/datasets/utils/__init__.py b/siclib/datasets/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/datasets/utils/align_megadepth.py b/siclib/datasets/utils/align_megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..123e456bd3603f7187785153733bd843b782d53b
--- /dev/null
+++ b/siclib/datasets/utils/align_megadepth.py
@@ -0,0 +1,41 @@
+import argparse
+import subprocess
+from pathlib import Path
+
+# flake8: noqa
+# mypy: ignore-errors
+
+parser = argparse.ArgumentParser(description="Aligns a COLMAP model and plots the horizon lines")
+parser.add_argument(
+ "--base_dir", type=str, help="Path to the base directory of the MegaDepth dataset"
+)
+parser.add_argument("--out_dir", type=str, help="Path to the output directory")
+args = parser.parse_args()
+
+base_dir = Path(args.base_dir)
+out_dir = Path(args.out_dir)
+
+scenes = [d.name for d in base_dir.iterdir() if d.is_dir()]
+print(scenes[:3], len(scenes))
+
+# exit()
+
+for scene in scenes:
+ image_dir = base_dir / scene / "images"
+ sfm_dir = base_dir / scene / "sparse" / "manhattan" / "0"
+
+ # Align model
+ align_dir = out_dir / scene / "sparse" / "align"
+ align_dir.mkdir(exist_ok=True, parents=True)
+
+ print(f"image_dir ({image_dir.exists()}): {image_dir}")
+ print(f"sfm_dir ({sfm_dir.exists()}): {sfm_dir}")
+ print(f"align_dir ({align_dir.exists()}): {align_dir}")
+
+ cmd = (
+ "colmap model_orientation_aligner "
+ + f"--image_path {image_dir} "
+ + f"--input_path {sfm_dir} "
+ + f"--output_path {str(align_dir)}"
+ )
+ subprocess.run(cmd, shell=True)
diff --git a/siclib/datasets/utils/download_openpano.py b/siclib/datasets/utils/download_openpano.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff751da36313d74fd268f7fa057c6a0bcb6c611d
--- /dev/null
+++ b/siclib/datasets/utils/download_openpano.py
@@ -0,0 +1,75 @@
+"""Helper script to download and extract OpenPano dataset."""
+
+import argparse
+import shutil
+from pathlib import Path
+
+import torch
+from tqdm import tqdm
+
+from siclib import logger
+
+PANO_URL = "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/openpano.zip"
+
+
+def download_and_extract_dataset(name: str, url: Path, output: Path) -> None:
+ """Download and extract a dataset from a URL."""
+ dataset_dir = output / name
+ if not output.exists():
+ output.mkdir(parents=True)
+
+ if dataset_dir.exists():
+ logger.info(f"Dataset {name} already exists at {dataset_dir}, skipping download.")
+ return
+
+ zip_file = output / f"{name}.zip"
+
+ if not zip_file.exists():
+ logger.info(f"Downloading dataset {name} to {zip_file} from {url}.")
+ torch.hub.download_url_to_file(url, zip_file)
+
+ logger.info(f"Extracting dataset {name} in {output}.")
+ shutil.unpack_archive(zip_file, output, format="zip")
+ zip_file.unlink()
+
+
+def main():
+ """Prepare the OpenPano dataset."""
+ parser = argparse.ArgumentParser(description="Download and extract OpenPano dataset.")
+ parser.add_argument("--name", type=str, default="openpano", help="Name of the dataset.")
+ parser.add_argument(
+ "--laval_dir", type=str, default="data/laval-tonemap", help="Path the Laval dataset."
+ )
+
+ args = parser.parse_args()
+
+ out_dir = Path("data")
+ download_and_extract_dataset(args.name, PANO_URL, out_dir)
+
+ pano_dir = out_dir / args.name / "panoramas"
+ for split in ["train", "test", "val"]:
+ with open(pano_dir / f"{split}_panos.txt", "r") as f:
+ pano_list = f.readlines()
+ pano_list = [fname.strip() for fname in pano_list]
+
+ for fname in tqdm(pano_list, ncols=80, desc=f"Copying {split} panoramas"):
+ laval_path = Path(args.laval_dir) / fname
+ target_path = pano_dir / split / fname
+
+ # pano either exists in laval or is in split
+ if target_path.exists():
+ continue
+
+ if laval_path.exists():
+ shutil.copy(laval_path, target_path)
+ else: # not in laval and not in split
+ logger.warning(f"Panorama {fname} not found in {args.laval_dir} or {split} split.")
+
+ n_train = len(list(pano_dir.glob("train/*.jpg")))
+ n_test = len(list(pano_dir.glob("test/*.jpg")))
+ n_val = len(list(pano_dir.glob("val/*.jpg")))
+ logger.info(f"{args.name} contains {n_train}/{n_test}/{n_val} train/test/val panoramas.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/siclib/datasets/utils/tonemapping.py b/siclib/datasets/utils/tonemapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da1740724f003900771b099ef8f2de00deb8ea4
--- /dev/null
+++ b/siclib/datasets/utils/tonemapping.py
@@ -0,0 +1,316 @@
+import os
+
+os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
+import argparse
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from matplotlib import colors
+from tqdm import tqdm
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class tonemap:
+ def __init__(self):
+ pass
+
+ def process(self, img):
+ return img
+
+ def inv_process(self, img):
+ return img
+
+
+# Log correction
+class log_tonemap(tonemap):
+ # Constructor
+ # Base of log
+ # Scale of tonemapped
+ # Offset
+ def __init__(self, base, scale=1, offset=1):
+ self.base = base
+ self.scale = scale
+ self.offset = offset
+
+ def process(self, img):
+ tonemapped = (np.log(img + self.offset) / np.log(self.base)) * self.scale
+ return tonemapped
+
+ def inv_process(self, img):
+ inverse_tonemapped = np.power(self.base, (img) / self.scale) - self.offset
+ return inverse_tonemapped
+
+
+class log_tonemap_clip(tonemap):
+ # Constructor
+ # Base of log
+ # Scale of tonemapped
+ # Offset
+ def __init__(self, base, scale=1, offset=1):
+ self.base = base
+ self.scale = scale
+ self.offset = offset
+
+ def process(self, img):
+ tonemapped = np.clip((np.log(img * self.scale + self.offset) / np.log(self.base)), 0, 2) - 1
+ return tonemapped
+
+ def inv_process(self, img):
+ inverse_tonemapped = (np.power(self.base, (img + 1)) - self.offset) / self.scale
+ return inverse_tonemapped
+
+
+# Gamma Tonemap
+class gamma_tonemap(tonemap):
+ def __init__(
+ self,
+ gamma,
+ ):
+ self.gamma = gamma
+
+ def process(self, img):
+ tonemapped = np.power(img, 1 / self.gamma)
+ return tonemapped
+
+ def inv_process(self, img):
+ inverse_tonemapped = np.power(img, self.gamma)
+ return inverse_tonemapped
+
+
+class linear_clip(tonemap):
+ def __init__(self, scale, mean):
+ self.scale = scale
+ self.mean = mean
+
+ def process(self, img):
+ tonemapped = np.clip((img - self.mean) / self.scale, -1, 1)
+ return tonemapped
+
+ def inv_process(self, img):
+ inverse_tonemapped = img * self.scale + self.mean
+ return inverse_tonemapped
+
+
+def make_tonemap_HDR(opt):
+ if opt.mode == "luminance":
+ res_tonemap = log_tonemap_clip(10, 1.0, 1.0)
+ else: # temperature
+ res_tonemap = linear_clip(5000.0, 5000.0)
+ return res_tonemap
+
+
+class LDRfromHDR:
+ def __init__(
+ self, tonemap="none", orig_scale=False, clip=True, quantization=0, color_jitter=0, noise=0
+ ):
+ self.tonemap_str, val = tonemap
+ if tonemap[0] == "gamma":
+ self.tonemap = gamma_tonemap(val)
+ elif tonemap[0] == "log10":
+ self.tonemap = log_tonemap(val)
+ else:
+ print("Warning: No tonemap specified, using linear")
+
+ self.clip = clip
+ self.orig_scale = orig_scale
+ self.bits = quantization
+ self.jitter = color_jitter
+ self.noise = noise
+
+ self.wbModel = None
+
+ def process(self, HDR):
+ LDR, normalized_scale = self.rescale(HDR)
+ LDR = self.apply_clip(LDR)
+ LDR = self.apply_scale(LDR, normalized_scale)
+ LDR = self.apply_tonemap(LDR)
+ LDR = self.colorJitter(LDR)
+ LDR = self.gaussianNoise(LDR)
+ LDR = self.quantize(LDR)
+ LDR = self.apply_white_balance(LDR)
+ return LDR, normalized_scale
+
+ def rescale(self, img, percentile=90, max_mapping=0.8):
+ r_percentile = np.percentile(img, percentile)
+ alpha = max_mapping / (r_percentile + 1e-10)
+
+ img_reexposed = img * alpha
+
+ normalized_scale = normalizeScale(1 / alpha)
+
+ return img_reexposed, normalized_scale
+
+ def rescaleAlpha(self, img, percentile=90, max_mapping=0.8):
+ r_percentile = np.percentile(img, percentile)
+ alpha = max_mapping / (r_percentile + 1e-10)
+
+ return alpha
+
+ def apply_clip(self, img):
+ if self.clip:
+ img = np.clip(img, 0, 1)
+ return img
+
+ def apply_scale(self, img, scale):
+ if self.orig_scale:
+ scale = unNormalizeScale(scale)
+ img = img * scale
+ return img
+
+ def apply_tonemap(self, img):
+ if self.tonemap_str == "none":
+ return img
+ gammaed = self.tonemap.process(img)
+ return gammaed
+
+ def quantize(self, img):
+ if self.bits == 0:
+ return img
+ max_val = np.power(2, self.bits)
+ img = img * max_val
+ img = np.floor(img)
+ img = img / max_val
+ return img
+
+ def colorJitter(self, img):
+ if self.jitter == 0:
+ return img
+ hsv = colors.rgb_to_hsv(img)
+ hue_offset = np.random.normal(0, self.jitter, 1)
+ hsv[:, :, 0] = (hsv[:, :, 0] + hue_offset) % 1.0
+ rgb = colors.hsv_to_rgb(hsv)
+ return rgb
+
+ def gaussianNoise(self, img):
+ if self.noise == 0:
+ return img
+ noise_amount = np.random.uniform(0, self.noise, 1)
+ noise_img = np.random.normal(0, noise_amount, img.shape)
+ img = img + noise_img
+ img = np.clip(img, 0, 1).astype(np.float32)
+ return img
+
+ def apply_white_balance(self, img):
+ if self.wbModel is None:
+ return img
+ img = self.wbModel.correctImage(img)
+ return img.copy()
+
+
+def make_LDRfromHDR(opt):
+ LDR_from_HDR = LDRfromHDR(
+ opt.tonemap_LDR, opt.orig_scale, opt.clip, opt.quantization, opt.color_jitter, opt.noise
+ )
+ return LDR_from_HDR
+
+
+def torchnormalizeEV(EV, mean=5.12, scale=6, clip=True):
+ # Normalize based on the computed distribution between -1 1
+ EV -= mean
+ EV = EV / scale
+
+ if clip:
+ EV = torch.clip(EV, min=-1, max=1)
+
+ return EV
+
+
+def torchnormalizeEV0(EV, mean=5.12, scale=6, clip=True):
+ # Normalize based on the computed distribution between 0 1
+ EV -= mean
+ EV = EV / scale
+
+ if clip:
+ EV = torch.clip(EV, min=-1, max=1)
+
+ EV += 0.5
+ EV = EV / 2
+
+ return EV
+
+
+def normalizeScale(x, scale=4):
+ x = np.log10(x + 1)
+
+ x = x / (scale / 2)
+ x = x - 1
+
+ return x
+
+
+def unNormalizeScale(x, scale=4):
+ x = x + 1
+ x = x * (scale / 2)
+
+ x = np.power(10, x) - 1
+
+ return x
+
+
+def normalizeIlluminance(x, scale=5):
+ x = np.log10(x + 1)
+
+ x = x / (scale / 2)
+ x = x - 1
+
+ return x
+
+
+def unNormalizeIlluminance(x, scale=5):
+ x = x + 1
+ x = x * (scale / 2)
+
+ x = np.power(10, x) - 1
+
+ return x
+
+
+def main(args):
+ processor = LDRfromHDR(
+ # tonemap=("log10", 10),
+ tonemap=("gamma", args.gamma),
+ orig_scale=False,
+ clip=True,
+ quantization=0,
+ color_jitter=0,
+ noise=0,
+ )
+
+ img_list = list(os.listdir(args.hdr_dir))
+ img_list = [f for f in img_list if f.endswith(args.extension)]
+ img_list = [f for f in img_list if not f.startswith("._")]
+
+ if not os.path.exists(args.out_dir):
+ os.makedirs(args.out_dir)
+
+ for fname in tqdm(img_list):
+ fname_out = ".".join(fname.split(".")[:-1])
+ out = os.path.join(args.out_dir, f"{fname_out}.jpg")
+ if os.path.exists(out) and not args.overwrite:
+ continue
+
+ fpath = os.path.join(args.hdr_dir, fname)
+ img = cv2.imread(fpath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ ldr, scale = processor.process(img)
+
+ ldr = (ldr * 255).astype(np.uint8)
+ ldr = cv2.cvtColor(ldr, cv2.COLOR_RGB2BGR)
+ cv2.imwrite(out, ldr)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--hdr_dir", type=str, default="hdr")
+ parser.add_argument("--out_dir", type=str, default="ldr")
+ parser.add_argument("--extension", type=str, default=".exr")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument("--gamma", type=float, default=2)
+ args = parser.parse_args()
+
+ main(args)
diff --git a/siclib/eval/__init__.py b/siclib/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..550493c7bc54a9d827d8e12010b29a7017970821
--- /dev/null
+++ b/siclib/eval/__init__.py
@@ -0,0 +1,18 @@
+import torch
+
+from siclib.eval.eval_pipeline import EvalPipeline
+from siclib.utils.tools import get_class
+
+
+def get_benchmark(benchmark):
+ return get_class(f"{__name__}.{benchmark}", EvalPipeline)
+
+
+@torch.no_grad()
+def run_benchmark(benchmark, eval_conf, experiment_dir, model=None):
+ """This overwrites existing benchmarks"""
+ experiment_dir.mkdir(exist_ok=True, parents=True)
+ bm = get_benchmark(benchmark)
+
+ pipeline = bm(eval_conf)
+ return pipeline.run(experiment_dir, model=model, overwrite=True, overwrite_eval=True)
diff --git a/siclib/eval/configs/deepcalib.yaml b/siclib/eval/configs/deepcalib.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..2a0173265d7e8ff64110ccff94e90d6d9a31fdf4
--- /dev/null
+++ b/siclib/eval/configs/deepcalib.yaml
@@ -0,0 +1,3 @@
+model:
+ name: networks.deepcalib
+ weights: weights/deepcalib.tar
diff --git a/siclib/eval/configs/geocalib-pinhole.yaml b/siclib/eval/configs/geocalib-pinhole.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..32b036201ae27f38811dc7558c8bf4449ee6630c
--- /dev/null
+++ b/siclib/eval/configs/geocalib-pinhole.yaml
@@ -0,0 +1,2 @@
+model:
+ name: networks.geocalib_pretrained
diff --git a/siclib/eval/configs/geocalib-simple_radial.yaml b/siclib/eval/configs/geocalib-simple_radial.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..6bdeeb4e21cb9b6f3a179d829c2adc3ae7f9f0e5
--- /dev/null
+++ b/siclib/eval/configs/geocalib-simple_radial.yaml
@@ -0,0 +1,4 @@
+model:
+ name: networks.geocalib_pretrained
+ camera_model: simple_radial
+ model_weights: distorted
diff --git a/siclib/eval/configs/uvp.yaml b/siclib/eval/configs/uvp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..887dfde752becf3ec06be35411a68cf9af11eb74
--- /dev/null
+++ b/siclib/eval/configs/uvp.yaml
@@ -0,0 +1,18 @@
+model:
+ name: optimization.vp_from_prior
+ SOLVER_FLAGS: [True, True, True, True, True]
+ magsac_scoring: true
+ min_lines: 5
+ verbose: false
+
+ # RANSAC inlier threshold
+ th_pixels: 3
+
+ # 3 uses the gravity in the LS refinement, 2 does not. Here we use a prior on the gravity, so use 2
+ ls_refinement: 2
+
+ # change to 3 to add a Ceres optimization after the non minimal solver (slower)
+ nms: 1
+
+ # deeplsd, lsd
+ line_type: deeplsd
diff --git a/siclib/eval/eval_pipeline.py b/siclib/eval/eval_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb5e4ada2779bd7b42b14488037b45395131fd4e
--- /dev/null
+++ b/siclib/eval/eval_pipeline.py
@@ -0,0 +1,106 @@
+import json
+
+import h5py
+import numpy as np
+from omegaconf import OmegaConf
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def load_eval(dir):
+ summaries, results = {}, {}
+ with h5py.File(str(dir / "results.h5"), "r") as hfile:
+ for k in hfile.keys():
+ r = np.array(hfile[k])
+ if len(r.shape) < 3:
+ results[k] = r
+ for k, v in hfile.attrs.items():
+ summaries[k] = v
+ with open(dir / "summaries.json", "r") as f:
+ s = json.load(f)
+ summaries = {k: v if v is not None else np.nan for k, v in s.items()}
+ return summaries, results
+
+
+def save_eval(dir, summaries, figures, results):
+ with h5py.File(str(dir / "results.h5"), "w") as hfile:
+ for k, v in results.items():
+ arr = np.array(v)
+ if not np.issubdtype(arr.dtype, np.number):
+ arr = arr.astype("object")
+ hfile.create_dataset(k, data=arr)
+ # just to be safe, not used in practice
+ for k, v in summaries.items():
+ hfile.attrs[k] = v
+ s = {
+ k: float(v) if np.isfinite(v) else None
+ for k, v in summaries.items()
+ if not isinstance(v, list)
+ }
+ s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}}
+ with open(dir / "summaries.json", "w") as f:
+ json.dump(s, f, indent=4)
+
+ for fig_name, fig in figures.items():
+ fig.savefig(dir / f"{fig_name}.png")
+
+
+def exists_eval(dir):
+ return (dir / "results.h5").exists() and (dir / "summaries.json").exists()
+
+
+class EvalPipeline:
+ default_conf = {}
+
+ export_keys = []
+ optional_export_keys = []
+
+ def __init__(self, conf):
+ """Assumes"""
+ self.default_conf = OmegaConf.create(self.default_conf)
+ self.conf = OmegaConf.merge(self.default_conf, conf)
+ self._init(self.conf)
+
+ def _init(self, conf):
+ pass
+
+ @classmethod
+ def get_dataloader(cls, data_conf=None):
+ """Returns a data loader with samples for each eval datapoint"""
+ raise NotImplementedError
+
+ def get_predictions(self, experiment_dir, model=None, overwrite=False):
+ """Export a prediction file for each eval datapoint"""
+ raise NotImplementedError
+
+ def run_eval(self, loader, pred_file):
+ """Run the eval on cached predictions"""
+ raise NotImplementedError
+
+ def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False):
+ """Run export+eval loop"""
+ self.save_conf(experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval)
+ pred_file = self.get_predictions(experiment_dir, model=model, overwrite=overwrite)
+
+ f = {}
+ if not exists_eval(experiment_dir) or overwrite_eval or overwrite:
+ s, f, r = self.run_eval(self.get_dataloader(self.conf.data, 1), pred_file)
+ save_eval(experiment_dir, s, f, r)
+ s, r = load_eval(experiment_dir)
+ return s, f, r
+
+ def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False):
+ # store config
+ conf_output_path = experiment_dir / "conf.yaml"
+ if conf_output_path.exists():
+ saved_conf = OmegaConf.load(conf_output_path)
+ if (saved_conf.data != self.conf.data) or (saved_conf.model != self.conf.model):
+ assert (
+ overwrite
+ ), "configs changed, add --overwrite to rerun experiment with new conf"
+ if saved_conf.eval != self.conf.eval:
+ assert (
+ overwrite or overwrite_eval
+ ), "eval configs changed, add --overwrite_eval to rerun evaluation"
+ OmegaConf.save(self.conf, experiment_dir / "conf.yaml")
diff --git a/siclib/eval/inspect.py b/siclib/eval/inspect.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac00c89c736a4b9bb6a7351b103d120f3745ad43
--- /dev/null
+++ b/siclib/eval/inspect.py
@@ -0,0 +1,62 @@
+import argparse
+from collections import defaultdict
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib
+import matplotlib.pyplot as plt
+
+from siclib.eval import get_benchmark
+from siclib.eval.eval_pipeline import load_eval
+from siclib.settings import EVAL_PATH
+from siclib.visualization.global_frame import GlobalFrame
+from siclib.visualization.two_view_frame import TwoViewFrame
+
+# flake8: noqa
+# mypy: ignore-errors
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("benchmark", type=str)
+ parser.add_argument("--x", type=str, default=None)
+ parser.add_argument("--y", type=str, default=None)
+ parser.add_argument("--backend", type=str, default=None)
+ parser.add_argument("--default_plot", type=str, default=TwoViewFrame.default_conf["default"])
+
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_intermixed_args()
+
+ output_dir = Path(EVAL_PATH, args.benchmark)
+
+ results = {}
+ summaries = defaultdict(dict)
+
+ predictions = {}
+
+ if args.backend:
+ matplotlib.use(args.backend)
+
+ bm = get_benchmark(args.benchmark)
+ loader = bm.get_dataloader()
+
+ for name in args.dotlist:
+ experiment_dir = output_dir / name
+ pred_file = experiment_dir / "predictions.h5"
+ s, results[name] = load_eval(experiment_dir)
+ predictions[name] = pred_file
+ for k, v in s.items():
+ summaries[k][name] = v
+
+ pprint(summaries)
+
+ plt.close("all")
+
+ frame = GlobalFrame(
+ {"child": {"default": args.default_plot}, **vars(args)},
+ results,
+ loader,
+ predictions,
+ child_frame=TwoViewFrame,
+ )
+ frame.draw()
+ plt.show()
diff --git a/siclib/eval/io.py b/siclib/eval/io.py
new file mode 100644
index 0000000000000000000000000000000000000000..276a47c6e2fa58ba4d9980bd8ea5af624da44c10
--- /dev/null
+++ b/siclib/eval/io.py
@@ -0,0 +1,105 @@
+import argparse
+from pathlib import Path
+from pprint import pprint
+from typing import Optional
+
+import pkg_resources
+from hydra import compose, initialize
+from omegaconf import OmegaConf
+
+from siclib.models import get_model
+from siclib.settings import TRAINING_PATH
+from siclib.utils.experiments import load_experiment
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path:
+ default_configs = {}
+ print(f"Looking for default config: {'siclib', str(defaults)}")
+ for c in pkg_resources.resource_listdir("siclib.eval", str(defaults)):
+ if c.endswith(".yaml"):
+ default_configs[Path(c).stem] = Path(
+ pkg_resources.resource_filename("siclib.eval", defaults + c)
+ )
+ if name_or_path is None:
+ return None
+ if name_or_path in default_configs:
+ return default_configs[name_or_path]
+ path = Path(name_or_path)
+ if not path.exists():
+ raise FileNotFoundError(
+ f"Cannot find the config file: {name_or_path}. "
+ f"Not in the default configs {list(default_configs.keys())} "
+ "and not an existing path."
+ )
+ return Path(path)
+
+
+def extract_benchmark_conf(conf, benchmark):
+ mconf = OmegaConf.create({"model": conf.get("model", {})})
+ if "benchmarks" in conf.keys():
+ return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {}))
+ else:
+ return mconf
+
+
+def parse_eval_args(benchmark, args, configs_path, default=None):
+ conf = {"data": {}, "model": {}, "eval": {}}
+
+ if args.conf:
+ print(f"Loading config: {configs_path}")
+ conf_path = parse_config_path(args.conf, configs_path)
+ initialize(version_base=None, config_path=configs_path)
+ custom_conf = compose(config_name=args.conf)
+ conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark)
+ args.tag = args.tag if args.tag is not None else conf_path.name.replace(".yaml", "")
+
+ cli_conf = OmegaConf.from_cli(args.dotlist)
+ conf = OmegaConf.merge(conf, cli_conf)
+ conf.checkpoint = args.checkpoint or conf.get("checkpoint")
+
+ if conf.checkpoint and not conf.checkpoint.endswith(".tar"):
+ checkpoint_conf = OmegaConf.load(TRAINING_PATH / conf.checkpoint / "config.yaml")
+ conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf)
+
+ if default:
+ conf = OmegaConf.merge(default, conf)
+
+ if args.tag is not None:
+ name = args.tag
+ elif args.conf and conf.checkpoint:
+ name = f"{args.conf}_{conf.checkpoint}"
+ elif args.conf:
+ name = args.conf
+ elif conf.checkpoint:
+ name = conf.checkpoint
+ if len(args.dotlist) > 0 and not args.tag:
+ name = f"{name}_" + ":".join(args.dotlist)
+
+ print("Running benchmark:", benchmark)
+ print("Experiment tag:", name)
+ print("Config:")
+ pprint(OmegaConf.to_container(conf))
+ return name, conf
+
+
+def load_model(model_conf, checkpoint, get_last=False):
+ if checkpoint:
+ model = load_experiment(checkpoint, conf=model_conf, get_last=get_last).eval()
+ else:
+ model = get_model(model_conf.name)(model_conf).eval()
+ return model
+
+
+def get_eval_parser():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--tag", type=str, default=None)
+ parser.add_argument("--checkpoint", type=str, default=None)
+ parser.add_argument("--conf", type=str, default=None)
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument("--overwrite_eval", action="store_true")
+ parser.add_argument("--plot", action="store_true")
+ parser.add_argument("dotlist", nargs="*")
+ return parser
diff --git a/siclib/eval/lamar2k.py b/siclib/eval/lamar2k.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d2087ff460bf2ea07ce1ca29069188a512fddd8
--- /dev/null
+++ b/siclib/eval/lamar2k.py
@@ -0,0 +1,89 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class Lamar2k(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/lamar2k",
+ "test_img_dir": "${.dataset_dir}/images",
+ "test_csv": "${.dataset_dir}/images.csv",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/lamar2k.zip",
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(Lamar2k.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = Lamar2k(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/megadepth2k.py b/siclib/eval/megadepth2k.py
new file mode 100644
index 0000000000000000000000000000000000000000..58ee733c376871eb33c6cddd983dcac769db374d
--- /dev/null
+++ b/siclib/eval/megadepth2k.py
@@ -0,0 +1,89 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class Megadepth2k(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/megadepth2k",
+ "test_img_dir": "${.dataset_dir}/images",
+ "test_csv": "${.dataset_dir}/images.csv",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/megadepth2k.zip",
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(Megadepth2k.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = Megadepth2k(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/megadepth2k_radial.py b/siclib/eval/megadepth2k_radial.py
new file mode 100644
index 0000000000000000000000000000000000000000..092d0ff94c2034ca4c0d4fe965fac6490c31e75a
--- /dev/null
+++ b/siclib/eval/megadepth2k_radial.py
@@ -0,0 +1,101 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.eval.utils import download_and_extract_benchmark
+from siclib.geometry.camera import SimpleRadial
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class Megadepth2kRadial(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/megadepth2k-radial",
+ "test_img_dir": "${.dataset_dir}/images",
+ "test_csv": "${.dataset_dir}/images.csv",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/megadepth2k-radial.zip",
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+ def _init(self, conf):
+ self.verbose = conf.eval.verbose
+ self.num_vis = self.conf.eval.num_vis
+
+ self.CameraModel = SimpleRadial
+
+ if conf.url is not None:
+ ds_dir = Path(conf.data.dataset_dir)
+ download_and_extract_benchmark(ds_dir.name, conf.url, ds_dir.parent)
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(Megadepth2kRadial.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = Megadepth2kRadial(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/openpano.py b/siclib/eval/openpano.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a7f5544bba15fb70adec90351f05e12787c4e1
--- /dev/null
+++ b/siclib/eval/openpano.py
@@ -0,0 +1,70 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class OpenPano(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/poly+maps+laval/poly+maps+laval",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": None,
+ }
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(OpenPano.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = OpenPano(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/openpano_radial.py b/siclib/eval/openpano_radial.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9264f045c43e2374cdca63f30cc20c91a4efe6a
--- /dev/null
+++ b/siclib/eval/openpano_radial.py
@@ -0,0 +1,70 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class OpenPanoRadial(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/poly+maps+laval/pano_dataset_distorted",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": None,
+ }
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(OpenPanoRadial.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = OpenPanoRadial(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/run_perceptual.py b/siclib/eval/run_perceptual.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3d26cad2c05e44f925144489e936a3a5a22fa7b
--- /dev/null
+++ b/siclib/eval/run_perceptual.py
@@ -0,0 +1,84 @@
+"""Run the TPAMI 2023 Perceptual model.
+
+Run the model of the paper
+ A Deep Perceptual Measure for Lens and Camera Calibration, TPAMI 2023
+ https://lvsn.github.io/deepcalib/
+through the public Dashboard available at http://rachmaninoff.gel.ulaval.ca:8005.
+"""
+
+import argparse
+import json
+import re
+import time
+from pathlib import Path
+
+from selenium import webdriver
+from selenium.webdriver.common.by import By
+from tqdm import tqdm
+
+# mypy: ignore-errors
+
+JS_DROP_FILES = "var k=arguments,d=k[0],g=k[1],c=k[2],m=d.ownerDocument||document;for(var e=0;;){var f=d.getBoundingClientRect(),b=f.left+(g||(f.width/2)),a=f.top+(c||(f.height/2)),h=m.elementFromPoint(b,a);if(h&&d.contains(h)){break}if(++e>1){var j=new Error('Element not interactable');j.code=15;throw j}d.scrollIntoView({behavior:'instant',block:'center',inline:'center'})}var l=m.createElement('INPUT');l.setAttribute('type','file');l.setAttribute('multiple','');l.setAttribute('style','position:fixed;z-index:2147483647;left:0;top:0;');l.onchange=function(q){l.parentElement.removeChild(l);q.stopPropagation();var r={constructor:DataTransfer,effectAllowed:'all',dropEffect:'none',types:['Files'],files:l.files,setData:function u(){},getData:function o(){},clearData:function s(){},setDragImage:function i(){}};if(window.DataTransferItemList){r.items=Object.setPrototypeOf(Array.prototype.map.call(l.files,function(x){return{constructor:DataTransferItem,kind:'file',type:x.type,getAsFile:function v(){return x},getAsString:function y(A){var z=new FileReader();z.onload=function(B){A(B.target.result)};z.readAsText(x)},webkitGetAsEntry:function w(){return{constructor:FileSystemFileEntry,name:x.name,fullPath:'/'+x.name,isFile:true,isDirectory:false,file:function z(A){A(x)}}}}}),{constructor:DataTransferItemList,add:function t(){},clear:function p(){},remove:function n(){}})}['dragenter','dragover','drop'].forEach(function(v){var w=m.createEvent('DragEvent');w.initMouseEvent(v,true,true,m.defaultView,0,0,0,b,a,false,false,false,false,0,null);Object.setPrototypeOf(w,null);w.dataTransfer=r;Object.setPrototypeOf(w,DragEvent.prototype);h.dispatchEvent(w)})};m.documentElement.appendChild(l);l.getBoundingClientRect();return l" # noqa E501
+
+
+def setup_driver():
+ """Setup the Selenium browser."""
+ options = webdriver.FirefoxOptions()
+ geckodriver_path = "/snap/bin/geckodriver" # specify the path to your geckodriver
+ driver_service = webdriver.FirefoxService(executable_path=geckodriver_path)
+ return webdriver.Firefox(options=options, service=driver_service)
+
+
+def run(args):
+ """Run on an image folder."""
+ driver = setup_driver()
+ driver.get("http://rachmaninoff.gel.ulaval.ca:8005/")
+ time.sleep(5)
+ result_div = driver.find_element(By.ID, "estimated-parameters-display")
+
+ def upload_image(path):
+ path = Path(path).absolute().as_posix()
+ elem = driver.find_element(By.ID, "dash-uploader")
+ inp = driver.execute_script(JS_DROP_FILES, elem, 25, 25)
+ inp._execute("sendKeysToElement", {"value": [path], "text": path})
+
+ def run_image(path, prev_result, timeout_seconds=60):
+ # One main assumption is that subsequent images will have different results
+ # from each other, otherwise we cannot detect that the inference has completed.
+ upload_image(path)
+ started = time.time()
+ while True:
+ result = result_div.text
+ if (result != prev_result) and result:
+ return result
+ prev_result = result
+ if (time.time() - started) > timeout_seconds:
+ raise TimeoutError
+
+ result = str(result_div.text)
+ number = r"(nan|-?\d*\.?\d*)"
+ pattern = re.compile(
+ f"Pitch: {number}° / Roll: {number}° / HFOV : {number}° / Distortion: {number}"
+ )
+
+ paths = sorted(args.images.iterdir())
+ results = {}
+ for path in (pbar := tqdm(paths)):
+ pbar.set_description(path.name)
+ result = run_image(path, result)
+ match = pattern.match(result)
+ if match is None:
+ print("Error, cannot parse:", result, path)
+ continue
+ results[path.name] = tuple(map(float, match.groups()))
+
+ args.results.write_text(json.dumps(results))
+ driver.close()
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("images", type=Path)
+ parser.add_argument("results", type=Path)
+ args = parser.parse_args()
+ run(args)
diff --git a/siclib/eval/simple_pipeline.py b/siclib/eval/simple_pipeline.py
new file mode 100644
index 0000000000000000000000000000000000000000..7271e961bd93192120e41767b44ad3ebbbbda105
--- /dev/null
+++ b/siclib/eval/simple_pipeline.py
@@ -0,0 +1,408 @@
+import logging
+import resource
+from collections import defaultdict
+from pathlib import Path
+from pprint import pprint
+from typing import Dict, List, Tuple
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from omegaconf import OmegaConf
+from tqdm import tqdm
+
+from siclib.datasets import get_dataset
+from siclib.eval.eval_pipeline import EvalPipeline
+from siclib.eval.io import get_eval_parser, load_model, parse_eval_args
+from siclib.eval.utils import download_and_extract_benchmark, plot_scatter_grid
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.camera import Pinhole
+from siclib.geometry.gravity import Gravity
+from siclib.models.cache_loader import CacheLoader
+from siclib.models.utils.metrics import (
+ gravity_error,
+ latitude_error,
+ pitch_error,
+ roll_error,
+ up_error,
+ vfov_error,
+)
+from siclib.settings import EVAL_PATH
+from siclib.utils.conversions import rad2deg
+from siclib.utils.export_predictions import export_predictions
+from siclib.utils.tensor import add_batch_dim
+from siclib.utils.tools import AUCMetric, set_seed
+from siclib.visualization import visualize_batch, viz2d
+
+# flake8: noqa
+# mypy: ignore-errors
+
+logger = logging.getLogger(__name__)
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+def calculate_pixel_projection_error(
+ camera_pred: BaseCamera, camera_gt: BaseCamera, N: int = 500, distortion_only: bool = True
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate the pixel projection error between two cameras.
+
+ 1. Project a grid of points with the ground truth camera to the image plane.
+ 2. Project the same grid of points with the estimated camera to the image plane.
+ 3. Calculate the pixel distance between the ground truth and estimated points.
+
+ Args:
+ camera_pred (Camera): Predicted camera.
+ camera_gt (Camera): Ground truth camera.
+ N (int, optional): Number of points in the grid. Defaults to 500.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Pixel distance and valid pixels.
+ """
+ H, W = camera_gt.size.unbind(-1)
+ H, W = H.int(), W.int()
+
+ assert torch.allclose(
+ camera_gt.size, camera_pred.size
+ ), f"Cameras must have the same size: {camera_gt.size} != {camera_pred.size}"
+
+ if distortion_only:
+ params = camera_gt._data.clone()
+ params[..., -2:] = camera_pred._data[..., -2:]
+ CameraModel = type(camera_gt)
+ camera_pred = CameraModel(params)
+
+ x_gt, y_gt = torch.meshgrid(
+ torch.linspace(0, H - 1, N), torch.linspace(0, W - 1, N), indexing="xy"
+ )
+ xy = torch.stack((x_gt, y_gt), dim=-1).reshape(-1, 2)
+
+ camera_pin_gt = camera_gt.pinhole()
+ uv_pin, _ = camera_pin_gt.image2world(xy)
+
+ # gt
+ xy_undist_gt, valid_dist_gt = camera_gt.world2image(uv_pin)
+ # pred
+ xy_undist, valid_dist = camera_pred.world2image(uv_pin)
+
+ valid = valid_dist_gt & valid_dist
+
+ dist = (xy_undist - xy_undist_gt) ** 2
+ dist = (dist.sum(-1)).sqrt()
+
+ return dist[valid_dist_gt], valid[valid_dist_gt]
+
+
+def compute_camera_metrics(
+ camera_pred: BaseCamera, camera_gt: BaseCamera, thresholds: List[float]
+) -> Dict[str, float]:
+ results = defaultdict(list)
+ results["vfov"].append(rad2deg(camera_pred.vfov).item())
+ results["vfov_error"].append(vfov_error(camera_pred, camera_gt).item())
+
+ results["focal"].append(camera_pred.f[..., 1].item())
+ focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1])
+ results["focal_error"].append(focal_error.item())
+
+ rel_focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1]) / camera_gt.f[..., 1]
+ results["rel_focal_error"].append(rel_focal_error.item())
+
+ if hasattr(camera_pred, "k1"):
+ results["k1"].append(camera_pred.k1.item())
+ k1_error = torch.abs(camera_pred.k1 - camera_gt.k1)
+ results["k1_error"].append(k1_error.item())
+
+ if thresholds is None:
+ return results
+
+ err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=False)
+ for th in thresholds:
+ results[f"pixel_projection_error@{th}"].append(
+ ((err[valid] < th).sum() / len(valid)).float().item()
+ )
+
+ err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=True)
+ for th in thresholds:
+ results[f"pixel_distortion_error@{th}"].append(
+ ((err[valid] < th).sum() / len(valid)).float().item()
+ )
+ return results
+
+
+def compute_gravity_metrics(gravity_pred: Gravity, gravity_gt: Gravity) -> Dict[str, float]:
+ results = defaultdict(list)
+ results["roll"].append(rad2deg(gravity_pred.roll).item())
+ results["pitch"].append(rad2deg(gravity_pred.pitch).item())
+
+ results["roll_error"].append(roll_error(gravity_pred, gravity_gt).item())
+ results["pitch_error"].append(pitch_error(gravity_pred, gravity_gt).item())
+ results["gravity_error"].append(gravity_error(gravity_pred[None], gravity_gt[None]).item())
+ return results
+
+
+class SimplePipeline(EvalPipeline):
+ default_conf = {
+ "data": {},
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": None, # url to benchmark.zip
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+ def _init(self, conf):
+ self.verbose = conf.eval.verbose
+ self.num_vis = self.conf.eval.num_vis
+
+ self.CameraModel = Pinhole
+
+ if conf.url is not None:
+ ds_dir = Path(conf.data.dataset_dir)
+ download_and_extract_benchmark(ds_dir.name, conf.url, ds_dir.parent)
+
+ @classmethod
+ def get_dataloader(cls, data_conf=None, batch_size=None):
+ """Returns a data loader with samples for each eval datapoint"""
+ data_conf = data_conf or cls.default_conf["data"]
+
+ if batch_size is not None:
+ data_conf["test_batch_size"] = batch_size
+
+ do_shuffle = data_conf["test_batch_size"] > 1
+ dataset = get_dataset(data_conf["name"])(data_conf)
+ return dataset.get_data_loader("test", shuffle=do_shuffle)
+
+ def get_predictions(self, experiment_dir, model=None, overwrite=False):
+ """Export a prediction file for each eval datapoint"""
+ # set_seed(0)
+ pred_file = experiment_dir / "predictions.h5"
+ if not pred_file.exists() or overwrite:
+ if model is None:
+ model = load_model(self.conf.model, self.conf.checkpoint)
+ export_predictions(
+ self.get_dataloader(self.conf.data),
+ model,
+ pred_file,
+ keys=self.export_keys,
+ optional_keys=self.optional_export_keys,
+ verbose=self.verbose,
+ )
+ return pred_file
+
+ def get_figures(self, results):
+ figures = {}
+
+ if self.num_vis == 0:
+ return figures
+
+ gl = ["up", "latitude"]
+ rpf = ["roll", "pitch", "vfov"]
+
+ # check if rpf in results
+ if all(k in results for k in rpf):
+ x_keys = [f"{k}_gt" for k in rpf]
+
+ # gt vs error
+ y_keys = [f"{k}_error" for k in rpf]
+ fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False)
+ figures |= {"rpf_gt_error": fig}
+
+ # gt vs pred
+ y_keys = [f"{k}" for k in rpf]
+ fig, _ = plot_scatter_grid(results, x_keys, y_keys, diag=True, show_means=False)
+ figures |= {"rpf_gt_pred": fig}
+
+ if all(f"{k}_error" in results for k in gl):
+ x_keys = [f"{k}_gt" for k in rpf]
+ y_keys = [f"{k}_error" for k in gl]
+ fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False)
+ figures |= {"gl_gt_error": fig}
+
+ return figures
+
+ def run_eval(self, loader, pred_file):
+ conf = self.conf.eval
+ results = defaultdict(list)
+
+ save_to = Path(pred_file).parent / "figures"
+ if not save_to.exists() and self.num_vis > 0:
+ save_to.mkdir()
+
+ cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval()
+
+ if not self.verbose:
+ logger.info(f"Evaluating {pred_file}")
+
+ for i, data in enumerate(
+ tqdm(loader, desc="Evaluating", total=len(loader), ncols=80, disable=not self.verbose)
+ ):
+ # NOTE: data is batched but pred is not
+ pred = cache_loader(data)
+
+ results["names"].append(data["name"][0])
+
+ gt_cam = data["camera"][0]
+ gt_gravity = data["gravity"][0]
+ # add gt parameters
+ results["roll_gt"].append(rad2deg(gt_gravity.roll).item())
+ results["pitch_gt"].append(rad2deg(gt_gravity.pitch).item())
+ results["vfov_gt"].append(rad2deg(gt_cam.vfov).item())
+ results["focal_gt"].append(gt_cam.f[1].item())
+
+ results["k1_gt"].append(gt_cam.k1.item())
+
+ if "camera" in pred:
+ # pred["camera"] is a tensor of the parameters
+ pred_cam = self.CameraModel(pred["camera"])
+
+ pred_camera = pred_cam[None].undo_scale_crop(data)[0]
+ gt_camera = gt_cam[None].undo_scale_crop(data)[0]
+
+ camera_metrics = compute_camera_metrics(
+ pred_camera, gt_camera, conf.pixel_thresholds
+ )
+
+ for k, v in camera_metrics.items():
+ results[k].extend(v)
+
+ if "focal_uncertainty" in pred:
+ focal_uncertainty = pred["focal_uncertainty"]
+ results["focal_uncertainty"].append(focal_uncertainty.item())
+
+ if "vfov_uncertainty" in pred:
+ vfov_uncertainty = rad2deg(pred["vfov_uncertainty"])
+ results["vfov_uncertainty"].append(vfov_uncertainty.item())
+
+ if "gravity" in pred:
+ # pred["gravity"] is a tensor of the parameters
+ pred_gravity = Gravity(pred["gravity"])
+
+ gravity_metrics = compute_gravity_metrics(pred_gravity, gt_gravity)
+ for k, v in gravity_metrics.items():
+ results[k].extend(v)
+
+ if "roll_uncertainty" in pred:
+ roll_uncertainty = rad2deg(pred["roll_uncertainty"])
+ results["roll_uncertainty"].append(roll_uncertainty.item())
+
+ if "pitch_uncertainty" in pred:
+ pitch_uncertainty = rad2deg(pred["pitch_uncertainty"])
+ results["pitch_uncertainty"].append(pitch_uncertainty.item())
+
+ if "gravity_uncertainty" in pred:
+ gravity_uncertainty = rad2deg(pred["gravity_uncertainty"])
+ results["gravity_uncertainty"].append(gravity_uncertainty.item())
+
+ if "up_field" in pred:
+ up_err = up_error(pred["up_field"].unsqueeze(0), data["up_field"])
+ results["up_error"].append(up_err.mean(axis=(1, 2)).item())
+ results["up_med_error"].append(up_err.median().item())
+
+ if "up_confidence" in pred:
+ up_confidence = pred["up_confidence"].unsqueeze(0)
+ weighted_error = (up_err * up_confidence).sum(axis=(1, 2))
+ weighted_error = weighted_error / up_confidence.sum(axis=(1, 2))
+ results["up_weighted_error"].append(weighted_error.item())
+
+ if i < self.num_vis:
+ pred_batched = add_batch_dim(pred)
+ up_fig = visualize_batch.make_up_figure(pred=pred_batched, data=data)
+ up_fig = up_fig["up"]
+ plt.tight_layout()
+ viz2d.save_plot(save_to / f"up-{i}-{up_err.median().item():.3f}.jpg")
+ plt.close()
+
+ if "latitude_field" in pred:
+ lat_err = latitude_error(
+ pred["latitude_field"].unsqueeze(0), data["latitude_field"]
+ )
+ results["latitude_error"].append(lat_err.mean(axis=(1, 2)).item())
+ results["latitude_med_error"].append(lat_err.median().item())
+
+ if "latitude_confidence" in pred:
+ lat_confidence = pred["latitude_confidence"].unsqueeze(0)
+ weighted_error = (lat_err * lat_confidence).sum(axis=(1, 2))
+ weighted_error = weighted_error / lat_confidence.sum(axis=(1, 2))
+ results["latitude_weighted_error"].append(weighted_error.item())
+
+ if i < self.num_vis:
+ pred_batched = add_batch_dim(pred)
+ lat_fig = visualize_batch.make_latitude_figure(pred=pred_batched, data=data)
+ lat_fig = lat_fig["latitude"]
+ plt.tight_layout()
+ viz2d.save_plot(save_to / f"latitude-{i}-{lat_err.median().item():.3f}.jpg")
+ plt.close()
+
+ summaries = {}
+ for k, v in results.items():
+ arr = np.array(v)
+ if not np.issubdtype(np.array(v).dtype, np.number):
+ continue
+
+ if k.endswith("_error") or "recall" in k or "pixel" in k:
+ summaries[f"mean_{k}"] = round(np.nanmean(arr), 3)
+ summaries[f"median_{k}"] = round(np.nanmedian(arr), 3)
+
+ if any(keyword in k for keyword in ["roll", "pitch", "vfov", "gravity"]):
+ if not conf.thresholds:
+ continue
+
+ auc = AUCMetric(
+ elements=arr, thresholds=list(conf.thresholds), min_error=1
+ ).compute()
+ for i, t in enumerate(conf.thresholds):
+ summaries[f"auc_{k}@{t}"] = round(auc[i], 3)
+
+ return summaries, self.get_figures(results), results
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(SimplePipeline.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = SimplePipeline(conf)
+ s, f, r = pipeline.run(
+ experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/stanford2d3d.py b/siclib/eval/stanford2d3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..5380b1a0d68118e9acbdea64a22f4d203962df91
--- /dev/null
+++ b/siclib/eval/stanford2d3d.py
@@ -0,0 +1,89 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class Stanford2D3D(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/stanford2d3d",
+ "test_img_dir": "${.dataset_dir}/images",
+ "test_csv": "${.dataset_dir}/images.csv",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/stanford2d3d.zip",
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(Stanford2D3D.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = Stanford2D3D(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/tartanair.py b/siclib/eval/tartanair.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12d78420c7581e2234a00cb0b558954ef545646
--- /dev/null
+++ b/siclib/eval/tartanair.py
@@ -0,0 +1,89 @@
+import resource
+from pathlib import Path
+from pprint import pprint
+
+import matplotlib.pyplot as plt
+import torch
+from omegaconf import OmegaConf
+
+from siclib.eval.io import get_eval_parser, parse_eval_args
+from siclib.eval.simple_pipeline import SimplePipeline
+from siclib.settings import EVAL_PATH
+
+# flake8: noqa
+# mypy: ignore-errors
+
+rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1]))
+
+torch.set_grad_enabled(False)
+
+
+class Tartanair(SimplePipeline):
+ default_conf = {
+ "data": {
+ "name": "simple_dataset",
+ "dataset_dir": "data/tartanair",
+ "test_img_dir": "${.dataset_dir}/images",
+ "test_csv": "${.dataset_dir}/images.csv",
+ "augmentations": {"name": "identity"},
+ "preprocessing": {"resize": 320, "edge_divisible_by": 32},
+ "test_batch_size": 1,
+ },
+ "model": {},
+ "eval": {
+ "thresholds": [1, 5, 10],
+ "pixel_thresholds": [0.5, 1, 3, 5],
+ "num_vis": 10,
+ "verbose": True,
+ },
+ "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/tartanair.zip",
+ }
+
+ export_keys = [
+ "camera",
+ "gravity",
+ ]
+
+ optional_export_keys = [
+ "focal_uncertainty",
+ "vfov_uncertainty",
+ "roll_uncertainty",
+ "pitch_uncertainty",
+ "gravity_uncertainty",
+ "up_field",
+ "up_confidence",
+ "latitude_field",
+ "latitude_confidence",
+ ]
+
+
+if __name__ == "__main__":
+ dataset_name = Path(__file__).stem
+ parser = get_eval_parser()
+ args = parser.parse_intermixed_args()
+
+ default_conf = OmegaConf.create(Tartanair.default_conf)
+
+ # mingle paths
+ output_dir = Path(EVAL_PATH, dataset_name)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf)
+
+ experiment_dir = output_dir / name
+ experiment_dir.mkdir(exist_ok=True)
+
+ pipeline = Tartanair(conf)
+ s, f, r = pipeline.run(
+ experiment_dir,
+ overwrite=args.overwrite,
+ overwrite_eval=args.overwrite_eval,
+ )
+
+ pprint(s)
+
+ if args.plot:
+ for name, fig in f.items():
+ fig.canvas.manager.set_window_title(name)
+ plt.show()
diff --git a/siclib/eval/utils.py b/siclib/eval/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b0ba53848603370aefea2f0fa1b06a5cd2abdd0
--- /dev/null
+++ b/siclib/eval/utils.py
@@ -0,0 +1,116 @@
+import logging
+import shutil
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None:
+ benchmark_dir = output / name
+ if not output.exists():
+ output.mkdir(parents=True)
+
+ if benchmark_dir.exists():
+ logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.")
+ return
+
+ if name == "stanford2d3d":
+ # prompt user to sign data sharing and usage terms
+ txt = "\n" + "#" * 108 + "\n\n"
+ txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n"
+ txt += (
+ "https://docs.google.com/forms/d/e/"
+ + "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n"
+ )
+ txt += "#" * 108 + "\n\n"
+ txt += "Did you fill out the data sharing and usage terms? [y/n] "
+ choice = input(txt)
+ if choice.lower() != "y":
+ raise ValueError(
+ "You must agree to the terms of use to download the Stanford2D3D dataset."
+ )
+
+ zip_file = output / f"{name}.zip"
+
+ if not zip_file.exists():
+ logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.")
+ torch.hub.download_url_to_file(url, zip_file)
+
+ logger.info(f"Extracting benchmark {name} in {output}.")
+ shutil.unpack_archive(zip_file, output, format="zip")
+ zip_file.unlink()
+
+
+def check_keys_recursive(d, pattern):
+ if isinstance(pattern, dict):
+ {check_keys_recursive(d[k], v) for k, v in pattern.items()}
+ else:
+ for k in pattern:
+ assert k in d.keys()
+
+
+def plot_scatter_grid(
+ results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True
+): # sourcery skip: low-code-quality
+ if ax is None:
+ N, M = len(y_keys), len(x_keys)
+ fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5))
+
+ if N == 1:
+ ax = np.array(ax)
+ ax = ax.reshape(1, -1)
+
+ if M == 1:
+ ax = np.array(ax)
+ ax = ax.reshape(-1, 1)
+ else:
+ fig = None
+
+ for j, kx in enumerate(x_keys):
+ for i, ky in enumerate(y_keys):
+ ax[i, j].scatter(
+ results[kx],
+ results[ky],
+ s=1,
+ alpha=0.5,
+ label=name or None,
+ )
+
+ ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}")
+ ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}")
+
+ low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0])
+ high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1])
+ if diag == "all" or (i == j and diag):
+ ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x")
+
+ if name or diag == "all" or (i == j and diag):
+ ax[i, j].legend()
+
+ if not show_means:
+ return fig, ax
+
+ means = {"y": {}, "x": {}}
+ for kx in x_keys:
+ for ky in y_keys:
+ means["x"][kx] = np.mean(results[kx])
+ means["y"][ky] = np.mean(results[ky])
+
+ for j, kx in enumerate(x_keys):
+ for i, ky in enumerate(y_keys):
+ xlim = np.min(results[kx]), np.max(results[kx])
+ ylim = np.min(results[ky]), np.max(results[ky])
+ means_x = [means["x"][kx]]
+ means_y = [means["y"][ky]]
+ color = plt.cm.tab10(line_idx)
+ ax[i, j].vlines(means_x, *ylim, colors=[color])
+ ax[i, j].hlines(means_y, *xlim, colors=[color])
+
+ return fig, ax
diff --git a/siclib/geometry/__init__.py b/siclib/geometry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/geometry/base_camera.py b/siclib/geometry/base_camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e07615c274f57f3cec64308b5c58dcbb96fc125
--- /dev/null
+++ b/siclib/geometry/base_camera.py
@@ -0,0 +1,523 @@
+# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich
+# https://github.com/cvg/pixloc
+# Released under the Apache License 2.0
+
+"""Convenience classes a for camera models.
+
+Based on PyTorch tensors: differentiable, batched, with GPU support.
+"""
+
+from abc import abstractmethod
+from typing import Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+from torch.func import jacfwd, vmap
+from torch.nn import functional as F
+
+from siclib.geometry.gravity import Gravity
+from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2rotmat
+from siclib.utils.tensor import TensorWrapper, autocast
+
+# mypy: ignore-errors
+
+
+class BaseCamera(TensorWrapper):
+ """Camera tensor class."""
+
+ eps = 1e-3
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}).
+
+ Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where
+ - w, h: image size in pixels
+ - fx, fy: focal lengths in pixels
+ - cx, cy: principal points in normalized image coordinates
+ - dist: distortion parameters
+
+ Args:
+ data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}).
+ """
+ # w, h, fx, fy, cx, cy, dist
+ assert data.shape[-1] in {6, 7, 8}, data.shape
+
+ pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],))
+ data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data
+ super().__init__(data)
+
+ @classmethod
+ def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera":
+ """Create a Camera object from a dictionary of parameters.
+
+ Args:
+ param_dict (Dict[str, torch.Tensor]): Dictionary of parameters.
+
+ Returns:
+ Camera: Camera object.
+ """
+ for key, value in param_dict.items():
+ if not isinstance(value, torch.Tensor):
+ param_dict[key] = torch.tensor(value)
+
+ h, w = param_dict["height"], param_dict["width"]
+ cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2)
+
+ vfov = param_dict.get("vfov")
+ f = param_dict.get("f", fov2focal(vfov, h))
+
+ if "dist" in param_dict:
+ k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1]
+ elif "k1_hat" in param_dict:
+ k1 = param_dict["k1_hat"] * (f / h) ** 2
+
+ k2 = param_dict.get("k2", torch.zeros_like(k1))
+ else:
+ k1 = param_dict.get("k1", torch.zeros_like(f))
+ k2 = param_dict.get("k2", torch.zeros_like(f))
+
+ fx, fy = f, f
+ if "scales" in param_dict:
+ scales = param_dict["scales"]
+ fx = fx * scales[..., 0] / scales[..., 1]
+
+ params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1)
+ return cls(params)
+
+ def pinhole(self):
+ """Return the pinhole camera model."""
+ return self.__class__(self._data[..., :6])
+
+ @property
+ def size(self) -> torch.Tensor:
+ """Size (width height) of the images, with shape (..., 2)."""
+ return self._data[..., :2]
+
+ @property
+ def f(self) -> torch.Tensor:
+ """Focal lengths (fx, fy) with shape (..., 2)."""
+ return self._data[..., 2:4]
+
+ @property
+ def vfov(self) -> torch.Tensor:
+ """Vertical field of view in radians."""
+ return focal2fov(self.f[..., 1], self.size[..., 1])
+
+ @property
+ def hfov(self) -> torch.Tensor:
+ """Horizontal field of view in radians."""
+ return focal2fov(self.f[..., 0], self.size[..., 0])
+
+ @property
+ def c(self) -> torch.Tensor:
+ """Principal points (cx, cy) with shape (..., 2)."""
+ return self._data[..., 4:6]
+
+ @property
+ def K(self) -> torch.Tensor:
+ """Returns the self intrinsic matrix with shape (..., 3, 3)."""
+ shape = self.shape + (3, 3)
+ K = self._data.new_zeros(shape)
+ K[..., 0, 0] = self.f[..., 0]
+ K[..., 1, 1] = self.f[..., 1]
+ K[..., 0, 2] = self.c[..., 0]
+ K[..., 1, 2] = self.c[..., 1]
+ K[..., 2, 2] = 1
+ return K
+
+ def update_focal(self, delta: torch.Tensor, as_log: bool = False):
+ """Update the self parameters after changing the focal length."""
+ f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta
+
+ # clamp focal length to a reasonable range for stability during training
+ min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1])
+ max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1])
+ min_f = min_f.unsqueeze(-1).expand(-1, 2)
+ max_f = max_f.unsqueeze(-1).expand(-1, 2)
+ f = f.clamp(min=min_f, max=max_f)
+
+ # make sure focal ration stays the same (avoid inplace operations)
+ fx = f[..., 1] * self.f[..., 0] / self.f[..., 1]
+ f = torch.stack([fx, f[..., 1]], -1)
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([self.size, f, self.c, dist], -1))
+
+ def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]):
+ """Update the self parameters after resizing an image."""
+ scales = (scales, scales) if isinstance(scales, (int, float)) else scales
+ s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales)
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1))
+
+ def crop(self, pad: Tuple[float]):
+ """Update the self parameters after cropping an image."""
+ pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad)
+ size = self.size + pad.to(self.size)
+ c = self.c + pad.to(self.c) / 2
+
+ dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape)
+ return self.__class__(torch.cat([size, self.f, c, dist], -1))
+
+ def undo_scale_crop(self, data: Dict[str, torch.Tensor]):
+ """Undo transforms done during scaling and cropping."""
+ camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self
+ return camera.scale(1.0 / data["scales"])
+
+ @autocast
+ def in_image(self, p2d: torch.Tensor):
+ """Check if 2D points are within the image boundaries."""
+ assert p2d.shape[-1] == 2
+ size = self.size.unsqueeze(-2)
+ return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1)
+
+ @autocast
+ def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Project 3D points into the self plane and check for visibility."""
+ z = p3d[..., -1]
+ valid = z > self.eps
+ z = z.clamp(min=self.eps)
+ p2d = p3d[..., :-1] / z.unsqueeze(-1)
+ return p2d, valid
+
+ def J_project(self, p3d: torch.Tensor):
+ """Jacobian of the projection function."""
+ x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2]
+ zero = torch.zeros_like(z)
+ z = z.clamp(min=self.eps)
+ J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1)
+ J = J.reshape(p3d.shape[:-1] + (2, 3))
+ return J # N x 2 x 3
+
+ @abstractmethod
+ def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ raise NotImplementedError("distort() must be implemented.")
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the distortion function."""
+ if wrt == "scale2pts": # (..., 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0).squeeze(-3, -2)
+
+ elif wrt == "scale2dist": # (..., 1)
+ J = []
+ for idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[idx, :6], x[None]], -1)
+ return self.__class__(params).distort(p2d[idx], return_scale=True)[0]
+
+ J.append(vmap(jacfwd(func))(self[idx].dist))
+
+ return torch.cat(J, dim=0)
+
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @abstractmethod
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ raise NotImplementedError("undistort() must be implemented.")
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ if wrt == "pts": # (..., 2, 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0).squeeze(-3)
+
+ elif wrt == "dist": # (..., 1)
+ J = []
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)
+ return self.__class__(params).undistort(p2d[batch_idx])[0]
+
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
+
+ return torch.cat(J, dim=0)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @autocast
+ def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Compute the offset for the up-projection."""
+ return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the distortion offset for up-projection."""
+ if wrt == "uv": # (B, N, 2, 2)
+ J = [
+ vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None]
+ for idx in range(p2d.shape[0])
+ ]
+
+ return torch.cat(J, dim=0)
+
+ elif wrt == "dist": # (B, N, 2)
+ J = []
+ for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension
+
+ def func(x):
+ params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None]
+ return self.__class__(params).up_projection_offset(p2d[batch_idx][None])
+
+ J.append(vmap(jacfwd(func))(self[batch_idx].dist))
+
+ return torch.cat(J, dim=0).squeeze(1)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ @autocast
+ def denormalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert normalized 2D coordinates into pixel coordinates."""
+ return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2)
+
+ def J_denormalize(self):
+ """Jacobian of the denormalization function."""
+ return torch.diag_embed(self.f) # ..., 2 x 2
+
+ @autocast
+ def normalize(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Convert pixel coordinates into normalized 2D coordinates."""
+ return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2))
+
+ def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"):
+ """Jacobian of the normalization function."""
+ # ... x N x 2 x 2
+ if wrt == "f":
+ J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2)
+ return torch.diag_embed(J_f)
+ elif wrt == "pts":
+ J_pts = 1 / self.f
+ return torch.diag_embed(J_pts)
+ else:
+ raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}")
+
+ def pixel_coordinates(self) -> torch.Tensor:
+ """Pixel coordinates in self frame.
+
+ Returns:
+ torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2).
+ """
+ w, h = self.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # create grid
+ x = torch.arange(0, w, dtype=self.dtype, device=self.device)
+ y = torch.arange(0, h, dtype=self.dtype, device=self.device)
+ x, y = torch.meshgrid(x, y, indexing="xy")
+ xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2)
+
+ # add batch dimension (normalize() would broadcast but we make it explicit)
+ B = self.shape[0]
+ xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy
+
+ return xy.to(self.device).to(self.dtype)
+
+ def normalized_image_coordinates(self) -> torch.Tensor:
+ """Normalized image coordinates in self frame.
+
+ Returns:
+ torch.Tensor: Normalized image coordinates as a tensor of shape (B, h * w, 3).
+ """
+ xy = self.pixel_coordinates()
+ uv1, _ = self.image2world(xy)
+
+ B = self.shape[0]
+ uv1 = uv1.reshape(B, -1, 3)
+ return uv1.to(self.device).to(self.dtype)
+
+ @autocast
+ def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor:
+ """Get the bearing vectors of pixel coordinates.
+
+ Args:
+ p2d (torch.Tensor): Pixel coordinates as a tensor of shape (..., 3).
+
+ Returns:
+ torch.Tensor: Bearing vectors as a tensor of shape (..., 3).
+ """
+ return F.normalize(p3d, dim=-1)
+
+ @autocast
+ def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Transform 3D points into 2D pixel coordinates."""
+ p2d, visible = self.project(p3d)
+ p2d, mask = self.distort(p2d)
+ p2d = self.denormalize(p2d)
+ valid = visible & mask & self.in_image(p2d)
+ return p2d, valid
+
+ @autocast
+ def J_world2image(self, p3d: torch.Tensor):
+ """Jacobian of the world2image function."""
+ p2d_proj, valid = self.project(p3d)
+
+ J_dnorm = self.J_denormalize()
+ J_dist = self.J_distort(p2d_proj)
+ J_proj = self.J_project(p3d)
+
+ J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj)
+ return J, valid
+
+ @autocast
+ def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Transform point in the image plane to 3D world coordinates."""
+ p2d = self.normalize(p2d)
+ p2d, valid = self.undistort(p2d)
+ ones = p2d.new_ones(p2d.shape[:-1] + (1,))
+ p3d = torch.cat([p2d, ones], -1)
+ return p3d, valid
+
+ @autocast
+ def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]:
+ """Jacobian of the image2world function."""
+ if wrt == "dist":
+ p2d_norm = self.normalize(p2d)
+ return self.J_undistort(p2d_norm, wrt)
+ elif wrt == "f":
+ J_norm2f = self.J_normalize(p2d, wrt)
+ p2d_norm = self.normalize(p2d)
+ J_dist2norm = self.J_undistort(p2d_norm, "pts")
+
+ return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f)
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+ @autocast
+ def undistort_image(self, img: torch.Tensor) -> torch.Tensor:
+ """Undistort an image using the distortion model."""
+ assert self.shape[0] == 1, "Batch size must be 1."
+ W, H = self.size.unbind(-1)
+ H, W = H.int().item(), W.int().item()
+
+ x, y = torch.arange(0, W), torch.arange(0, H)
+ x, y = torch.meshgrid(x, y, indexing="xy")
+ coords = torch.stack((x, y), dim=-1).reshape(-1, 2)
+
+ p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype))
+ p2d, _ = self.world2image(p3d)
+
+ mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W))
+ grid = torch.stack((mapx, mapy), dim=-1)
+ grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1
+ return F.grid_sample(img, grid, align_corners=True)
+
+ def get_img_from_pano(
+ self,
+ pano_img: torch.Tensor,
+ gravity: Gravity,
+ yaws: torch.Tensor = 0.0,
+ resize_factor: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Render an image from a panorama.
+
+ Args:
+ pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1].
+ gravity (Gravity): Gravity direction of the camera.
+ yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0.
+ resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the
+ field of view. Defaults to 1.
+
+ Returns:
+ torch.Tensor: Image rendered from the panorama.
+ """
+ B = self.shape[0]
+ if B > 0:
+ assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width."
+ assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height."
+
+ w, h = self.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if isinstance(yaws, (int, float)):
+ yaws = [yaws]
+ if isinstance(resize_factor, (int, float)):
+ resize_factor = [resize_factor]
+
+ yaws = (
+ yaws.to(self.dtype).to(self.device)
+ if isinstance(yaws, torch.Tensor)
+ else self.new_tensor(yaws)
+ )
+
+ if isinstance(resize_factor, torch.Tensor):
+ resize_factor = resize_factor.to(self.dtype).to(self.device)
+ elif resize_factor is not None:
+ resize_factor = self.new_tensor(resize_factor)
+
+ assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor."
+ pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3
+
+ pano_imgs = []
+ for i, yaw in enumerate(yaws):
+ if resize_factor is not None:
+ # resize the panorama such that the fov of the panorama has the same height as the
+ # image
+ vfov = self.vfov[i] if B != 0 else self.vfov
+ scale = np.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i]
+ pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale))
+
+ # pano_img = pano_img.permute(2, 0, 1).unsqueeze(0)
+ mode = "bicubic" if scale >= 1 else "area"
+ resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode)
+ else:
+ # make sure to copy: resized_pano = pano_img
+ resized_pano = pano_img
+ pano_shape = pano_img.shape[-2:][::-1]
+
+ pano_imgs.append((resized_pano, pano_shape))
+
+ xy = self.pixel_coordinates()
+ uv1, valid = self.image2world(xy)
+ bearings = self.pixel_bearing_many(uv1)
+
+ # rotate bearings
+ R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws)
+ rotated_bearings = bearings @ gravity.R @ R_yaw
+
+ # spherical coordinates
+ lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2])
+ lat = torch.atan2(
+ rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1)
+ )
+
+ images = []
+ for idx, (resized_pano, pano_shape) in enumerate(pano_imgs):
+ min_lon, max_lon = -torch.pi, torch.pi
+ min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0
+ min_x, max_x = 0, pano_shape[0] - 1.0
+ min_y, max_y = 0, pano_shape[1] - 1.0
+
+ # map Spherical Coordinates to Panoramic Coordinates
+ nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x
+ ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y
+
+ # reshape and cast to numpy for remap
+ mapx = nx.reshape((1, h, w))
+ mapy = ny.reshape((1, h, w))
+
+ grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension
+ # Normalize to [-1, 1]
+ grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1
+ # Apply grid sample
+ image = F.grid_sample(resized_pano, grid, align_corners=True)
+ images.append(image)
+
+ return torch.concatenate(images, 0) if B > 0 else images[0]
+
+ def __repr__(self):
+ """Print the Camera object."""
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
diff --git a/siclib/geometry/camera.py b/siclib/geometry/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..02e1139357036be37e3d15371676a97e67d4fdc9
--- /dev/null
+++ b/siclib/geometry/camera.py
@@ -0,0 +1,280 @@
+"""Implementation of the pinhole, simple radial, and simple divisional camera models."""
+
+from typing import Tuple
+
+import torch
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.utils.tensor import autocast
+
+# flake8: noqa: E741
+
+# mypy: ignore-errors
+
+
+class Pinhole(BaseCamera):
+ """Implementation of the pinhole camera model."""
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates."""
+ if return_scale:
+ return p2d.new_ones(p2d.shape[:-1] + (1,))
+
+ return p2d, p2d.new_ones((p2d.shape[0], 1)).bool()
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the distortion function."""
+ if wrt == "pts":
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+ def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates."""
+ return pts, pts.new_ones((pts.shape[0], 1)).bool()
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ if wrt == "pts":
+ return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2))
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
+
+
+class SimpleRadial(BaseCamera):
+ """Implementation of the simple radial camera model."""
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., 1)."""
+ return self._data[..., 6:]
+
+ @property
+ def k1(self) -> torch.Tensor:
+ """Distortion parameters, with shape (...)."""
+ return self._data[..., 6]
+
+ @property
+ def k1_hat(self) -> torch.Tensor:
+ """Distortion parameters, with shape (...)."""
+ return self.k1 / (self.f[..., 1] / self.size[..., 1]) ** 2
+
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)):
+ """Update the self parameters after changing the k1 distortion parameter."""
+ delta_dist = self.new_ones(self.dist.shape) * delta
+ dist = (self.dist + delta_dist).clamp(*dist_range)
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Check if the distorted points are valid."""
+ return p2d.new_ones(p2d.shape[:-1]).bool()
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 + self.k1[..., None, None] * r2
+
+ if return_scale:
+ return radial, None
+
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
+ """Jacobian of the distortion function."""
+ k1 = self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ if wrt == "pts": # (..., 2, 2)
+ radial = 1 + k1 * r2
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+ return (2 * k1 * ppT) + torch.diag_embed(radial.expand(radial.shape[:-1] + (2,)))
+ elif wrt == "dist": # (..., 2)
+ return r2 * p2d
+ elif wrt == "scale2dist": # (..., 1)
+ return r2
+ elif wrt == "scale2pts": # (..., 2)
+ return 2 * k1 * p2d
+ else:
+ return super().J_distort(p2d, wrt)
+
+ @autocast
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ b1 = -self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 + b1 * r2
+ return p2d * radial, self.check_valid(p2d)
+
+ @autocast
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ b1 = -self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ if wrt == "dist":
+ return -r2 * p2d
+ elif wrt == "pts":
+ radial = 1 + b1 * r2
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+ return (2 * b1[..., None] * ppT) + torch.diag_embed(
+ radial.expand(radial.shape[:-1] + (2,))
+ )
+ else:
+ return super().J_undistort(p2d, wrt)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the up-projection offset."""
+ if wrt == "uv": # (..., 2, 2)
+ return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,)))
+ elif wrt == "dist":
+ return 2 * p2d # (..., 2)
+ else:
+ return super().J_up_projection_offset(p2d, wrt)
+
+
+class SimpleDivisional(BaseCamera):
+ """Implementation of the simple divisional camera model."""
+
+ @property
+ def dist(self) -> torch.Tensor:
+ """Distortion parameters, with shape (..., 1)."""
+ return self._data[..., 6:]
+
+ @property
+ def k1(self) -> torch.Tensor:
+ """Distortion parameters, with shape (...)."""
+ return self._data[..., 6]
+
+ def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)):
+ """Update the self parameters after changing the k1 distortion parameter."""
+ delta_dist = self.new_ones(self.dist.shape) * delta
+ dist = (self.dist + delta_dist).clamp(*dist_range)
+ data = torch.cat([self.size, self.f, self.c, dist], -1)
+ return self.__class__(data)
+
+ @autocast
+ def check_valid(self, p2d: torch.Tensor) -> torch.Tensor:
+ """Check if the distorted points are valid."""
+ return p2d.new_ones(p2d.shape[:-1]).bool()
+
+ def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]:
+ """Distort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0))
+ denom = 2 * self.k1[..., None, None] * r2
+
+ ones = radial.new_ones(radial.shape)
+ radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6))
+
+ if return_scale:
+ return radial, None
+
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"):
+ """Jacobian of the distortion function."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6))
+ if wrt == "scale2pts": # (B, N, 2)
+ d1 = t0 * 2 * r2
+ d2 = self.k1[..., None, None] * r2**2
+ denom = d1 * d2
+ return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
+
+ elif wrt == "scale2dist":
+ d1 = 2 * self.k1[..., None, None] * t0
+ d2 = 2 * r2 * self.k1[..., None, None] ** 2
+ denom = d1 * d2
+ return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6)
+
+ else:
+ return super().J_distort(p2d, wrt)
+
+ @autocast
+ def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]:
+ """Undistort normalized 2D coordinates and check for validity of the distortion model."""
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ denom = 1 + self.k1[..., None, None] * r2
+ radial = 1 / denom.masked_fill(denom == 0, 1e6)
+ return p2d * radial, self.check_valid(p2d)
+
+ def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor:
+ """Jacobian of the undistortion function."""
+ # return super().J_undistort(p2d, wrt)
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ k1 = self.k1[..., None, None]
+ if wrt == "dist":
+ denom = (1 + k1 * r2) ** 2
+ return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d
+ elif wrt == "pts":
+ t0 = 1 + k1 * r2
+ t0 = t0.masked_fill(t0 == 0, 1e6)
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+ J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,)))
+ return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2)
+
+ else:
+ return super().J_undistort(p2d, wrt)
+
+ def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Jacobian of the up-projection offset.
+
+ func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv
+ - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv
+ """
+ k1 = self.k1[..., None, None]
+ r2 = torch.sum(p2d**2, -1, keepdim=True)
+ t0 = (1 - 4 * k1 * r2).clamp(min=1e-6)
+ t1 = torch.sqrt(t0)
+ if wrt == "dist":
+ denom = 4 * t0 ** (3 / 2)
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = 16 / denom
+
+ denom = r2 * t1 * k1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 2 / denom
+
+ denom = (r2 * k1) ** 2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (1 - t1) / denom
+
+ return J * p2d
+ elif wrt == "uv":
+ # ! unstable (gradient checker might fail), rewrite to use single division (by denom)
+ ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2)
+
+ denom = 2 * r2 * t1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,)))
+
+ denom = 4 * t1 * r2**2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 16 / denom[..., None] * ppT
+
+ denom = 4 * r2 * t0 ** (3 / 2)
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (32 * k1[..., None]) / denom[..., None] * ppT
+
+ denom = r2**2 * t1
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - 4 / denom[..., None] * ppT
+
+ denom = k1 * r2**3
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J + (4 * (1 - t1) / denom)[..., None] * ppT
+
+ denom = k1 * r2**2
+ denom = denom.masked_fill(denom == 0, 1e6)
+ J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,)))
+
+ return J
+ else:
+ return super().J_up_projection_offset(p2d, wrt)
+
+
+camera_models = {
+ "pinhole": Pinhole,
+ "simple_radial": SimpleRadial,
+ "simple_divisional": SimpleDivisional,
+}
diff --git a/siclib/geometry/gradient_checker.py b/siclib/geometry/gradient_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e55fbfa9a9d5a4a5bc33724902f0a79a5906567
--- /dev/null
+++ b/siclib/geometry/gradient_checker.py
@@ -0,0 +1,571 @@
+"""Test gradient implementations."""
+
+import logging
+import unittest
+
+import torch
+from torch.func import jacfwd, vmap
+
+from siclib.geometry.camera import camera_models
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.jacobians import J_up_projection
+from siclib.geometry.manifolds import SphericalManifold
+from siclib.geometry.perspective_fields import J_perspective_field, get_perspective_field
+from siclib.models.optimization.lm_optimizer import LMOptimizer
+from siclib.utils.conversions import deg2rad, fov2focal
+
+# flake8: noqa E731
+# mypy: ignore-errors
+
+H, W = 320, 320
+
+K1 = -0.1
+
+# CAMERA_MODEL = "pinhole"
+CAMERA_MODEL = "simple_radial"
+# CAMERA_MODEL = "simple_divisional"
+
+Camera = camera_models[CAMERA_MODEL]
+
+# detect anomaly
+torch.autograd.set_detect_anomaly(True)
+
+
+logger = logging.getLogger("geocalib.models.base_model")
+logger.setLevel("ERROR")
+
+
+def get_toy_rpf(roll=None, pitch=None, vfov=None) -> torch.Tensor:
+ """Return a random roll, pitch, focal length if not specified."""
+ if roll is None:
+ roll = deg2rad((torch.rand(1) - 0.5) * 90) # -45 ~ 45
+ elif not isinstance(roll, torch.Tensor):
+ roll = torch.tensor(deg2rad(roll)).unsqueeze(0)
+
+ if pitch is None:
+ pitch = deg2rad((torch.rand(1) - 0.5) * 90) # -45 ~ 45
+ elif not isinstance(pitch, torch.Tensor):
+ pitch = torch.tensor(deg2rad(pitch)).unsqueeze(0)
+
+ if vfov is None:
+ vfov = deg2rad(5 + torch.rand(1) * 75) # 5 ~ 80
+ elif not isinstance(vfov, torch.Tensor):
+ vfov = torch.tensor(deg2rad(vfov)).unsqueeze(0)
+
+ return torch.stack([roll, pitch, fov2focal(vfov, H)], dim=-1).float()
+
+
+class TestJacobianFunctions(unittest.TestCase):
+ """Test the jacobian functions."""
+
+ eps = 5e-3
+
+ def validate(self, J: torch.Tensor, J_auto: torch.Tensor):
+ """Check if the jacobians are close and finite."""
+ self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical")
+ self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto")
+
+ text_j = f" > {self.eps}\nJ:\n{J[0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0].numpy()}"
+ max_diff = torch.max(torch.abs(J - J_auto))
+ text = f"Overall - max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J, J_auto, atol=self.eps), text)
+
+ def test_spherical_plus(self):
+ """Test the spherical plus operator."""
+ rpf = get_toy_rpf()
+ gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1])
+ J = gravity.J_update(spherical=True)
+
+ # auto jacobian
+ delta = gravity.vec3d.new_zeros(gravity.vec3d.shape)[..., :-1]
+
+ def spherical_plus(delta: torch.Tensor) -> torch.Tensor:
+ """Plus operator."""
+ return SphericalManifold.plus(gravity.vec3d, delta)
+
+ J_auto = vmap(jacfwd(spherical_plus))(delta).squeeze(0)
+
+ self.validate(J, J_auto)
+
+ def test_up_projection_uv(self):
+ """Test the up projection jacobians."""
+ rpf = get_toy_rpf()
+
+ r, p, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ gravity = Gravity.from_rp(r, p)
+ uv = camera.normalize(camera.pixel_coordinates())
+
+ J = J_up_projection(uv, gravity.vec3d, "uv")
+
+ # auto jacobian
+ def projection_uv(uv: torch.Tensor) -> torch.Tensor:
+ """Projection."""
+ abc = gravity.vec3d
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv
+ return projected_up2d[0, 0]
+
+ J_auto = vmap(jacfwd(projection_uv))(uv[0])[None]
+
+ self.validate(J, J_auto)
+
+ def test_up_projection_abc(self):
+ """Test the up projection jacobians."""
+ rpf = get_toy_rpf()
+
+ r, p, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ gravity = Gravity.from_rp(r, p)
+ uv = camera.normalize(camera.pixel_coordinates())
+ J = J_up_projection(uv, gravity.vec3d, "abc")
+
+ # auto jacobian
+ def projection_abc(abc: torch.Tensor) -> torch.Tensor:
+ """Projection."""
+ return abc[..., None, :2] - abc[..., 2, None, None] * uv
+
+ J_auto = vmap(jacfwd(projection_abc))(gravity.vec3d)[0]
+
+ self.validate(J, J_auto)
+
+ def test_undistort_pts(self):
+ """Test the undistortion jacobians."""
+ if CAMERA_MODEL == "pinhole":
+ return
+
+ rpf = get_toy_rpf()
+ _, _, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ uv = camera.normalize(camera.pixel_coordinates())
+ J = camera.J_undistort(uv, "pts")
+
+ # auto jacobian
+ def func_pts(pts):
+ return camera.undistort(pts)[0][0]
+
+ J_auto = vmap(jacfwd(func_pts))(uv[0])[None].squeeze(-3)
+
+ self.validate(J, J_auto)
+
+ def test_undistort_k1(self):
+ """Test the undistortion jacobians."""
+ if CAMERA_MODEL == "pinhole":
+ return
+
+ rpf = get_toy_rpf()
+ _, _, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ uv = camera.normalize(camera.pixel_coordinates())
+ J = camera.J_undistort(uv, "dist")
+
+ # auto jacobian
+ def func_k1(k1):
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ return camera.undistort(uv)[0][0]
+
+ J_auto = vmap(jacfwd(func_k1))(camera.dist[..., :1]).squeeze(-1)
+
+ self.validate(J, J_auto)
+
+ def test_up_projection_offset(self):
+ """Test the up projection offset jacobians."""
+ if CAMERA_MODEL == "pinhole":
+ return
+
+ rpf = get_toy_rpf()
+ # J = up_projection_offset(rpf)
+ _, _, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ uv = camera.normalize(camera.pixel_coordinates())
+ J = camera.up_projection_offset(uv)
+
+ # auto jacobian
+ def projection_uv(uv: torch.Tensor) -> torch.Tensor:
+ """Projection."""
+ s, _ = camera.distort(uv, return_scale=True)
+ return s[0, 0, 0]
+
+ J_auto = vmap(jacfwd(projection_uv))(uv[0])[None].squeeze(-2)
+
+ self.validate(J, J_auto)
+
+ def test_J_up_projection_offset_uv(self):
+ """Test the up projection offset jacobians."""
+ if CAMERA_MODEL == "pinhole":
+ return
+
+ rpf = get_toy_rpf()
+ _, _, f = rpf.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]})
+ uv = camera.normalize(camera.pixel_coordinates())
+ J = camera.J_up_projection_offset(uv, "uv")
+
+ # auto jacobian
+ def projection_uv(uv: torch.Tensor) -> torch.Tensor:
+ """Projection."""
+ return camera.up_projection_offset(uv)[0, 0]
+
+ J_auto = vmap(jacfwd(projection_uv))(uv[0])[None]
+
+ # print(J.shape, J_auto.shape)
+
+ self.validate(J, J_auto)
+
+
+class TestEuclidean(unittest.TestCase):
+ """Test the Euclidean manifold jacobians."""
+
+ eps = 5e-3
+
+ def validate(self, J: torch.Tensor, J_auto: torch.Tensor):
+ """Check if the jacobians are close and finite."""
+ self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical")
+ self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto")
+
+ # print(f"analytical:\n{J[0, 0, 0].numpy()}\nauto:\n{J_auto[0, 0, 0].numpy()}")
+
+ text_j = f" > {self.eps}\nJ:\n{J[0, 0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0, 0].numpy()}"
+
+ J_up2grav = J[..., :2, :2]
+ J_up2grav_auto = J_auto[..., :2, :2]
+ max_diff = torch.max(torch.abs(J_up2grav - J_up2grav_auto))
+ text = f"UP - GRAV max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_up2grav, J_up2grav_auto, atol=self.eps), text)
+
+ J_up2focal = J[..., :2, 2]
+ J_up2focal_auto = J_auto[..., :2, 2]
+ max_diff = torch.max(torch.abs(J_up2focal - J_up2focal_auto))
+ text = f"UP - FOCAL max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_up2focal, J_up2focal_auto, atol=self.eps), text)
+
+ if CAMERA_MODEL != "pinhole":
+ J_up2k1 = J[..., :2, 3]
+ J_up2k1_auto = J_auto[..., :2, 3]
+ max_diff = torch.max(torch.abs(J_up2k1 - J_up2k1_auto))
+ text = f"UP - K1 max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_up2k1, J_up2k1_auto, atol=self.eps), text)
+
+ J_lat2grav = J[..., 2:, :2]
+ J_lat2grav_auto = J_auto[..., 2:, :2]
+ max_diff = torch.max(torch.abs(J_lat2grav - J_lat2grav_auto))
+ text = f"LAT - GRAV max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2grav, J_lat2grav_auto, atol=self.eps), text)
+
+ J_lat2focal = J[..., 2:, 2]
+ J_lat2focal_auto = J_auto[..., 2:, 2]
+ max_diff = torch.max(torch.abs(J_lat2focal - J_lat2focal_auto))
+ text = f"LAT - FOCAL max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2focal, J_lat2focal_auto, atol=self.eps), text)
+
+ if CAMERA_MODEL != "pinhole":
+ J_lat2k1 = J[..., 2:, 3]
+ J_lat2k1_auto = J_auto[..., 2:, 3]
+ max_diff = torch.max(torch.abs(J_lat2k1 - J_lat2k1_auto))
+ text = f"LAT - K1 max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2k1, J_lat2k1_auto, atol=self.eps), text)
+
+ max_diff = torch.max(torch.abs(J - J_auto[..., : J.shape[-1]]))
+ text = f"Overall - max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J, J_auto[..., : J.shape[-1]], atol=self.eps), text)
+
+ def local_pf_calc(self, rpfk: torch.Tensor):
+ """Calculate the perspective field."""
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+ up, lat = get_perspective_field(camera, gravity)
+ persp = torch.cat([up, torch.sin(lat)], dim=-3)
+ return persp.permute(0, 2, 3, 1).reshape(1, -1, 3)
+
+ def test_random(self):
+ """Random rpf."""
+ rpf = get_toy_rpf()
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2)
+ J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_zero_roll(self):
+ """Roll = 0."""
+ rpf = get_toy_rpf(roll=0)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2)
+ J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_zero_pitch(self):
+ """Pitch = 0."""
+ rpf = get_toy_rpf(pitch=0)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2)
+ J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_max_roll(self):
+ """Roll = -45, 45."""
+ for roll in [-45, 45]:
+ rpf = get_toy_rpf(roll=roll)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2)
+ J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_max_pitch(self):
+ """Pitch = -45, 45."""
+ for pitch in [-45, 45]:
+ rpf = get_toy_rpf(pitch=pitch)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2)
+ J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+
+class TestSpherical(unittest.TestCase):
+ """Test the spherical manifold jacobians."""
+
+ eps = 5e-3
+
+ def validate(self, J: torch.Tensor, J_auto: torch.Tensor):
+ """Check if the jacobians are close and finite."""
+ self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical")
+ self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto")
+
+ text_j = f" > {self.eps}\nJ:\n{J[0, 0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0, 0].numpy()}"
+
+ J_up2grav = J[..., :2, :2]
+ J_up2grav_auto = J_auto[..., :2, :2]
+ max_diff = torch.max(torch.abs(J_up2grav - J_up2grav_auto))
+ text = f"UP - GRAV max diff is {max_diff:.4f}" + text_j
+
+ self.assertTrue(torch.allclose(J_up2grav, J_up2grav_auto, atol=self.eps), text)
+
+ J_up2focal = J[..., :2, 2]
+ J_up2focal_auto = J_auto[..., :2, 2]
+ max_diff = torch.max(torch.abs(J_up2focal - J_up2focal_auto))
+ text = f"UP - FOCAL max diff is {max_diff:.4f}" + text_j
+
+ self.assertTrue(torch.allclose(J_up2focal, J_up2focal_auto, atol=self.eps), text)
+
+ if CAMERA_MODEL != "pinhole":
+ J_up2k1 = J[..., :2, 3]
+ J_up2k1_auto = J_auto[..., :2, 3]
+ max_diff = torch.max(torch.abs(J_up2k1 - J_up2k1_auto))
+ text = f"UP - K1 max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_up2k1, J_up2k1_auto, atol=self.eps), text)
+
+ J_lat2grav = J[..., 2:, :2]
+ J_lat2grav_auto = J_auto[..., 2:, :2]
+ max_diff = torch.max(torch.abs(J_lat2grav - J_lat2grav_auto))
+ text = f"LAT - GRAV max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2grav, J_lat2grav_auto, atol=self.eps), text)
+
+ J_lat2focal = J[..., 2:, 2]
+ J_lat2focal_auto = J_auto[..., 2:, 2]
+ max_diff = torch.max(torch.abs(J_lat2focal - J_lat2focal_auto))
+ text = f"LAT - FOCAL max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2focal, J_lat2focal_auto, atol=self.eps), text)
+
+ if CAMERA_MODEL != "pinhole":
+ J_lat2k1 = J[..., 2:, 3]
+ J_lat2k1_auto = J_auto[..., 2:, 3]
+ max_diff = torch.max(torch.abs(J_lat2k1 - J_lat2k1_auto))
+ text = f"LAT - K1 max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J_lat2k1, J_lat2k1_auto, atol=self.eps), text)
+
+ max_diff = torch.max(torch.abs(J - J_auto[..., : J.shape[-1]]))
+ text = f"Overall - max diff is {max_diff:.4f}" + text_j
+ self.assertTrue(torch.allclose(J, J_auto[..., : J.shape[-1]], atol=self.eps), text)
+
+ def local_pf_calc(self, uvfk: torch.Tensor, gravity: Gravity):
+ """Calculate the perspective field."""
+ delta, f, k1 = uvfk[..., :2], uvfk[..., 2], uvfk[..., 3]
+ cam = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ up, lat = get_perspective_field(cam, gravity.update(delta, spherical=True))
+ persp = torch.cat([up, torch.sin(lat)], dim=-3)
+ return persp.permute(0, 2, 3, 1).reshape(1, -1, 3)
+
+ def test_random(self):
+ """Test random rpf."""
+ rpf = get_toy_rpf()
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2)
+
+ uvfk = torch.zeros_like(rpfk)
+ uvfk[..., 2] = f
+ uvfk[..., 3] = k1
+ func = lambda uvfk: self.local_pf_calc(uvfk, gravity)
+ J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_zero_roll(self):
+ """Test roll = 0."""
+ rpf = get_toy_rpf(roll=0)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2)
+
+ uvfk = torch.zeros_like(rpfk)
+ uvfk[..., 2] = f
+ uvfk[..., 3] = k1
+ func = lambda uvfk: self.local_pf_calc(uvfk, gravity)
+ J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_zero_pitch(self):
+ """Test pitch = 0."""
+ rpf = get_toy_rpf(pitch=0)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2)
+
+ uvfk = torch.zeros_like(rpfk)
+ uvfk[..., 2] = f
+ uvfk[..., 3] = k1
+ func = lambda uvfk: self.local_pf_calc(uvfk, gravity)
+ J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_max_roll(self):
+ """Test roll = -45, 45."""
+ for roll in [-45, 45]:
+ rpf = get_toy_rpf(roll=roll)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2)
+
+ uvfk = torch.zeros_like(rpfk)
+ uvfk[..., 2] = f
+ uvfk[..., 3] = k1
+ func = lambda uvfk: self.local_pf_calc(uvfk, gravity)
+ J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+ def test_max_pitch(self):
+ """Test pitch = -45, 45."""
+ for pitch in [-45, 45]:
+ rpf = get_toy_rpf(pitch=pitch)
+ rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1)
+ r, p, f, k1 = rpfk.unbind(dim=-1)
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1})
+ gravity = Gravity.from_rp(r, p)
+
+ J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2)
+
+ uvfk = torch.zeros_like(rpfk)
+ uvfk[..., 2] = f
+ uvfk[..., 3] = k1
+ func = lambda uvfk: self.local_pf_calc(uvfk, gravity)
+ J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4)
+
+ self.validate(J, J_auto)
+
+
+class TestLM(unittest.TestCase):
+ """Test the LM optimizer."""
+
+ eps = 1e-3
+
+ def test_random_spherical(self):
+ """Test random rpf."""
+ rpf = get_toy_rpf()
+ gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1])
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": rpf[..., 2], "k1": [K1]})
+
+ up, lat = get_perspective_field(camera, gravity)
+
+ lm = LMOptimizer({"use_spherical_manifold": True, "camera_model": CAMERA_MODEL})
+
+ out = lm({"up_field": up, "latitude_field": lat})
+
+ cam_opt = out["camera"]
+ gravity_opt = out["gravity"]
+
+ if hasattr(cam_opt, "k1"):
+ text = f"cam_opt: {cam_opt.k1.numpy()} | rpf: {[K1]}"
+ self.assertTrue(
+ torch.allclose(cam_opt.k1, torch.tensor([K1]).float(), atol=self.eps), text
+ )
+
+ text = f"cam_opt: {cam_opt.f[..., 1].numpy()} | rpf: {rpf[..., 2].numpy()}"
+ self.assertTrue(torch.allclose(cam_opt.f[..., 1], rpf[..., 2], atol=self.eps), text)
+
+ text = f"gravity_opt.roll: {gravity_opt.roll.numpy()} | rpf: {rpf[..., 0].numpy()}"
+ self.assertTrue(torch.allclose(gravity_opt.roll, rpf[..., 0], atol=self.eps), text)
+
+ text = f"gravity_opt.pitch: {gravity_opt.pitch.numpy()} | rpf: {rpf[..., 1].numpy()}"
+ self.assertTrue(torch.allclose(gravity_opt.pitch, rpf[..., 1], atol=self.eps), text)
+
+ def test_random(self):
+ """Test random rpf."""
+ rpf = get_toy_rpf()
+ gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1])
+ camera = Camera.from_dict({"height": [H], "width": [W], "f": rpf[..., 2], "k1": [K1]})
+
+ up, lat = get_perspective_field(camera, gravity)
+
+ lm = LMOptimizer({"use_spherical_manifold": False, "camera_model": CAMERA_MODEL})
+ out = lm({"up_field": up, "latitude_field": lat})
+
+ cam_opt = out["camera"]
+ gravity_opt = out["gravity"]
+
+ if hasattr(cam_opt, "k1"):
+ text = f"cam_opt: {cam_opt.k1.numpy()} | rpf: {[K1]}"
+ self.assertTrue(
+ torch.allclose(cam_opt.k1, torch.tensor([K1]).float(), atol=self.eps), text
+ )
+
+ text = f"cam_opt: {cam_opt.f[..., 1].numpy()} | rpf: {rpf[..., 2].numpy()}"
+ self.assertTrue(torch.allclose(cam_opt.f[..., 1], rpf[..., 2], atol=self.eps), text)
+
+ text = f"gravity_opt.roll: {gravity_opt.roll.numpy()} | rpf: {rpf[..., 0].numpy()}"
+ self.assertTrue(torch.allclose(gravity_opt.roll, rpf[..., 0], atol=self.eps), text)
+
+ text = f"gravity_opt.pitch: {gravity_opt.pitch.numpy()} | rpf: {rpf[..., 1].numpy()}"
+ self.assertTrue(torch.allclose(gravity_opt.pitch, rpf[..., 1], atol=self.eps), text)
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/siclib/geometry/gravity.py b/siclib/geometry/gravity.py
new file mode 100644
index 0000000000000000000000000000000000000000..18e592f19261375e2c44c1305b809aaef303faf9
--- /dev/null
+++ b/siclib/geometry/gravity.py
@@ -0,0 +1,128 @@
+"""Tensor class for gravity vector in camera frame."""
+
+import torch
+from torch.nn import functional as F
+
+from siclib.geometry.manifolds import EuclideanManifold, SphericalManifold
+from siclib.utils.conversions import rad2rotmat
+from siclib.utils.tensor import TensorWrapper, autocast
+
+# mypy: ignore-errors
+
+
+class Gravity(TensorWrapper):
+ """Gravity vector in camera frame."""
+
+ eps = 1e-4
+
+ @autocast
+ def __init__(self, data: torch.Tensor) -> None:
+ """Create gravity vector from data.
+
+ Args:
+ data (torch.Tensor): gravity vector as 3D vector in camera frame.
+ """
+ assert data.shape[-1] == 3, data.shape
+
+ data = F.normalize(data, dim=-1)
+
+ super().__init__(data)
+
+ @classmethod
+ def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity":
+ """Create gravity vector from roll and pitch angles."""
+ if not isinstance(roll, torch.Tensor):
+ roll = torch.tensor(roll)
+ if not isinstance(pitch, torch.Tensor):
+ pitch = torch.tensor(pitch)
+
+ sr, cr = torch.sin(roll), torch.cos(roll)
+ sp, cp = torch.sin(pitch), torch.cos(pitch)
+ return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1))
+
+ @property
+ def vec3d(self) -> torch.Tensor:
+ """Return the gravity vector in the representation."""
+ return self._data
+
+ @property
+ def x(self) -> torch.Tensor:
+ """Return first component of the gravity vector."""
+ return self._data[..., 0]
+
+ @property
+ def y(self) -> torch.Tensor:
+ """Return second component of the gravity vector."""
+ return self._data[..., 1]
+
+ @property
+ def z(self) -> torch.Tensor:
+ """Return third component of the gravity vector."""
+ return self._data[..., 2]
+
+ @property
+ def roll(self) -> torch.Tensor:
+ """Return the roll angle of the gravity vector."""
+ roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps))
+ offset = -torch.pi * torch.sign(self.x)
+ return torch.where(self.y < 0, roll, -roll + offset)
+
+ def J_roll(self) -> torch.Tensor:
+ """Return the Jacobian of the roll angle of the gravity vector."""
+ cp, _ = torch.cos(self.pitch), torch.sin(self.pitch)
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
+ Jr = self.new_zeros(self.shape + (3,))
+ Jr[..., 0] = -cr * cp
+ Jr[..., 1] = sr * cp
+ return Jr
+
+ @property
+ def pitch(self) -> torch.Tensor:
+ """Return the pitch angle of the gravity vector."""
+ return torch.asin(self.z)
+
+ def J_pitch(self) -> torch.Tensor:
+ """Return the Jacobian of the pitch angle of the gravity vector."""
+ cp, sp = torch.cos(self.pitch), torch.sin(self.pitch)
+ cr, sr = torch.cos(self.roll), torch.sin(self.roll)
+
+ Jp = self.new_zeros(self.shape + (3,))
+ Jp[..., 0] = sr * sp
+ Jp[..., 1] = cr * sp
+ Jp[..., 2] = cp
+ return Jp
+
+ @property
+ def rp(self) -> torch.Tensor:
+ """Return the roll and pitch angles of the gravity vector."""
+ return torch.stack([self.roll, self.pitch], dim=-1)
+
+ def J_rp(self) -> torch.Tensor:
+ """Return the Jacobian of the roll and pitch angles of the gravity vector."""
+ return torch.stack([self.J_roll(), self.J_pitch()], dim=-1)
+
+ @property
+ def R(self) -> torch.Tensor:
+ """Return the rotation matrix from the gravity vector."""
+ return rad2rotmat(roll=self.roll, pitch=self.pitch)
+
+ def J_R(self) -> torch.Tensor:
+ """Return the Jacobian of the rotation matrix from the gravity vector."""
+ raise NotImplementedError
+
+ def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity":
+ """Update the gravity vector by adding a delta."""
+ if spherical:
+ data = SphericalManifold.plus(self.vec3d, delta)
+ return self.__class__(data)
+
+ data = EuclideanManifold.plus(self.rp, delta)
+ return self.from_rp(data[..., 0], data[..., 1])
+
+ def J_update(self, spherical: bool = False) -> torch.Tensor:
+ """Return the Jacobian of the update."""
+ return SphericalManifold if spherical else EuclideanManifold
+
+ def __repr__(self):
+ """Print the Camera object."""
+ return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}"
diff --git a/siclib/geometry/jacobians.py b/siclib/geometry/jacobians.py
new file mode 100644
index 0000000000000000000000000000000000000000..444acc2a759308e3f7d4d7d5aac196a3ec51fca7
--- /dev/null
+++ b/siclib/geometry/jacobians.py
@@ -0,0 +1,64 @@
+"""Jacobians for optimization."""
+
+import torch
+
+# flake8: noqa: E741
+
+
+@torch.jit.script
+def J_vecnorm(vec: torch.Tensor) -> torch.Tensor:
+ """Compute the jacobian of vec / norm2(vec).
+
+ Args:
+ vec (torch.Tensor): [..., D] tensor.
+
+ Returns:
+ torch.Tensor: [..., D, D] Jacobian.
+ """
+ D = vec.shape[-1]
+ norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1)
+
+ if (norm_x == 0).any():
+ norm_x = norm_x + 1e-6
+
+ xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D)
+ identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D)
+
+ return identity / norm_x - (xxT / norm_x**3) # (..., D, D)
+
+
+@torch.jit.script
+def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the jacobian of the focal2fov function."""
+ return -4 * h / (4 * focal**2 + h**2)
+
+
+@torch.jit.script
+def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor:
+ """Compute the jacobian of the up-vector projection.
+
+ Args:
+ uv (torch.Tensor): Normalized image coordinates of shape (..., 2).
+ abc (torch.Tensor): Gravity vector of shape (..., 3).
+ wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv".
+
+ Raises:
+ ValueError: If the wrt parameter is unknown.
+
+ Returns:
+ torch.Tensor: Jacobian with respect to the parameter.
+ """
+ if wrt == "uv":
+ c = abc[..., 2][..., None, None, None]
+ return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2))
+
+ elif wrt == "abc":
+ J = uv.new_zeros(uv.shape[:-1] + (2, 3))
+ J[..., 0, 0] = 1
+ J[..., 1, 1] = 1
+ J[..., 0, 2] = -uv[..., 0]
+ J[..., 1, 2] = -uv[..., 1]
+ return J
+
+ else:
+ raise ValueError(f"Unknown wrt: {wrt}")
diff --git a/siclib/geometry/manifolds.py b/siclib/geometry/manifolds.py
new file mode 100644
index 0000000000000000000000000000000000000000..dea879379cba6a6105a245720af6df336fd4e49f
--- /dev/null
+++ b/siclib/geometry/manifolds.py
@@ -0,0 +1,112 @@
+"""Implementation of manifolds."""
+
+import logging
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+
+class EuclideanManifold:
+ """Simple euclidean manifold."""
+
+ @staticmethod
+ def J_plus(x: torch.Tensor) -> torch.Tensor:
+ """Plus operator Jacobian."""
+ return torch.eye(x.shape[-1]).to(x)
+
+ @staticmethod
+ def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
+ """Plus operator."""
+ return x + delta
+
+
+class SphericalManifold:
+ """Implementation of the spherical manifold.
+
+ Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25).
+
+ Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by
+ Golub et al.
+ """
+
+ @staticmethod
+ def householder_vector(x: torch.Tensor) -> torch.Tensor:
+ """Return the Householder vector and beta.
+
+ Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies
+ in Mathematical Sciences) but using the nth element of the input vector as pivot instead of
+ first.
+
+ This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is
+ orthogonal and H * x = ||x||_2 * e_n.
+
+ Args:
+ x (torch.Tensor): [..., n] tensor.
+
+ Returns:
+ torch.Tensor: v of shape [..., n]
+ torch.Tensor: beta of shape [...]
+ """
+ sigma = torch.sum(x[..., :-1] ** 2, -1)
+ xpiv = x[..., -1]
+ norm = torch.norm(x, dim=-1)
+ if torch.any(sigma < 1e-7):
+ sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma)
+ logger.warning("sigma < 1e-7")
+
+ vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm))
+ beta = 2 * vpiv**2 / (sigma + vpiv**2)
+ v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1)
+ return v, beta
+
+ @staticmethod
+ def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
+ """Apply Householder transformation.
+
+ Args:
+ y (torch.Tensor): Vector to transform of shape [..., n].
+ v (torch.Tensor): Householder vector of shape [..., n].
+ beta (torch.Tensor): Householder beta of shape [...].
+
+ Returns:
+ torch.Tensor: Transformed vector of shape [..., n].
+ """
+ return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None]
+
+ @classmethod
+ def J_plus(cls, x: torch.Tensor) -> torch.Tensor:
+ """Plus operator Jacobian."""
+ v, beta = cls.householder_vector(x)
+ H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v)
+ H = H + torch.eye(H.shape[-1]).to(H)
+ return H[..., :-1] # J
+
+ @classmethod
+ def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor:
+ """Plus operator.
+
+ Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State
+ Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth
+ element of the input vector as pivot instead of first.
+
+ Args:
+ x: point on the manifold
+ delta: tangent vector
+ """
+ eps = 1e-7
+ # keep norm is not equal to 1
+ nx = torch.norm(x, dim=-1, keepdim=True)
+ nd = torch.norm(delta, dim=-1, keepdim=True)
+
+ # make sure we don't divide by zero in backward as torch.where computes grad for both
+ # branches
+ nd_ = torch.where(nd < eps, nd + eps, nd)
+ sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_)
+
+ # cos is applied to last dim instead of first
+ exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1)
+
+ v, beta = cls.householder_vector(x)
+ return nx * cls.apply_householder(exp_delta, v, beta)
diff --git a/siclib/geometry/perspective_fields.py b/siclib/geometry/perspective_fields.py
new file mode 100644
index 0000000000000000000000000000000000000000..44395ba001e084d3bb391ecc11bca7254b5c5296
--- /dev/null
+++ b/siclib/geometry/perspective_fields.py
@@ -0,0 +1,367 @@
+"""Implementation of perspective fields.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py
+"""
+
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.jacobians import J_up_projection, J_vecnorm
+from siclib.geometry.manifolds import SphericalManifold
+
+# flake8: noqa: E266
+
+
+def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor:
+ """Get the horizon line from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True.
+
+ Returns:
+ torch.Tensor: In image frame, fraction of image left/right border intersection with
+ respect to image height.
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ # project horizon midpoint to image plane
+ horizon_midpoint = camera.new_tensor([0, 0, 1])
+ horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint
+ midpoint = horizon_midpoint[:2] / horizon_midpoint[2]
+
+ # compute left and right offset to borders
+ left_offset = midpoint[0] * torch.tan(gravity.roll)
+ right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll)
+ left, right = midpoint[1] + left_offset, midpoint[1] - right_offset
+
+ horizon = camera.new_tensor([left, right])
+ return horizon / camera.size[1] if relative else horizon
+
+
+def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor:
+ """Get the up vector field from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ normalize (bool, optional): Whether to normalize the up vector. Defaults to True.
+
+ Returns:
+ torch.Tensor: up vector field as tensor of shape (..., h, w, 2).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ uv = camera.normalize(camera.pixel_coordinates())
+
+ # projected up is (a, b) - c * (u, v)
+ abc = gravity.vec3d
+ projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2)
+
+ if hasattr(camera, "dist"):
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
+ offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
+
+ # (..., N, 2)
+ projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d)
+
+ if normalize:
+ projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2)
+
+ return projected_up2d.reshape(camera.shape[0], h, w, 2)
+
+
+def J_up_field(
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
+) -> torch.Tensor:
+ """Get the jacobian of the up field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # Forward
+ xy = camera.pixel_coordinates()
+ uv = camera.normalize(xy)
+
+ projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv
+
+ # Backward
+ J = []
+
+ # (..., N, 2, 2)
+ J_norm2proj = J_vecnorm(
+ get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2)
+ )
+
+ # distortion values
+ if hasattr(camera, "dist"):
+ d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1)
+ d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ offset = camera.up_projection_offset(uv) # (..., N, 2)
+ offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2)
+
+ ######################
+ ## Gravity Jacobian ##
+ ######################
+
+ J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3)
+
+ if hasattr(camera, "dist"):
+ # (..., N, 2, 3)
+ J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc)
+
+ J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
+ J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta)
+ J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta)
+ J.append(J_up2delta)
+
+ ######################
+ ### Focal Jacobian ###
+ ######################
+
+ J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2)
+
+ if hasattr(camera, "dist"):
+ J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv)
+ J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d)
+
+ inner = (uv * projected_up2d).sum(-1)[..., None, None]
+ J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv")
+ J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2)
+ J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up
+
+ J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2)
+
+ if log_focal:
+ J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2)
+
+ J_uv2f = J_uv2f.sum(-1) # (..., N, 2)
+
+ J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2)
+ J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1)
+ J.append(J_up2f)
+
+ ######################
+ ##### K1 Jacobian ####
+ ######################
+
+ if hasattr(camera, "dist"):
+ J_duv = camera.J_distort(uv, wrt="scale2dist")
+ J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2)
+ J_offset = torch.einsum(
+ "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv
+ )
+ J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d)
+ J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None]
+ J.append(J_k1)
+
+ n_params = sum(j.shape[-1] for j in J)
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params)
+
+
+def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor:
+ """Get the latitudes of the camera pixels in radians.
+
+ Latitudes are defined as the angle between the ray and the up vector.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+
+ Returns:
+ torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ uv1, _ = camera.image2world(camera.pixel_coordinates())
+ rays = camera.pixel_bearing_many(uv1)
+
+ lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d)
+
+ eps = 1e-6
+ lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps))
+
+ return lat_asin.reshape(camera.shape[0], h, w, 1)
+
+
+def J_latitude_field(
+ camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False
+) -> torch.Tensor:
+ """Get the jacobian of the latitude field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3).
+ """
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ # Forward
+ xy = camera.pixel_coordinates()
+ uv1, _ = camera.image2world(xy)
+ uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3)
+
+ # Backward
+ J = []
+ J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2)
+
+ ######################
+ ## Gravity Jacobian ##
+ ######################
+
+ J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp()
+ J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2)
+ J.append(J_delta)
+
+ ######################
+ ### Focal Jacobian ###
+ ######################
+
+ J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2)
+ if log_focal:
+ J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :]
+ J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2)
+
+ J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3)
+ J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1)
+ J.append(J_f)
+
+ ######################
+ ##### K1 Jacobian ####
+ ######################
+
+ if hasattr(camera, "dist"):
+ J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2)
+ # (..., N, 2)
+ J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1)
+ # (..., N, 1)
+ J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1)
+ J.append(J_k1)
+
+ n_params = sum(j.shape[-1] for j in J)
+ return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params)
+
+
+def get_perspective_field(
+ camera: BaseCamera,
+ gravity: Gravity,
+ use_up: bool = True,
+ use_latitude: bool = True,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the perspective field from the camera parameters.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape
+ (..., 2, h, w) and (..., 1, h, w).
+ """
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
+
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if use_up:
+ permute = (0, 3, 1, 2)
+ # (..., 2, h, w)
+ up = get_up_field(camera, gravity).permute(permute)
+ else:
+ shape = (camera.shape[0], 2, h, w)
+ up = camera.new_zeros(shape)
+
+ if use_latitude:
+ permute = (0, 3, 1, 2)
+ # (..., 1, h, w)
+ lat = get_latitude_field(camera, gravity).permute(permute)
+ else:
+ shape = (camera.shape[0], 1, h, w)
+ lat = camera.new_zeros(shape)
+
+ return up, lat
+
+
+def J_perspective_field(
+ camera: BaseCamera,
+ gravity: Gravity,
+ use_up: bool = True,
+ use_latitude: bool = True,
+ spherical: bool = False,
+ log_focal: bool = False,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Get the jacobian of the perspective field.
+
+ Args:
+ camera (Camera): Camera parameters.
+ gravity (Gravity): Gravity vector.
+ use_up (bool, optional): Whether to include the up vector field. Defaults to True.
+ use_latitude (bool, optional): Whether to include the latitude field. Defaults to True.
+ spherical (bool, optional): Whether to use spherical coordinates. Defaults to False.
+ log_focal (bool, optional): Whether to use log-focal length. Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape
+ (..., h, w, 2, 4) and (..., h, w, 1, 4).
+ """
+ assert use_up or use_latitude, "At least one of use_up or use_latitude must be True."
+
+ camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera
+ gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity
+
+ w, h = camera.size[0].unbind(-1)
+ h, w = h.round().to(int), w.round().to(int)
+
+ if use_up:
+ J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4)
+ else:
+ shape = (camera.shape[0], h, w, 2, 4)
+ J_up = camera.new_zeros(shape)
+
+ if use_latitude:
+ J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4)
+ else:
+ shape = (camera.shape[0], h, w, 1, 4)
+ J_lat = camera.new_zeros(shape)
+
+ return J_up, J_lat
diff --git a/siclib/models/__init__.py b/siclib/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eee392a5a1f927763a532527a48090564e191fb
--- /dev/null
+++ b/siclib/models/__init__.py
@@ -0,0 +1,28 @@
+import importlib.util
+
+from siclib.models.base_model import BaseModel
+from siclib.utils.tools import get_class
+
+
+def get_model(name):
+ import_paths = [
+ name,
+ f"{__name__}.{name}",
+ ]
+ for path in import_paths:
+ try:
+ spec = importlib.util.find_spec(path)
+ except ModuleNotFoundError:
+ spec = None
+ if spec is not None:
+ try:
+ return get_class(path, BaseModel)
+ except AssertionError:
+ mod = __import__(path, fromlist=[""])
+ try:
+ return mod.__main_model__
+ except AttributeError as exc:
+ print(exc)
+ continue
+
+ raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]')
diff --git a/siclib/models/base_model.py b/siclib/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fb793f3ffac32ad353e0653f5a6218f1d6d3fbe
--- /dev/null
+++ b/siclib/models/base_model.py
@@ -0,0 +1,205 @@
+"""Base class for trainable models."""
+
+import logging
+import re
+from abc import ABCMeta, abstractmethod
+from copy import copy
+
+import omegaconf
+import torch
+from omegaconf import OmegaConf
+from torch import nn
+
+logger = logging.getLogger(__name__)
+
+try:
+ import wandb
+except ImportError:
+ logger.debug("Could not import wandb.")
+ wandb = None
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class MetaModel(ABCMeta):
+ def __prepare__(name, bases, **kwds):
+ total_conf = OmegaConf.create()
+ for base in bases:
+ for key in ("base_default_conf", "default_conf"):
+ update = getattr(base, key, {})
+ if isinstance(update, dict):
+ update = OmegaConf.create(update)
+ total_conf = OmegaConf.merge(total_conf, update)
+ return dict(base_default_conf=total_conf)
+
+
+class BaseModel(nn.Module, metaclass=MetaModel):
+ """
+ What the child model is expect to declare:
+ default_conf: dictionary of the default configuration of the model.
+ It recursively updates the default_conf of all parent classes, and
+ it is updated by the user-provided configuration passed to __init__.
+ Configurations can be nested.
+
+ required_data_keys: list of expected keys in the input data dictionary.
+
+ strict_conf (optional): boolean. If false, BaseModel does not raise
+ an error when the user provides an unknown configuration entry.
+
+ _init(self, conf): initialization method, where conf is the final
+ configuration object (also accessible with `self.conf`). Accessing
+ unknown configuration entries will raise an error.
+
+ _forward(self, data): method that returns a dictionary of batched
+ prediction tensors based on a dictionary of batched input data tensors.
+
+ loss(self, pred, data): method that returns a dictionary of losses,
+ computed from model predictions and input data. Each loss is a batch
+ of scalars, i.e. a torch.Tensor of shape (B,).
+ The total loss to be optimized has the key `'total'`.
+
+ metrics(self, pred, data): method that returns a dictionary of metrics,
+ each as a batch of scalars.
+ """
+
+ default_conf = {
+ "name": None,
+ "trainable": True, # if false: do not optimize this model parameters
+ "freeze_batch_normalization": False, # use test-time statistics
+ "timeit": False, # time forward pass
+ "watch": False, # log weights and gradients to wandb
+ }
+ required_data_keys = []
+ strict_conf = False
+
+ def __init__(self, conf):
+ """Perform some logic and call the _init method of the child model."""
+ super().__init__()
+ default_conf = OmegaConf.merge(self.base_default_conf, OmegaConf.create(self.default_conf))
+ if self.strict_conf:
+ OmegaConf.set_struct(default_conf, True)
+
+ # fixme: backward compatibility
+ if "pad" in conf and "pad" not in default_conf: # backward compat.
+ with omegaconf.read_write(conf):
+ with omegaconf.open_dict(conf):
+ conf["interpolation"] = {"pad": conf.pop("pad")}
+
+ if isinstance(conf, dict):
+ conf = OmegaConf.create(conf)
+ self.conf = conf = OmegaConf.merge(default_conf, conf)
+ OmegaConf.set_readonly(conf, True)
+ OmegaConf.set_struct(conf, True)
+ self.required_data_keys = copy(self.required_data_keys)
+ self._init(conf)
+
+ # load pretrained weights
+ if "weights" in conf and conf.weights is not None:
+ logger.info(f"Loading checkpoint {conf.weights}")
+ ckpt = torch.load(str(conf.weights), map_location="cpu", weights_only=False)
+ weights_key = "model" if "model" in ckpt else "state_dict"
+ self.flexible_load(ckpt[weights_key])
+
+ if not conf.trainable:
+ for p in self.parameters():
+ p.requires_grad = False
+
+ if conf.watch:
+ try:
+ wandb.watch(self, log="all", log_graph=True, log_freq=10)
+ logger.info(f"Watching {self.__class__.__name__}.")
+ except ValueError:
+ logger.warning(f"Could not watch {self.__class__.__name__}.")
+
+ n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)
+ logger.info(f"Creating model {self.__class__.__name__} ({n_trainable/1e6:.2f} Mio)")
+
+ def flexible_load(self, state_dict):
+ """TODO: fix a probable nasty bug, and move to BaseModel."""
+ # replace *gravity* with *up*
+ for key in list(state_dict.keys()):
+ if "gravity" in key:
+ new_key = key.replace("gravity", "up")
+ state_dict[new_key] = state_dict.pop(key)
+ # print(f"Renaming {key} to {new_key}")
+
+ # replace *_head.* with *_head.decoder.* for original paramnet checkpoints
+ for key in list(state_dict.keys()):
+ if "linear_pred_latitude" in key or "linear_pred_up" in key:
+ continue
+
+ if "_head" in key and "_head.decoder" not in key:
+ # check if _head.{num} in key
+ pattern = r"_head\.\d+"
+ if re.search(pattern, key):
+ continue
+ new_key = key.replace("_head.", "_head.decoder.")
+ state_dict[new_key] = state_dict.pop(key)
+ # print(f"Renaming {key} to {new_key}")
+
+ dict_params = set(state_dict.keys())
+ model_params = set(map(lambda n: n[0], self.named_parameters()))
+
+ if dict_params == model_params: # perfect fit
+ logger.info("Loading all parameters of the checkpoint.")
+ self.load_state_dict(state_dict, strict=True)
+ return
+ elif len(dict_params & model_params) == 0: # perfect mismatch
+ strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:])
+ state_dict = {strip_prefix(n): p for n, p in state_dict.items()}
+ dict_params = set(state_dict.keys())
+ if len(dict_params & model_params) == 0:
+ raise ValueError(
+ "Could not manage to load the checkpoint with"
+ "parameters:" + "\n\t".join(sorted(dict_params))
+ )
+ common_params = dict_params & model_params
+ left_params = dict_params - model_params
+ left_params = [
+ p for p in left_params if "running" not in p and "num_batches_tracked" not in p
+ ]
+ logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params)))
+ if left_params:
+ # ignore running stats of batchnorm
+ logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params)))
+ self.load_state_dict(state_dict, strict=False)
+
+ def train(self, mode=True):
+ super().train(mode)
+
+ def freeze_bn(module):
+ if isinstance(module, nn.modules.batchnorm._BatchNorm):
+ module.eval()
+
+ if self.conf.freeze_batch_normalization:
+ self.apply(freeze_bn)
+
+ return self
+
+ def forward(self, data):
+ """Check the data and call the _forward method of the child model."""
+
+ def recursive_key_check(expected, given):
+ for key in expected:
+ assert key in given, f"Missing key {key} in data: {list(given.keys())}"
+ if isinstance(expected, dict):
+ recursive_key_check(expected[key], given[key])
+
+ recursive_key_check(self.required_data_keys, data)
+ return self._forward(data)
+
+ @abstractmethod
+ def _init(self, conf):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def _forward(self, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
+
+ @abstractmethod
+ def loss(self, pred, data):
+ """To be implemented by the child class."""
+ raise NotImplementedError
diff --git a/siclib/models/cache_loader.py b/siclib/models/cache_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..23c72b1f557f7e39a1240e5f748ef5c7550ac2a6
--- /dev/null
+++ b/siclib/models/cache_loader.py
@@ -0,0 +1,109 @@
+import string
+
+import h5py
+import torch
+
+from siclib.datasets.base_dataset import collate
+from siclib.models.base_model import BaseModel
+from siclib.settings import DATA_PATH
+from siclib.utils.tensor import batch_to_device
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def pad_line_features(pred, seq_l: int = None):
+ raise NotImplementedError
+
+
+def recursive_load(grp, pkeys):
+ return {
+ k: (
+ torch.from_numpy(grp[k].__array__())
+ if isinstance(grp[k], h5py.Dataset)
+ else recursive_load(grp[k], list(grp.keys()))
+ )
+ for k in pkeys
+ }
+
+
+class CacheLoader(BaseModel):
+ default_conf = {
+ "path": "???", # can be a format string like exports/{scene}/
+ "data_keys": None, # load all keys
+ "device": None, # load to same device as data
+ "trainable": False,
+ "add_data_path": True,
+ "collate": True,
+ "scale": ["keypoints"],
+ "padding_fn": None,
+ "padding_length": None, # required for batching!
+ "numeric_type": "float32", # [None, "float16", "float32", "float64"]
+ }
+
+ required_data_keys = ["name"] # we need an identifier
+
+ def _init(self, conf):
+ self.hfiles = {}
+ self.padding_fn = conf.padding_fn
+ if self.padding_fn is not None:
+ self.padding_fn = eval(self.padding_fn)
+ self.numeric_dtype = {
+ None: None,
+ "float16": torch.float16,
+ "float32": torch.float32,
+ "float64": torch.float64,
+ }[conf.numeric_type]
+
+ def _forward(self, data): # sourcery skip: low-code-quality
+ preds = []
+ device = self.conf.device
+ if not device:
+ if devices := {v.device for v in data.values() if isinstance(v, torch.Tensor)}:
+ assert len(devices) == 1
+ device = devices.pop()
+
+ else:
+ device = "cpu"
+
+ var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]]
+ for i, name in enumerate(data["name"]):
+ fpath = self.conf.path.format(**{k: data[k][i] for k in var_names})
+ if self.conf.add_data_path:
+ fpath = DATA_PATH / fpath
+ hfile = h5py.File(str(fpath), "r")
+ grp = hfile[name]
+ pkeys = self.conf.data_keys if self.conf.data_keys is not None else grp.keys()
+ pred = recursive_load(grp, pkeys)
+ if self.numeric_dtype is not None:
+ pred = {
+ k: (
+ v
+ if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v)
+ else v.to(dtype=self.numeric_dtype)
+ )
+ for k, v in pred.items()
+ }
+ pred = batch_to_device(pred, device)
+ for k, v in pred.items():
+ for pattern in self.conf.scale:
+ if k.startswith(pattern):
+ view_idx = k.replace(pattern, "")
+ scales = (
+ data["scales"]
+ if len(view_idx) == 0
+ else data[f"view{view_idx}"]["scales"]
+ )
+ pred[k] = pred[k] * scales[i]
+ # use this function to fix number of keypoints etc.
+ if self.padding_fn is not None:
+ pred = self.padding_fn(pred, self.conf.padding_length)
+ preds.append(pred)
+ hfile.close()
+ if self.conf.collate:
+ return batch_to_device(collate(preds), device)
+ assert len(preds) == 1
+ return batch_to_device(preds[0], device)
+
+ def loss(self, pred, data):
+ raise NotImplementedError
diff --git a/siclib/models/decoders/__init__.py b/siclib/models/decoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/models/decoders/fpn.py b/siclib/models/decoders/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..78d0085b645ac64c0c6ede144474a9d6446214fd
--- /dev/null
+++ b/siclib/models/decoders/fpn.py
@@ -0,0 +1,198 @@
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from siclib.models import BaseModel
+from siclib.models.utils.modules import ConvModule, FeatureFusionBlock
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class DecoderBlock(nn.Module):
+ def __init__(
+ self,
+ previous,
+ out,
+ ksize=3,
+ num_convs=1,
+ norm_str="BatchNorm2d",
+ padding="zeros",
+ fusion="sum",
+ ):
+ super().__init__()
+
+ self.fusion = fusion
+
+ if self.fusion == "sum":
+ self.fusion_layers = nn.Identity()
+ elif self.fusion == "glu":
+ self.fusion_layers = nn.Sequential(
+ nn.Conv2d(2 * out, 2 * out, 1, padding=0, bias=True),
+ nn.GLU(dim=1),
+ )
+ elif self.fusion == "ff":
+ self.fusion_layers = FeatureFusionBlock(out, upsample=False)
+ else:
+ raise ValueError(f"Unknown fusion: {self.fusion}")
+
+ if norm_str is not None:
+ norm = getattr(nn, norm_str, None)
+
+ layers = []
+ for i in range(num_convs):
+ conv = nn.Conv2d(
+ previous if i == 0 else out,
+ out,
+ kernel_size=ksize,
+ padding=ksize // 2,
+ bias=norm_str is None,
+ padding_mode=padding,
+ )
+ layers.append(conv)
+ if norm_str is not None:
+ layers.append(norm(out))
+ layers.append(nn.ReLU(inplace=True))
+ self.layers = nn.Sequential(*layers)
+
+ def forward(self, previous, skip):
+ _, _, hp, wp = previous.shape
+ _, _, hs, ws = skip.shape
+ scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp])))
+
+ upsampled = nn.functional.interpolate(
+ previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False
+ )
+ # If the shape of the input map `skip` is not a multiple of 2,
+ # it will not match the shape of the upsampled map `upsampled`.
+ # If the downsampling uses ceil_mode=False, we need to crop `skip`.
+ # If it uses ceil_mode=True (not supported here), we should pad it.
+ _, _, hu, wu = upsampled.shape
+ _, _, hs, ws = skip.shape
+ if (hu <= hs) and (wu <= ws):
+ skip = skip[:, :, :hu, :wu]
+ elif (hu >= hs) and (wu >= ws):
+ skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs])
+ else:
+ raise ValueError(f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}")
+
+ skip = skip.clone()
+ feats_skip = self.layers(skip)
+ if self.fusion == "sum":
+ return self.fusion_layers(feats_skip + upsampled)
+ elif self.fusion == "glu":
+ x = torch.cat([feats_skip, upsampled], dim=1)
+ return self.fusion_layers(x)
+ elif self.fusion == "ff":
+ return self.fusion_layers(feats_skip, upsampled)
+ else:
+ raise ValueError(f"Unknown fusion: {self.fusion}")
+
+
+class FPN(BaseModel):
+ default_conf = {
+ "predict_uncertainty": True,
+ "in_channels_list": [64, 128, 256, 512],
+ "out_channels": 64,
+ "num_convs": 1,
+ "norm": None,
+ "padding": "zeros",
+ "fusion": "sum",
+ "with_low_level": True,
+ }
+
+ required_data_keys = ["hl"]
+
+ def _init(self, conf):
+ self.in_channels_list = conf.in_channels_list
+ self.out_channels = conf.out_channels
+
+ self.num_convs = conf.num_convs
+ self.norm = conf.norm
+ self.padding = conf.padding
+
+ self.fusion = conf.fusion
+
+ self.first = nn.Conv2d(
+ self.in_channels_list[-1], self.out_channels, 1, padding=0, bias=True
+ )
+ self.blocks = nn.ModuleList(
+ [
+ DecoderBlock(
+ c,
+ self.out_channels,
+ ksize=1,
+ num_convs=self.num_convs,
+ norm_str=self.norm,
+ padding=self.padding,
+ fusion=self.fusion,
+ )
+ for c in self.in_channels_list[::-1][1:]
+ ]
+ )
+ self.out = nn.Sequential(
+ ConvModule(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ ConvModule(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ )
+
+ self.predict_uncertainty = conf.predict_uncertainty
+ if self.predict_uncertainty:
+ self.linear_pred_uncertainty = nn.Sequential(
+ ConvModule(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1),
+ )
+
+ self.with_ll = conf.with_low_level
+ if self.with_ll:
+ self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1)
+ self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False)
+
+ def _forward(self, features):
+ layers = features["hl"]
+ feats = None
+
+ for idx, x in enumerate(reversed(layers)):
+ feats = self.first(x) if feats is None else self.blocks[idx - 1](feats, x)
+
+ feats = self.out(feats)
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
+ feats = self.out_conv(feats)
+
+ if self.with_ll:
+ assert "ll" in features, "Low-level features are required for this model"
+ feats_ll = features["ll"].clone()
+ feats = self.ll_fusion(feats, feats_ll)
+
+ uncertainty = (
+ self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None
+ )
+ return feats, uncertainty
+
+ def loss(self, pred, data):
+ raise NotImplementedError
+
+ def metrics(self, pred, data):
+ raise NotImplementedError
diff --git a/siclib/models/decoders/latitude_decoder.py b/siclib/models/decoders/latitude_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..b61269f30336a20aef048c3a50011313c2a4f4e4
--- /dev/null
+++ b/siclib/models/decoders/latitude_decoder.py
@@ -0,0 +1,133 @@
+"""Latitude decoder head.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields
+"""
+
+import logging
+
+import torch
+from torch import nn
+
+from siclib.models import get_model
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.metrics import latitude_error
+from siclib.models.utils.perspective_encoding import decode_bin_latitude
+from siclib.utils.conversions import deg2rad
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class LatitudeDecoder(BaseModel):
+ default_conf = {
+ "loss_type": "l1",
+ "use_loss": True,
+ "use_uncertainty_loss": True,
+ "loss_weight": 1.0,
+ "recall_thresholds": [1, 3, 5, 10],
+ "use_tanh": True, # backward compatibility to original perspective weights
+ "decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
+ }
+
+ required_data_keys = ["features"]
+
+ def _init(self, conf):
+ self.loss_type = conf.loss_type
+ self.loss_weight = conf.loss_weight
+
+ self.use_uncertainty_loss = conf.use_uncertainty_loss
+ self.predict_uncertainty = conf.decoder.predict_uncertainty
+
+ self.num_classes = 1
+ self.is_classification = self.conf.loss_type == "classification"
+ if self.is_classification:
+ self.num_classes = 180
+
+ self.decoder = get_model(conf.decoder.name)(conf.decoder)
+ self.linear_pred_latitude = nn.Conv2d(
+ self.decoder.out_channels, self.num_classes, kernel_size=1
+ )
+
+ def calculate_losses(self, predictions, targets, confidence=None):
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
+
+ residuals = predictions - targets
+ if self.loss_type == "l2":
+ loss = (residuals**2).sum(axis=1)
+ elif self.loss_type == "l1":
+ loss = residuals.abs().sum(axis=1)
+ elif self.loss_type == "cauchy":
+ c = 0.007 # -> corresponds to about 5 degrees
+ residuals = (residuals**2).sum(axis=1)
+ loss = c**2 / 2 * torch.log(1 + residuals / c**2)
+ elif self.loss_type == "huber":
+ c = deg2rad(1)
+ loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
+ else:
+ raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")
+
+ if confidence is not None and self.use_uncertainty_loss:
+ conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
+ conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
+ loss = loss * conf_weight.detach()
+
+ losses = {f"latitude-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
+ losses = {k: v * self.loss_weight for k, v in losses.items()}
+
+ return losses
+
+ def _forward(self, data):
+ out = {}
+ x, log_confidence = self.decoder(data["features"])
+ lat = self.linear_pred_latitude(x)
+
+ if self.predict_uncertainty:
+ out["latitude_confidence"] = torch.sigmoid(log_confidence)
+
+ if self.is_classification:
+ out["latitude_field_logits"] = lat
+ out["latitude_field"] = decode_bin_latitude(
+ lat.argmax(dim=1), self.num_classes
+ ).unsqueeze(1)
+ return out
+
+ eps = 1e-5 # avoid nan in backward of asin
+ lat = torch.tanh(lat) if self.conf.use_tanh else lat
+ lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps))
+
+ out["latitude_field"] = lat
+ return out
+
+ def loss(self, pred, data):
+ if not self.conf.use_loss or self.is_classification:
+ return {}, self.metrics(pred, data)
+
+ predictions = pred["latitude_field"]
+ targets = data["latitude_field"]
+
+ losses = self.calculate_losses(predictions, targets, pred.get("latitude_confidence"))
+
+ total = 0 + losses[f"latitude-{self.loss_type}-loss"]
+ losses |= {"latitude_total": total}
+ return losses, self.metrics(pred, data)
+
+ def metrics(self, pred, data):
+ predictions = pred["latitude_field"]
+ targets = data["latitude_field"]
+
+ error = latitude_error(predictions, targets)
+ out = {"latitude_angle_error": error.mean(axis=(1, 2))}
+
+ if "latitude_confidence" in pred:
+ weighted_error = (error * pred["latitude_confidence"]).sum(axis=(1, 2))
+ out["latitude_angle_error_weighted"] = weighted_error / pred["latitude_confidence"].sum(
+ axis=(1, 2)
+ )
+
+ for th in self.conf.recall_thresholds:
+ rec = (error < th).float().mean(axis=(1, 2))
+ out[f"latitude_angle_recall@{th}"] = rec
+
+ return out
diff --git a/siclib/models/decoders/light_hamburger.py b/siclib/models/decoders/light_hamburger.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7973084b206b32321e68e6d7298f2d7a709d23
--- /dev/null
+++ b/siclib/models/decoders/light_hamburger.py
@@ -0,0 +1,243 @@
+"""Light HamHead Decoder.
+
+Adapted from:
+https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from siclib.models import BaseModel
+from siclib.models.utils.modules import ConvModule, FeatureFusionBlock
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class _MatrixDecomposition2DBase(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ self.spatial = True
+
+ self.S = 1
+ self.D = 512
+ self.R = 64
+
+ self.train_steps = 6
+ self.eval_steps = 7
+
+ self.inv_t = 100
+ self.eta = 0.9
+
+ self.rand_init = True
+
+ def _build_bases(self, B, S, D, R, device="cpu"):
+ raise NotImplementedError
+
+ def local_step(self, x, bases, coef):
+ raise NotImplementedError
+
+ # @torch.no_grad()
+ def local_inference(self, x, bases):
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ coef = torch.bmm(x.transpose(1, 2), bases)
+ coef = F.softmax(self.inv_t * coef, dim=-1)
+
+ steps = self.train_steps if self.training else self.eval_steps
+ for _ in range(steps):
+ bases, coef = self.local_step(x, bases, coef)
+
+ return bases, coef
+
+ def compute_coef(self, x, bases, coef):
+ raise NotImplementedError
+
+ def forward(self, x, return_bases=False):
+ B, C, H, W = x.shape
+
+ # (B, C, H, W) -> (B * S, D, N)
+ if self.spatial:
+ D = C // self.S
+ N = H * W
+ x = x.view(B * self.S, D, N)
+ else:
+ D = H * W
+ N = C // self.S
+ x = x.view(B * self.S, N, D).transpose(1, 2)
+
+ if not self.rand_init and not hasattr(self, "bases"):
+ bases = self._build_bases(1, self.S, D, self.R, device=x.device)
+ self.register_buffer("bases", bases)
+
+ # (S, D, R) -> (B * S, D, R)
+ if self.rand_init:
+ bases = self._build_bases(B, self.S, D, self.R, device=x.device)
+ else:
+ bases = self.bases.repeat(B, 1, 1)
+
+ bases, coef = self.local_inference(x, bases)
+
+ # (B * S, N, R)
+ coef = self.compute_coef(x, bases, coef)
+
+ # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N)
+ x = torch.bmm(bases, coef.transpose(1, 2))
+
+ # (B * S, D, N) -> (B, C, H, W)
+ x = x.view(B, C, H, W) if self.spatial else x.transpose(1, 2).view(B, C, H, W)
+ # (B * H, D, R) -> (B, H, N, D)
+ bases = bases.view(B, self.S, D, self.R)
+
+ return x
+
+
+class NMF2D(_MatrixDecomposition2DBase):
+ def __init__(self):
+ super().__init__()
+
+ self.inv_t = 1
+
+ def _build_bases(self, B, S, D, R, device="cpu"):
+ bases = torch.rand((B * S, D, R)).to(device)
+ bases = F.normalize(bases, dim=1)
+
+ return bases
+
+ # @torch.no_grad()
+ def local_step(self, x, bases, coef):
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # Multiplicative Update
+ coef = coef * numerator / (denominator + 1e-6)
+
+ # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R)
+ numerator = torch.bmm(x, coef)
+ # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R)
+ denominator = bases.bmm(coef.transpose(1, 2).bmm(coef))
+ # Multiplicative Update
+ bases = bases * numerator / (denominator + 1e-6)
+
+ return bases, coef
+
+ def compute_coef(self, x, bases, coef):
+ # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R)
+ numerator = torch.bmm(x.transpose(1, 2), bases)
+ # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R)
+ denominator = coef.bmm(bases.transpose(1, 2).bmm(bases))
+ # multiplication update
+ coef = coef * numerator / (denominator + 1e-6)
+
+ return coef
+
+
+class Hamburger(nn.Module):
+ def __init__(self, ham_channels=512, norm_cfg=None, **kwargs):
+ super().__init__()
+
+ self.ham_in = ConvModule(ham_channels, ham_channels, 1)
+
+ self.ham = NMF2D()
+
+ self.ham_out = ConvModule(ham_channels, ham_channels, 1)
+
+ def forward(self, x):
+ enjoy = self.ham_in(x)
+ enjoy = F.relu(enjoy, inplace=False)
+ enjoy = self.ham(enjoy)
+ enjoy = self.ham_out(enjoy)
+ ham = F.relu(x + enjoy, inplace=False)
+
+ return ham
+
+
+class LightHamHead(BaseModel):
+ """Is Attention Better Than Matrix Decomposition?
+ This head is the implementation of `HamNet
+ `_.
+
+ Args:
+ ham_channels (int): input channels for Hamburger.
+ ham_kwargs (int): kwagrs for Ham.
+ """
+
+ default_conf = {
+ "predict_uncertainty": True,
+ "out_channels": 64,
+ "in_channels": [64, 128, 320, 512],
+ "in_index": [0, 1, 2, 3],
+ "ham_channels": 512,
+ "with_low_level": True,
+ }
+
+ def _init(self, conf):
+ self.in_index = conf.in_index
+ self.in_channels = conf.in_channels
+ self.out_channels = conf.out_channels
+ self.ham_channels = conf.ham_channels
+ self.align_corners = False
+ self.predict_uncertainty = conf.predict_uncertainty
+
+ self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1)
+
+ self.hamburger = Hamburger(self.ham_channels)
+
+ self.align = ConvModule(self.ham_channels, self.out_channels, 1)
+
+ if self.predict_uncertainty:
+ self.linear_pred_uncertainty = nn.Sequential(
+ ConvModule(
+ in_channels=self.out_channels,
+ out_channels=self.out_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ ),
+ nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1),
+ )
+
+ self.with_ll = conf.with_low_level
+ if self.with_ll:
+ self.out_conv = ConvModule(
+ self.out_channels, self.out_channels, 3, padding=1, bias=False
+ )
+ self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False)
+
+ def _forward(self, features):
+ """Forward function."""
+ # inputs = self._transform_inputs(inputs)
+ inputs = [features["hl"][i] for i in self.in_index]
+
+ inputs = [
+ F.interpolate(
+ level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners
+ )
+ for level in inputs
+ ]
+
+ inputs = torch.cat(inputs, dim=1)
+ x = self.squeeze(inputs)
+
+ x = self.hamburger(x)
+
+ feats = self.align(x)
+
+ if self.with_ll:
+ assert "ll" in features, "Low-level features are required for this model"
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
+ feats = self.out_conv(feats)
+ feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False)
+ feats_ll = features["ll"].clone()
+ feats = self.ll_fusion(feats, feats_ll)
+
+ uncertainty = (
+ self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None
+ )
+
+ return feats, uncertainty
+
+ def loss(self, pred, data):
+ raise NotImplementedError
diff --git a/siclib/models/decoders/perspective_decoder.py b/siclib/models/decoders/perspective_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4537586cec3e6d3a411eead431c4633395ef93a
--- /dev/null
+++ b/siclib/models/decoders/perspective_decoder.py
@@ -0,0 +1,59 @@
+"""Perspective fields decoder heads.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields
+"""
+
+import logging
+
+from siclib.models import get_model
+from siclib.models.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class PerspectiveDecoder(BaseModel):
+ default_conf = {
+ "up_decoder": {"name": "decoders.up_decoder"},
+ "latitude_decoder": {"name": "decoders.latitude_decoder"},
+ }
+
+ required_data_keys = ["features"]
+
+ def _init(self, conf):
+ logger.debug(f"Initializing PerspectiveDecoder with config: {conf}")
+ self.use_up = conf.up_decoder is not None
+ self.use_latitude = conf.latitude_decoder is not None
+
+ if self.use_up:
+ self.up_head = get_model(conf.up_decoder.name)(conf.up_decoder)
+
+ if self.use_latitude:
+ self.latitude_head = get_model(conf.latitude_decoder.name)(conf.latitude_decoder)
+
+ def _forward(self, data):
+ out_up = self.up_head(data) if self.use_up else {}
+ out_lat = self.latitude_head(data) if self.use_latitude else {}
+ return out_up | out_lat
+
+ def loss(self, pred, data):
+ ref = data["up_field"] if self.use_up else data["latitude_field"]
+
+ total = ref.new_zeros(ref.shape[0])
+ losses, metrics = {}, {}
+ if self.use_up:
+ up_losses, up_metrics = self.up_head.loss(pred, data)
+ losses |= up_losses
+ metrics |= up_metrics
+ total = total + losses.get("up_total", 0)
+
+ if self.use_latitude:
+ latitude_losses, latitude_metrics = self.latitude_head.loss(pred, data)
+ losses |= latitude_losses
+ metrics |= latitude_metrics
+ total = total + losses.get("latitude_total", 0)
+
+ losses["perspective_total"] = total
+ return losses, metrics
diff --git a/siclib/models/decoders/up_decoder.py b/siclib/models/decoders/up_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..30c72713cd2df8b92f156adc84628d18a934be33
--- /dev/null
+++ b/siclib/models/decoders/up_decoder.py
@@ -0,0 +1,128 @@
+"""up decoder head.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields
+"""
+
+import logging
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from siclib.models import get_model
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.metrics import up_error
+from siclib.models.utils.perspective_encoding import decode_up_bin
+from siclib.utils.conversions import deg2rad
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class UpDecoder(BaseModel):
+ default_conf = {
+ "loss_type": "l1",
+ "use_loss": True,
+ "use_uncertainty_loss": True,
+ "loss_weight": 1.0,
+ "recall_thresholds": [1, 3, 5, 10],
+ "decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
+ }
+
+ required_data_keys = ["features"]
+
+ def _init(self, conf):
+ self.loss_type = conf.loss_type
+ self.loss_weight = conf.loss_weight
+
+ self.use_uncertainty_loss = conf.use_uncertainty_loss
+ self.predict_uncertainty = conf.decoder.predict_uncertainty
+
+ self.num_classes = 2
+ self.is_classification = self.conf.loss_type == "classification"
+ if self.is_classification:
+ self.num_classes = 73
+
+ self.decoder = get_model(conf.decoder.name)(conf.decoder)
+ self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1)
+
+ def calculate_losses(self, predictions, targets, confidence=None):
+ predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
+
+ residuals = predictions - targets
+ if self.loss_type == "l2":
+ loss = (residuals**2).sum(axis=1)
+ elif self.loss_type == "l1":
+ loss = residuals.abs().sum(axis=1)
+ elif self.loss_type == "dot":
+ loss = 1 - (residuals * targets).sum(axis=1)
+ elif self.loss_type == "cauchy":
+ c = 0.007 # -> corresponds to about 5 degrees
+ residuals = (residuals**2).sum(axis=1)
+ loss = c**2 / 2 * torch.log(1 + residuals / c**2)
+ elif self.loss_type == "huber":
+ c = deg2rad(1)
+ loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
+ else:
+ raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")
+
+ if confidence is not None and self.use_uncertainty_loss:
+ conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
+ conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
+ loss = loss * conf_weight.detach()
+
+ losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
+ losses = {k: v * self.loss_weight for k, v in losses.items()}
+
+ return losses
+
+ def _forward(self, data):
+ out = {}
+ x, log_confidence = self.decoder(data["features"])
+ up = self.linear_pred_up(x)
+
+ if self.predict_uncertainty:
+ out["up_confidence"] = torch.sigmoid(log_confidence)
+
+ if self.is_classification:
+ out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes)
+ return out
+
+ up = F.normalize(up, dim=1)
+
+ out["up_field"] = up
+ return out
+
+ def loss(self, pred, data):
+ if not self.conf.use_loss or self.is_classification:
+ return {}, self.metrics(pred, data)
+
+ predictions = pred["up_field"]
+ targets = data["up_field"]
+
+ losses = self.calculate_losses(predictions, targets, pred.get("up_confidence"))
+
+ total = 0 + losses[f"up-{self.loss_type}-loss"]
+ losses |= {"up_total": total}
+ return losses, self.metrics(pred, data)
+
+ def metrics(self, pred, data):
+ predictions = pred["up_field"]
+ targets = data["up_field"]
+
+ mask = predictions.sum(axis=1) != 0
+
+ error = up_error(predictions, targets) * mask
+ out = {"up_angle_error": error.mean(axis=(1, 2))}
+
+ if "up_confidence" in pred:
+ weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2))
+ out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2))
+
+ for th in self.conf.recall_thresholds:
+ rec = (error < th).float().mean(axis=(1, 2))
+ out[f"up_angle_recall@{th}"] = rec
+
+ return out
diff --git a/siclib/models/encoders/__init__.py b/siclib/models/encoders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/models/encoders/low_level_encoder.py b/siclib/models/encoders/low_level_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..99d4086131dea93d908052813621268f967c18f9
--- /dev/null
+++ b/siclib/models/encoders/low_level_encoder.py
@@ -0,0 +1,54 @@
+import logging
+
+import torch.nn as nn
+
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.modules import ConvModule
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class LowLevelEncoder(BaseModel):
+ default_conf = {
+ "feat_dim": 64,
+ "in_channel": 3,
+ "keep_resolution": True,
+ }
+
+ required_data_keys = ["image"]
+
+ def _init(self, conf):
+ logger.debug(f"Initializing LowLevelEncoder with {conf}")
+
+ if self.conf.keep_resolution:
+ self.conv1 = ConvModule(conf.in_channel, conf.feat_dim, kernel_size=3, padding=1)
+ self.conv2 = ConvModule(conf.feat_dim, conf.feat_dim, kernel_size=3, padding=1)
+ else:
+ self.conv1 = nn.Conv2d(
+ conf.in_channel, conf.feat_dim, kernel_size=7, stride=2, padding=3, bias=False
+ )
+ self.bn1 = nn.BatchNorm2d(conf.feat_dim)
+ self.relu = nn.ReLU(inplace=True)
+
+ def _forward(self, data):
+ x = data["image"]
+
+ assert (
+ x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0
+ ), "Image size must be multiple of 32 if not using single image input."
+
+ if self.conf.keep_resolution:
+ c1 = self.conv1(x)
+ c2 = self.conv2(c1)
+ else:
+ x = self.conv1(x)
+ x = self.bn1(x)
+ c2 = self.relu(x)
+
+ return {"features": c2}
+
+ def loss(self, pred, data):
+ raise NotImplementedError
diff --git a/siclib/models/encoders/mscan.py b/siclib/models/encoders/mscan.py
new file mode 100644
index 0000000000000000000000000000000000000000..81131d09082e9634df195f72a3fd47012162b209
--- /dev/null
+++ b/siclib/models/encoders/mscan.py
@@ -0,0 +1,258 @@
+"""Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic
+Segmentation (NeurIPS 2022)
+
+based on: https://github.com/Visual-Attention-Network/SegNeXt
+"""
+
+import torch
+import torch.nn as nn
+from torch.nn.modules.utils import _pair as to_2tuple
+
+from siclib.models import BaseModel
+from siclib.models.utils.modules import DropPath, DWConv
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0
+ ):
+ """Initialize the MLP."""
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Conv2d(in_features, hidden_features, 1)
+ self.dwconv = DWConv(hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Conv2d(hidden_features, out_features, 1)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ """Forward pass."""
+ x = self.fc1(x)
+
+ x = self.dwconv(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+
+ return x
+
+
+class StemConv(nn.Module):
+ def __init__(self, in_channels, out_channels):
+ super(StemConv, self).__init__()
+
+ self.proj = nn.Sequential(
+ nn.Conv2d(
+ in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
+ ),
+ nn.BatchNorm2d(out_channels // 2),
+ nn.GELU(),
+ nn.Conv2d(
+ out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1)
+ ),
+ nn.BatchNorm2d(out_channels),
+ )
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.size()
+ x = x.flatten(2).transpose(1, 2)
+ return x, H, W
+
+
+class AttentionModule(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
+ self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim)
+ self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim)
+
+ self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
+ self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
+
+ self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
+ self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
+ self.conv3 = nn.Conv2d(dim, dim, 1)
+
+ def forward(self, x):
+ u = x.clone()
+ attn = self.conv0(x)
+
+ attn_0 = self.conv0_1(attn)
+ attn_0 = self.conv0_2(attn_0)
+
+ attn_1 = self.conv1_1(attn)
+ attn_1 = self.conv1_2(attn_1)
+
+ attn_2 = self.conv2_1(attn)
+ attn_2 = self.conv2_2(attn_2)
+ attn = attn + attn_0 + attn_1 + attn_2
+
+ attn = self.conv3(attn)
+
+ return attn * u
+
+
+class SpatialAttention(nn.Module):
+ def __init__(self, d_model):
+ super().__init__()
+ self.d_model = d_model
+ self.proj_1 = nn.Conv2d(d_model, d_model, 1)
+ self.activation = nn.GELU()
+ self.spatial_gating_unit = AttentionModule(d_model)
+ self.proj_2 = nn.Conv2d(d_model, d_model, 1)
+
+ def forward(self, x):
+ shorcut = x.clone()
+ x = self.proj_1(x)
+ x = self.activation(x)
+ x = self.spatial_gating_unit(x)
+ x = self.proj_2(x)
+ x = x + shorcut
+ return x
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim,
+ mlp_ratio=4.0,
+ drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ ):
+ super().__init__()
+ self.norm1 = nn.BatchNorm2d(dim)
+ self.attn = SpatialAttention(dim)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = nn.BatchNorm2d(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop
+ )
+ layer_scale_init_value = 1e-2
+ self.layer_scale_1 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
+ )
+ self.layer_scale_2 = nn.Parameter(
+ layer_scale_init_value * torch.ones((dim)), requires_grad=True
+ )
+
+ def forward(self, x, H, W):
+ B, N, C = x.shape
+ x = x.permute(0, 2, 1).view(B, C, H, W)
+ x = x + self.drop_path(
+ self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))
+ )
+ x = x + self.drop_path(
+ self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))
+ )
+ x = x.view(B, C, N).permute(0, 2, 1)
+ return x
+
+
+class OverlapPatchEmbed(nn.Module):
+ """Image to Patch Embedding"""
+
+ def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+
+ self.proj = nn.Conv2d(
+ in_chans,
+ embed_dim,
+ kernel_size=patch_size,
+ stride=stride,
+ padding=(patch_size[0] // 2, patch_size[1] // 2),
+ )
+ self.norm = nn.BatchNorm2d(embed_dim)
+
+ def forward(self, x):
+ x = self.proj(x)
+ _, _, H, W = x.shape
+ x = self.norm(x)
+
+ x = x.flatten(2).transpose(1, 2)
+
+ return x, H, W
+
+
+class MSCAN(BaseModel):
+ default_conf = {
+ "in_channels": 3,
+ "embed_dims": [64, 128, 320, 512],
+ "mlp_ratios": [8, 8, 4, 4],
+ "drop_rate": 0.0,
+ "drop_path_rate": 0.1,
+ "depths": [3, 3, 12, 3],
+ "num_stages": 4,
+ }
+
+ required_data_keys = ["image"]
+
+ def _init(self, conf):
+ self.depths = conf.depths
+ self.num_stages = conf.num_stages
+
+ # stochastic depth decay rule
+ dpr = [x.item() for x in torch.linspace(0, conf.drop_path_rate, sum(conf.depths))]
+ cur = 0
+
+ for i in range(conf.num_stages):
+ if i == 0:
+ patch_embed = StemConv(3, conf.embed_dims[0])
+ else:
+ patch_embed = OverlapPatchEmbed(
+ patch_size=7 if i == 0 else 3,
+ stride=4 if i == 0 else 2,
+ in_chans=conf.in_chans if i == 0 else conf.embed_dims[i - 1],
+ embed_dim=conf.embed_dims[i],
+ )
+
+ block = nn.ModuleList(
+ [
+ Block(
+ dim=conf.embed_dims[i],
+ mlp_ratio=conf.mlp_ratios[i],
+ drop=conf.drop_rate,
+ drop_path=dpr[cur + j],
+ )
+ for j in range(conf.depths[i])
+ ]
+ )
+ norm = nn.LayerNorm(conf.embed_dims[i])
+ cur += conf.depths[i]
+
+ setattr(self, f"patch_embed{i + 1}", patch_embed)
+ setattr(self, f"block{i + 1}", block)
+ setattr(self, f"norm{i + 1}", norm)
+
+ def _forward(self, data):
+ img = data["image"]
+ # rgb -> bgr and from [0, 1] to [0, 255]
+ x = img[:, [2, 1, 0], :, :] * 255.0
+
+ B = x.shape[0]
+ outs = []
+
+ for i in range(self.num_stages):
+ patch_embed = getattr(self, f"patch_embed{i + 1}")
+ block = getattr(self, f"block{i + 1}")
+ norm = getattr(self, f"norm{i + 1}")
+ x, H, W = patch_embed(x)
+ for blk in block:
+ x = blk(x, H, W)
+ x = norm(x)
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
+ outs.append(x)
+
+ return {"features": outs}
+
+ def loss(self, pred, data):
+ """Compute the loss."""
+ raise NotImplementedError
diff --git a/siclib/models/encoders/resnet.py b/siclib/models/encoders/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8a6521417f09e536788eadf3baf037046a2d44
--- /dev/null
+++ b/siclib/models/encoders/resnet.py
@@ -0,0 +1,83 @@
+"""Basic ResNet encoder for image feature extraction.
+
+https://pytorch.org/hub/pytorch_vision_resnet/
+"""
+
+import torch
+import torch.nn as nn
+import torchvision
+from torchvision.models.feature_extraction import create_feature_extractor
+
+from siclib.models.base_model import BaseModel
+
+# mypy: ignore-errors
+
+
+def remove_conv_stride(conv):
+ """Remove the stride from a convolutional layer."""
+ conv_new = nn.Conv2d(
+ conv.in_channels,
+ conv.out_channels,
+ conv.kernel_size,
+ bias=conv.bias is not None,
+ stride=1,
+ padding=conv.padding,
+ )
+ conv_new.weight = conv.weight
+ conv_new.bias = conv.bias
+ return conv_new
+
+
+class ResNet(BaseModel):
+ """ResNet encoder for image features extraction."""
+
+ default_conf = {
+ "encoder": "resnet18",
+ "pretrained": True,
+ "input_dim": 3,
+ "remove_stride_from_first_conv": True,
+ "num_downsample": None, # how many downsample bloc
+ "pixel_mean": [0.485, 0.456, 0.406],
+ "pixel_std": [0.229, 0.224, 0.225],
+ }
+
+ required_data_keys = ["image"]
+
+ def build_encoder(self, conf):
+ """Build the encoder from the configuration."""
+ if conf.pretrained:
+ assert conf.input_dim == 3
+
+ Encoder = getattr(torchvision.models, conf.encoder)
+
+ layers = ["layer1", "layer2", "layer3", "layer4"]
+ kw = {"replace_stride_with_dilation": [False, False, False]}
+
+ if conf.num_downsample is not None:
+ layers = layers[: conf.num_downsample]
+
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
+
+ if conf.remove_stride_from_first_conv:
+ encoder.conv1 = remove_conv_stride(encoder.conv1)
+
+ return encoder, layers
+
+ def _init(self, conf):
+ self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False)
+
+ self.encoder, self.layers = self.build_encoder(conf)
+
+ def _forward(self, data):
+ image = data["image"]
+ image = (image - self.pixel_mean) / self.pixel_std
+ skip_features = list(self.encoder(image).values())
+
+ # print(f"skip_features: {[f.shape for f in skip_features]}")
+ return {"features": skip_features}
+
+ def loss(self, pred, data):
+ """Compute the loss."""
+ raise NotImplementedError
diff --git a/siclib/models/encoders/vgg.py b/siclib/models/encoders/vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8c62323dca054896aaeecd798afa335b32df3fa
--- /dev/null
+++ b/siclib/models/encoders/vgg.py
@@ -0,0 +1,75 @@
+"""Simple VGG encoder for image features extraction."""
+
+import torch
+import torchvision
+from torchvision.models.feature_extraction import create_feature_extractor
+
+from siclib.models.base_model import BaseModel
+
+# mypy: ignore-errors
+
+
+class VGG(BaseModel):
+ """VGG encoder for image features extraction."""
+
+ default_conf = {
+ "encoder": "vgg13",
+ "pretrained": True,
+ "input_dim": 3,
+ "num_downsample": None, # how many downsample blocs to use
+ "pixel_mean": [0.485, 0.456, 0.406],
+ "pixel_std": [0.229, 0.224, 0.225],
+ }
+
+ required_data_keys = ["image"]
+
+ def build_encoder(self, conf):
+ """Build the encoder from the configuration."""
+ if conf.pretrained:
+ assert conf.input_dim == 3
+
+ Encoder = getattr(torchvision.models, conf.encoder)
+
+ kw = {}
+ if conf.encoder == "vgg13":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.13",
+ "features.18",
+ "features.23",
+ ]
+ elif conf.encoder == "vgg16":
+ layers = [
+ "features.3",
+ "features.8",
+ "features.15",
+ "features.22",
+ "features.29",
+ ]
+ else:
+ raise NotImplementedError(f"Encoder not implemented: {conf.encoder}")
+
+ if conf.num_downsample is not None:
+ layers = layers[: conf.num_downsample]
+
+ encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw)
+ encoder = create_feature_extractor(encoder, return_nodes=layers)
+
+ return encoder, layers
+
+ def _init(self, conf):
+ self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False)
+ self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False)
+
+ self.encoder, self.layers = self.build_encoder(conf)
+
+ def _forward(self, data):
+ image = data["image"]
+ image = (image - self.pixel_mean) / self.pixel_std
+ skip_features = self.encoder(image).values()
+ return {"features": skip_features}
+
+ def loss(self, pred, data):
+ """Compute the loss."""
+ raise NotImplementedError
diff --git a/siclib/models/extractor.py b/siclib/models/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..4096666163a097504f78df505972c048ac523192
--- /dev/null
+++ b/siclib/models/extractor.py
@@ -0,0 +1,126 @@
+"""Simple interface for GeoCalib model."""
+
+from pathlib import Path
+from typing import Dict, Optional
+
+import torch
+import torch.nn as nn
+from torch.nn.functional import interpolate
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.models.networks.geocalib import GeoCalib as Model
+from siclib.utils.image import ImagePreprocessor, load_image
+
+
+class GeoCalib(nn.Module):
+ """Simple interface for GeoCalib model."""
+
+ def __init__(self, weights: str = "pinhole"):
+ """Initialize the model with optional config overrides.
+
+ Args:
+ weights (str, optional): Weights to load. Defaults to "pinhole".
+ """
+ super().__init__()
+ if weights == "pinhole":
+ url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar"
+ elif weights == "distorted":
+ url = (
+ "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar"
+ )
+ else:
+ raise ValueError(f"Unknown weights: {weights}")
+
+ # load checkpoint
+ model_dir = f"{torch.hub.get_dir()}/geocalib"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
+ )
+
+ self.model = Model({})
+ self.model.flexible_load(state_dict["model"])
+ self.model.eval()
+
+ self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32})
+
+ def load_image(self, path: Path) -> torch.Tensor:
+ """Load image from path."""
+ return load_image(path)
+
+ def _post_process(
+ self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor]
+ ) -> tuple[BaseCamera, dict[str, torch.Tensor]]:
+ """Post-process model output by undoing scaling and cropping."""
+ camera = camera.undo_scale_crop(img_data)
+
+ w, h = camera.size.unbind(-1)
+ h = h[0].round().int().item()
+ w = w[0].round().int().item()
+
+ for k in ["latitude_field", "up_field"]:
+ out[k] = interpolate(out[k], size=(h, w), mode="bilinear")
+ for k in ["up_confidence", "latitude_confidence"]:
+ out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0]
+
+ inverse_scales = 1.0 / img_data["scales"]
+ zero = camera.new_zeros(camera.f.shape[0])
+ out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1]
+ return camera, out
+
+ @torch.no_grad()
+ def calibrate(
+ self,
+ img: torch.Tensor,
+ camera_model: str = "pinhole",
+ priors: Optional[Dict[str, torch.Tensor]] = None,
+ shared_intrinsics: bool = False,
+ ) -> Dict[str, torch.Tensor]:
+ """Perform calibration with online resizing.
+
+ Assumes input image is in range [0, 1] and in RGB format.
+
+ Args:
+ img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W)
+ camera_model (str, optional): Camera model. Defaults to "pinhole".
+ priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}.
+ shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False.
+
+ Returns:
+ Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties.
+ """
+ if len(img.shape) == 3:
+ img = img[None] # add batch dim
+ if not shared_intrinsics:
+ assert len(img.shape) == 4 and img.shape[0] == 1
+
+ img_data = self.image_processor(img)
+
+ if priors is None:
+ priors = {}
+
+ prior_values = {}
+ if prior_focal := priors.get("focal"):
+ prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal
+ prior_values["prior_focal"] = prior_focal * img_data["scales"][1]
+
+ if "gravity" in priors:
+ prior_gravity = priors["gravity"]
+ prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity
+ prior_values["prior_gravity"] = prior_gravity
+
+ self.model.optimizer.set_camera_model(camera_model)
+ self.model.optimizer.shared_intrinsics = shared_intrinsics
+
+ out = self.model(img_data | prior_values)
+
+ camera, gravity = out["camera"], out["gravity"]
+ camera, out = self._post_process(camera, img_data, out)
+
+ return {
+ "camera": camera,
+ "gravity": gravity,
+ "covariance": out["covariance"],
+ **{k: out[k] for k in out.keys() if "field" in k},
+ **{k: out[k] for k in out.keys() if "confidence" in k},
+ **{k: out[k] for k in out.keys() if "uncertainty" in k},
+ }
diff --git a/siclib/models/networks/__init__.py b/siclib/models/networks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/models/networks/deepcalib.py b/siclib/models/networks/deepcalib.py
new file mode 100644
index 0000000000000000000000000000000000000000..a05d00e09bbcf8f0dbeec5ac0fcace781232eea3
--- /dev/null
+++ b/siclib/models/networks/deepcalib.py
@@ -0,0 +1,299 @@
+import logging
+from copy import deepcopy
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+from torch.nn import Identity
+
+from siclib.geometry.camera import SimpleRadial
+from siclib.geometry.gravity import Gravity
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.metrics import dist_error, pitch_error, roll_error, vfov_error
+from siclib.models.utils.modules import _DenseBlock, _Transition
+from siclib.utils.conversions import deg2rad, pitch2rho, rho2pitch
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def get_centers_and_edges(min: float, max: float, num_bins: int) -> Tuple[np.ndarray, torch.Tensor]:
+ centers = torch.linspace(min, max + ((max - min) / (num_bins - 1)), num_bins + 1).float()
+ edges = centers.detach() - ((centers.detach()[1] - centers[0]) / 2.0)
+ return centers, edges
+
+
+class DeepCalib(BaseModel):
+ default_conf = {
+ "name": "densenet",
+ "model": "densenet161",
+ "loss": "NLL",
+ "num_bins": 256,
+ "freeze_batch_normalization": False,
+ "model": "densenet161",
+ "pretrained": True, # whether to use ImageNet weights
+ "heads": ["roll", "rho", "vfov", "k1_hat"],
+ "flip": [], # keys of predictions to flip the sign of
+ "rpf_scales": [1, 1, 1],
+ "bounds": {
+ "roll": [-45, 45],
+ "rho": [-1, 1],
+ "vfov": [20, 105],
+ "k1_hat": [-0.7, 0.7],
+ },
+ "use_softamax": False,
+ }
+
+ mean = [0.485, 0.456, 0.406]
+ std = [0.229, 0.224, 0.225]
+
+ strict_conf = False
+
+ required_data_keys = ["image", "image_size"]
+
+ def _init(self, conf):
+ self.is_classification = True if self.conf.loss in ["NLL"] else False
+
+ self.num_bins = conf.num_bins
+
+ self.roll_centers, self.roll_edges = get_centers_and_edges(
+ deg2rad(conf.bounds.roll[0]), deg2rad(conf.bounds.roll[1]), self.num_bins
+ )
+
+ self.rho_centers, self.rho_edges = get_centers_and_edges(
+ conf.bounds.rho[0], conf.bounds.rho[1], self.num_bins
+ )
+
+ self.fov_centers, self.fov_edges = get_centers_and_edges(
+ deg2rad(conf.bounds.vfov[0]), deg2rad(conf.bounds.vfov[1]), self.num_bins
+ )
+
+ self.k1_hat_centers, self.k1_hat_edges = get_centers_and_edges(
+ conf.bounds.k1_hat[0], conf.bounds.k1_hat[1], self.num_bins
+ )
+
+ Model = getattr(torchvision.models, conf.model)
+ weights = "DEFAULT" if self.conf.pretrained else None
+ self.model = Model(weights=weights)
+
+ layers = []
+
+ # 2208 for 161 layers. 1024 for 121
+ num_features = self.model.classifier.in_features
+ head_layers = 3
+ layers.append(_Transition(num_features, num_features // 2))
+ num_features = num_features // 2
+ growth_rate = 32
+ layers.append(
+ _DenseBlock(
+ num_layers=head_layers,
+ num_input_features=num_features,
+ growth_rate=growth_rate,
+ bn_size=4,
+ drop_rate=0,
+ )
+ )
+ layers.append(nn.BatchNorm2d(num_features + head_layers * growth_rate))
+ layers.append(nn.ReLU())
+ layers.append(nn.AdaptiveAvgPool2d((1, 1)))
+ layers.append(nn.Flatten())
+ layers.append(nn.Linear(num_features + head_layers * growth_rate, 512))
+ layers.append(nn.ReLU())
+ self.model.classifier = Identity()
+ self.model.features.norm5 = Identity()
+
+ if self.is_classification:
+ layers.append(nn.Linear(512, self.num_bins))
+ layers.append(nn.LogSoftmax(dim=1))
+ else:
+ layers.append(nn.Linear(512, 1))
+ layers.append(nn.Tanh())
+
+ self.roll_head = nn.Sequential(*deepcopy(layers))
+ self.rho_head = nn.Sequential(*deepcopy(layers))
+ self.vfov_head = nn.Sequential(*deepcopy(layers))
+ self.k1_hat_head = nn.Sequential(*deepcopy(layers))
+
+ def bins_to_val(self, centers, pred):
+ if centers.device != pred.device:
+ centers = centers.to(pred.device)
+
+ if not self.conf.use_softamax:
+ return centers[pred.argmax(1)]
+
+ beta = 1e-0
+ pred_softmax = F.softmax(pred / beta, dim=1)
+ weighted_centers = centers[:-1].unsqueeze(0) * pred_softmax
+ val = weighted_centers.sum(dim=1)
+ return val
+
+ def _forward(self, data):
+ image = data["image"]
+ mean, std = image.new_tensor(self.mean), image.new_tensor(self.std)
+ image = (image - mean[:, None, None]) / std[:, None, None]
+ shared_features = self.model.features(image)
+ pred = {}
+
+ if "roll" in self.conf.heads:
+ pred["roll"] = self.roll_head(shared_features)
+ if "rho" in self.conf.heads:
+ pred["rho"] = self.rho_head(shared_features)
+ if "vfov" in self.conf.heads:
+ pred["vfov"] = self.vfov_head(shared_features)
+ if "vfov" in self.conf.flip:
+ pred["vfov"] = pred["vfov"] * -1
+ if "k1_hat" in self.conf.heads:
+ pred["k1_hat"] = self.k1_hat_head(shared_features)
+
+ size = data["image_size"]
+ w, h = size[:, 0], size[:, 1]
+
+ if self.is_classification:
+ parameters = {
+ "roll": self.bins_to_val(self.roll_centers, pred["roll"]),
+ "rho": self.bins_to_val(self.rho_centers, pred["rho"]),
+ "vfov": self.bins_to_val(self.fov_centers, pred["vfov"]),
+ "k1_hat": self.bins_to_val(self.k1_hat_centers, pred["k1_hat"]),
+ "width": w,
+ "height": h,
+ }
+
+ for k in self.conf.flip:
+ parameters[k] = parameters[k] * -1
+
+ for i, k in enumerate(["roll", "rho", "vfov"]):
+ parameters[k] = parameters[k] * self.conf.rpf_scales[i]
+
+ camera = SimpleRadial.from_dict(parameters)
+
+ roll, pitch = parameters["roll"], rho2pitch(parameters["rho"], camera.f[..., 1], h)
+ gravity = Gravity.from_rp(roll, pitch)
+
+ else: # regression
+ if "roll" in self.conf.heads:
+ pred["roll"] = pred["roll"] * deg2rad(45)
+ if "vfov" in self.conf.heads:
+ pred["vfov"] = (pred["vfov"] + 1) * deg2rad((105 - 20) / 2 + 20)
+
+ camera = SimpleRadial.from_dict(pred | {"width": w, "height": h})
+ gravity = Gravity.from_rp(pred["roll"], pred["pitch"])
+
+ return pred | {"camera": camera, "gravity": gravity}
+
+ def loss(self, pred, data):
+ loss = {"total": 0}
+ if self.conf.loss == "Huber":
+ loss_fn = nn.HuberLoss(reduction="none")
+ elif self.conf.loss == "L1":
+ loss_fn = nn.L1Loss(reduction="none")
+ elif self.conf.loss == "L2":
+ loss_fn = nn.MSELoss(reduction="none")
+ elif self.conf.loss == "NLL":
+ loss_fn = nn.NLLLoss(reduction="none")
+
+ gt_cam = data["camera"]
+
+ if "roll" in self.conf.heads:
+ # nbins softmax values if classification, else scalar value
+ gt_roll = data["gravity"].roll.float()
+ pred_roll = pred["roll"].float()
+
+ if gt_roll.device != self.roll_edges.device:
+ self.roll_edges = self.roll_edges.to(gt_roll.device)
+ self.roll_centers = self.roll_centers.to(gt_roll.device)
+
+ if self.is_classification:
+ gt_roll = (
+ torch.bucketize(gt_roll.contiguous(), self.roll_edges) - 1
+ ) # converted to class
+
+ assert (gt_roll >= 0).all(), gt_roll
+ assert (gt_roll < self.num_bins).all(), gt_roll
+ else:
+ assert pred_roll.dim() == gt_roll.dim()
+
+ loss_roll = loss_fn(pred_roll, gt_roll)
+ loss["roll"] = loss_roll
+ loss["total"] += loss_roll
+
+ if "rho" in self.conf.heads:
+ gt_rho = pitch2rho(data["gravity"].pitch, gt_cam.f[..., 1], gt_cam.size[..., 1]).float()
+ pred_rho = pred["rho"].float()
+
+ if gt_rho.device != self.rho_edges.device:
+ self.rho_edges = self.rho_edges.to(gt_rho.device)
+ self.rho_centers = self.rho_centers.to(gt_rho.device)
+
+ if self.is_classification:
+ gt_rho = torch.bucketize(gt_rho.contiguous(), self.rho_edges) - 1
+
+ assert (gt_rho >= 0).all(), gt_rho
+ assert (gt_rho < self.num_bins).all(), gt_rho
+ else:
+ assert pred_rho.dim() == gt_rho.dim()
+
+ # print(f"Rho: {gt_rho.shape}, {pred_rho.shape}")
+ loss_rho = loss_fn(pred_rho, gt_rho)
+ loss["rho"] = loss_rho
+ loss["total"] += loss_rho
+
+ if "vfov" in self.conf.heads:
+ gt_vfov = gt_cam.vfov.float()
+ pred_vfov = pred["vfov"].float()
+
+ if gt_vfov.device != self.fov_edges.device:
+ self.fov_edges = self.fov_edges.to(gt_vfov.device)
+ self.fov_centers = self.fov_centers.to(gt_vfov.device)
+
+ if self.is_classification:
+ gt_vfov = torch.bucketize(gt_vfov.contiguous(), self.fov_edges) - 1
+
+ assert (gt_vfov >= 0).all(), gt_vfov
+ assert (gt_vfov < self.num_bins).all(), gt_vfov
+ else:
+ min_vfov = deg2rad(self.conf.bounds.vfov[0])
+ max_vfov = deg2rad(self.conf.bounds.vfov[1])
+ gt_vfov = (2 * (gt_vfov - min_vfov) / (max_vfov - min_vfov)) - 1
+ assert pred_vfov.dim() == gt_vfov.dim()
+
+ loss_vfov = loss_fn(pred_vfov, gt_vfov)
+ loss["vfov"] = loss_vfov
+ loss["total"] += loss_vfov
+
+ if "k1_hat" in self.conf.heads:
+ gt_k1_hat = data["camera"].k1_hat.float()
+ pred_k1_hat = pred["k1_hat"].float()
+
+ if gt_k1_hat.device != self.k1_hat_edges.device:
+ self.k1_hat_edges = self.k1_hat_edges.to(gt_k1_hat.device)
+ self.k1_hat_centers = self.k1_hat_centers.to(gt_k1_hat.device)
+
+ if self.is_classification:
+ gt_k1_hat = torch.bucketize(gt_k1_hat.contiguous(), self.k1_hat_edges) - 1
+
+ assert (gt_k1_hat >= 0).all(), gt_k1_hat
+ assert (gt_k1_hat < self.num_bins).all(), gt_k1_hat
+ else:
+ assert pred_k1_hat.dim() == gt_k1_hat.dim()
+
+ loss_k1_hat = loss_fn(pred_k1_hat, gt_k1_hat)
+ loss["k1_hat"] = loss_k1_hat
+ loss["total"] += loss_k1_hat
+
+ return loss, self.metrics(pred, data)
+
+ def metrics(self, pred, data):
+ pred_cam, gt_cam = pred["camera"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity"], data["gravity"]
+
+ return {
+ "roll_error": roll_error(pred_gravity, gt_gravity),
+ "pitch_error": pitch_error(pred_gravity, gt_gravity),
+ "vfov_error": vfov_error(pred_cam, gt_cam),
+ "k1_error": dist_error(pred_cam, gt_cam),
+ }
diff --git a/siclib/models/networks/dust3r.py b/siclib/models/networks/dust3r.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d00ac2f898eec450ddcb20f1fd34e51171dd422
--- /dev/null
+++ b/siclib/models/networks/dust3r.py
@@ -0,0 +1,81 @@
+"""Wrapper for DUSt3R model to estimate focal length.
+
+DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132
+"""
+
+import sys
+
+sys.path.append("third_party/dust3r")
+
+import torch
+from dust3r.cloud_opt import GlobalAlignerMode, global_aligner
+from dust3r.image_pairs import make_pairs
+from dust3r.inference import inference, load_model
+from dust3r.utils.image import load_images
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.gravity import Gravity
+from siclib.models import BaseModel
+
+# mypy: ignore-errors
+
+
+class Dust3R(BaseModel):
+ """DUSt3R model for focal length estimation."""
+
+ default_conf = {
+ "model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth",
+ "device": "cuda",
+ "batch_size": 1,
+ "schedule": "cosine",
+ "lr": 0.01,
+ "niter": 300,
+ "show_scene": False,
+ }
+
+ required_data_keys = ["path"]
+
+ def _init(self, conf):
+ """Initialize the DUSt3R model."""
+ self.model = load_model(conf["model_path"], conf["device"])
+
+ def _forward(self, data):
+ """Forward pass of the DUSt."""
+ assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}"
+
+ path = data["path"][0]
+ images = [path] * 2
+
+ with torch.enable_grad():
+ images = load_images(images, size=512)
+ pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True)
+ output = inference(
+ pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"]
+ )
+ scene = global_aligner(
+ output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer
+ )
+ _ = scene.compute_global_alignment(
+ init="mst",
+ niter=self.conf["niter"],
+ schedule=self.conf["schedule"],
+ lr=self.conf["lr"],
+ )
+
+ # retrieve useful values from scene:
+ focals = scene.get_focals().mean(dim=0)
+
+ h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1]
+ h, w = focals.new_tensor(h), focals.new_tensor(w)
+
+ camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals})
+ gravity = Gravity.from_rp([0.0], [0.0])
+
+ if self.conf["show_scene"]:
+ scene.show()
+
+ return {"camera": camera, "gravity": gravity}
+
+ def loss(self, pred, data):
+ """Loss function for DUSt3R model."""
+ return {}, {}
diff --git a/siclib/models/networks/geocalib.py b/siclib/models/networks/geocalib.py
new file mode 100644
index 0000000000000000000000000000000000000000..23af7f288c91eeea6716fb32e14609716cc3467e
--- /dev/null
+++ b/siclib/models/networks/geocalib.py
@@ -0,0 +1,66 @@
+import logging
+
+from siclib.models import get_model
+from siclib.models.base_model import BaseModel
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class GeoCalib(BaseModel):
+ default_conf = {
+ "backbone": {"name": "encoders.mscan"},
+ "ll_enc": {"name": "encoders.low_level_encoder"},
+ "perspective_decoder": {"name": "decoders.perspective_decoder"},
+ "optimizer": {"name": "optimization.lm_optimizer"},
+ }
+
+ required_data_keys = ["image"]
+
+ def _init(self, conf):
+ logger.debug(f"Initializing GeoCalib with {conf}")
+ self.backbone = get_model(conf.backbone["name"])(conf.backbone)
+ self.ll_enc = get_model(conf.ll_enc["name"])(conf.ll_enc) if conf.ll_enc else None
+
+ self.perspective_decoder = get_model(conf.perspective_decoder["name"])(
+ conf.perspective_decoder
+ )
+
+ self.optimizer = (
+ get_model(conf.optimizer["name"])(conf.optimizer) if conf.optimizer else None
+ )
+
+ def _forward(self, data):
+ backbone_out = self.backbone(data)
+ features = {"hl": backbone_out["features"], "padding": backbone_out.get("padding", None)}
+
+ if self.ll_enc is not None:
+ features["ll"] = self.ll_enc(data)["features"] # low level features
+
+ out = self.perspective_decoder({"features": features})
+
+ out |= {
+ k: data[k]
+ for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"]
+ if k in data
+ }
+
+ if self.optimizer is not None:
+ out |= self.optimizer(out)
+
+ return out
+
+ def loss(self, pred, data):
+ losses, metrics = self.perspective_decoder.loss(pred, data)
+ total = losses["perspective_total"]
+
+ if self.optimizer is not None:
+ opt_losses, param_metrics = self.optimizer.loss(pred, data)
+ losses |= opt_losses
+ metrics |= param_metrics
+ total = total + opt_losses["param_total"]
+
+ losses["total"] = total
+ return losses, metrics
diff --git a/siclib/models/networks/geocalib_pretrained.py b/siclib/models/networks/geocalib_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dde770f18bb5bb1a7c4a4ede06fbb9c907e7e50
--- /dev/null
+++ b/siclib/models/networks/geocalib_pretrained.py
@@ -0,0 +1,41 @@
+"""Interface for GeoCalib inference package."""
+
+from geocalib import GeoCalib
+from siclib.models.base_model import BaseModel
+
+
+# mypy: ignore-errors
+class GeoCalibPretrained(BaseModel):
+ """GeoCalib pretrained model."""
+
+ default_conf = {
+ "camera_model": "pinhole",
+ "model_weights": "pinhole",
+ }
+
+ def _init(self, conf):
+ """Initialize pretrained GeoCalib model."""
+ self.model = GeoCalib(weights=conf.model_weights)
+
+ def _forward(self, data):
+ """Forward pass."""
+ priors = {}
+ if "prior_gravity" in data:
+ priors["gravity"] = data["prior_gravity"]
+
+ if "prior_focal" in data:
+ priors["focal"] = data["prior_focal"]
+
+ results = self.model.calibrate(
+ data["image"], camera_model=self.conf.camera_model, priors=priors
+ )
+
+ return results
+
+ def metrics(self, pred, data):
+ """Compute metrics."""
+ raise NotImplementedError("GeoCalibPretrained does not support metrics computation.")
+
+ def loss(self, pred, data):
+ """Compute loss."""
+ raise NotImplementedError("GeoCalibPretrained does not support loss computation.")
diff --git a/siclib/models/optimization/__init__.py b/siclib/models/optimization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/models/optimization/inference_optimizer.py b/siclib/models/optimization/inference_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5666fc9d74bb7c6890bfd20ef36f13c4a47580b
--- /dev/null
+++ b/siclib/models/optimization/inference_optimizer.py
@@ -0,0 +1,36 @@
+import logging
+
+from siclib.models.optimization.lm_optimizer import LMOptimizer
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class InferenceOptimizer(LMOptimizer):
+ default_conf = {
+ # Camera model parameters
+ "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
+ "shared_intrinsics": False, # share focal length across all images in batch
+ "estimate_gravity": True,
+ "estimate_focal": True,
+ "estimate_k1": True, # will be ignored if camera_model is pinhole
+ # LM optimizer parameters
+ "num_steps": 30,
+ "lambda_": 0.1,
+ "fix_lambda": False,
+ "early_stop": True,
+ "atol": 1e-8,
+ "rtol": 1e-8,
+ "use_spherical_manifold": True, # use spherical manifold for gravity optimization
+ "use_log_focal": True, # use log focal length for optimization
+ # Loss function parameters
+ "loss_fn": "huber_loss", # {"squared_loss", "huber_loss"}
+ "up_loss_fn_scale": 1e-2,
+ "lat_loss_fn_scale": 1e-2,
+ "init_conf": {"name": "trivial"}, # pass config of other models to use as initializer
+ # Misc
+ "loss_weight": 1,
+ "verbose": False,
+ }
diff --git a/siclib/models/optimization/lm_optimizer.py b/siclib/models/optimization/lm_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..50b899aa89c18c5d149987dad038a21a1db05748
--- /dev/null
+++ b/siclib/models/optimization/lm_optimizer.py
@@ -0,0 +1,603 @@
+import logging
+import time
+from typing import Dict, Tuple
+
+import torch
+from torch import nn
+
+import siclib.models.optimization.losses as losses
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.camera import camera_models
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.jacobians import J_focal2fov
+from siclib.geometry.perspective_fields import J_perspective_field, get_perspective_field
+from siclib.models import get_model
+from siclib.models.base_model import BaseModel
+from siclib.models.optimization.utils import (
+ early_stop,
+ get_initial_estimation,
+ optimizer_step,
+ update_lambda,
+)
+from siclib.models.utils.metrics import (
+ dist_error,
+ gravity_error,
+ pitch_error,
+ roll_error,
+ vfov_error,
+)
+from siclib.utils.conversions import rad2deg
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class LMOptimizer(BaseModel):
+ default_conf = {
+ # Camera model parameters
+ "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
+ "shared_intrinsics": False, # share focal length across all images in batch
+ # LM optimizer parameters
+ "num_steps": 10,
+ "lambda_": 0.1,
+ "fix_lambda": False,
+ "early_stop": False,
+ "atol": 1e-8,
+ "rtol": 1e-8,
+ "use_spherical_manifold": True, # use spherical manifold for gravity optimization
+ "use_log_focal": True, # use log focal length for optimization
+ # Loss function parameters
+ "loss_fn": "squared_loss", # {"squared_loss", "huber_loss"}
+ "up_loss_fn_scale": 1e-2,
+ "lat_loss_fn_scale": 1e-2,
+ "init_conf": {"name": "trivial"}, # pass config of other models to use as initializer
+ # Misc
+ "loss_weight": 1,
+ "verbose": False,
+ }
+
+ def _init(self, conf):
+ self.loss_fn = getattr(losses, conf.loss_fn)
+ self.num_steps = conf.num_steps
+
+ self.set_camera_model(conf.camera_model)
+
+ self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)
+
+ self.initializer = None
+ if self.conf.init_conf.name not in ["trivial", "heuristic"]:
+ self.initializer = get_model(conf.init_conf.name)(conf.init_conf)
+
+ def set_camera_model(self, camera_model: str) -> None:
+ """Set the camera model to use for the optimization.
+
+ Args:
+ camera_model (str): Camera model to use.
+ """
+ assert (
+ camera_model in camera_models.keys()
+ ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
+ self.camera_model = camera_models[camera_model]
+ self.camera_has_distortion = hasattr(self.camera_model, "dist")
+
+ logger.debug(
+ f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
+ )
+
+ def setup_optimization_and_priors(
+ self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
+ ) -> None:
+ """Setup the optimization and priors for the LM optimizer.
+
+ Args:
+ data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults
+ to None.
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
+ Defaults to False.
+ """
+ if data is None:
+ data = {}
+ self.shared_intrinsics = shared_intrinsics
+
+ if shared_intrinsics: # si => must use pinhole
+ assert (
+ self.camera_model == camera_models["pinhole"]
+ ), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}"
+
+ self.estimate_gravity = True
+ if "prior_gravity" in data:
+ self.estimate_gravity = False
+ logger.debug("Using provided gravity as prior.")
+
+ self.estimate_focal = True
+ if "prior_focal" in data:
+ self.estimate_focal = False
+ logger.debug("Using provided focal as prior.")
+
+ self.estimate_k1 = True
+ if "prior_k1" in data:
+ self.estimate_k1 = False
+ logger.debug("Using provided k1 as prior.")
+
+ self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
+ self.focal_delta_dims = (
+ (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
+ )
+ self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,)
+
+ logger.debug(f"Camera Model: {self.camera_model}")
+ logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})")
+ logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})")
+ logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})")
+
+ logger.debug(f"Shared intrinsics: {self.shared_intrinsics}")
+
+ def calculate_residuals(
+ self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
+ ) -> Dict[str, torch.Tensor]:
+ """Calculate the residuals for the optimization.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields.
+
+ Returns:
+ Dict[str, torch.Tensor]: Residuals for the optimization.
+ """
+ perspective_up, perspective_lat = get_perspective_field(camera, gravity)
+ perspective_lat = torch.sin(perspective_lat)
+
+ residuals = {}
+ if "up_field" in data:
+ up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
+ residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)
+
+ if "latitude_field" in data:
+ target_lat = torch.sin(data["latitude_field"])
+ lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
+ residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)
+
+ return residuals
+
+ def calculate_costs(
+ self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """Calculate the costs and weights for the optimization.
+
+ Args:
+ residuals (torch.Tensor): Residuals for the optimization.
+ data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence.
+
+ Returns:
+ Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the
+ optimization.
+ """
+ costs, weights = {}, {}
+
+ if "up_residual" in residuals:
+ up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
+ up_cost, up_weight, _ = losses.scaled_loss(
+ up_cost, self.loss_fn, self.conf.up_loss_fn_scale
+ )
+
+ if "up_confidence" in data:
+ up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
+ up_weight = up_weight * up_conf
+ up_cost = up_cost * up_conf
+
+ costs["up_cost"] = up_cost
+ weights["up_weights"] = up_weight
+
+ if "latitude_residual" in residuals:
+ lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
+ lat_cost, lat_weight, _ = losses.scaled_loss(
+ lat_cost, self.loss_fn, self.conf.lat_loss_fn_scale
+ )
+
+ if "latitude_confidence" in data:
+ lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
+ lat_weight = lat_weight * lat_conf
+ lat_cost = lat_cost * lat_conf
+
+ costs["latitude_cost"] = lat_cost
+ weights["latitude_weights"] = lat_weight
+
+ return costs, weights
+
+ def calculate_gradient_and_hessian(
+ self,
+ J: torch.Tensor,
+ residuals: torch.Tensor,
+ weights: torch.Tensor,
+ shared_intrinsics: bool,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate the gradient and Hessian for given the Jacobian, residuals, and weights.
+
+ Args:
+ J (torch.Tensor): Jacobian.
+ residuals (torch.Tensor): Residuals.
+ weights (torch.Tensor): Weights.
+ shared_intrinsics (bool): Whether to share the intrinsics across the batch.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian.
+ """
+ dims = ()
+ if self.estimate_gravity:
+ dims = (0, 1)
+ if self.estimate_focal:
+ dims += (2,)
+ if self.camera_has_distortion and self.estimate_k1:
+ dims += (3,)
+ assert dims, "No parameters to optimize"
+
+ J = J[..., dims]
+
+ Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
+ Grad = weights[..., None] * Grad
+ Grad = Grad.sum(-2) # (B, N_params)
+
+ if shared_intrinsics:
+ # reshape to (1, B * (N_params-1) + 1)
+ Grad_g = Grad[..., :2].reshape(1, -1)
+ Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
+ Grad = torch.cat([Grad_g, Grad_f], dim=-1)
+
+ Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
+ Hess = weights[..., None, None] * Hess
+ Hess = Hess.sum(-3)
+
+ if shared_intrinsics:
+ H_g = torch.block_diag(*list(Hess[..., :2, :2]))
+ J_fg = Hess[..., :2, 2].flatten()
+ J_gf = Hess[..., 2, :2].flatten()
+ J_f = Hess[..., 2, 2].sum()
+ dims = H_g.shape[-1] + 1
+ Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
+ Hess[:-1, :-1] = H_g
+ Hess[-1, :-1] = J_gf
+ Hess[:-1, -1] = J_fg
+ Hess[-1, -1] = J_f
+ Hess = Hess.unsqueeze(0)
+
+ return Grad, Hess
+
+ def setup_system(
+ self,
+ camera: BaseCamera,
+ gravity: Gravity,
+ residuals: Dict[str, torch.Tensor],
+ weights: Dict[str, torch.Tensor],
+ as_rpf: bool = False,
+ shared_intrinsics: bool = False,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Calculate the gradient and Hessian for the optimization.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ residuals (Dict[str, torch.Tensor]): Residuals for the optimization.
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
+ as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to
+ roll, pitch, and focal length. Defaults to False.
+ shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
+ Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization.
+ """
+ J_up, J_lat = J_perspective_field(
+ camera,
+ gravity,
+ spherical=self.conf.use_spherical_manifold and not as_rpf,
+ log_focal=self.conf.use_log_focal and not as_rpf,
+ )
+
+ J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3)
+ J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3)
+
+ n_params = (
+ 2 * self.estimate_gravity
+ + self.estimate_focal
+ + (self.camera_has_distortion and self.estimate_k1)
+ )
+ Grad = J_up.new_zeros(J_up.shape[0], n_params)
+ Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)
+
+ if shared_intrinsics:
+ N_params = Grad.shape[0] * (n_params - 1) + 1
+ Grad = Grad.new_zeros(1, N_params)
+ Hess = Hess.new_zeros(1, N_params, N_params)
+
+ if "up_residual" in residuals:
+ Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
+ J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
+ )
+
+ if self.conf.verbose:
+ logger.info(f"Up J:\n{Up_Grad.mean(0)}")
+
+ Grad = Grad + Up_Grad
+ Hess = Hess + Up_Hess
+
+ if "latitude_residual" in residuals:
+ Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
+ J_lat,
+ residuals["latitude_residual"],
+ weights["latitude_weights"],
+ shared_intrinsics,
+ )
+
+ if self.conf.verbose:
+ logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")
+
+ Grad = Grad + Lat_Grad
+ Hess = Hess + Lat_Hess
+
+ return Grad, Hess
+
+ def estimate_uncertainty(
+ self,
+ camera_opt: BaseCamera,
+ gravity_opt: Gravity,
+ errors: Dict[str, torch.Tensor],
+ weights: Dict[str, torch.Tensor],
+ ) -> Dict[str, torch.Tensor]:
+ """Estimate the uncertainty of the optimized camera and gravity at the final step.
+
+ Args:
+ camera_opt (BaseCamera): Final optimized camera.
+ gravity_opt (Gravity): Final optimized gravity.
+ errors (Dict[str, torch.Tensor]): Costs for the optimization.
+ weights (Dict[str, torch.Tensor]): Weights for the optimization.
+
+ Returns:
+ Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity.
+ """
+ _, Hess = self.setup_system(
+ camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
+ )
+ Cov = torch.inverse(Hess)
+
+ roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ if self.estimate_gravity:
+ roll_uncertainty = Cov[..., 0, 0]
+ pitch_uncertainty = Cov[..., 1, 1]
+
+ try:
+ delta_uncertainty = Cov[..., :2, :2]
+ eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
+ gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
+ except RuntimeError:
+ logger.warning("Could not calculate gravity uncertainty")
+ gravity_uncertainty = Cov.new_zeros(Cov.shape[0])
+
+ focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
+ if self.estimate_focal:
+ focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
+ fov_uncertainty = (
+ J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
+ )
+
+ return {
+ "covariance": Cov,
+ "roll_uncertainty": torch.sqrt(roll_uncertainty),
+ "pitch_uncertainty": torch.sqrt(pitch_uncertainty),
+ "gravity_uncertainty": torch.sqrt(gravity_uncertainty),
+ "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
+ "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
+ }
+
+ def update_estimate(
+ self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
+ ) -> Tuple[BaseCamera, Gravity]:
+ """Update the camera and gravity estimates with the given delta.
+
+ Args:
+ camera (BaseCamera): Optimized camera.
+ gravity (Gravity): Optimized gravity.
+ delta (torch.Tensor): Delta to update the camera and gravity estimates.
+
+ Returns:
+ Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates.
+ """
+ delta_gravity = (
+ delta[..., self.gravity_delta_dims]
+ if self.estimate_gravity
+ else delta.new_zeros(delta.shape[:-1] + (2,))
+ )
+ new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)
+
+ delta_f = (
+ delta[..., self.focal_delta_dims]
+ if self.estimate_focal
+ else delta.new_zeros(delta.shape[:-1] + (1,))
+ )
+ new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)
+
+ delta_dist = (
+ delta[..., self.k1_delta_dims]
+ if self.camera_has_distortion and self.estimate_k1
+ else delta.new_zeros(delta.shape[:-1] + (1,))
+ )
+ if self.camera_has_distortion:
+ new_camera = new_camera.update_dist(delta_dist)
+
+ return new_camera, new_gravity
+
+ def optimize(
+ self,
+ data: Dict[str, torch.Tensor],
+ camera_opt: BaseCamera,
+ gravity_opt: Gravity,
+ ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:
+ """Optimize the camera and gravity estimates.
+
+ Args:
+ data (Dict[str, torch.Tensor]): Input data.
+ camera_opt (BaseCamera): Optimized camera.
+ gravity_opt (Gravity): Optimized gravity.
+
+ Returns:
+ Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity
+ estimates and optimization information.
+ """
+ key = list(data.keys())[0]
+ B = data[key].shape[0]
+
+ lamb = data[key].new_ones(B) * self.conf.lambda_
+ if self.shared_intrinsics:
+ lamb = data[key].new_ones(1) * self.conf.lambda_
+
+ infos = {"stop_at": self.num_steps}
+ for i in range(self.num_steps):
+ if self.conf.verbose:
+ logger.info(f"Step {i+1}/{self.num_steps}")
+
+ errors = self.calculate_residuals(camera_opt, gravity_opt, data)
+ costs, weights = self.calculate_costs(errors, data)
+
+ if i == 0:
+ prev_cost = sum(c.mean(-1) for c in costs.values())
+ for k, c in costs.items():
+ infos[f"initial_{k}"] = c.mean(-1)
+
+ infos["initial_cost"] = prev_cost
+
+ Grad, Hess = self.setup_system(
+ camera_opt,
+ gravity_opt,
+ errors,
+ weights,
+ shared_intrinsics=self.shared_intrinsics,
+ )
+ delta = optimizer_step(Grad, Hess, lamb) # (B, N_params)
+
+ if self.shared_intrinsics:
+ delta_g = delta[..., :-1].reshape(B, 2)
+ delta_f = delta[..., -1].expand(B, 1)
+ delta = torch.cat([delta_g, delta_f], dim=-1)
+
+ # calculate new cost
+ camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
+ new_cost, _ = self.calculate_costs(
+ self.calculate_residuals(camera_opt, gravity_opt, data), data
+ )
+ new_cost = sum(c.mean(-1) for c in new_cost.values())
+
+ if not self.conf.fix_lambda and not self.shared_intrinsics:
+ lamb = update_lambda(lamb, prev_cost, new_cost)
+
+ if self.conf.verbose:
+ logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}")
+ logger.info(f"Camera:\n{camera_opt._data}")
+
+ if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
+ infos["stop_at"] = min(i + 1, infos["stop_at"])
+
+ if self.conf.early_stop:
+ if self.conf.verbose:
+ logger.info(f"Early stopping at step {i+1}")
+ break
+
+ prev_cost = new_cost
+
+ if i == self.num_steps - 1 and self.conf.early_stop:
+ logger.warning("Reached maximum number of steps without convergence.")
+
+ final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3)
+ final_cost, weights = self.calculate_costs(final_errors, data) # (B, N)
+
+ if not self.training:
+ infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)
+
+ infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
+ for k, c in final_cost.items():
+ infos[f"final_{k}"] = c.mean(-1)
+
+ infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())
+
+ return camera_opt, gravity_opt, infos
+
+ def _forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """Run the LM optimization."""
+ if self.initializer is None:
+ camera_init, gravity_init = get_initial_estimation(
+ data, self.camera_model, trivial_init=self.conf.init_conf.name == "trivial"
+ )
+ else:
+ out = self.initializer(data)
+ camera_init = out["camera"]
+ gravity_init = out["gravity"]
+
+ self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)
+
+ start = time.time()
+ camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)
+
+ if self.conf.verbose:
+ logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")
+
+ logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
+ logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")
+
+ logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
+ logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")
+
+ return {"camera": camera_opt, "gravity": gravity_opt, **infos}
+
+ def metrics(
+ self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]
+ ) -> Dict[str, torch.Tensor]:
+ """Calculate the metrics for the optimization."""
+ pred_cam, gt_cam = pred["camera"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity"], data["gravity"]
+
+ infos = {"stop_at": pred["stop_at"]}
+ for k, v in pred.items():
+ if "initial" in k or "final" in k:
+ infos[k] = v
+
+ return {
+ "roll_error": roll_error(pred_gravity, gt_gravity),
+ "pitch_error": pitch_error(pred_gravity, gt_gravity),
+ "gravity_error": gravity_error(pred_gravity, gt_gravity),
+ "vfov_error": vfov_error(pred_cam, gt_cam),
+ "k1_error": dist_error(pred_cam, gt_cam),
+ **infos,
+ }
+
+ def loss(
+ self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]
+ ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
+ """Calculate the loss for the optimization."""
+ pred_cam, gt_cam = pred["camera"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity"], data["gravity"]
+
+ loss_fn = nn.L1Loss(reduction="none")
+
+ # loss will be 0 if estimate is false and prior is provided during training
+ gravity_loss = loss_fn(pred_gravity.vec3d, gt_gravity.vec3d)
+
+ h = data["camera"].size[0, 0]
+ focal_loss = loss_fn(pred_cam.f, gt_cam.f).mean(-1) / h
+
+ dist_loss = focal_loss.new_zeros(focal_loss.shape)
+ if self.camera_has_distortion:
+ dist_loss = loss_fn(pred_cam.dist, gt_cam.dist).sum(-1)
+
+ losses = {
+ "gravity": gravity_loss.sum(-1),
+ "focal": focal_loss,
+ "dist": dist_loss,
+ "param_total": gravity_loss.sum(-1) + focal_loss + dist_loss,
+ }
+
+ losses = {k: v * self.conf.loss_weight for k, v in losses.items()}
+ return losses, self.metrics(pred, data)
diff --git a/siclib/models/optimization/losses.py b/siclib/models/optimization/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..34cd1f393c0465f860d8bc03f0ae4b8ecacd2b5a
--- /dev/null
+++ b/siclib/models/optimization/losses.py
@@ -0,0 +1,93 @@
+"""Generic losses and error functions for optimization or training deep networks."""
+
+from typing import Callable, Tuple
+
+import torch
+
+
+def scaled_loss(
+ x: torch.Tensor, fn: Callable, a: float
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Apply a loss function to a tensor and pre- and post-scale it.
+
+ Args:
+ x: the data tensor, should already be squared: `x = y**2`.
+ fn: the loss function, with signature `fn(x) -> y`.
+ a: the scale parameter.
+
+ Returns:
+ The value of the loss, and its first and second derivatives.
+ """
+ a2 = a**2
+ loss, loss_d1, loss_d2 = fn(x / a2)
+ return loss * a2, loss_d1, loss_d2 / a2
+
+
+def squared_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """A dummy squared loss."""
+ return x, torch.ones_like(x), torch.zeros_like(x)
+
+
+def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """The classical robust Huber loss, with first and second derivatives."""
+ mask = x <= 1
+ sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass
+ isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx)
+ loss = torch.where(mask, x, 2 * sx - 1)
+ loss_d1 = torch.where(mask, torch.ones_like(x), isx)
+ loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x))
+ return loss, loss_d1, loss_d2
+
+
+def barron_loss(
+ x: torch.Tensor, alpha: torch.Tensor, derivatives: bool = True, eps: float = 1e-7
+) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """Parameterized & adaptive robust loss function.
+
+ Described in:
+ A General and Adaptive Robust Loss Function, Barron, CVPR 2019
+
+ alpha = 2 -> L2 loss
+ alpha = 1 -> Charbonnier loss (smooth L1)
+ alpha = 0 -> Cauchy loss
+ alpha = -2 -> Geman-McClure loss
+ alpha = -inf -> Welsch loss
+
+ Contrary to the original implementation, assume the the input is already
+ squared and scaled (basically scale=1). Computes the first derivative, but
+ not the second (TODO if needed).
+ """
+ loss_two = x
+ loss_zero = 2 * torch.log1p(torch.clamp(0.5 * x, max=33e37))
+
+ # The loss when not in one of the above special cases.
+ # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
+ beta_safe = torch.abs(alpha - 2.0).clamp(min=eps)
+ # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
+ alpha_safe = torch.where(alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha))
+ alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps)
+
+ loss_otherwise = (
+ 2 * (beta_safe / alpha_safe) * (torch.pow(x / beta_safe + 1.0, 0.5 * alpha) - 1.0)
+ )
+
+ # Select which of the cases of the loss to return.
+ loss = torch.where(alpha == 0, loss_zero, torch.where(alpha == 2, loss_two, loss_otherwise))
+ dummy = torch.zeros_like(x)
+
+ if derivatives:
+ loss_two_d1 = torch.ones_like(x)
+ loss_zero_d1 = 2 / (x + 2)
+ loss_otherwise_d1 = torch.pow(x / beta_safe + 1.0, 0.5 * alpha - 1.0)
+ loss_d1 = torch.where(
+ alpha == 0, loss_zero_d1, torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1)
+ )
+
+ return loss, loss_d1, dummy
+ else:
+ return loss, dummy, dummy
+
+
+def scaled_barron(a, c):
+ """Return a scaled Barron loss function."""
+ return lambda x: scaled_loss(x, lambda y: barron_loss(y, y.new_tensor(a)), c)
diff --git a/siclib/models/optimization/perspective_opt.py b/siclib/models/optimization/perspective_opt.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b3c1532c9c064368ef025d92aba2b4d287c62ef
--- /dev/null
+++ b/siclib/models/optimization/perspective_opt.py
@@ -0,0 +1,195 @@
+import logging
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from tqdm import tqdm
+
+from siclib.geometry.camera import Pinhole as Camera
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.perspective_fields import get_perspective_field
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.metrics import pitch_error, roll_error, vfov_error
+from siclib.utils.conversions import deg2rad
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class PerspectiveParamOpt(BaseModel):
+ default_conf = {
+ "max_steps": 1000,
+ "lr": 0.01,
+ "lr_scheduler": {
+ "name": "ReduceLROnPlateau",
+ "options": {"mode": "min", "patience": 3},
+ },
+ "patience": 3,
+ "abs_tol": 1e-7,
+ "rel_tol": 1e-9,
+ "lamb": 0.5,
+ "verbose": False,
+ }
+
+ required_data_keys = ["up_field", "latitude_field"]
+
+ def _init(self, conf):
+ pass
+
+ def cost_function(self, pred, target):
+ """Compute cost function for perspective parameter optimization."""
+ eps = 1e-7
+
+ lat_loss = F.l1_loss(pred["latitude_field"], target["latitude_field"], reduction="none")
+ lat_loss = lat_loss.squeeze(1)
+
+ up_loss = F.cosine_similarity(pred["up_field"], target["up_field"], dim=1)
+ up_loss = torch.acos(torch.clip(up_loss, -1 + eps, 1 - eps))
+
+ cost = (self.conf.lamb * lat_loss) + ((1 - self.conf.lamb) * up_loss)
+ return {
+ "total": torch.mean(cost),
+ "up": torch.mean(up_loss),
+ "latitude": torch.mean(lat_loss),
+ }
+
+ def check_convergence(self, loss, losses_prev):
+ """Check if optimization has converged."""
+
+ if loss["total"].item() <= self.conf.abs_tol:
+ return True, losses_prev
+
+ if len(losses_prev) < self.conf.patience:
+ losses_prev.append(loss["total"].item())
+
+ elif np.abs(loss["total"].item() - losses_prev[0]) < self.conf.rel_tol:
+ return True, losses_prev
+
+ else:
+ losses_prev.append(loss["total"].item())
+ losses_prev = losses_prev[-self.conf.patience :]
+
+ return False, losses_prev
+
+ def _update_estimate(self, camera: Camera, gravity: Gravity):
+ """Update camera estimate based on current parameters."""
+
+ camera = Camera.from_dict(
+ {"height": camera.size[..., 1], "width": camera.size[..., 0], "vfov": self.vfov_opt}
+ )
+ gravity = Gravity.from_rp(self.roll_opt, self.pitch_opt)
+ return camera, gravity
+
+ def optimize(self, data, camera_init, gravity_init):
+ """Optimize camera parameters to minimize cost function."""
+ device = data["up_field"].device
+ self.roll_opt = nn.Parameter(gravity_init.roll, requires_grad=True).to(device)
+ self.pitch_opt = nn.Parameter(gravity_init.pitch, requires_grad=True).to(device)
+ self.vfov_opt = nn.Parameter(camera_init.vfov, requires_grad=True).to(device)
+
+ optimizer = torch.optim.Adam(
+ [self.roll_opt, self.pitch_opt, self.vfov_opt], lr=self.conf.lr
+ )
+
+ lr_scheduler = None
+ if self.conf.lr_scheduler["name"] is not None:
+ lr_scheduler = getattr(torch.optim.lr_scheduler, self.conf.lr_scheduler["name"])(
+ optimizer, **self.conf.lr_scheduler["options"]
+ )
+
+ losses_prev = []
+
+ loop = range(self.conf.max_steps)
+ if self.conf.verbose:
+ pbar = tqdm(loop, desc="Optimizing", total=len(loop), ncols=100)
+
+ with torch.set_grad_enabled(True):
+ self.train()
+ for _ in loop:
+ optimizer.zero_grad()
+
+ camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init)
+
+ up, lat = get_perspective_field(camera_opt, gravity_opt)
+ pred = {"up_field": up, "latitude_field": lat}
+
+ loss = self.cost_function(pred, data)
+ loss["total"].backward()
+ optimizer.step()
+
+ if lr_scheduler is not None:
+ lr_scheduler.step(loss["total"])
+
+ if self.conf.verbose:
+ pbar.set_postfix({k[:3]: v.item() for k, v in loss.items()})
+ pbar.update(1)
+
+ converged, losses_prev = self.check_convergence(loss, losses_prev)
+ if converged:
+ if self.conf.verbose:
+ pbar.close()
+ break
+
+ camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init)
+ return {"camera_opt": camera_opt, "gravity_opt": gravity_opt}
+
+ def _get_init_params(self, data) -> Tuple[Camera, Gravity]:
+ """Get initial camera parameters for optimization."""
+ up_ref = data["up_field"]
+ latitude_ref = data["latitude_field"]
+
+ h, w = latitude_ref.shape[-2:]
+
+ # init roll is angle of the up vector at the center of the image
+ init_r = -torch.arctan2(
+ up_ref[:, 0, int(h / 2), int(w / 2)],
+ -up_ref[:, 1, int(h / 2), int(w / 2)],
+ )
+
+ # init pitch is the value at the center of the latitude map
+ init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
+
+ # init vfov is the difference between the central top and bottom of the latitude map
+ init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
+ init_vfov = torch.abs(init_vfov)
+ init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))
+
+ h, w = (
+ latitude_ref.new_ones(latitude_ref.shape[0]) * h,
+ latitude_ref.new_ones(latitude_ref.shape[0]) * w,
+ )
+ params = {"width": w, "height": h, "vfov": init_vfov}
+ camera = Camera.from_dict(params)
+ gravity = Gravity.from_rp(init_r, init_p)
+ return camera, gravity
+
+ def _forward(self, data):
+ """Forward pass of optimization model."""
+
+ assert data["up_field"].shape[0] == 1, "Batch size must be 1 for optimization model."
+
+ # detach all tensors to avoid backprop
+ for k, v in data.items():
+ if isinstance(v, torch.Tensor):
+ data[k] = v.detach()
+
+ camera_init, gravity_init = self._get_init_params(data)
+ return self.optimize(data, camera_init, gravity_init)
+
+ def metrics(self, pred, data):
+ pred_cam, gt_cam = pred["camera_opt"], data["camera"]
+ pred_grav, gt_grav = pred["gravity_opt"], data["gravity"]
+
+ return {
+ "roll_opt_error": roll_error(pred_grav, gt_grav),
+ "pitch_opt_error": pitch_error(pred_grav, gt_grav),
+ "vfov_opt_error": vfov_error(pred_cam, gt_cam),
+ }
+
+ def loss(self, pred, data):
+ """No loss function for this optimization model."""
+ return {"opt_param_total": 0}, self.metrics(pred, data)
diff --git a/siclib/models/optimization/ransac.py b/siclib/models/optimization/ransac.py
new file mode 100644
index 0000000000000000000000000000000000000000..04c9ca658059d3aa8e2c71322b6a4fa07f5cd342
--- /dev/null
+++ b/siclib/models/optimization/ransac.py
@@ -0,0 +1,407 @@
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from siclib.geometry.camera import Pinhole
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.perspective_fields import get_latitude_field, get_up_field
+from siclib.models.base_model import BaseModel
+from siclib.models.utils.metrics import (
+ latitude_error,
+ pitch_error,
+ roll_error,
+ up_error,
+ vfov_error,
+)
+from siclib.utils.conversions import skew_symmetric
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def get_up_lines(up, xy):
+ up_lines = torch.cat([up, torch.zeros_like(up[..., :1])], dim=-1)
+
+ xy1 = torch.cat([xy, torch.ones_like(xy[..., :1])], dim=-1)
+
+ xy2 = xy1 + up_lines
+
+ return torch.einsum("...ij,...j->...i", skew_symmetric(xy1), xy2)
+
+
+def calculate_vvp(line1, line2):
+ return torch.einsum("...ij,...j->...i", skew_symmetric(line1), line2)
+
+
+def calculate_vvps(xs, ys, up):
+ xy_grav = torch.stack([xs[..., :2], ys[..., :2]], dim=-1).float()
+ up_lines = get_up_lines(up, xy_grav) # (B, N, 2, D)
+ vvp = calculate_vvp(up_lines[..., 0, :], up_lines[..., 1, :]) # (B, N, 3)
+ vvp = vvp / vvp[..., (2,)]
+ return vvp
+
+
+def get_up_samples(pred, xs, ys):
+ B, N = xs.shape[:2]
+ batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device)
+ zeros = torch.zeros_like(xs).to(xs.device)
+ ones = torch.ones_like(xs).to(xs.device)
+ sample_indices_x = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4)
+ sample_indices_y = torch.stack([batch_indices, ones, ys, xs], dim=-1).long() # (B, N, 3, 4)
+ up_x = pred["up_field"][sample_indices_x[..., (0, 1), :].unbind(-1)] # (B, N, 2)
+ up_y = pred["up_field"][sample_indices_y[..., (0, 1), :].unbind(-1)] # (B, N, 2)
+ return torch.stack([up_x, up_y], dim=-1) # (B, N, 2, D)
+
+
+def get_latitude_samples(pred, xs, ys):
+ # Setup latitude
+ B, N = xs.shape[:2]
+ batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device)
+ zeros = torch.zeros_like(xs).to(xs.device)
+ sample_indices = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4)
+ latitude = pred["latitude_field"][sample_indices[..., 2, :].unbind(-1)]
+ return torch.sin(latitude) # (B, N)
+
+
+class MinimalSolver:
+ def __init__(self):
+ pass
+
+ @staticmethod
+ def solve_focal(
+ L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Solve for focal length.
+
+ Args:
+ L (torch.Tensor): Latitude samples.
+ xy (torch.Tensor): xy of latitude samples of shape (..., 2).
+ vvp (torch.Tensor): Vertical vanishing points of shape (..., 3).
+ c (torch.Tensor): Principal points of shape (..., 2).
+
+ Returns:
+ Tuple[torch.Tensor, torch.Tensor]: Positive and negative solution of focal length.
+ """
+ c = c.unsqueeze(1)
+ u, v = (xy - c).unbind(-1)
+
+ vx, vy, vz = vvp.unbind(-1)
+ cx, cy = c.unbind(-1)
+ vx = vx - cx * vz
+ vy = vy - cy * vz
+
+ # Solve quadratic equation
+ a0 = (L**2 - 1) * vz**2
+ a1 = L**2 * (vz**2 * (u**2 + v**2) + vx**2 + vy**2) - 2 * vz * (vx * u + vy * v)
+ a2 = L**2 * (v**2 + u**2) * (vx**2 + vy**2) - (u * vx + v * vy) ** 2
+
+ a0 = torch.where(a0 == 0, torch.ones_like(a0) * 1e-6, a0)
+
+ f2_pos = -a1 / (2 * a0) + torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0)
+ f2_neg = -a1 / (2 * a0) - torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0)
+
+ f_pos, f_neg = torch.sqrt(f2_pos), torch.sqrt(f2_neg)
+
+ return f_pos, f_neg
+
+ @staticmethod
+ def solve_scale(
+ L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor
+ ) -> torch.Tensor:
+ """Solve for scale of homogeneous vector.
+
+ Args:
+ L (torch.Tensor): Latitude samples.
+ xy (torch.Tensor): xy of latitude samples of shape (..., 2).
+ vvp (torch.Tensor): Vertical vanishing points of shape (..., 3).
+ c (torch.Tensor): Principal points of shape (..., 2).
+ f (torch.Tensor): Focal lengths.
+
+ Returns:
+ torch.Tensor: Estimated scales.
+ """
+ c = c.unsqueeze(1)
+ u, v = (xy - c).unbind(-1)
+
+ vx, vy, vz = vvp.unbind(-1)
+ cx, cy = c.unbind(-1)
+ vx = vx - cx * vz
+ vy = vy - cy * vz
+
+ w2 = (f**2 * L**2 * (u**2 + v**2 + f**2)) / (vx * u + vy * v + vz * f**2) ** 2
+ return torch.sqrt(w2)
+
+ @staticmethod
+ def solve_abc(
+ vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor, w: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """Solve for abc vector (solution to homogeneous equation).
+
+ Args:
+ vvp (torch.Tensor): Vertical vanishing points of shape (..., 3).
+ c (torch.Tensor): Principal points of shape (..., 2).
+ f (torch.Tensor): Focal lengths.
+ w (torch.Tensor): Scales.
+
+ Returns:
+ torch.Tensor: Estimated abc vector.
+ """
+ vx, vy, vz = vvp.unbind(-1)
+ cx, cy = c.unsqueeze(1).unbind(-1)
+ vx = vx - cx * vz
+ vy = vy - cy * vz
+
+ a = vx / f
+ b = vy / f
+ c = vz
+
+ abc = torch.stack((a, b, c), dim=-1)
+
+ return F.normalize(abc, dim=-1) if w is None else abc * w.unsqueeze(-1)
+
+ @staticmethod
+ def solve_rp(abc: torch.Tensor) -> torch.Tensor:
+ """Solve for roll, pitch.
+
+ Args:
+ abc (torch.Tensor): Estimated abc vector.
+
+ Returns:
+ torch.Tensor: Estimated roll, pitch, focal length.
+ """
+ a, _, c = abc.unbind(-1)
+ roll = torch.asin(-a / torch.sqrt(1 - c**2))
+ pitch = torch.asin(c)
+ return roll, pitch
+
+
+class RPFSolver(BaseModel):
+ default_conf = {
+ "n_iter": 1000,
+ "up_inlier_th": 1,
+ "latitude_inlier_th": 1,
+ "error_fn": "angle", # angle or mse
+ "up_weight": 1,
+ "latitude_weight": 1,
+ "loss_weight": 1,
+ "use_latitude": True,
+ }
+
+ def _init(self, conf):
+ self.solver = MinimalSolver()
+
+ def check_up_inliers(self, pred, est_camera, est_gravity, N=1):
+ pred_up = pred["up_field"]
+ # expand from from (B, 1, H, W) to (B * N, 1, H, W)
+ B = pred_up.shape[0]
+ pred_up = pred_up.unsqueeze(1).expand(-1, N, -1, -1, -1)
+ pred_up = pred_up.reshape(B * N, *pred_up.shape[2:])
+
+ est_up = get_up_field(est_camera, est_gravity).permute(0, 3, 1, 2)
+
+ if self.conf.error_fn == "angle":
+ mse = up_error(est_up, pred_up)
+ elif self.conf.error_fn == "mse":
+ mse = F.mse_loss(est_up, pred_up, reduction="none").mean(1)
+ else:
+ raise ValueError(f"Unknown error function: {self.conf.error_fn}")
+
+ # shape (B, H, W)
+ conf = pred.get("up_confidence", pred_up.new_ones(pred_up.shape[0], *pred_up.shape[-2:]))
+ # shape (B, N, H, W)
+ conf = conf.unsqueeze(1).expand(-1, N, -1, -1)
+ # shape (B * N, H, W)
+ conf = conf.reshape(B * N, *conf.shape[-2:])
+
+ return (mse < self.conf.up_inlier_th) * conf
+
+ def check_latitude_inliers(self, pred, est_camera, est_gravity, N=1):
+ B = pred["up_field"].shape[0]
+ pred_latitude = pred.get("latitude_field")
+
+ if pred_latitude is None:
+ shape = (B * N, *pred["up_field"].shape[-2:])
+ return est_camera.new_zeros(shape)
+
+ # expand from from (B, 1, H, W) to (B * N, 1, H, W)
+ pred_latitude = pred_latitude.unsqueeze(1).expand(-1, N, -1, -1, -1)
+ pred_latitude = pred_latitude.reshape(B * N, *pred_latitude.shape[2:])
+
+ est_latitude = get_latitude_field(est_camera, est_gravity).permute(0, 3, 1, 2)
+
+ if self.conf.error_fn == "angle":
+ error = latitude_error(est_latitude, pred_latitude)
+ elif self.conf.error_fn == "mse":
+ error = F.mse_loss(est_latitude, pred_latitude, reduction="none").mean(1)
+ else:
+ raise ValueError(f"Unknown error function: {self.conf.error_fn}")
+
+ conf = pred.get(
+ "latitude_confidence",
+ pred_latitude.new_ones(pred_latitude.shape[0], *pred_latitude.shape[-2:]),
+ )
+ conf = conf.unsqueeze(1).expand(-1, N, -1, -1)
+ conf = conf.reshape(B * N, *conf.shape[-2:])
+ return (error < self.conf.latitude_inlier_th) * conf
+
+ def get_best_index(self, data, camera, gravity, inliers=None):
+ B, _, H, W = data["up_field"].shape
+ N = self.conf.n_iter
+
+ up_inliers = self.check_up_inliers(data, camera, gravity, N)
+ latitude_inliers = self.check_latitude_inliers(data, camera, gravity, N)
+
+ up_inliers = up_inliers.reshape(B, N, H, W)
+ latitude_inliers = latitude_inliers.reshape(B, N, H, W)
+
+ if inliers is not None:
+ up_inliers = up_inliers * inliers.unsqueeze(1)
+ latitude_inliers = latitude_inliers * inliers.unsqueeze(1)
+
+ up_inliers = up_inliers.sum((2, 3))
+ latitude_inliers = latitude_inliers.sum((2, 3))
+
+ total_inliers = (
+ self.conf.up_weight * up_inliers + self.conf.latitude_weight * latitude_inliers
+ )
+
+ best_idx = total_inliers.argmax(-1)
+
+ return best_idx, total_inliers[torch.arange(B), best_idx]
+
+ def solve_rpf(self, pred, xs, ys, principal_points, focal=None):
+ device = pred["up_field"].device
+
+ # Get samples
+ up = get_up_samples(pred, xs, ys)
+
+ # Calculate vvps
+ vvp = calculate_vvps(xs, ys, up).to(device)
+
+ # Solve for focal length
+ xy = torch.stack([xs[..., 2], ys[..., 2]], dim=-1).float()
+ if focal is not None:
+ f = focal.new_ones(xs[..., 2].shape) * focal.unsqueeze(-1)
+ f_pos, f_neg = f, f
+ else:
+ L = get_latitude_samples(pred, xs, ys)
+ f_pos, f_neg = self.solver.solve_focal(L, xy, vvp, principal_points)
+
+ # Solve for abc
+ abc_pos = self.solver.solve_abc(vvp, principal_points, f_pos)
+ abc_neg = self.solver.solve_abc(vvp, principal_points, f_neg)
+
+ # Solve for roll, pitch
+ roll_pos, pitch_pos = self.solver.solve_rp(abc_pos)
+ roll_neg, pitch_neg = self.solver.solve_rp(abc_neg)
+
+ rpf_pos = torch.stack([roll_pos, pitch_pos, f_pos], dim=-1)
+ rpf_neg = torch.stack([roll_neg, pitch_neg, f_neg], dim=-1)
+
+ return rpf_pos, rpf_neg
+
+ def get_camera_and_gravity(self, pred, rpf):
+ B, _, H, W = pred["up_field"].shape
+ N = rpf.shape[1]
+
+ w = pred["up_field"].new_ones(B, N) * W
+ h = pred["up_field"].new_ones(B, N) * H
+ cx = w / 2.0
+ cy = h / 2.0
+
+ roll, pitch, focal = rpf.unbind(-1)
+
+ params = torch.stack([w, h, focal, focal, cx, cy], dim=-1)
+ params = params.reshape(B * N, params.shape[-1])
+ cam = Pinhole(params)
+
+ roll, pitch = roll.reshape(B * N), pitch.reshape(B * N)
+ gravity = Gravity.from_rp(roll, pitch)
+
+ return cam, gravity
+
+ def _forward(self, data):
+ device = data["up_field"].device
+ B, _, H, W = data["up_field"].shape
+
+ principal_points = torch.tensor([H / 2.0, W / 2.0]).expand(B, 2).to(device)
+
+ if not self.conf.use_latitude and "latitude_field" in data:
+ data.pop("latitude_field")
+
+ if "inliers" in data:
+ indices = torch.nonzero(data["inliers"] == 1, as_tuple=False)
+ batch_indices = torch.unique(indices[:, 0])
+
+ sampled_indices = []
+ for batch_index in batch_indices:
+ batch_mask = indices[:, 0] == batch_index
+
+ batch_indices_sampled = np.random.choice(
+ batch_mask.sum(), self.conf.n_iter * 3, replace=True
+ )
+ batch_indices_sampled = batch_indices_sampled.reshape(self.conf.n_iter, 3)
+ sampled_indices.append(indices[batch_mask][batch_indices_sampled][:, :, 1:])
+
+ ys, xs = torch.stack(sampled_indices, dim=0).unbind(-1)
+
+ else:
+ xs = torch.randint(0, W, (B, self.conf.n_iter, 3)).to(device)
+ ys = torch.randint(0, H, (B, self.conf.n_iter, 3)).to(device)
+
+ rpf_pos, rpf_neg = self.solve_rpf(
+ data, xs, ys, principal_points, focal=data.get("prior_focal")
+ )
+
+ cams_pos, gravity_pos = self.get_camera_and_gravity(data, rpf_pos)
+ cams_neg, gravity_neg = self.get_camera_and_gravity(data, rpf_neg)
+
+ inliers = data.get("inliers", None)
+ best_pos, score_pos = self.get_best_index(data, cams_pos, gravity_pos, inliers)
+ best_neg, score_neg = self.get_best_index(data, cams_neg, gravity_neg, inliers)
+
+ rpf = rpf_pos[torch.arange(B), best_pos]
+ rpf[score_neg > score_pos] = rpf_neg[torch.arange(B), best_neg][score_neg > score_pos]
+
+ cam, gravity = self.get_camera_and_gravity(data, rpf.unsqueeze(1))
+
+ return {
+ "camera_opt": cam,
+ "gravity_opt": gravity,
+ "up_inliers": self.check_up_inliers(data, cam, gravity),
+ "latitude_inliers": self.check_latitude_inliers(data, cam, gravity),
+ }
+
+ def metrics(self, pred, data):
+ pred_cam, gt_cam = pred["camera_opt"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"]
+
+ return {
+ "roll_opt_error": roll_error(pred_gravity, gt_gravity),
+ "pitch_opt_error": pitch_error(pred_gravity, gt_gravity),
+ "vfov_opt_error": vfov_error(pred_cam, gt_cam),
+ }
+
+ def loss(self, pred, data):
+ pred_cam, gt_cam = pred["camera_opt"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"]
+
+ h = data["camera"].size[0, 0]
+
+ gravity_loss = F.l1_loss(pred_gravity.vec3d, gt_gravity.vec3d, reduction="none")
+ focal_loss = F.l1_loss(pred_cam.f, gt_cam.f, reduction="none").sum(-1) / h
+
+ total_loss = gravity_loss.sum(-1)
+ if self.conf.estimate_focal:
+ total_loss = total_loss + focal_loss
+
+ losses = {
+ "opt_gravity": gravity_loss.sum(-1),
+ "opt_focal": focal_loss,
+ "opt_param_total": total_loss,
+ }
+
+ losses = {k: v * self.conf.loss_weight for k, v in losses.items()}
+ return losses, self.metrics(pred, data)
diff --git a/siclib/models/optimization/utils.py b/siclib/models/optimization/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aae429134c3c34f78ec4df439c83c2df58e0509
--- /dev/null
+++ b/siclib/models/optimization/utils.py
@@ -0,0 +1,172 @@
+import logging
+from typing import Dict
+
+import torch
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.gravity import Gravity
+from siclib.utils.conversions import deg2rad, focal2fov
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def get_initial_estimation(
+ data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True
+) -> BaseCamera:
+ """Get initial camera for optimization using heuristics."""
+ return (
+ get_trivial_estimation(data, camera_model)
+ if trivial_init
+ else get_heuristic_estimation(data, camera_model)
+ )
+
+
+def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
+ """Get initial camera for optimization using heuristics.
+
+ Initial camera is initialized with the following heuristics:
+ - roll is the angle of the up vector at the center of the image
+ - pitch is the value at the center of the latitude map
+ - vfov is the difference between the central top and bottom of the latitude map
+ - distortions are set to zero
+
+ Use the prior values if available.
+
+ Args:
+ data (Dict[str, torch.Tensor]): Input data dictionary.
+ camera_model (BaseCamera): Camera model to use.
+
+ Returns:
+ BaseCamera: Initial camera for optimization.
+ """
+ up_ref = data["up_field"].detach()
+ latitude_ref = data["latitude_field"].detach()
+
+ h, w = up_ref.shape[-2:]
+ batch_h, batch_w = (
+ up_ref.new_ones((up_ref.shape[0],)) * h,
+ up_ref.new_ones((up_ref.shape[0],)) * w,
+ )
+
+ # init roll is angle of the up vector at the center of the image
+ init_r = -torch.atan2(
+ up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)]
+ )
+ init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45))
+
+ # init pitch is the value at the center of the latitude map
+ init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
+ init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45))
+
+ # init vfov is the difference between the central top and bottom of the latitude map
+ init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
+ init_vfov = torch.abs(init_vfov)
+ init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))
+
+ focal = data.get("prior_focal")
+ init_vfov = init_vfov if focal is None else focal2fov(focal, h)
+
+ params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
+ params |= {"scales": data["scales"]} if "scales" in data else {}
+ params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
+ camera = camera_model.from_dict(params)
+ camera = camera.float().to(data["up_field"].device)
+
+ gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device)
+ if "prior_gravity" in data:
+ gravity = data["prior_gravity"].float().to(up_ref.device)
+
+ return camera, gravity
+
+
+def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
+
+ Args:
+ data (Dict[str, torch.Tensor]): Input data dictionary.
+ camera_model (BaseCamera): Camera model to use.
+
+ Returns:
+ BaseCamera: Initial camera for optimization.
+ """
+ """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
+ ref = data.get("up_field", data["latitude_field"])
+ ref = ref.detach()
+
+ h, w = ref.shape[-2:]
+ batch_h, batch_w = (
+ ref.new_ones((ref.shape[0],)) * h,
+ ref.new_ones((ref.shape[0],)) * w,
+ )
+
+ init_r = ref.new_zeros((ref.shape[0],))
+ init_p = ref.new_zeros((ref.shape[0],))
+
+ focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
+ init_vfov = init_vfov if focal is None else focal2fov(focal, h)
+
+ params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
+ params |= {"scales": data["scales"]} if "scales" in data else {}
+ params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
+ camera = camera_model.from_dict(params)
+ camera = camera.float().to(ref.device)
+
+ gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
+
+ if "prior_gravity" in data:
+ gravity = data["prior_gravity"].float().to(ref.device)
+
+ return camera, gravity
+
+
+def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
+ """Early stopping criterion based on cost convergence."""
+ return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
+
+
+def update_lambda(
+ lamb: torch.Tensor,
+ prev_cost: torch.Tensor,
+ new_cost: torch.Tensor,
+ lambda_min: float = 1e-6,
+ lambda_max: float = 1e2,
+) -> torch.Tensor:
+ """Update damping factor for Levenberg-Marquardt optimization."""
+ new_lamb = lamb.new_zeros(lamb.shape)
+ new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
+ lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
+ return lamb
+
+
+def optimizer_step(
+ G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
+) -> torch.Tensor:
+ """One optimization step with Gauss-Newton or Levenberg-Marquardt.
+
+ Args:
+ G (torch.Tensor): Batched gradient tensor of size (..., N).
+ H (torch.Tensor): Batched hessian tensor of size (..., N, N).
+ lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
+ eps (float, optional): Epsilon for damping. Defaults to 1e-6.
+
+ Returns:
+ torch.Tensor: Batched update tensor of size (..., N).
+ """
+ diag = H.diagonal(dim1=-2, dim2=-1)
+ diag = diag * lambda_.unsqueeze(-1) # (B, 3)
+
+ H = H + diag.clamp(min=eps).diag_embed()
+
+ H_, G_ = H.cpu(), G.cpu()
+ try:
+ U = torch.linalg.cholesky(H_)
+ except RuntimeError:
+ logger.warning("Cholesky decomposition failed. Stopping.")
+ delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
+ else:
+ delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
+
+ return delta.to(H.device)
diff --git a/siclib/models/optimization/vp_from_prior.py b/siclib/models/optimization/vp_from_prior.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8cd333573dffc7ff2fc8b1ffbbe751a6aae3d4b
--- /dev/null
+++ b/siclib/models/optimization/vp_from_prior.py
@@ -0,0 +1,182 @@
+"""Wrapper for VP estimation with prior gravity using the VP-Estimation-with-Prior-Gravity library.
+
+repo: https://github.com/cvg/VP-Estimation-with-Prior-Gravity
+"""
+
+import sys
+
+sys.path.append("third_party/VP-Estimation-with-Prior-Gravity")
+sys.path.append("third_party/VP-Estimation-with-Prior-Gravity/src/deeplsd")
+
+import logging
+import random
+
+import cv2
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from vp_estimation_with_prior_gravity.evaluation import get_labels_from_vp, project_vp_to_image
+from vp_estimation_with_prior_gravity.features.line_detector import LineDetector
+from vp_estimation_with_prior_gravity.solvers import run_hybrid_uncalibrated
+from vp_estimation_with_prior_gravity.visualization import plot_images, plot_lines, plot_vp
+
+from siclib.geometry.camera import Pinhole
+from siclib.geometry.gravity import Gravity
+from siclib.models import BaseModel
+from siclib.models.utils.metrics import gravity_error, pitch_error, roll_error, vfov_error
+
+# flake8: noqa
+# mypy: ignore-errors
+
+logger = logging.getLogger(__name__)
+
+
+class VPEstimator(BaseModel):
+ # Which solvers to us for our hybrid solver:
+ # 0 - 2lines 200g
+ # 1 - 2lines 110g
+ # 2 - 2lines 011g
+ # 3 - 4lines 211
+ # 4 - 4lines 220
+ default_conf = {
+ "SOLVER_FLAGS": [True, True, True, True, True],
+ "th_pixels": 3, # RANSAC inlier threshold
+ "ls_refinement": 2, # 3 uses the gravity in the LS refinement, 2 does not.
+ "nms": 3, # change to 3 to add a Ceres optimization after the non minimal solver (slower)
+ "magsac_scoring": True,
+ "line_type": "deeplsd", # 'lsd' or 'deeplsd'
+ "min_lines": 5, # only trust images with at least this many lines
+ "verbose": False,
+ }
+
+ def _init(self, conf):
+ if conf.SOLVER_FLAGS in [
+ [True, False, False, False, False],
+ [False, False, True, False, False],
+ ]:
+ self.vertical = np.array([random.random() / 1e12, 1, random.random() / 1e12])
+ self.vertical /= np.linalg.norm(self.vertical)
+ else:
+ self.vertical = np.array([0.0, 1, 0.0])
+
+ self.line_detector = LineDetector(line_detector=conf.line_type)
+
+ self.verbose = conf.verbose
+
+ def visualize_lines(self, vp, lines, img, K):
+ vp_labels = get_labels_from_vp(
+ lines[:, :, [1, 0]], project_vp_to_image(vp, K), threshold=self.conf.th_pixels
+ )[0]
+
+ plot_images([img, img])
+ plot_lines([lines, np.empty((0, 2, 2))])
+ plot_vp([np.empty((0, 2, 2)), lines], [[], vp_labels])
+
+ plt.show()
+
+ def get_vvp(self, vp, K):
+ best_idx, best_cossim = 0, -1
+ for i, point in enumerate(vp):
+ cossim = np.dot(self.vertical, point) / np.linalg.norm(point)
+ point = -point * np.dot(self.vertical, point)
+ try:
+ gravity = Gravity(np.linalg.inv(K) @ point)
+ except:
+ continue
+
+ if (
+ np.abs(cossim) > best_cossim
+ and gravity.pitch.abs() <= np.pi / 4
+ and gravity.roll.abs() <= np.pi / 4
+ ):
+ best_idx, best_cossim = i, np.abs(cossim)
+
+ vvp = vp[best_idx]
+ return -vvp * np.sign(np.dot(self.vertical, vvp))
+
+ def _forward(self, data):
+ device = data["image"].device
+ images = data["image"].cpu()
+
+ estimations = []
+ for idx, img in enumerate(images.unbind(0)):
+ if "prior_gravity" in data:
+ self.vertical = -data["prior_gravity"][idx].vec3d.cpu().numpy()
+ else:
+ self.vertical = np.array([0.0, 1, 0.0])
+
+ img = img.numpy().transpose(1, 2, 0) * 255
+ img = img.astype(np.uint8)
+ gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
+
+ lines = self.line_detector.detect_lines(gray_img)[:, :, [1, 0]]
+
+ if len(lines) < self.conf.min_lines:
+ logger.warning("Not enough lines detected! Skipping...")
+ gravity = Gravity.from_rp(np.nan, np.nan)
+ camera = Pinhole.from_dict(
+ {"f": np.nan, "height": img.shape[0], "width": img.shape[1]}
+ )
+ estimations.append({"camera": camera, "gravity": gravity})
+ continue
+
+ principle_point = np.array([img.shape[1] / 2.0, img.shape[0] / 2.0])
+ f, vp = run_hybrid_uncalibrated(
+ lines - principle_point[None, None, :],
+ self.vertical,
+ th_pixels=self.conf.th_pixels,
+ ls_refinement=self.conf.ls_refinement,
+ nms=self.conf.nms,
+ magsac_scoring=self.conf.magsac_scoring,
+ sprt=True,
+ solver_flags=self.conf.SOLVER_FLAGS,
+ )
+ vp[:, 1] *= -1
+
+ K = np.array(
+ [[f, 0.0, principle_point[0]], [0.0, f, principle_point[1]], [0.0, 0.0, 1.0]]
+ )
+
+ if self.verbose:
+ self.visualize_lines(vp, lines, img, K)
+
+ vp_labels = get_labels_from_vp(
+ lines[:, :, [1, 0]], project_vp_to_image(vp, K), threshold=self.conf.th_pixels
+ )[0]
+ out = {"vp": vp, "lines": lines, "K": K, "vp_labels": vp_labels}
+
+ vp = project_vp_to_image(vp, K)
+
+ vvp = self.get_vvp(vp, K)
+
+ vvp = -vvp * np.sign(np.dot(self.vertical, vvp))
+ try:
+ K_inv = np.linalg.inv(K)
+ gravity = Gravity(K_inv @ vvp)
+ except np.linalg.LinAlgError:
+ gravity = Gravity.from_rp(np.nan, np.nan)
+
+ camera = Pinhole.from_dict({"f": f, "height": img.shape[0], "width": img.shape[1]})
+ estimations.append({"camera": camera, "gravity": gravity})
+
+ if len(estimations) == 0:
+ return {}
+
+ gravity = torch.stack([Gravity(est["gravity"].vec3d) for est in estimations], dim=0)
+ camera = torch.stack([Pinhole(est["camera"]._data) for est in estimations], dim=0)
+
+ return {"camera": camera.float().to(device), "gravity": gravity.float().to(device)} | out
+
+ def metrics(self, pred, data):
+ pred_cam, gt_cam = pred["camera_opt"], data["camera"]
+ pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"]
+
+ return {
+ "roll_opt_error": roll_error(pred_gravity, gt_gravity),
+ "pitch_opt_error": pitch_error(pred_gravity, gt_gravity),
+ "gravity_opt_error": gravity_error(pred_gravity, gt_gravity),
+ "vfov_opt_error": vfov_error(pred_cam, gt_cam),
+ }
+
+ def loss(self, pred, data):
+ return {}, self.metrics(pred, data)
diff --git a/siclib/models/utils/__init__.py b/siclib/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/models/utils/metrics.py b/siclib/models/utils/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a51c7fbea501185d8200b7e35045a6123d5f045
--- /dev/null
+++ b/siclib/models/utils/metrics.py
@@ -0,0 +1,123 @@
+"""Various metrics for evaluating predictions."""
+
+import logging
+
+import torch
+from torch.nn import functional as F
+
+from siclib.geometry.base_camera import BaseCamera
+from siclib.geometry.gravity import Gravity
+from siclib.utils.conversions import rad2deg
+
+logger = logging.getLogger(__name__)
+
+
+def pitch_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
+ """Computes the pitch error between two gravities.
+
+ Args:
+ pred_gravity (Gravity): Predicted camera.
+ target_gravity (Gravity): Ground truth camera.
+
+ Returns:
+ torch.Tensor: Pitch error in degrees.
+ """
+ return rad2deg(torch.abs(pred_gravity.pitch - target_gravity.pitch))
+
+
+def roll_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
+ """Computes the roll error between two gravities.
+
+ Args:
+ pred_gravity (Gravity): Predicted Gravity.
+ target_gravity (Gravity): Ground truth Gravity.
+
+ Returns:
+ torch.Tensor: Roll error in degrees.
+ """
+ return rad2deg(torch.abs(pred_gravity.roll - target_gravity.roll))
+
+
+def gravity_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
+ """Computes the gravity error between two gravities.
+
+ Args:
+ pred_gravity (Gravity): Predicted Gravity.
+ target_gravity (Gravity): Ground truth Gravity.
+
+ Returns:
+ torch.Tensor: Gravity error in degrees.
+ """
+ assert (
+ pred_gravity.vec3d.shape == target_gravity.vec3d.shape
+ ), f"{pred_gravity.vec3d.shape} != {target_gravity.vec3d.shape}"
+ assert pred_gravity.vec3d.ndim == 2, f"{pred_gravity.vec3d.ndim} != 2"
+ assert pred_gravity.vec3d.shape[1] == 3, f"{pred_gravity.vec3d.shape[1]} != 3"
+
+ cossim = F.cosine_similarity(pred_gravity.vec3d, target_gravity.vec3d, dim=-1).clamp(-1, 1)
+ return rad2deg(torch.acos(cossim))
+
+
+def vfov_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
+ """Computes the vertical field of view error between two cameras.
+
+ Args:
+ pred_cam (Camera): Predicted camera.
+ target_cam (Camera): Ground truth camera.
+
+ Returns:
+ torch.Tensor: Vertical field of view error in degrees.
+ """
+ return rad2deg(torch.abs(pred_cam.vfov - target_cam.vfov))
+
+
+def dist_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
+ """Computes the distortion parameter error between two cameras.
+
+ Returns zero if the cameras do not have distortion parameters.
+
+ Args:
+ pred_cam (Camera): Predicted camera.
+ target_cam (Camera): Ground truth camera.
+
+ Returns:
+ torch.Tensor: distortion error.
+ """
+ if hasattr(pred_cam, "dist") and hasattr(target_cam, "dist"):
+ return torch.abs(pred_cam.dist[..., 0] - target_cam.dist[..., 0])
+
+ logger.debug(
+ f"Predicted / target camera doesn't have distortion parameters: {pred_cam}/{target_cam}"
+ )
+ return pred_cam.new_zeros(pred_cam.f.shape[0])
+
+
+def latitude_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Computes the latitude error between two tensors.
+
+ Args:
+ predictions (torch.Tensor): Predicted latitude field of shape (B, 1, H, W).
+ targets (torch.Tensor): Ground truth latitude field of shape (B, 1, H, W).
+
+ Returns:
+ torch.Tensor: Latitude error in degrees of shape (B, H, W).
+ """
+ return rad2deg(torch.abs(predictions - targets)).squeeze(1)
+
+
+def up_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
+ """Computes the up error between two tensors.
+
+ Args:
+ predictions (torch.Tensor): Predicted up field of shape (B, 2, H, W).
+ targets (torch.Tensor): Ground truth up field of shape (B, 2, H, W).
+
+ Returns:
+ torch.Tensor: Up error in degrees of shape (B, H, W).
+ """
+ assert predictions.shape == targets.shape, f"{predictions.shape} != {targets.shape}"
+ assert predictions.ndim == 4, f"{predictions.ndim} != 4"
+ assert predictions.shape[1] == 2, f"{predictions.shape[1]} != 2"
+
+ angle = F.cosine_similarity(predictions, targets, dim=1).clamp(-1, 1)
+ return rad2deg(torch.acos(angle))
diff --git a/siclib/models/utils/modules.py b/siclib/models/utils/modules.py
new file mode 100644
index 0000000000000000000000000000000000000000..c79e2c121ae739ac089b56efe33b4dfa5e31ca33
--- /dev/null
+++ b/siclib/models/utils/modules.py
@@ -0,0 +1,264 @@
+"""Various modules used in the decoder of the model.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields
+"""
+
+import logging
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0 and scale_by_keep:
+ random_tensor.div_(keep_prob)
+ return x * random_tensor
+
+
+class DropPath(nn.Module):
+ """DropBlock, DropPath
+
+ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
+
+ Papers:
+ DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
+
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
+
+ Code:
+ DropBlock impl inspired by two Tensorflow impl:
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
+
+ Hacked together by / Copyright 2020 Ross Wightman
+ """
+
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+ self.scale_by_keep = scale_by_keep
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
+
+ def extra_repr(self):
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
+
+
+class DWConv(nn.Module):
+ def __init__(self, dim=768):
+ super(DWConv, self).__init__()
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
+
+ def forward(self, x):
+ x = self.dwconv(x)
+ return x
+
+
+class MLP(nn.Module):
+ """Linear Embedding."""
+
+ def __init__(self, input_dim=2048, embed_dim=768):
+ super().__init__()
+ self.proj = nn.Linear(input_dim, embed_dim)
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2)
+ x = self.proj(x)
+ return x
+
+
+class ConvModule(nn.Module):
+ """Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ padding: int = 0,
+ use_norm: bool = False,
+ bias: bool = True,
+ ):
+ super().__init__()
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
+ self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity()
+ self.activate = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return self.activate(x)
+
+
+class ResidualConvUnit(nn.Module):
+ """Residual convolution module."""
+
+ def __init__(self, features):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
+
+ self.relu = torch.nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward pass.
+ Args:
+ x (tensor): input
+ Returns:
+ tensor: output
+ """
+ out = self.relu(x)
+ out = self.conv1(out)
+ out = self.relu(out)
+ out = self.conv2(out)
+ return out + x
+
+
+class FeatureFusionBlock(nn.Module):
+ """Feature fusion block."""
+
+ def __init__(self, features, unit2only=False, upsample=True):
+ """Init.
+ Args:
+ features (int): number of features
+ """
+ super().__init__()
+ self.upsample = upsample
+
+ if not unit2only:
+ self.resConfUnit1 = ResidualConvUnit(features)
+ self.resConfUnit2 = ResidualConvUnit(features)
+
+ def forward(self, *xs):
+ """Forward pass."""
+ output = xs[0]
+
+ if len(xs) == 2:
+ output = output + self.resConfUnit1(xs[1])
+
+ output = self.resConfUnit2(output)
+
+ if self.upsample:
+ output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False)
+
+ return output
+
+
+class _DenseLayer(nn.Module):
+ def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient):
+ super().__init__()
+ self.norm1 = nn.BatchNorm2d(num_input_features)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.conv1 = nn.Conv2d(
+ num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False
+ )
+
+ self.norm2 = nn.BatchNorm2d(bn_size * growth_rate)
+ self.relu2 = nn.ReLU(inplace=True)
+ self.conv2 = nn.Conv2d(
+ bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False
+ )
+
+ self.drop_rate = float(drop_rate)
+ self.memory_efficient = memory_efficient
+
+ def bn_function(self, inputs):
+ concated_features = torch.cat(inputs, 1)
+ return self.conv1(self.relu1(self.norm1(concated_features)))
+
+ def any_requires_grad(self, inp):
+ return any(tensor.requires_grad for tensor in inp)
+
+ @torch.jit.unused # noqa: T484
+ def call_checkpoint_bottleneck(self, inp):
+ def closure(*inputs):
+ return self.bn_function(inputs)
+
+ return cp.checkpoint(closure, *inp)
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, inp) -> Tensor: # noqa: F811
+ pass
+
+ @torch.jit._overload_method # noqa: F811
+ def forward(self, inp): # noqa: F811
+ pass
+
+ # torchscript does not yet support *args, so we overload method
+ # allowing it to take either a List[Tensor] or single Tensor
+ def forward(self, inp): # noqa: F811
+ prev_features = [inp] if isinstance(inp, Tensor) else inp
+ if self.memory_efficient and self.any_requires_grad(prev_features):
+ if torch.jit.is_scripting():
+ raise Exception("Memory Efficient not supported in JIT")
+
+ bottleneck_output = self.call_checkpoint_bottleneck(prev_features)
+ else:
+ bottleneck_output = self.bn_function(prev_features)
+
+ new_features = self.conv2(self.relu2(self.norm2(bottleneck_output)))
+ if self.drop_rate > 0:
+ new_features = F.dropout(new_features, p=self.drop_rate, training=self.training)
+ return new_features
+
+
+class _DenseBlock(nn.ModuleDict):
+ _version = 2
+
+ def __init__(
+ self,
+ num_layers,
+ num_input_features,
+ bn_size,
+ growth_rate,
+ drop_rate,
+ memory_efficient=False,
+ ):
+ super().__init__()
+ for i in range(num_layers):
+ layer = _DenseLayer(
+ num_input_features + i * growth_rate,
+ growth_rate=growth_rate,
+ bn_size=bn_size,
+ drop_rate=drop_rate,
+ memory_efficient=memory_efficient,
+ )
+ self.add_module("denselayer%d" % (i + 1), layer)
+
+ def forward(self, init_features):
+ features = [init_features]
+ for name, layer in self.items():
+ new_features = layer(features)
+ features.append(new_features)
+ return torch.cat(features, 1)
+
+
+class _Transition(nn.Sequential):
+ def __init__(self, num_input_features, num_output_features):
+ super().__init__()
+ self.norm = nn.BatchNorm2d(num_input_features)
+ self.relu = nn.ReLU(inplace=True)
+ self.conv = nn.Conv2d(
+ num_input_features, num_output_features, kernel_size=1, stride=1, bias=False
+ )
+ self.pool = nn.AvgPool2d(kernel_size=2, stride=2)
diff --git a/siclib/models/utils/perspective_encoding.py b/siclib/models/utils/perspective_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..1890fbd4caa5508f5768cc9111841f58d61e451d
--- /dev/null
+++ b/siclib/models/utils/perspective_encoding.py
@@ -0,0 +1,82 @@
+"""Perspective field utilities.
+
+Adapted from https://github.com/jinlinyi/PerspectiveFields
+"""
+
+import torch
+
+from siclib.utils.conversions import deg2rad, rad2deg
+
+
+def encode_up_bin(vector_field: torch.Tensor, num_bin: int) -> torch.Tensor:
+ """Encode vector field into classification bins.
+
+ Args:
+ vector_field (torch.Tensor): gravity field of shape (2, h, w), with channel 0 cos(theta) and
+ 1 sin(theta)
+ num_bin (int): number of classification bins
+
+ Returns:
+ torch.Tensor: encoded bin indices of shape (1, h, w)
+ """
+ angle = (
+ torch.atan2(vector_field[1, :, :], vector_field[0, :, :]) / torch.pi * 180 + 180
+ ) % 360 # [0,360)
+ angle_bin = torch.round(torch.div(angle, (360 / (num_bin - 1)))).long()
+ angle_bin[angle_bin == num_bin - 1] = 0
+ invalid = (vector_field == 0).sum(0) == vector_field.size(0)
+ angle_bin[invalid] = num_bin - 1
+ return deg2rad(angle_bin.type(torch.LongTensor))
+
+
+def decode_up_bin(angle_bin: torch.Tensor, num_bin: int) -> torch.Tensor:
+ """Decode classification bins into vector field.
+
+ Args:
+ angle_bin (torch.Tensor): bin indices of shape (1, h, w)
+ num_bin (int): number of classification bins
+
+ Returns:
+ torch.Tensor: decoded vector field of shape (2, h, w)
+ """
+ angle = (angle_bin * (360 / (num_bin - 1)) - 180) / 180 * torch.pi
+ cos = torch.cos(angle)
+ sin = torch.sin(angle)
+ vector_field = torch.stack((cos, sin), dim=1)
+ invalid = angle_bin == num_bin - 1
+ invalid = invalid.unsqueeze(1).repeat(1, 2, 1, 1)
+ vector_field[invalid] = 0
+ return vector_field
+
+
+def encode_bin_latitude(latimap: torch.Tensor, num_classes: int) -> torch.Tensor:
+ """Encode latitude map into classification bins.
+
+ Args:
+ latimap (torch.Tensor): latitude map of shape (h, w) with values in [-90, 90]
+ num_classes (int): number of classes
+
+ Returns:
+ torch.Tensor: encoded latitude bin indices
+ """
+ boundaries = torch.arange(-90, 90, 180 / num_classes)[1:]
+ binmap = torch.bucketize(rad2deg(latimap), boundaries)
+ return binmap.type(torch.LongTensor)
+
+
+def decode_bin_latitude(binmap: torch.Tensor, num_classes: int) -> torch.Tensor:
+ """Decode classification bins to latitude map.
+
+ Args:
+ binmap (torch.Tensor): encoded classification bins
+ num_classes (int): number of classes
+
+ Returns:
+ torch.Tensor: latitude map of shape (h, w)
+ """
+ bin_size = 180 / num_classes
+ bin_centers = torch.arange(-90, 90, bin_size) + bin_size / 2
+ bin_centers = bin_centers.to(binmap.device)
+ latimap = bin_centers[binmap]
+
+ return deg2rad(latimap)
diff --git a/siclib/pose_estimation.py b/siclib/pose_estimation.py
new file mode 100644
index 0000000000000000000000000000000000000000..13854e940712f9a6fb48f6a3965ac97e176f3137
--- /dev/null
+++ b/siclib/pose_estimation.py
@@ -0,0 +1,148 @@
+import pickle
+from pathlib import Path
+
+import numpy as np
+import poselib
+import pycolmap
+
+from siclib.models.extractor import VP
+
+from .models.extractor import GeoCalib
+from .utils.image import load_image
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class AbsolutePoseEstimator:
+ default_opts = {
+ "ransac": "poselib_gravity", # pycolmap, poselib, poselib_gravity
+ "refinement": "pycolmap_gravity", # pycolmap, pycolmap_gravity, none
+ "gravity_weight": 50000,
+ "max_reproj_error": 48.0,
+ "loss_function_scale": 1.0,
+ "use_vp": False,
+ "max_uncertainty": 10.0 / 180.0 * 3.1415, # radians
+ "cache_path": "../../outputs/inloc/calib.pickle",
+ }
+
+ def __init__(self, pose_opts=None):
+ pose_opts = {} if pose_opts is None else pose_opts
+ self.opts = {**self.default_opts, **pose_opts}
+ self.device = "cuda"
+
+ if self.opts["use_vp"]:
+ self.calib = VP().to(self.device)
+ self.cache_path = str(self.opts["cache_path"]).replace(".pickle", "_vp.pickle")
+ else:
+ self.calib = GeoCalib().to(self.device)
+ self.cache_path = str(self.opts["cache_path"])
+
+ # self.read_cache()
+ self.cache = {}
+
+ def read_cache(self):
+ print(f"Reading cache from {self.cache_path} ({Path(self.cache_path).exists()})")
+ if not Path(self.cache_path).exists():
+ self.cache = {}
+ return
+ with open(self.cache_path, "rb") as handle:
+ self.cache = pickle.load(handle)
+
+ def write_cache(self):
+ with open(self.cache_path, "wb") as handle:
+ pickle.dump(self.cache, handle, protocol=pickle.HIGHEST_PROTOCOL)
+
+ def __call__(self, query_path, p2d, p3d, camera_dict):
+ focal_length = pycolmap.Camera(camera_dict).mean_focal_length()
+
+ if query_path in self.cache:
+ calib = self.cache[query_path]
+ else:
+ calib = self.calib.calibrate(
+ load_image(query_path).to(self.device), priors={"f": focal_length}
+ )
+ calib = {k: v[0].detach().cpu().numpy() for k, v in calib.items()}
+ self.cache[query_path] = calib
+ # self.write_cache()
+
+ if self.opts["ransac"] == "pycolmap":
+ ret = pycolmap.absolute_pose_estimation(
+ p2d, p3d, camera_dict, self.opts["max_reproj_error"] # , do_refine=False
+ )
+ elif self.opts["ransac"] == "poselib":
+ M, ret = poselib.estimate_absolute_pose(
+ p2d,
+ p3d,
+ camera_dict,
+ ransac_opt={"max_reproj_error": self.opts["max_reproj_error"]},
+ )
+ ret["success"] = M is not None
+ ret["qvec"] = M.q
+ ret["tvec"] = M.t
+ elif self.opts["ransac"] == "poselib_gravity":
+ g_q = calib["gravity"].vec3d
+ g_qu = calib.get("gravity_uncertainty", self.opts["max_uncertainty"])
+ M, ret = poselib.estimate_absolute_pose_gravity(
+ p2d,
+ p3d,
+ camera_dict,
+ g_q,
+ g_qu * 2 * 180 / 3.1415, # convert to scalar
+ ransac_opt={"max_reproj_error": self.opts["max_reproj_error"]},
+ )
+ ret["success"] = M is not None
+ ret["qvec"] = M.q
+ ret["tvec"] = M.t
+ else:
+ raise NotImplementedError(self.opts["ransac"])
+ r_opts = {
+ "refine_focal_length": False,
+ "refine_extra_params": False,
+ "print_summary": False,
+ "loss_function_scale": self.opts["loss_function_scale"],
+ }
+ if self.opts["refinement"] == "pycolmap_gravity":
+ g_q = calib["gravity"].vec3d
+ g_qu = calib.get("gravity_uncertainty", self.opts["max_uncertainty"])
+ if g_qu <= self.opts["max_uncertainty"]:
+ g_gt = np.array([0, 0, 1]) # world frame
+ ret_ref = pycolmap.pose_refinement_gravity(
+ ret["tvec"],
+ ret["qvec"],
+ p2d,
+ p3d,
+ ret["inliers"],
+ camera_dict,
+ g_q,
+ g_gt,
+ self.opts["gravity_weight"],
+ r_opts,
+ )
+ else:
+ ret_ref = pycolmap.pose_refinement(
+ ret["tvec"],
+ ret["qvec"],
+ p2d,
+ p3d,
+ ret["inliers"],
+ camera_dict,
+ r_opts,
+ )
+ elif self.opts["refinement"] == "pycolmap":
+ ret_ref = pycolmap.pose_refinement(
+ ret["tvec"],
+ ret["qvec"],
+ p2d,
+ p3d,
+ ret["inliers"],
+ camera_dict,
+ r_opts,
+ )
+ elif self.opts["refinement"] == "none":
+ ret_ref = {}
+ else:
+ raise NotImplementedError(self.opts["refinement"])
+ ret = {**ret, **ret_ref}
+ ret["camera_dict"] = camera_dict
+ return ret, calib
diff --git a/siclib/pyproject.toml b/siclib/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..7ce9891101e33ce630d1f25bc130f9066bc64fe3
--- /dev/null
+++ b/siclib/pyproject.toml
@@ -0,0 +1,47 @@
+[build-system]
+requires = ["setuptools", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "siclib"
+version = "1.0"
+description = "Training library for GeoCalib: Learning Single-image Calibration with Geometric Optimization"
+authors = [
+ { name = "Alexander Veicht" },
+ { name = "Paul-Edouard Sarlin" },
+ { name = "Philipp Lindenberger" },
+]
+requires-python = ">=3.9"
+license = { file = "LICENSE" }
+classifiers = [
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+]
+urls = { Repository = "https://github.com/cvg/GeoCalib" }
+
+dynamic = ["dependencies"]
+
+[project.optional-dependencies]
+dev = ["black==23.9.1", "flake8", "isort==5.12.0"]
+
+[tool.setuptools.packages.find]
+where = ["."]
+
+[tool.setuptools.dynamic]
+dependencies = { file = ["requirements.txt"] }
+
+[tool.black]
+line-length = 100
+exclude = "(venv/|docs/|third_party/)"
+
+[tool.isort]
+profile = "black"
+line_length = 100
+atomic = true
+
+[tool.flake8]
+max-line-length = 100
+docstring-convention = "google"
+ignore = ["E203", "W503", "E402"]
+exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"]
diff --git a/siclib/requirements.txt b/siclib/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..3b7748c50891d5ea018d02a7a12853780f06cde6
--- /dev/null
+++ b/siclib/requirements.txt
@@ -0,0 +1,14 @@
+torch
+torchvision
+opencv-python
+kornia
+matplotlib
+
+omegaconf
+albumentations
+h5py
+hydra-core
+pandas
+tqdm
+tensorboard
+wandb
diff --git a/siclib/settings.py b/siclib/settings.py
new file mode 100644
index 0000000000000000000000000000000000000000..3606b983b6ba69d96c4a1e2b14d8b4f0f8c70854
--- /dev/null
+++ b/siclib/settings.py
@@ -0,0 +1,12 @@
+from pathlib import Path
+
+# flake8: noqa
+# mypy: ignore-errors
+try:
+ from settings import DATA_PATH, EVAL_PATH, TRAINING_PATH
+except ModuleNotFoundError:
+ # @TODO: Add a way to patch paths
+ root = Path(__file__).parent.parent # top-level directory
+ DATA_PATH = root / "data/" # datasets and pretrained weights
+ TRAINING_PATH = root / "outputs/training/" # training checkpoints
+ EVAL_PATH = root / "outputs/results/" # evaluation results
diff --git a/siclib/train.py b/siclib/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce1ee134c56ee33ee9b13a452987dedc8b76a912
--- /dev/null
+++ b/siclib/train.py
@@ -0,0 +1,750 @@
+"""
+A generic training script that works with any model and dataset.
+
+Author: Paul-Edouard Sarlin (skydes)
+"""
+
+# Filter annoying warnings
+import warnings
+
+warnings.simplefilter("ignore", UserWarning)
+
+import argparse
+import copy
+import re
+import shutil
+import signal
+from collections import defaultdict
+from pathlib import Path
+from pydoc import locate
+
+import numpy as np
+import torch
+from hydra import compose, initialize
+from omegaconf import OmegaConf
+from torch.cuda.amp import GradScaler, autocast
+from tqdm import tqdm
+
+from siclib import __module_name__, logger
+from siclib.datasets import get_dataset
+from siclib.eval import run_benchmark
+from siclib.models import get_model
+from siclib.settings import EVAL_PATH, TRAINING_PATH
+from siclib.utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment
+from siclib.utils.stdout_capturing import capture_outputs
+from siclib.utils.summary_writer import SummaryWriter
+from siclib.utils.tensor import batch_to_device
+from siclib.utils.tools import (
+ AverageMetric,
+ MedianMetric,
+ PRMetric,
+ RecallMetric,
+ fork_rng,
+ get_device,
+ set_seed,
+)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+# TODO: Fix pbar pollution in logs
+# TODO: add plotting during evaluation
+
+default_train_conf = {
+ "seed": "???", # training seed
+ "epochs": 1, # number of epochs
+ "num_steps": None, # number of steps, overwrites epochs
+ "optimizer": "adam", # name of optimizer in [adam, sgd, rmsprop]
+ "opt_regexp": None, # regular expression to filter parameters to optimize
+ "optimizer_options": {}, # optional arguments passed to the optimizer
+ "lr": 0.001, # learning rate
+ "lr_schedule": {
+ "type": None,
+ "start": 0,
+ "exp_div_10": 0,
+ "on_epoch": False,
+ "factor": 1.0,
+ },
+ "lr_scaling": [(100, ["dampingnet.const"])],
+ "eval_every_iter": 1000, # interval for evaluation on the validation set
+ "save_every_iter": 5000, # interval for saving the current checkpoint
+ "log_every_iter": 200, # interval for logging the loss to the console
+ "log_grad_every_iter": None, # interval for logging gradient hists
+ "writer": "tensorboard", # tensorboard or wandb
+ "test_every_epoch": 1, # interval for evaluation on the test benchmarks
+ "keep_last_checkpoints": 10, # keep only the last X checkpoints
+ "load_experiment": None, # initialize the model from a previous experiment
+ "median_metrics": [], # add the median of some metrics
+ "recall_metrics": {}, # add the recall of some metrics
+ "pr_metrics": {}, # add pr curves, set labels/predictions/mask keys
+ "best_key": "loss/total", # key to use to select the best checkpoint
+ "dataset_callback_fn": None, # data func called at the start of each epoch
+ "dataset_callback_on_val": False, # call data func on val data?
+ "clip_grad": None,
+ "pr_curves": {},
+ "plot": None,
+ "submodules": [],
+}
+default_train_conf = OmegaConf.create(default_train_conf)
+
+
+def get_lr_scheduler(optimizer, conf):
+ """Get lr scheduler specified by conf."""
+ # logger.info(f"Using lr scheduler with conf: {conf}")
+ if conf.type not in ["factor", "exp", None]:
+ if hasattr(conf.options, "schedulers"):
+ # Add option to chain multiple schedulers together
+ # This is useful for e.g. warmup, then cosine decay
+ """Example: {
+ "type": "SequentialLR",
+ "options": {
+ "milestones": [1_000],
+ "schedulers": [
+ {"type": "LinearLR", "options": {"total_iters": 10, "start_factor": 0.001}},
+ {"type": "MultiStepLR", "options": {"milestones": [40, 60], "gamma": 0.1}},
+ ],
+ }
+ }
+ """
+ schedulers = []
+ for scheduler_conf in conf.options.schedulers:
+ scheduler = get_lr_scheduler(optimizer, scheduler_conf)
+ schedulers.append(scheduler)
+
+ options = {k: v for k, v in conf.options.items() if k != "schedulers"}
+ return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, schedulers, **options)
+
+ return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options)
+
+ # backward compatibility
+ def lr_fn(it): # noqa: E306
+ if conf.type is None:
+ return 1
+ if conf.type == "factor":
+ return 1.0 if it < conf.start else conf.factor
+ if conf.type == "exp":
+ gam = 10 ** (-1 / conf.exp_div_10)
+ return 1.0 if it < conf.start else gam
+ else:
+ raise ValueError(conf.type)
+
+ return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn)
+
+
+@torch.no_grad()
+def do_evaluation(model, loader, device, loss_fn, conf, pbar=True):
+ model.eval()
+ results = {}
+ recall_results = {}
+ pr_metrics = defaultdict(PRMetric)
+ figures = []
+ if conf.plot is not None:
+ n, plot_fn = conf.plot
+ plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False)
+ for i, data in enumerate(tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)):
+ data = batch_to_device(data, device, non_blocking=True)
+ with torch.no_grad():
+ pred = model(data)
+ losses, metrics = loss_fn(pred, data)
+ if conf.plot is not None and i in plot_ids:
+ figures.append(locate(plot_fn)(pred, data))
+ # add PR curves
+ for k, v in conf.pr_curves.items():
+ pr_metrics[k].update(
+ pred[v["labels"]],
+ pred[v["predictions"]],
+ mask=pred[v["mask"]] if "mask" in v.keys() else None,
+ )
+ del pred, data
+
+ numbers = {**metrics, **{f"loss/{k}": v for k, v in losses.items()}}
+ for k, v in numbers.items():
+ if k not in results:
+ results[k] = AverageMetric()
+ if k in conf.median_metrics:
+ results[f"{k}_median"] = MedianMetric()
+
+ if k not in recall_results and k in conf.recall_metrics.keys():
+ ths = conf.recall_metrics[k]
+ recall_results[k] = RecallMetric(ths)
+
+ results[k].update(v)
+ if k in conf.median_metrics:
+ results[f"{k}_median"].update(v)
+ if k in conf.recall_metrics.keys():
+ recall_results[k].update(v)
+
+ del numbers
+
+ results = {k: results[k].compute() for k in results}
+
+ for k, v in recall_results.items():
+ for th, recall in zip(conf.recall_metrics[k], v.compute()):
+ results[f"{k}_recall@{th}"] = recall
+
+ return results, {k: v.compute() for k, v in pr_metrics.items()}, figures
+
+
+def filter_parameters(params, regexp):
+ """Filter trainable parameters based on regular expressions."""
+
+ # Examples of regexp:
+ # '.*(weight|bias)$'
+ # 'cnn\.(enc0|enc1).*bias'
+ def filter_fn(x):
+ n, p = x
+ match = re.search(regexp, n)
+ if not match:
+ p.requires_grad = False
+ return match
+
+ params = list(filter(filter_fn, params))
+ assert len(params) > 0, regexp
+ logger.info("Selected parameters:\n" + "\n".join(n for n, p in params))
+ return params
+
+
+def pack_lr_parameters(params, base_lr, lr_scaling):
+ """Pack each group of parameters with the respective scaled learning rate."""
+ filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names]))
+ scale2params = defaultdict(list)
+ for n, p in params:
+ scale = 1
+ is_match = [f in n for f in filters]
+ if any(is_match):
+ scale = scales[is_match.index(True)]
+ scale2params[scale].append((n, p))
+ logger.info(
+ "Parameters with scaled learning rate:\n%s",
+ {s: [n for n, _ in ps] for s, ps in scale2params.items() if s != 1},
+ )
+ return [
+ {"lr": scale * base_lr, "params": [p for _, p in ps]} for scale, ps in scale2params.items()
+ ]
+
+
+def training(rank, conf, output_dir, args):
+ if args.restore:
+ logger.info(f"Restoring from previous training of {args.experiment}")
+ try:
+ init_cp = get_last_checkpoint(args.experiment, allow_interrupted=False)
+ except AssertionError:
+ init_cp = get_best_checkpoint(args.experiment)
+ logger.info(f"Restoring from checkpoint {init_cp.name}")
+ init_cp = torch.load(str(init_cp), map_location="cpu")
+ conf = OmegaConf.merge(OmegaConf.create(init_cp["conf"]), conf)
+ conf.train = OmegaConf.merge(default_train_conf, conf.train)
+ epoch = init_cp["epoch"] + 1
+
+ # get the best loss or eval metric from the previous best checkpoint
+ best_cp = get_best_checkpoint(args.experiment)
+ best_cp = torch.load(str(best_cp), map_location="cpu")
+ best_eval = best_cp["eval"][conf.train.best_key]
+ del best_cp
+ else:
+ # we start a new, fresh training
+ conf.train = OmegaConf.merge(default_train_conf, conf.train)
+ epoch = 0
+ best_eval = float("inf")
+ if conf.train.load_experiment:
+ logger.info(f"Will fine-tune from weights of {conf.train.load_experiment}")
+ # the user has to make sure that the weights are compatible
+ try:
+ init_cp = get_last_checkpoint(conf.train.load_experiment)
+ except AssertionError:
+ init_cp = get_best_checkpoint(conf.train.load_experiment)
+ # init_cp = get_last_checkpoint(conf.train.load_experiment)
+ init_cp = torch.load(str(init_cp), map_location="cpu")
+ # load the model config of the old setup, and overwrite with current config
+ conf.model = OmegaConf.merge(OmegaConf.create(init_cp["conf"]).model, conf.model)
+ print(conf.model)
+ else:
+ init_cp = None
+
+ OmegaConf.set_struct(conf, True) # prevent access to unknown entries
+ set_seed(conf.train.seed)
+ if rank == 0:
+ writer = SummaryWriter(conf, args, str(output_dir))
+
+ data_conf = copy.deepcopy(conf.data)
+ if args.distributed:
+ logger.info(f"Training in distributed mode with {args.n_gpus} GPUs")
+ assert torch.cuda.is_available()
+ device = rank
+ torch.distributed.init_process_group(
+ backend="nccl",
+ world_size=args.n_gpus,
+ rank=device,
+ init_method="file://" + str(args.lock_file),
+ )
+ torch.cuda.set_device(device)
+
+ # adjust batch size and num of workers since these are per GPU
+ if "batch_size" in data_conf:
+ data_conf.batch_size = int(data_conf.batch_size / args.n_gpus)
+ if "train_batch_size" in data_conf:
+ data_conf.train_batch_size = int(data_conf.train_batch_size / args.n_gpus)
+ if "num_workers" in data_conf:
+ data_conf.num_workers = int((data_conf.num_workers + args.n_gpus - 1) / args.n_gpus)
+ else:
+ device = get_device()
+ logger.info(f"Using device {device}")
+
+ dataset = get_dataset(data_conf.name)(data_conf)
+
+ # Optionally load a different validation dataset than the training one
+ val_data_conf = conf.get("data_val", None)
+ if val_data_conf is None:
+ val_dataset = dataset
+ else:
+ val_dataset = get_dataset(val_data_conf.name)(val_data_conf)
+
+ # @TODO: add test data loader
+
+ if args.overfit:
+ # we train and eval with the same single training batch
+ logger.info("Data in overfitting mode")
+ assert not args.distributed
+ train_loader = dataset.get_overfit_loader("train")
+ val_loader = val_dataset.get_overfit_loader("val")
+ else:
+ train_loader = dataset.get_data_loader("train", distributed=args.distributed)
+ val_loader = val_dataset.get_data_loader("val")
+ if rank == 0:
+ logger.info(f"Training loader has {len(train_loader)} batches")
+ logger.info(f"Validation loader has {len(val_loader)} batches")
+
+ # interrupts are caught and delayed for graceful termination
+ def sigint_handler(signal, frame):
+ logger.info("Caught keyboard interrupt signal, will terminate")
+ nonlocal stop
+ if stop:
+ raise KeyboardInterrupt
+ stop = True
+
+ stop = False
+ signal.signal(signal.SIGINT, sigint_handler)
+
+ model = get_model(conf.model.name)(conf.model).to(device)
+ if args.compile:
+ model = torch.compile(model, mode=args.compile)
+ loss_fn = model.loss
+ if init_cp is not None:
+ model.load_state_dict(init_cp["model"], strict=False)
+ if args.distributed:
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device])
+ if rank == 0 and args.print_arch:
+ logger.info(f"Model: \n{model}")
+
+ torch.backends.cudnn.benchmark = True
+ if args.detect_anomaly:
+ logger.info("Enabling anomaly detection")
+ torch.autograd.set_detect_anomaly(True)
+
+ optimizer_fn = {
+ "sgd": torch.optim.SGD,
+ "adam": torch.optim.Adam,
+ "adamw": torch.optim.AdamW,
+ "rmsprop": torch.optim.RMSprop,
+ }[conf.train.optimizer]
+ params = [(n, p) for n, p in model.named_parameters() if p.requires_grad]
+ if conf.train.opt_regexp:
+ params = filter_parameters(params, conf.train.opt_regexp)
+ all_params = [p for n, p in params]
+ logger.info(f"Num parameters: {sum(p.numel() for p in all_params)}")
+
+ lr_params = pack_lr_parameters(params, conf.train.lr, conf.train.lr_scaling)
+ optimizer = optimizer_fn(lr_params, lr=conf.train.lr, **conf.train.optimizer_options)
+ scaler = GradScaler(enabled=args.mixed_precision is not None)
+ logger.info(f"Training with mixed_precision={args.mixed_precision}")
+
+ mp_dtype = {
+ "float16": torch.float16,
+ "bfloat16": torch.bfloat16,
+ None: torch.float32, # we disable it anyway
+ }[args.mixed_precision]
+
+ results = None # fix bug with it saving
+
+ lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_schedule)
+ logger.info(f"Using lr scheduler of type {type(lr_scheduler)}")
+
+ if args.restore:
+ optimizer.load_state_dict(init_cp["optimizer"])
+ if "lr_scheduler" in init_cp:
+ lr_scheduler.load_state_dict(init_cp["lr_scheduler"])
+
+ if rank == 0:
+ logger.info("Starting training with configuration:\n%s", OmegaConf.to_yaml(conf))
+ losses_ = None
+
+ def trace_handler(p):
+ # torch.profiler.tensorboard_trace_handler(str(output_dir))
+ output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10)
+ print(output)
+ p.export_chrome_trace("trace_" + str(p.step_num) + ".json")
+ p.export_stacks("/tmp/profiler_stacks.txt", "self_cuda_time_total")
+
+ if args.profile:
+ prof = torch.profiler.profile(
+ schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1),
+ on_trace_ready=torch.profiler.tensorboard_trace_handler(str(output_dir)),
+ record_shapes=True,
+ profile_memory=True,
+ with_stack=True,
+ )
+ prof.__enter__()
+
+ if conf.train.log_grad_every_iter:
+ writer.watch(model, log_freq=conf.train.log_grad_every_iter)
+
+ if conf.train.num_steps is not None:
+ conf.train.epochs = conf.train.num_steps // len(train_loader) + 1
+ conf.train.epochs = conf.train.epochs // (args.n_gpus if args.distributed else 1)
+ logger.info(f"Setting epochs to {conf.train.epochs} to match num_steps.")
+
+ while epoch < conf.train.epochs and not stop:
+ tot_it = (len(train_loader) * epoch) * (args.n_gpus if args.distributed else 1)
+ tot_n_samples = tot_it * train_loader.batch_size
+
+ if conf.train.num_steps is not None and tot_it > conf.train.num_steps:
+ logger.info(f"Reached max number of steps {conf.train.num_steps}")
+ stop = True
+
+ if rank == 0:
+ logger.info(f"Starting epoch {epoch}")
+
+ # we first run the eval
+ if (
+ rank == 0
+ and epoch % conf.train.test_every_epoch == 0
+ and (epoch > 0 or not args.no_test_0)
+ ):
+ for bname, eval_conf in conf.get("benchmarks", {}).items():
+ logger.info(f"Running eval on {bname}")
+ s, f, r = run_benchmark(
+ bname,
+ eval_conf,
+ EVAL_PATH / bname / args.experiment / str(epoch),
+ model.eval(),
+ )
+ for metric_name, value in s.items():
+ writer.add_scalar(f"test/{bname}/{metric_name}", value, step=tot_n_samples)
+ for fig_name, fig in f.items():
+ writer.add_figure(f"figures/{bname}/{fig_name}", fig, step=tot_n_samples)
+
+ str_results = [f"{k} {v:.3E}" for k, v in s.items() if isinstance(v, float)]
+ if rank == 0:
+ logger.info(f'[Test {bname}] {{{", ".join(str_results)}}}')
+
+ # set the seed
+ set_seed(conf.train.seed + epoch)
+
+ # update learning rate
+ if conf.train.lr_schedule.on_epoch and epoch > 0:
+ old_lr = optimizer.param_groups[0]["lr"]
+ lr_scheduler.step(epoch)
+ logger.info(f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}')
+
+ if args.distributed:
+ train_loader.sampler.set_epoch(epoch)
+ if epoch > 0 and conf.train.dataset_callback_fn and not args.overfit:
+ loaders = [train_loader]
+ if conf.train.dataset_callback_on_val:
+ loaders += [val_loader]
+ for loader in loaders:
+ if isinstance(loader.dataset, torch.utils.data.Subset):
+ getattr(loader.dataset.dataset, conf.train.dataset_callback_fn)(
+ conf.train.seed + epoch
+ )
+ else:
+ getattr(loader.dataset, conf.train.dataset_callback_fn)(conf.train.seed + epoch)
+ for it, data in enumerate(train_loader):
+ # logger.info(f"Starting iteration {it} - epoch {epoch} - rank {rank}")
+ tot_it = (len(train_loader) * epoch + it) * (args.n_gpus if args.distributed else 1)
+ tot_n_samples = tot_it
+ if not args.log_it:
+ # We normalize the x-axis of tensorboard to num samples!
+ tot_n_samples *= train_loader.batch_size
+
+ model.train()
+ optimizer.zero_grad()
+
+ with autocast(enabled=args.mixed_precision is not None, dtype=mp_dtype):
+ data = batch_to_device(data, device, non_blocking=False)
+ pred = model(data)
+ losses, metrics = loss_fn(pred, data)
+ loss = torch.mean(losses["total"])
+
+ # Skip the iteration if any rank encountered a NaN
+ if loss_has_nan(loss, distributed=args.distributed):
+ logger.warning(f"Skipping iteration {it} due to NaN (rank {rank})")
+ del pred, data, loss, losses, metrics
+ torch.cuda.empty_cache()
+ continue
+
+ do_backward = loss.requires_grad
+ if args.distributed:
+ do_backward = torch.tensor(do_backward).float().to(device)
+ torch.distributed.all_reduce(do_backward, torch.distributed.ReduceOp.PRODUCT)
+ do_backward = do_backward > 0
+
+ if do_backward:
+ scaler.scale(loss).backward()
+ if args.detect_anomaly:
+ # Check for params without any gradient which causes
+ # problems in distributed training with checkpointing
+ detected_anomaly = False
+ for name, param in model.named_parameters():
+ if param.grad is None and param.requires_grad:
+ logger.warning(f"param {name} has no gradient.")
+ detected_anomaly = True
+ if detected_anomaly:
+ raise RuntimeError("Detected anomaly in training.")
+
+ if conf.train.get("clip_grad", None):
+ scaler.unscale_(optimizer)
+ try:
+ torch.nn.utils.clip_grad_norm_(
+ all_params,
+ max_norm=conf.train.clip_grad,
+ error_if_nonfinite=True,
+ )
+ scaler.step(optimizer)
+ except RuntimeError:
+ logger.warning("NaN detected in gradient clipping. Skipping iteration.")
+ scaler.update()
+ else:
+ scaler.step(optimizer)
+ scaler.update()
+
+ if not conf.train.lr_schedule.on_epoch:
+ [lr_scheduler.step() for _ in range(args.n_gpus if args.distributed else 1)]
+ else:
+ if rank == 0:
+ logger.warning(f"Skip iteration {it} due to detach/nan. (rank {rank})")
+
+ if args.profile:
+ prof.step()
+
+ if it % conf.train.log_every_iter == 0:
+ train_results = metrics | losses
+ for k in sorted(train_results.keys()):
+ if args.distributed:
+ train_results[k] = train_results[k].sum(-1)
+ torch.distributed.reduce(train_results[k], dst=0)
+ train_results[k] /= train_loader.batch_size * args.n_gpus
+ train_results[k] = torch.mean(train_results[k], -1)
+ train_results[k] = train_results[k].item()
+ if rank == 0:
+ str_losses = [f"{k} {v:.3E}" for k, v in train_results.items()]
+ logger.info(
+ "[E {} | it {}] loss {{{}}}".format(epoch, it, ", ".join(str_losses))
+ )
+ for k, v in train_results.items():
+ writer.add_scalar("training/" + k, v, tot_n_samples)
+
+ writer.add_scalar("training/lr", optimizer.param_groups[0]["lr"], tot_n_samples)
+ writer.add_scalar("training/epoch", epoch, tot_n_samples)
+
+ if (
+ conf.train.log_grad_every_iter is not None
+ and it % conf.train.log_grad_every_iter == 0
+ ):
+ grad_txt = ""
+ for name, param in model.named_parameters():
+ if param.grad is not None and param.requires_grad:
+ if name.endswith("bias"):
+ continue
+ writer.add_histogram(f"grad/{name}", param.grad.detach(), tot_n_samples)
+ norm = torch.norm(param.grad.detach(), 2)
+ grad_txt += f"{name} {norm.item():.3f} \n"
+ writer.add_text(f"grad/summary", grad_txt, tot_n_samples)
+ del pred, data, loss, losses
+
+ # Run validation
+ if (
+ (it % conf.train.eval_every_iter == 0 and (it > 0 or epoch == -int(args.no_eval_0)))
+ or stop
+ or it == (len(train_loader) - 1)
+ ):
+ with fork_rng(seed=conf.train.seed):
+ results, pr_metrics, figures = do_evaluation(
+ model,
+ val_loader,
+ device,
+ loss_fn,
+ conf.train,
+ pbar=(rank == -1),
+ )
+
+ if rank == 0:
+ str_results = [
+ f"{k} {v:.3E}" for k, v in results.items() if isinstance(v, float)
+ ]
+ logger.info(f'[Validation] {{{", ".join(str_results)}}}')
+ for k, v in results.items():
+ if isinstance(v, dict):
+ writer.add_scalars(f"figure/val/{k}", v, tot_n_samples)
+ else:
+ writer.add_scalar("val/" + k, v, tot_n_samples)
+ for k, v in pr_metrics.items():
+ writer.add_pr_curve("val/" + k, *v, tot_n_samples)
+ # @TODO: optional always save checkpoint
+ if results[conf.train.best_key] < best_eval:
+ best_eval = results[conf.train.best_key]
+ save_experiment(
+ model,
+ optimizer,
+ lr_scheduler,
+ conf,
+ losses_,
+ results,
+ best_eval,
+ epoch,
+ tot_it,
+ output_dir,
+ stop,
+ args.distributed,
+ cp_name="checkpoint_best.tar",
+ )
+ logger.info(f"New best val: {conf.train.best_key}={best_eval}")
+ if len(figures) > 0:
+ for i, figs in enumerate(figures):
+ for name, fig in figs.items():
+ writer.add_figure(f"figures/{i}_{name}", fig, tot_n_samples)
+ torch.cuda.empty_cache() # should be cleared at the first iter
+
+ if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0:
+ if results is None:
+ results, _, _ = do_evaluation(
+ model,
+ val_loader,
+ device,
+ loss_fn,
+ conf.train,
+ pbar=(rank == -1),
+ )
+ best_eval = results[conf.train.best_key]
+ best_eval = save_experiment(
+ model,
+ optimizer,
+ lr_scheduler,
+ conf,
+ losses_,
+ results,
+ best_eval,
+ epoch,
+ tot_it,
+ output_dir,
+ stop,
+ args.distributed,
+ )
+
+ if stop:
+ break
+
+ if rank == 0:
+ best_eval = save_experiment(
+ model,
+ optimizer,
+ lr_scheduler,
+ conf,
+ losses_,
+ results,
+ best_eval,
+ epoch,
+ tot_it,
+ output_dir=output_dir,
+ stop=stop,
+ distributed=args.distributed,
+ )
+
+ epoch += 1
+
+ logger.info(f"Finished training on process {rank}.")
+ if rank == 0:
+ writer.close()
+
+
+def loss_has_nan(loss: torch.Tensor, distributed: bool) -> bool:
+ """Check if any rank has encountered a NaN loss."""
+ has_nan = torch.tensor([torch.isnan(loss).any().float()]).to(loss.device)
+
+ # Synchronize the has_nan variable across all ranks
+ if distributed:
+ torch.distributed.all_reduce(has_nan, op=torch.distributed.ReduceOp.MAX)
+
+ return has_nan.item() > 0.5
+
+
+def main_worker(rank, conf, output_dir, args):
+ if rank == 0:
+ with capture_outputs(output_dir / "log.txt"):
+ training(rank, conf, output_dir, args)
+ else:
+ training(rank, conf, output_dir, args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("experiment", type=str)
+ parser.add_argument("--conf", type=str)
+ parser.add_argument(
+ "--mixed_precision",
+ "--mp",
+ default=None,
+ type=str,
+ choices=["float16", "bfloat16"],
+ )
+ parser.add_argument(
+ "--compile",
+ default=None,
+ type=str,
+ choices=["default", "reduce-overhead", "max-autotune"],
+ )
+ parser.add_argument("--overfit", action="store_true")
+ parser.add_argument("--restore", action="store_true")
+ parser.add_argument("--distributed", action="store_true")
+ parser.add_argument("--profile", action="store_true")
+ parser.add_argument("--print_arch", "--pa", action="store_true")
+ parser.add_argument("--detect_anomaly", "--da", action="store_true")
+ parser.add_argument("--log_it", "--log_it", action="store_true")
+ parser.add_argument("--no_eval_0", action="store_true")
+ parser.add_argument("--no_test_0", action="store_true")
+ parser.add_argument("dotlist", nargs="*")
+ args = parser.parse_intermixed_args()
+
+ logger.info(f"Starting experiment {args.experiment}")
+ output_dir = Path(TRAINING_PATH, args.experiment)
+ output_dir.mkdir(exist_ok=True, parents=True)
+
+ conf = OmegaConf.from_cli(args.dotlist)
+
+ if args.conf:
+ initialize(version_base=None, config_path="configs")
+ conf = compose(config_name=args.conf, overrides=args.dotlist)
+ elif args.restore:
+ restore_conf = OmegaConf.load(output_dir / "config.yaml")
+ conf = OmegaConf.merge(restore_conf, conf)
+
+ if not args.restore:
+ if conf.train.seed is None:
+ conf.train.seed = torch.initial_seed() & (2**32 - 1)
+ OmegaConf.save(conf, str(output_dir / "config.yaml"))
+
+ # copy geocalib and submodule into output dir
+ for module in conf.train.submodules + [__module_name__]:
+ mod_dir = Path(__import__(str(module)).__file__).parent
+ shutil.copytree(mod_dir, output_dir / module, dirs_exist_ok=True)
+
+ if args.distributed:
+ args.n_gpus = torch.cuda.device_count()
+ args.lock_file = output_dir / "distributed_lock"
+ if args.lock_file.exists():
+ args.lock_file.unlink()
+ torch.multiprocessing.spawn(main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args))
+ else:
+ main_worker(0, conf, output_dir, args)
diff --git a/siclib/utils/__init__.py b/siclib/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/utils/conversions.py b/siclib/utils/conversions.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd019d3c58fc76236543cebb985184a671f07146
--- /dev/null
+++ b/siclib/utils/conversions.py
@@ -0,0 +1,149 @@
+"""Utility functions for conversions between different representations."""
+
+from typing import Optional
+
+import torch
+
+
+def skew_symmetric(v: torch.Tensor) -> torch.Tensor:
+ """Create a skew-symmetric matrix from a (batched) vector of size (..., 3).
+
+ Args:
+ (torch.Tensor): Vector of size (..., 3).
+
+ Returns:
+ (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3).
+ """
+ z = torch.zeros_like(v[..., 0])
+ return torch.stack(
+ [
+ z,
+ -v[..., 2],
+ v[..., 1],
+ v[..., 2],
+ z,
+ -v[..., 0],
+ -v[..., 1],
+ v[..., 0],
+ z,
+ ],
+ dim=-1,
+ ).reshape(v.shape[:-1] + (3, 3))
+
+
+def rad2rotmat(
+ roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix.
+
+ Args:
+ roll (torch.Tensor): Roll angle in radians.
+ pitch (torch.Tensor): Pitch angle in radians.
+ yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None.
+
+ Returns:
+ torch.Tensor: Rotation matrix of shape (..., 3, 3).
+ """
+ if yaw is None:
+ yaw = roll.new_zeros(roll.shape)
+
+ Rx = pitch.new_zeros(pitch.shape + (3, 3))
+ Rx[..., 0, 0] = 1
+ Rx[..., 1, 1] = torch.cos(pitch)
+ Rx[..., 1, 2] = torch.sin(pitch)
+ Rx[..., 2, 1] = -torch.sin(pitch)
+ Rx[..., 2, 2] = torch.cos(pitch)
+
+ Ry = yaw.new_zeros(yaw.shape + (3, 3))
+ Ry[..., 0, 0] = torch.cos(yaw)
+ Ry[..., 0, 2] = -torch.sin(yaw)
+ Ry[..., 1, 1] = 1
+ Ry[..., 2, 0] = torch.sin(yaw)
+ Ry[..., 2, 2] = torch.cos(yaw)
+
+ Rz = roll.new_zeros(roll.shape + (3, 3))
+ Rz[..., 0, 0] = torch.cos(roll)
+ Rz[..., 0, 1] = torch.sin(roll)
+ Rz[..., 1, 0] = -torch.sin(roll)
+ Rz[..., 1, 1] = torch.cos(roll)
+ Rz[..., 2, 2] = 1
+
+ return Rz @ Rx @ Ry
+
+
+def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
+ """Compute focal length from (vertical/horizontal) field of view.
+
+ Args:
+ fov (torch.Tensor): Field of view in radians.
+ size (torch.Tensor): Image height / width in pixels.
+
+ Returns:
+ torch.Tensor: Focal length in pixels.
+ """
+ return size / 2 / torch.tan(fov / 2)
+
+
+def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor:
+ """Compute (vertical/horizontal) field of view from focal length.
+
+ Args:
+ focal (torch.Tensor): Focal length in pixels.
+ size (torch.Tensor): Image height / width in pixels.
+
+ Returns:
+ torch.Tensor: Field of view in radians.
+ """
+ return 2 * torch.arctan(size / (2 * focal))
+
+
+def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the distance from principal point to the horizon.
+
+ Args:
+ pitch (torch.Tensor): Pitch angle in radians.
+ f (torch.Tensor): Focal length in pixels.
+ h (torch.Tensor): Image height in pixels.
+
+ Returns:
+ torch.Tensor: Relative distance to the horizon.
+ """
+ return torch.tan(pitch) * f / h
+
+
+def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor:
+ """Compute the pitch angle from the distance to the horizon.
+
+ Args:
+ rho (torch.Tensor): Relative distance to the horizon.
+ f (torch.Tensor): Focal length in pixels.
+ h (torch.Tensor): Image height in pixels.
+
+ Returns:
+ torch.Tensor: Pitch angle in radians.
+ """
+ return torch.atan(rho * h / f)
+
+
+def rad2deg(rad: torch.Tensor) -> torch.Tensor:
+ """Convert radians to degrees.
+
+ Args:
+ rad (torch.Tensor): Angle in radians.
+
+ Returns:
+ torch.Tensor: Angle in degrees.
+ """
+ return rad / torch.pi * 180
+
+
+def deg2rad(deg: torch.Tensor) -> torch.Tensor:
+ """Convert degrees to radians.
+
+ Args:
+ deg (torch.Tensor): Angle in degrees.
+
+ Returns:
+ torch.Tensor: Angle in radians.
+ """
+ return deg / 180 * torch.pi
diff --git a/siclib/utils/experiments.py b/siclib/utils/experiments.py
new file mode 100644
index 0000000000000000000000000000000000000000..c35a99953d09710a12015173cf70b53d43591ed5
--- /dev/null
+++ b/siclib/utils/experiments.py
@@ -0,0 +1,135 @@
+"""
+A set of utilities to manage and load checkpoints of training experiments.
+
+Author: Paul-Edouard Sarlin (skydes)
+"""
+
+import logging
+import os
+import re
+import shutil
+from pathlib import Path
+
+import torch
+from omegaconf import OmegaConf
+
+from siclib.models import get_model
+from siclib.settings import TRAINING_PATH
+
+logger = logging.getLogger(__name__)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def list_checkpoints(dir_):
+ """List all valid checkpoints in a given directory."""
+ checkpoints = []
+ for p in dir_.glob("checkpoint_*.tar"):
+ numbers = re.findall(r"(\d+)", p.name)
+ assert len(numbers) <= 2
+ if len(numbers) == 0:
+ continue
+ if len(numbers) == 1:
+ checkpoints.append((int(numbers[0]), p))
+ else:
+ checkpoints.append((int(numbers[1]), p))
+ return checkpoints
+
+
+def get_last_checkpoint(exper, allow_interrupted=True):
+ """Get the last saved checkpoint for a given experiment name."""
+ ckpts = list_checkpoints(Path(TRAINING_PATH, exper))
+ if not allow_interrupted:
+ ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name]
+ assert len(ckpts) > 0
+ return sorted(ckpts)[-1][1]
+
+
+def get_best_checkpoint(exper):
+ """Get the checkpoint with the best loss, for a given experiment name."""
+ return Path(TRAINING_PATH, exper, "checkpoint_best.tar")
+
+
+def delete_old_checkpoints(dir_, num_keep):
+ """Delete all but the num_keep last saved checkpoints."""
+ ckpts = list_checkpoints(dir_)
+ ckpts = sorted(ckpts)[::-1]
+ kept = 0
+ for ckpt in ckpts:
+ if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep:
+ logger.info(f"Deleting checkpoint {ckpt[1].name}")
+ ckpt[1].unlink()
+ else:
+ kept += 1
+
+
+def load_experiment(exper, conf=None, get_last=False, ckpt=None):
+ """Load and return the model of a given experiment."""
+ if conf is None:
+ conf = {}
+
+ exper = Path(exper)
+ if exper.suffix != ".tar":
+ ckpt = get_last_checkpoint(exper) if get_last else get_best_checkpoint(exper)
+ else:
+ ckpt = exper
+ logger.info(f"Loading checkpoint {ckpt.name}")
+ ckpt = torch.load(str(ckpt), map_location="cpu")
+
+ loaded_conf = OmegaConf.create(ckpt["conf"])
+ OmegaConf.set_struct(loaded_conf, False)
+ conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf))
+ model = get_model(conf.name)(conf).eval()
+
+ state_dict = ckpt["model"]
+
+ dict_params = set(state_dict.keys())
+ model_params = set(map(lambda n: n[0], model.named_parameters()))
+ diff = model_params - dict_params
+ if len(diff) > 0:
+ subs = os.path.commonprefix(list(diff)).rstrip(".")
+ logger.warning(f"Missing {len(diff)} parameters in {subs}: {diff}")
+ model.load_state_dict(state_dict, strict=False)
+ return model
+
+
+def save_experiment(
+ model,
+ optimizer,
+ lr_scheduler,
+ conf,
+ losses,
+ results,
+ best_eval,
+ epoch,
+ iter_i,
+ output_dir,
+ stop=False,
+ distributed=False,
+ cp_name=None,
+):
+ """Save the current model to a checkpoint
+ and return the best result so far."""
+ state = (model.module if distributed else model).state_dict()
+ checkpoint = {
+ "model": state,
+ "optimizer": optimizer.state_dict(),
+ "lr_scheduler": lr_scheduler.state_dict(),
+ "conf": OmegaConf.to_container(conf, resolve=True),
+ "epoch": epoch,
+ "losses": losses,
+ "eval": results,
+ }
+ if cp_name is None:
+ cp_name = f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar"
+ logger.info(f"Saving checkpoint {cp_name}")
+ cp_path = str(output_dir / cp_name)
+ torch.save(checkpoint, cp_path)
+
+ if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval:
+ best_eval = results[conf.train.best_key]
+ logger.info(f"New best val: {conf.train.best_key}={best_eval}")
+ shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar"))
+ delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints)
+ return best_eval
diff --git a/siclib/utils/export_predictions.py b/siclib/utils/export_predictions.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c66369cdd57f278eb3167f32eefd6e53cd19189
--- /dev/null
+++ b/siclib/utils/export_predictions.py
@@ -0,0 +1,84 @@
+"""
+Export the predictions of a model for a given dataloader (e.g. ImageFolder).
+Use a standalone script with `python3 -m geocalib.scipts.export_predictions dir`
+or call from another script.
+"""
+
+import logging
+from pathlib import Path
+
+import h5py
+import numpy as np
+import torch
+from tqdm import tqdm
+
+from siclib.utils.tensor import batch_to_device
+from siclib.utils.tools import get_device
+
+# flake8: noqa
+# mypy: ignore-errors
+
+logger = logging.getLogger(__name__)
+
+
+@torch.no_grad()
+def export_predictions(
+ loader,
+ model,
+ output_file,
+ as_half=False,
+ keys="*",
+ callback_fn=None,
+ optional_keys=None,
+ verbose=True,
+): # sourcery skip: low-code-quality
+ if optional_keys is None:
+ optional_keys = []
+
+ assert keys == "*" or isinstance(keys, (tuple, list))
+ Path(output_file).parent.mkdir(exist_ok=True, parents=True)
+ hfile = h5py.File(str(output_file), "w")
+ device = get_device()
+ model = model.to(device).eval()
+
+ if not verbose:
+ logger.info(f"Exporting predictions to {output_file}")
+
+ for data_ in tqdm(loader, desc="Exporting", total=len(loader), ncols=80, disable=not verbose):
+ data = batch_to_device(data_, device, non_blocking=True)
+ pred = model(data)
+ if callback_fn is not None:
+ pred = {**callback_fn(pred, data), **pred}
+ if keys != "*":
+ if len(set(keys) - set(pred.keys())) > 0:
+ raise ValueError(f"Missing key {set(keys) - set(pred.keys())}")
+ pred = {k: v for k, v in pred.items() if k in keys + optional_keys}
+
+ # assert len(pred) > 0, "No predictions found"
+
+ for idx in range(len(data["name"])):
+ pred_ = {k: v[idx].cpu().numpy() for k, v in pred.items()}
+
+ if as_half:
+ for k in pred_:
+ dt = pred_[k].dtype
+ if (dt == np.float32) and (dt != np.float16):
+ pred_[k] = pred_[k].astype(np.float16)
+ try:
+ name = data["name"][idx]
+ try:
+ grp = hfile.create_group(name)
+ except ValueError as e:
+ raise ValueError(f"Group already exists {name}") from e
+
+ # grp = hfile.create_group(name)
+ for k, v in pred_.items():
+ grp.create_dataset(k, data=v)
+ except RuntimeError:
+ print(f"Failed to export {name}")
+ continue
+
+ del pred
+
+ hfile.close()
+ return output_file
diff --git a/siclib/utils/image.py b/siclib/utils/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..c13d7218b0cf28f3a16b9f43c3634c52beedfa5c
--- /dev/null
+++ b/siclib/utils/image.py
@@ -0,0 +1,167 @@
+"""Image preprocessing utilities."""
+
+import collections.abc as collections
+from pathlib import Path
+from typing import Optional, Tuple
+
+import cv2
+import kornia
+import numpy as np
+import torch
+import torchvision
+from omegaconf import OmegaConf
+
+from siclib.utils.tensor import fit_features_to_multiple
+
+# mypy: ignore-errors
+
+
+class ImagePreprocessor:
+ """Preprocess images for calibration."""
+
+ default_conf = {
+ "resize": 320, # target edge length, None for no resizing
+ "edge_divisible_by": None,
+ "side": "short",
+ "interpolation": "bilinear",
+ "align_corners": None,
+ "antialias": True,
+ "square_crop": False,
+ "add_padding_mask": False,
+ "resize_backend": "kornia", # torchvision, kornia
+ }
+
+ def __init__(self, conf) -> None:
+ """Initialize the image preprocessor."""
+ super().__init__()
+ default_conf = OmegaConf.create(self.default_conf)
+ OmegaConf.set_struct(default_conf, True)
+ self.conf = OmegaConf.merge(default_conf, conf)
+
+ def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict:
+ """Resize and preprocess an image, return image and resize scale."""
+ h, w = img.shape[-2:]
+ size = h, w
+
+ if self.conf.square_crop:
+ min_size = min(h, w)
+ offset = (h - min_size) // 2, (w - min_size) // 2
+ img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size]
+ size = img.shape[-2:]
+
+ if self.conf.resize is not None:
+ if interpolation is None:
+ interpolation = self.conf.interpolation
+ size = self.get_new_image_size(h, w)
+ img = self.resize(img, size, interpolation)
+
+ scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img)
+ T = np.diag([scale[0].cpu(), scale[1].cpu(), 1])
+
+ data = {
+ "scales": scale,
+ "image_size": np.array(size[::-1]),
+ "transform": T,
+ "original_image_size": np.array([w, h]),
+ }
+
+ if self.conf.edge_divisible_by is not None:
+ # crop to make the edge divisible by a number
+ w_, h_ = img.shape[-1], img.shape[-2]
+ img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True)
+ crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img)
+ data["crop_pad"] = crop_pad
+ data["image_size"] = np.array([img.shape[-1], img.shape[-2]])
+
+ data["image"] = img
+ return data
+
+ def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor:
+ """Resize an image using the specified backend."""
+ if self.conf.resize_backend == "kornia":
+ return kornia.geometry.transform.resize(
+ img,
+ size,
+ side=self.conf.side,
+ antialias=self.conf.antialias,
+ align_corners=self.conf.align_corners,
+ interpolation=interpolation,
+ )
+ elif self.conf.resize_backend == "torchvision":
+ return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img)
+ else:
+ raise ValueError(f"{self.conf.resize_backend} not implemented.")
+
+ def load_image(self, image_path: Path) -> dict:
+ """Load an image from a path and preprocess it."""
+ return self(load_image(image_path))
+
+ def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]:
+ """Get the new image size after resizing."""
+ side = self.conf.side
+ if isinstance(self.conf.resize, collections.Iterable):
+ assert len(self.conf.resize) == 2
+ return tuple(self.conf.resize)
+ side_size = self.conf.resize
+ aspect_ratio = w / h
+ if side not in ("short", "long", "vert", "horz"):
+ raise ValueError(
+ f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'"
+ )
+ return (
+ (side_size, int(side_size * aspect_ratio))
+ if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0))
+ else (int(side_size / aspect_ratio), side_size)
+ )
+
+
+def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
+ """Normalize the image tensor and reorder the dimensions."""
+ if image.ndim == 3:
+ image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f"Not an image: {image.shape}")
+ return torch.tensor(image / 255.0, dtype=torch.float)
+
+
+def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray:
+ """Normalize and reorder the dimensions of an image tensor."""
+ if image.ndim == 3:
+ image = image.permute((1, 2, 0)) # CxHxW to HxWxC
+ elif image.ndim == 2:
+ image = image[None] # add channel axis
+ else:
+ raise ValueError(f"Not an image: {image.shape}")
+ return (image.cpu().detach().numpy() * 255).astype(np.uint8)
+
+
+def read_image(path: Path, grayscale: bool = False) -> np.ndarray:
+ """Read an image from path as RGB or grayscale."""
+ if not Path(path).exists():
+ raise FileNotFoundError(f"No image at path {path}.")
+ mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR
+ image = cv2.imread(str(path), mode)
+ if image is None:
+ raise IOError(f"Could not read image at {path}.")
+ if not grayscale:
+ image = image[..., ::-1]
+ return image
+
+
+def write_image(img: torch.Tensor, path: Path):
+ """Write an image tensor to a file."""
+ img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img
+ cv2.imwrite(str(path), img[..., ::-1])
+
+
+def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor:
+ """Load an image from a path and return as a tensor."""
+ image = read_image(path, grayscale=grayscale)
+ if return_tensor:
+ return numpy_image_to_torch(image)
+
+ assert image.ndim in [2, 3], f"Not an image: {image.shape}"
+ image = image[None] if image.ndim == 2 else image
+ return torch.tensor(image.copy(), dtype=torch.uint8)
diff --git a/siclib/utils/stdout_capturing.py b/siclib/utils/stdout_capturing.py
new file mode 100644
index 0000000000000000000000000000000000000000..19d6701acfeca592a7eecf36edcbc6f614eb6ad0
--- /dev/null
+++ b/siclib/utils/stdout_capturing.py
@@ -0,0 +1,132 @@
+"""
+Based on sacred/stdout_capturing.py in project Sacred
+https://github.com/IDSIA/sacred
+
+Author: Paul-Edouard Sarlin (skydes)
+"""
+
+from __future__ import division, print_function, unicode_literals
+
+import os
+import subprocess
+import sys
+from contextlib import contextmanager
+from threading import Timer
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def apply_backspaces_and_linefeeds(text):
+ """
+ Interpret backspaces and linefeeds in text like a terminal would.
+ Interpret text like a terminal by removing backspace and linefeed
+ characters and applying them line by line.
+ If final line ends with a carriage it keeps it to be concatenable with next
+ output chunk.
+ """
+ orig_lines = text.split("\n")
+ orig_lines_len = len(orig_lines)
+ new_lines = []
+ for orig_line_idx, orig_line in enumerate(orig_lines):
+ chars, cursor = [], 0
+ orig_line_len = len(orig_line)
+ for orig_char_idx, orig_char in enumerate(orig_line):
+ if orig_char == "\r" and (
+ orig_char_idx != orig_line_len - 1 or orig_line_idx != orig_lines_len - 1
+ ):
+ cursor = 0
+ elif orig_char == "\b":
+ cursor = max(0, cursor - 1)
+ else:
+ if orig_char == "\r":
+ cursor = len(chars)
+ if cursor == len(chars):
+ chars.append(orig_char)
+ else:
+ chars[cursor] = orig_char
+ cursor += 1
+ new_lines.append("".join(chars))
+ return "\n".join(new_lines)
+
+
+def flush():
+ """Try to flush all stdio buffers, both from python and from C."""
+ try:
+ sys.stdout.flush()
+ sys.stderr.flush()
+ except (AttributeError, ValueError, IOError):
+ pass # unsupported
+
+
+# Duplicate stdout and stderr to a file. Inspired by:
+# http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/
+# http://stackoverflow.com/a/651718/1388435
+# http://stackoverflow.com/a/22434262/1388435
+@contextmanager
+def capture_outputs(filename):
+ """Duplicate stdout and stderr to a file on the file descriptor level."""
+ with open(str(filename), "a+") as target:
+ original_stdout_fd = 1
+ original_stderr_fd = 2
+ target_fd = target.fileno()
+
+ # Save a copy of the original stdout and stderr file descriptors
+ saved_stdout_fd = os.dup(original_stdout_fd)
+ saved_stderr_fd = os.dup(original_stderr_fd)
+
+ tee_stdout = subprocess.Popen(
+ ["tee", "-a", "-i", "/dev/stderr"],
+ start_new_session=True,
+ stdin=subprocess.PIPE,
+ stderr=target_fd,
+ stdout=1,
+ )
+ tee_stderr = subprocess.Popen(
+ ["tee", "-a", "-i", "/dev/stderr"],
+ start_new_session=True,
+ stdin=subprocess.PIPE,
+ stderr=target_fd,
+ stdout=2,
+ )
+
+ flush()
+ os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd)
+ os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd)
+
+ try:
+ yield
+ finally:
+ flush()
+
+ # then redirect stdout back to the saved fd
+ tee_stdout.stdin.close()
+ tee_stderr.stdin.close()
+
+ # restore original fds
+ os.dup2(saved_stdout_fd, original_stdout_fd)
+ os.dup2(saved_stderr_fd, original_stderr_fd)
+
+ # wait for completion of the tee processes with timeout
+ # implemented using a timer because timeout support is py3 only
+ def kill_tees():
+ tee_stdout.kill()
+ tee_stderr.kill()
+
+ tee_timer = Timer(1, kill_tees)
+ try:
+ tee_timer.start()
+ tee_stdout.wait()
+ tee_stderr.wait()
+ finally:
+ tee_timer.cancel()
+
+ os.close(saved_stdout_fd)
+ os.close(saved_stderr_fd)
+
+ # Cleanup log file
+ with open(str(filename), "r") as target:
+ text = target.read()
+ text = apply_backspaces_and_linefeeds(text)
+ with open(str(filename), "w") as target:
+ target.write(text)
diff --git a/siclib/utils/summary_writer.py b/siclib/utils/summary_writer.py
new file mode 100644
index 0000000000000000000000000000000000000000..2542a4df5d27e777f99de6e1d6721b5a990104bc
--- /dev/null
+++ b/siclib/utils/summary_writer.py
@@ -0,0 +1,116 @@
+"""This module implements the writer class for logging to tensorboard or wandb."""
+
+import logging
+import os
+from typing import Any, Dict, Optional
+
+from omegaconf import DictConfig
+from torch import nn
+from torch.utils.tensorboard import SummaryWriter as TFSummaryWriter
+
+from siclib import __module_name__
+
+logger = logging.getLogger(__name__)
+
+try:
+ import wandb
+except ImportError:
+ logger.debug("Could not import wandb.")
+ wandb = None
+
+# mypy: ignore-errors
+
+
+def dot_conf(conf: DictConfig) -> Dict[str, Any]:
+ """Recursively convert a DictConfig to a flat dict with keys joined by dots."""
+ d = {}
+ for k, v in conf.items():
+ if isinstance(v, DictConfig):
+ d |= {f"{k}.{k2}": v2 for k2, v2 in dot_conf(v).items()}
+ else:
+ d[k] = v
+ return d
+
+
+class SummaryWriter:
+ """Writer class for logging to tensorboard or wandb."""
+
+ def __init__(self, conf: DictConfig, args: DictConfig, log_dir: str):
+ """Initialize the writer."""
+ self.conf = conf
+
+ if not conf.train.writer:
+ self.use_wandb = False
+ self.use_tensorboard = False
+ return
+
+ self.use_wandb = "wandb" in conf.train.writer
+ self.use_tensorboard = "tensorboard" in conf.train.writer
+
+ if self.use_wandb and not wandb:
+ raise ImportError("wandb not installed.")
+
+ if self.use_tensorboard:
+ self.writer = TFSummaryWriter(log_dir=log_dir)
+
+ if self.use_wandb:
+ os.environ["WANDB__SERVICE_WAIT"] = "300"
+ wandb.init(project=__module_name__, name=args.experiment, config=dot_conf(conf))
+
+ if conf.train.writer and not self.use_wandb and not self.use_tensorboard:
+ raise NotImplementedError(f"Writer {conf.train.writer} not implemented")
+
+ def add_scalar(self, tag: str, value: float, step: Optional[int] = None):
+ """Log a scalar value to tensorboard or wandb."""
+ if self.use_wandb:
+ step = 1 if step == 0 else step
+ wandb.log({tag: value}, step=step)
+
+ if self.use_tensorboard:
+ self.writer.add_scalar(tag, value, step)
+
+ def add_figure(self, tag: str, figure, step: Optional[int] = None):
+ """Log a figure to tensorboard or wandb."""
+ if self.use_wandb:
+ step = 1 if step == 0 else step
+ wandb.log({tag: figure}, step=step)
+ if self.use_tensorboard:
+ self.writer.add_figure(tag, figure, step)
+
+ def add_histogram(self, tag: str, values, step: Optional[int] = None):
+ """Log a histogram to tensorboard or wandb."""
+ if self.use_tensorboard:
+ self.writer.add_histogram(tag, values, step)
+
+ def add_text(self, tag: str, text: str, step: Optional[int] = None):
+ """Log text to tensorboard or wandb."""
+ if self.use_tensorboard:
+ self.writer.add_text(tag, text, step)
+
+ def add_pr_curve(self, tag: str, values, step: Optional[int] = None):
+ """Log a precision-recall curve to tensorboard or wandb."""
+ if self.use_wandb:
+ step = 1 if step == 0 else step
+ # @TODO: check if this works
+ # wandb.log({"pr": wandb.plots.precision_recall(y_test, y_probas, labels)})
+ wandb.log({tag: wandb.plots.precision_recall(values)}, step=step)
+
+ if self.use_tensorboard:
+ self.writer.add_pr_curve(tag, values, step)
+
+ def watch(self, model: nn.Module, log_freq: int = 1000):
+ """Watch a model for gradient updates."""
+ if self.use_wandb:
+ wandb.watch(
+ model,
+ log="gradients",
+ log_freq=log_freq,
+ )
+
+ def close(self):
+ """Close the writer."""
+ if self.use_wandb:
+ wandb.finish()
+
+ if self.use_tensorboard:
+ self.writer.close()
diff --git a/siclib/utils/tensor.py b/siclib/utils/tensor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8dd14e9cdd746225523f01e997dd57aed3673e2
--- /dev/null
+++ b/siclib/utils/tensor.py
@@ -0,0 +1,251 @@
+"""
+Author: Paul-Edouard Sarlin (skydes)
+"""
+
+import collections.abc as collections
+import functools
+import inspect
+from typing import Callable, List, Tuple
+
+import numpy as np
+import torch
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+string_classes = (str, bytes)
+
+
+def autocast(func: Callable) -> Callable:
+ """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays.
+
+ Use the device and dtype of the wrapper.
+
+ Args:
+ func (Callable): Method of a TensorWrapper class.
+
+ Returns:
+ Callable: Wrapped method.
+ """
+
+ @functools.wraps(func)
+ def wrap(self, *args):
+ device = torch.device("cpu")
+ dtype = None
+ if isinstance(self, TensorWrapper):
+ if self._data is not None:
+ device = self.device
+ dtype = self.dtype
+ elif not inspect.isclass(self) or not issubclass(self, TensorWrapper):
+ raise ValueError(self)
+
+ cast_args = []
+ for arg in args:
+ if isinstance(arg, np.ndarray):
+ arg = torch.from_numpy(arg)
+ arg = arg.to(device=device, dtype=dtype)
+ cast_args.append(arg)
+ return func(self, *cast_args)
+
+ return wrap
+
+
+class TensorWrapper:
+ """Wrapper for PyTorch tensors."""
+
+ _data = None
+
+ @autocast
+ def __init__(self, data: torch.Tensor):
+ """Wrapper for PyTorch tensors."""
+ self._data = data
+
+ @property
+ def shape(self) -> torch.Size:
+ """Shape of the underlying tensor."""
+ return self._data.shape[:-1]
+
+ @property
+ def device(self) -> torch.device:
+ """Get the device of the underlying tensor."""
+ return self._data.device
+
+ @property
+ def dtype(self) -> torch.dtype:
+ """Get the dtype of the underlying tensor."""
+ return self._data.dtype
+
+ def __getitem__(self, index) -> torch.Tensor:
+ """Get the underlying tensor."""
+ return self.__class__(self._data[index])
+
+ def __setitem__(self, index, item):
+ """Set the underlying tensor."""
+ self._data[index] = item.data
+
+ def to(self, *args, **kwargs):
+ """Move the underlying tensor to a new device."""
+ return self.__class__(self._data.to(*args, **kwargs))
+
+ def cpu(self):
+ """Move the underlying tensor to the CPU."""
+ return self.__class__(self._data.cpu())
+
+ def cuda(self):
+ """Move the underlying tensor to the GPU."""
+ return self.__class__(self._data.cuda())
+
+ def pin_memory(self):
+ """Pin the underlying tensor to memory."""
+ return self.__class__(self._data.pin_memory())
+
+ def float(self):
+ """Cast the underlying tensor to float."""
+ return self.__class__(self._data.float())
+
+ def double(self):
+ """Cast the underlying tensor to double."""
+ return self.__class__(self._data.double())
+
+ def detach(self):
+ """Detach the underlying tensor."""
+ return self.__class__(self._data.detach())
+
+ def numpy(self):
+ """Convert the underlying tensor to a numpy array."""
+ return self._data.detach().cpu().numpy()
+
+ def new_tensor(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_tensor(*args, **kwargs)
+
+ def new_zeros(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_zeros(*args, **kwargs)
+
+ def new_ones(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_ones(*args, **kwargs)
+
+ def new_full(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_full(*args, **kwargs)
+
+ def new_empty(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self._data.new_empty(*args, **kwargs)
+
+ def unsqueeze(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self.__class__(self._data.unsqueeze(*args, **kwargs))
+
+ def squeeze(self, *args, **kwargs):
+ """Create a new tensor of the same type and device."""
+ return self.__class__(self._data.squeeze(*args, **kwargs))
+
+ @classmethod
+ def stack(cls, objects: List, dim=0, *, out=None):
+ """Stack a list of objects with the same type and shape."""
+ data = torch.stack([obj._data for obj in objects], dim=dim, out=out)
+ return cls(data)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ """Support torch functions."""
+ if kwargs is None:
+ kwargs = {}
+ return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented
+
+
+def map_tensor(input_, func):
+ if isinstance(input_, string_classes):
+ return input_
+ elif isinstance(input_, collections.Mapping):
+ return {k: map_tensor(sample, func) for k, sample in input_.items()}
+ elif isinstance(input_, collections.Sequence):
+ return [map_tensor(sample, func) for sample in input_]
+ elif input_ is None:
+ return None
+ else:
+ return func(input_)
+
+
+def batch_to_numpy(batch):
+ return map_tensor(batch, lambda tensor: tensor.cpu().numpy())
+
+
+def batch_to_device(batch, device, non_blocking=True, detach=False):
+ def _func(tensor):
+ t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32)
+ return t.detach() if detach else t
+
+ return map_tensor(batch, _func)
+
+
+def remove_batch_dim(data: dict) -> dict:
+ """Remove batch dimension from elements in data"""
+ return {
+ k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items()
+ }
+
+
+def add_batch_dim(data: dict) -> dict:
+ """Add batch dimension to elements in data"""
+ return {
+ k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v
+ for k, v in data.items()
+ }
+
+
+def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False):
+ """Get padding to make the image size a multiple of the given number.
+
+ Args:
+ x (torch.Tensor): Input tensor.
+ multiple (int, optional): Multiple.
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
+
+ Returns:
+ torch.Tensor: Padding.
+ """
+ h, w = x.shape[-2:]
+
+ if crop:
+ pad_w = (w // multiple) * multiple - w
+ pad_h = (h // multiple) * multiple - h
+ else:
+ pad_w = (multiple - w % multiple) % multiple
+ pad_h = (multiple - h % multiple) % multiple
+
+ if mode == "center":
+ pad_l = pad_w // 2
+ pad_r = pad_w - pad_l
+ pad_t = pad_h // 2
+ pad_b = pad_h - pad_t
+ elif mode == "left":
+ pad_l = 0
+ pad_r = pad_w
+ pad_t = 0
+ pad_b = pad_h
+ else:
+ raise ValueError(f"Unknown mode {mode}")
+
+ return (pad_l, pad_r, pad_t, pad_b)
+
+
+def fit_features_to_multiple(
+ features: torch.Tensor, multiple: int = 32, crop: bool = False
+) -> Tuple[torch.Tensor, Tuple[int, int]]:
+ """Pad image to a multiple of the given number.
+
+ Args:
+ features (torch.Tensor): Input features.
+ multiple (int, optional): Multiple. Defaults to 32.
+ crop (bool, optional): Whether to crop or pad. Defaults to False.
+
+ Returns:
+ Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding.
+ """
+ pad = fit_to_multiple(features, multiple, crop=crop)
+ return torch.nn.functional.pad(features, pad, mode="reflect"), pad
diff --git a/siclib/utils/tools.py b/siclib/utils/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3869ecba347264be2112af1d1bc5f80359c716b
--- /dev/null
+++ b/siclib/utils/tools.py
@@ -0,0 +1,309 @@
+"""
+Various handy Python and PyTorch utils.
+
+Author: Paul-Edouard Sarlin (skydes)
+"""
+
+import os
+import random
+import time
+from collections.abc import Iterable
+from contextlib import contextmanager
+from typing import Optional
+
+import numpy as np
+import torch
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class AverageMetric:
+ def __init__(self, elements=None):
+ if elements is None:
+ elements = []
+ self._sum = 0
+ self._num_examples = 0
+ else:
+ mask = ~np.isnan(elements)
+ self._sum = sum(elements[mask])
+ self._num_examples = len(elements[mask])
+
+ def update(self, tensor):
+ assert tensor.dim() == 1, tensor.shape
+ tensor = tensor[~torch.isnan(tensor)]
+ self._sum += tensor.sum().item()
+ self._num_examples += len(tensor)
+
+ def compute(self):
+ return np.nan if self._num_examples == 0 else self._sum / self._num_examples
+
+
+# same as AverageMetric, but tracks all elements
+class FAverageMetric:
+ def __init__(self):
+ self._sum = 0
+ self._num_examples = 0
+ self._elements = []
+
+ def update(self, tensor):
+ self._elements += tensor.cpu().numpy().tolist()
+ assert tensor.dim() == 1, tensor.shape
+ tensor = tensor[~torch.isnan(tensor)]
+ self._sum += tensor.sum().item()
+ self._num_examples += len(tensor)
+
+ def compute(self):
+ return np.nan if self._num_examples == 0 else self._sum / self._num_examples
+
+
+class MedianMetric:
+ def __init__(self, elements=None):
+ if elements is None:
+ elements = []
+
+ self._elements = elements
+
+ def update(self, tensor):
+ assert tensor.dim() == 1, tensor.shape
+ self._elements += tensor.cpu().numpy().tolist()
+
+ def compute(self):
+ if len(self._elements) == 0:
+ return np.nan
+
+ # set nan to inf to avoid error
+ self._elements = np.array(self._elements)
+ self._elements[np.isnan(self._elements)] = np.inf
+ return np.nanmedian(self._elements)
+
+
+class PRMetric:
+ def __init__(self):
+ self.labels = []
+ self.predictions = []
+
+ @torch.no_grad()
+ def update(self, labels, predictions, mask=None):
+ assert labels.shape == predictions.shape
+ self.labels += (labels[mask] if mask is not None else labels).cpu().numpy().tolist()
+ self.predictions += (
+ (predictions[mask] if mask is not None else predictions).cpu().numpy().tolist()
+ )
+
+ @torch.no_grad()
+ def compute(self):
+ return np.array(self.labels), np.array(self.predictions)
+
+ def reset(self):
+ self.labels = []
+ self.predictions = []
+
+
+class QuantileMetric:
+ def __init__(self, q=0.05):
+ self._elements = []
+ self.q = q
+
+ def update(self, tensor):
+ assert tensor.dim() == 1
+ self._elements += tensor.cpu().numpy().tolist()
+
+ def compute(self):
+ if len(self._elements) == 0:
+ return np.nan
+ else:
+ return np.nanquantile(self._elements, self.q)
+
+
+class RecallMetric:
+ def __init__(self, ths, elements=None):
+ if elements is None:
+ elements = []
+
+ self._elements = elements
+ self.ths = ths
+
+ def update(self, tensor):
+ assert tensor.dim() == 1, tensor.shape
+ self._elements += tensor.cpu().numpy().tolist()
+
+ def compute(self):
+ # set nan to inf to avoid error
+ self._elements = np.array(self._elements)
+ self._elements[np.isnan(self._elements)] = np.inf
+
+ if isinstance(self.ths, Iterable):
+ return [self.compute_(th) for th in self.ths]
+ else:
+ return self.compute_(self.ths[0])
+
+ def compute_(self, th):
+ if len(self._elements) == 0:
+ return np.nan
+
+ s = (np.array(self._elements) < th).sum()
+ return s / len(self._elements)
+
+
+def compute_recall(errors):
+ num_elements = len(errors)
+ sort_idx = np.argsort(errors)
+ errors = np.array(errors.copy())[sort_idx]
+ recall = (np.arange(num_elements) + 1) / num_elements
+ return errors, recall
+
+
+def compute_auc(errors, thresholds, min_error: Optional[float] = None):
+ errors, recall = compute_recall(errors)
+
+ if min_error is not None:
+ min_index = np.searchsorted(errors, min_error, side="right")
+ min_score = min_index / len(errors)
+ recall = np.r_[min_score, min_score, recall[min_index:]]
+ errors = np.r_[0, min_error, errors[min_index:]]
+ else:
+ recall = np.r_[0, recall]
+ errors = np.r_[0, errors]
+
+ aucs = []
+ for t in thresholds:
+ last_index = np.searchsorted(errors, t, side="right")
+ r = np.r_[recall[:last_index], recall[last_index - 1]]
+ e = np.r_[errors[:last_index], t]
+ auc = np.trapz(r, x=e) / t
+ aucs.append(np.round(auc, 4))
+ return aucs
+
+
+class AUCMetric:
+ def __init__(self, thresholds, elements=None, min_error: Optional[float] = None):
+ self._elements = elements
+ self.thresholds = thresholds
+ self.min_error = min_error
+ if not isinstance(thresholds, list):
+ self.thresholds = [thresholds]
+
+ def update(self, tensor):
+ assert tensor.dim() == 1, tensor.shape
+ self._elements += tensor.cpu().numpy().tolist()
+
+ def compute(self):
+ if len(self._elements) == 0:
+ return np.nan
+
+ # set nan to inf to avoid error
+ self._elements = np.array(self._elements)
+ self._elements[np.isnan(self._elements)] = np.inf
+ return compute_auc(self._elements, self.thresholds, self.min_error)
+
+
+class Timer(object):
+ """A simpler timer context object.
+ Usage:
+ ```
+ > with Timer('mytimer'):
+ > # some computations
+ [mytimer] Elapsed: X
+ ```
+ """
+
+ def __init__(self, name=None):
+ self.name = name
+
+ def __enter__(self):
+ self.tstart = time.time()
+ return self
+
+ def __exit__(self, type, value, traceback):
+ self.duration = time.time() - self.tstart
+ if self.name is not None:
+ print(f"[{self.name}] Elapsed: {self.duration}")
+
+
+def get_class(mod_path, BaseClass):
+ """Get the class object which inherits from BaseClass and is defined in
+ the module named mod_name, child of base_path.
+ """
+ import inspect
+
+ mod = __import__(mod_path, fromlist=[""])
+ classes = inspect.getmembers(mod, inspect.isclass)
+ # Filter classes defined in the module
+ classes = [c for c in classes if c[1].__module__ == mod_path]
+ # Filter classes inherited from BaseModel
+ classes = [c for c in classes if issubclass(c[1], BaseClass)]
+ assert len(classes) == 1, classes
+ return classes[0][1]
+
+
+def set_num_threads(nt):
+ """Force numpy and other libraries to use a limited number of threads."""
+ try:
+ import mkl # type: ignore
+ except ImportError:
+ pass
+ else:
+ mkl.set_num_threads(nt)
+ torch.set_num_threads(1)
+ os.environ["IPC_ENABLE"] = "1"
+ for o in [
+ "OPENBLAS_NUM_THREADS",
+ "NUMEXPR_NUM_THREADS",
+ "OMP_NUM_THREADS",
+ "MKL_NUM_THREADS",
+ ]:
+ os.environ[o] = str(nt)
+
+
+def set_seed(seed):
+ random.seed(seed)
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_random_state(with_cuda):
+ pth_state = torch.get_rng_state()
+ np_state = np.random.get_state()
+ py_state = random.getstate()
+ if torch.cuda.is_available() and with_cuda:
+ cuda_state = torch.cuda.get_rng_state_all()
+ else:
+ cuda_state = None
+ return pth_state, np_state, py_state, cuda_state
+
+
+def set_random_state(state):
+ pth_state, np_state, py_state, cuda_state = state
+ torch.set_rng_state(pth_state)
+ np.random.set_state(np_state)
+ random.setstate(py_state)
+ if (
+ cuda_state is not None
+ and torch.cuda.is_available()
+ and len(cuda_state) == torch.cuda.device_count()
+ ):
+ torch.cuda.set_rng_state_all(cuda_state)
+
+
+@contextmanager
+def fork_rng(seed=None, with_cuda=True):
+ state = get_random_state(with_cuda)
+ if seed is not None:
+ set_seed(seed)
+ try:
+ yield
+ finally:
+ set_random_state(state)
+
+
+def get_device() -> str:
+ device = "cpu"
+ if torch.cuda.is_available():
+ device = "cuda"
+ elif torch.backends.mps.is_available():
+ device = "mps"
+ return device
diff --git a/siclib/visualization/__init__.py b/siclib/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/siclib/visualization/global_frame.py b/siclib/visualization/global_frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..f89ce91406d442ef388a573effd4f9f423250c0b
--- /dev/null
+++ b/siclib/visualization/global_frame.py
@@ -0,0 +1,282 @@
+import functools
+import traceback
+from copy import deepcopy
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.widgets import Button
+from omegaconf import OmegaConf
+
+from ..datasets.base_dataset import collate
+from ..models.cache_loader import CacheLoader
+from .tools import RadioHideTool
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class GlobalFrame:
+ default_conf = {
+ "x": "???",
+ "y": "???",
+ "diff": False,
+ "child": {},
+ "remove_outliers": False,
+ }
+
+ child_frame = None # MatchFrame
+
+ childs = []
+
+ lines = []
+
+ scatters = {}
+
+ def __init__(self, conf, results, loader, predictions, title=None, child_frame=None):
+ self.child_frame = child_frame
+ if self.child_frame is not None:
+ # We do NOT merge inside the child frame to keep settings across figs
+ self.default_conf["child"] = self.child_frame.default_conf
+
+ self.conf = OmegaConf.merge(self.default_conf, conf)
+ self.results = results
+ self.loader = loader
+ self.predictions = predictions
+ self.metrics = set()
+ for k, v in results.items():
+ self.metrics.update(v.keys())
+ self.metrics = sorted(list(self.metrics))
+
+ self.conf.x = conf["x"] or self.metrics[0]
+ self.conf.y = conf["y"] or self.metrics[1]
+
+ assert self.conf.x in self.metrics
+ assert self.conf.y in self.metrics
+
+ self.names = list(results)
+ self.fig, self.axes = self.init_frame()
+ if title is not None:
+ self.fig.canvas.manager.set_window_title(title)
+
+ self.xradios = self.fig.canvas.manager.toolmanager.add_tool(
+ "x",
+ RadioHideTool,
+ options=self.metrics,
+ callback_fn=self.update_x,
+ active=self.conf.x,
+ keymap="x",
+ )
+
+ self.yradios = self.fig.canvas.manager.toolmanager.add_tool(
+ "y",
+ RadioHideTool,
+ options=self.metrics,
+ callback_fn=self.update_y,
+ active=self.conf.y,
+ keymap="y",
+ )
+ if self.fig.canvas.manager.toolbar is not None:
+ self.fig.canvas.manager.toolbar.add_tool("x", "navigation")
+ self.fig.canvas.manager.toolbar.add_tool("y", "navigation")
+
+ def init_frame(self):
+ """initialize frame"""
+ fig, ax = plt.subplots()
+ ax.set_title("click on points")
+ diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06])
+ self.diffb = Button(diffb_ax, label="diff_only")
+ self.diffb.on_clicked(self.diff_clicked)
+ fig.canvas.mpl_connect("pick_event", self.on_scatter_pick)
+ fig.canvas.mpl_connect("motion_notify_event", self.hover)
+ return fig, ax
+
+ def draw(self):
+ """redraw content in frame"""
+ self.scatters = {}
+ self.axes.clear()
+ self.axes.set_xlabel(self.conf.x)
+ self.axes.set_ylabel(self.conf.y)
+
+ refx = 0.0
+ refy = 0.0
+ x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str))
+ y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str))
+
+ if self.conf.diff:
+ if not x_cat:
+ refx = np.array(self.results[self.names[0]][self.conf.x])
+ if not y_cat:
+ refy = np.array(self.results[self.names[0]][self.conf.y])
+ for name in list(self.results.keys()):
+ x = np.array(self.results[name][self.conf.x])
+ y = np.array(self.results[name][self.conf.y])
+
+ if x_cat and np.char.isdigit(x.astype(str)).all():
+ x = x.astype(int)
+ if y_cat and np.char.isdigit(y.astype(str)).all():
+ y = y.astype(int)
+
+ x = x if x_cat else x - refx
+ y = y if y_cat else y - refy
+
+ (s,) = self.axes.plot(x, y, "o", markersize=3, label=name, picker=True, pickradius=5)
+ self.scatters[name] = s
+
+ if x_cat and not y_cat:
+ xunique, ind, xinv, xbin = np.unique(
+ x, return_inverse=True, return_counts=True, return_index=True
+ )
+ ybin = np.bincount(xinv, weights=y)
+ sort_ax = np.argsort(ind)
+ self.axes.step(
+ xunique[sort_ax],
+ (ybin / xbin)[sort_ax],
+ where="mid",
+ color=s.get_color(),
+ )
+
+ if not x_cat:
+ xavg = np.nan_to_num(x).mean()
+ self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0)
+ xmed = np.median(x - refx)
+ self.axes.axvline(
+ xmed,
+ c=s.get_color(),
+ zorder=0,
+ alpha=0.5,
+ linestyle="dashed",
+ visible=False,
+ )
+
+ if not y_cat:
+ yavg = np.nan_to_num(y).mean()
+ self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5)
+ ymed = np.median(y - refy)
+ self.axes.axhline(
+ ymed,
+ c=s.get_color(),
+ zorder=0,
+ alpha=0.5,
+ linestyle="dashed",
+ visible=False,
+ )
+ if x_cat and x.dtype == object and xunique.shape[0] > 5:
+ self.axes.set_xticklabels(xunique[sort_ax], rotation=90)
+ self.axes.legend()
+
+ def on_scatter_pick(self, handle):
+ try:
+ art = handle.artist
+ try:
+ event = handle.mouseevent.button.value
+ except AttributeError:
+ return
+ name = art.get_label()
+ ind = handle.ind[0]
+ # draw lines
+ self.spawn_child(name, ind, event=event)
+ except Exception:
+ traceback.print_exc()
+ exit(0)
+
+ def spawn_child(self, model_name, ind, event=None):
+ [line.remove() for line in self.lines]
+ self.lines = []
+
+ x_source = self.scatters[model_name].get_xdata()[ind]
+ y_source = self.scatters[model_name].get_ydata()[ind]
+ for oname in self.names:
+ xn = self.scatters[oname].get_xdata()[ind]
+ yn = self.scatters[oname].get_ydata()[ind]
+
+ (ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r")
+ self.lines.append(ln)
+
+ self.fig.canvas.draw_idle()
+
+ if self.child_frame is None:
+ return
+
+ data = collate([self.loader.dataset[ind]])
+
+ preds = {
+ name: CacheLoader({"path": str(pfile), "add_data_path": False})(data)
+ for name, pfile in self.predictions.items()
+ }
+ summaries_i = {
+ name: {k: v[ind] for k, v in res.items() if k != "names"}
+ for name, res in self.results.items()
+ }
+ frame = self.child_frame(
+ self.conf.child,
+ deepcopy(data),
+ preds,
+ title=str(data["name"][0]),
+ event=event,
+ summaries=summaries_i,
+ )
+
+ frame.fig.canvas.mpl_connect(
+ "key_press_event",
+ functools.partial(self.on_childframe_key_event, frame=frame, ind=ind, event=event),
+ )
+ self.childs.append(frame)
+ self.childs[-1].fig.show()
+
+ def hover(self, event):
+ if event.inaxes != self.axes:
+ return
+
+ for _, s in self.scatters.items():
+ cont, ind = s.contains(event)
+ if cont:
+ ind = ind["ind"][0]
+ xdata, ydata = s.get_data()
+ [line.remove() for line in self.lines]
+ self.lines = []
+
+ for oname in self.names:
+ xn = self.scatters[oname].get_xdata()[ind]
+ yn = self.scatters[oname].get_ydata()[ind]
+
+ (ln,) = self.axes.plot(
+ [xdata[ind], xn],
+ [ydata[ind], yn],
+ "black",
+ zorder=0,
+ alpha=0.5,
+ )
+ self.lines.append(ln)
+ self.fig.canvas.draw_idle()
+ break
+
+ def diff_clicked(self, args):
+ self.conf.diff = not self.conf.diff
+ self.draw()
+ self.fig.canvas.draw_idle()
+
+ def update_x(self, x):
+ self.conf.x = x
+ self.draw()
+
+ def update_y(self, y):
+ self.conf.y = y
+ self.draw()
+
+ def on_childframe_key_event(self, key_event, frame, ind, event):
+ if key_event.key == "delete":
+ plt.close(frame.fig)
+ self.childs.remove(frame)
+ elif key_event.key in ["left", "right", "shift+left", "shift+right"]:
+ key = key_event.key
+ if key.startswith("shift+"):
+ key = key.replace("shift+", "")
+ else:
+ plt.close(frame.fig)
+ self.childs.remove(frame)
+ new_ind = ind + 1 if key_event.key == "right" else ind - 1
+ self.spawn_child(
+ self.names[0],
+ new_ind % len(self.loader),
+ event=event,
+ )
diff --git a/siclib/visualization/tools.py b/siclib/visualization/tools.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f28b9cb47140fb63a56b10baaec9fc7e02438ef
--- /dev/null
+++ b/siclib/visualization/tools.py
@@ -0,0 +1,472 @@
+import inspect
+import sys
+import warnings
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+from matplotlib.backend_tools import ToolToggleBase
+from matplotlib.widgets import Button, RadioButtons
+
+from siclib.geometry.camera import SimpleRadial as Camera
+from siclib.geometry.gravity import Gravity
+from siclib.geometry.perspective_fields import (
+ get_latitude_field,
+ get_perspective_field,
+ get_up_field,
+)
+from siclib.models.utils.metrics import latitude_error, up_error
+from siclib.utils.conversions import rad2deg
+from siclib.visualization.viz2d import (
+ add_text,
+ plot_confidences,
+ plot_heatmaps,
+ plot_horizon_lines,
+ plot_latitudes,
+ plot_vector_fields,
+)
+
+# flake8: noqa
+# mypy: ignore-errors
+
+with warnings.catch_warnings():
+ warnings.simplefilter("ignore")
+ plt.rcParams["toolbar"] = "toolmanager"
+
+
+class RadioHideTool(ToolToggleBase):
+ """Show lines with a given gid."""
+
+ default_keymap = "R"
+ description = "Show by gid"
+ default_toggled = False
+ radio_group = "default"
+
+ def __init__(self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.f = 1.0
+ self.options = options
+ self.callback_fn = callback_fn
+ self.active = self.options.index(active) if active else 0
+ self.default_keymap = keymap
+
+ self.enabled = self.default_toggled
+
+ def build_radios(self):
+ w = 0.2
+ self.radios_ax = self.figure.add_axes([1.0 - w, 0.4, w, 0.5], zorder=1)
+ # self.radios_ax = self.figure.add_axes([0.5-w/2, 1.0-0.2, w, 0.2], zorder=1)
+ self.radios = RadioButtons(self.radios_ax, self.options, active=self.active)
+ self.radios.on_clicked(self.on_radio_clicked)
+
+ def enable(self, *args):
+ size = self.figure.get_size_inches()
+ size[0] *= self.f
+ self.build_radios()
+ self.figure.canvas.draw_idle()
+ self.enabled = True
+
+ def disable(self, *args):
+ size = self.figure.get_size_inches()
+ size[0] /= self.f
+ self.radios_ax.remove()
+ self.radios = None
+ self.figure.canvas.draw_idle()
+ self.enabled = False
+
+ def on_radio_clicked(self, value):
+ self.active = self.options.index(value)
+ enabled = self.enabled
+ if enabled:
+ self.disable()
+ if self.callback_fn is not None:
+ self.callback_fn(value)
+ if enabled:
+ self.enable()
+
+
+class ToggleTool(ToolToggleBase):
+ """Show lines with a given gid."""
+
+ default_keymap = "t"
+ description = "Show by gid"
+
+ def __init__(self, *args, callback_fn=None, keymap="t", **kwargs):
+ super().__init__(*args, **kwargs)
+ self.f = 1.0
+ self.callback_fn = callback_fn
+ self.default_keymap = keymap
+ self.enabled = self.default_toggled
+
+ def enable(self, *args):
+ self.callback_fn(True)
+
+ def disable(self, *args):
+ self.callback_fn(False)
+
+
+def add_whitespace_left(fig, factor):
+ w, h = fig.get_size_inches()
+ left = fig.subplotpars.left
+ fig.set_size_inches([w * (1 + factor), h])
+ fig.subplots_adjust(left=(factor + left) / (1 + factor))
+
+
+def add_whitespace_bottom(fig, factor):
+ w, h = fig.get_size_inches()
+ b = fig.subplotpars.bottom
+ fig.set_size_inches([w, h * (1 + factor)])
+ fig.subplots_adjust(bottom=(factor + b) / (1 + factor))
+ fig.canvas.draw_idle()
+
+
+class ImagePlot:
+ plot_name = "image"
+ required_keys = ["image"]
+
+ def __init__(self, fig, axes, data, preds):
+ pass
+
+
+class HorizonLinePlot:
+ plot_name = "horizon_line"
+ required_keys = ["camera", "gravity"]
+
+ def __init__(self, fig, axes, data, preds):
+ for idx, name in enumerate(preds):
+ pred = preds[name]
+ gt_cam = data["camera"][0].detach().cpu()
+ gt_gravity = data["gravity"][0].detach().cpu()
+ plot_horizon_lines([gt_cam], [gt_gravity], line_colors="r", ax=[axes[0][idx]])
+
+ if "camera" in pred and "gravity" in pred:
+ pred_cam = Camera(pred["camera"][0].detach().cpu())
+ gravity = Gravity(pred["gravity"][0].detach().cpu())
+ plot_horizon_lines([pred_cam], [gravity], line_colors="yellow", ax=[axes[0][idx]])
+
+
+class LatitudePlot:
+ plot_name = "latitude"
+ required_keys = ["latitude_field"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ self.gt_mode = False # Flag to track whether to display ground truth or predicted
+ self.text_objects = [] # To store text objects
+
+ self.fig = fig
+ self.axes = axes
+ self.data = data
+ self.preds = preds
+
+ # Create a toggle button on the lower left corner of the first axis
+ self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06])
+ self.button = Button(self.ax_button, "Toggle GT")
+ self.button.on_clicked(self.toggle_display)
+
+ self.update_plot()
+
+ def toggle_display(self, event):
+ # Toggle between ground truth and predicted latitudes
+ self.gt_mode = not self.gt_mode
+ self.update_plot()
+
+ def update_plot(self):
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+ for idx, name in enumerate(self.preds):
+ pred = self.preds[name]
+
+ if self.gt_mode:
+ latitude = self.data["latitude_field"][0][0]
+ text = "\nGT"
+ else:
+ if "latitude_field" not in pred:
+ continue
+ latitude = pred["latitude_field"][0][0]
+ text = "\nPrediction"
+
+ self.artists += plot_latitudes([latitude], axes=[self.axes[0][idx]])
+
+ self.text_objects.append(add_text(idx, text))
+
+ # Update the plot
+ self.fig.canvas.draw()
+
+ def clear(self):
+ # Remove the button
+ self.button.disconnect_events()
+ self.ax_button.remove()
+
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+
+class LatitudeErrorPlot:
+ plot_name = "latitude_error"
+ required_keys = ["latitude_field"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ for idx, name in enumerate(preds):
+ pred = preds[name]
+ gt = data["latitude_field"].detach().cpu()
+
+ if "latitude_field" in pred:
+ lat = pred["latitude_field"].detach().cpu()
+ error = latitude_error(lat, gt)[0].numpy()
+
+ if "latitude_confidence" in pred:
+ confidence = pred["latitude_confidence"].detach().cpu().numpy()
+ confidence = np.log10(confidence).clip(-5)
+ confidence = (confidence + 5) / (confidence.max() + 5)
+ arts = plot_heatmaps(
+ [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence
+ )
+ else:
+ arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True)
+ self.artists += arts
+
+ def clear(self):
+ for x in self.artists:
+ x.remove()
+ x.colorbar.remove()
+
+ self.artists = []
+
+
+class LatitudeConfidencePlot:
+ plot_name = "latitude_confidence"
+ required_keys = []
+ # required_keys = ["latitude_confidence"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ for idx, name in enumerate(preds):
+ pred = preds[name]
+
+ if "latitude_confidence" in pred:
+ arts = plot_confidences([pred["latitude_confidence"][0]], axes=[axes[0][idx]])
+ self.artists += arts
+
+ def clear(self):
+ for x in self.artists:
+ x.remove()
+ x.colorbar.remove()
+
+ self.artists = []
+
+
+class UpPlot:
+ plot_name = "up"
+ required_keys = ["up_field"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ self.gt_mode = False # Flag to track whether to display ground truth or predicted
+ self.text_objects = [] # To store text objects
+
+ self.fig = fig
+ self.axes = axes
+ self.data = data
+ self.preds = preds
+
+ # Create a toggle button on the lower left corner of the first axis
+ self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06])
+ self.button = Button(self.ax_button, "Toggle GT")
+ self.button.on_clicked(self.toggle_display)
+
+ self.update_plot()
+
+ def toggle_display(self, event):
+ # Toggle between ground truth and predicted latitudes
+ self.gt_mode = not self.gt_mode
+ self.update_plot()
+
+ def update_plot(self):
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+ for idx, name in enumerate(self.preds):
+ pred = self.preds[name]
+
+ if self.gt_mode:
+ up = self.data["up_field"][0]
+ text = "\nGT"
+ else:
+ if "up_field" not in pred:
+ continue
+ up = pred["up_field"][0]
+ text = "\nPrediction"
+
+ # Plot up
+ self.artists += plot_vector_fields([up], axes=[self.axes[0][idx]])
+
+ self.text_objects.append(add_text(idx, text))
+
+ # Update the plot
+ self.fig.canvas.draw()
+
+ def clear(self):
+ # Remove the button
+ self.button.disconnect_events()
+ self.ax_button.remove()
+
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+
+class UpErrorPlot:
+ plot_name = "up_error"
+ required_keys = ["up_field"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ for idx, name in enumerate(preds):
+ pred = preds[name]
+ gt = data["up_field"].detach().cpu()
+
+ if "up_field" in pred:
+ up = pred["up_field"].detach().cpu()
+ error = up_error(up, gt)[0].numpy()
+
+ if "up_confidence" in pred:
+ confidence = pred["up_confidence"].detach().cpu().numpy()
+ confidence = np.log10(confidence).clip(-5)
+ confidence = (confidence + 5) / (confidence.max() + 5)
+ arts = plot_heatmaps(
+ [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence
+ )
+ else:
+ arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True)
+ self.artists += arts
+
+ def clear(self):
+ for x in self.artists:
+ x.remove()
+ x.colorbar.remove()
+
+ self.artists = []
+
+
+class UpConfidencePlot:
+ plot_name = "up_confidence"
+ required_keys = []
+ # required_keys = ["up_confidence"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ for idx, name in enumerate(preds):
+ pred = preds[name]
+
+ if "up_confidence" in pred:
+ arts = plot_confidences([pred["up_confidence"][0]], axes=[axes[0][idx]])
+ self.artists += arts
+
+ def clear(self):
+ for x in self.artists:
+ x.remove()
+ x.colorbar.remove()
+
+ self.artists = []
+
+
+class PerspectiveField:
+ plot_name = "perspective_field"
+ required_keys = ["camera", "gravity"]
+
+ def __init__(self, fig, axes, data, preds):
+ self.artists = []
+ self.gt_mode = False # Flag to track whether to display ground truth or predicted
+ self.text_objects = [] # To store text objects
+
+ self.fig = fig
+ self.axes = axes
+ self.data = data
+ self.preds = preds
+
+ # Create a toggle button on the lower left corner of the first axis
+ self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06])
+ self.button = Button(self.ax_button, "Toggle GT")
+ self.button.on_clicked(self.toggle_display)
+
+ self.update_plot()
+
+ def toggle_display(self, event):
+ # Toggle between ground truth and predicted latitudes
+ self.gt_mode = not self.gt_mode
+ self.update_plot()
+
+ def update_plot(self):
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+ for idx, name in enumerate(self.preds):
+ pred = self.preds[name]
+
+ if self.gt_mode:
+ camera = self.data["camera"]
+ gravity = self.data["gravity"]
+ text = "\nGT"
+ else:
+ camera = pred["camera"]
+ gravity = pred["gravity"]
+ text = "\nPrediction"
+ camera = Camera(camera)
+ gravity = Gravity(gravity)
+
+ up, latitude = get_perspective_field(camera, gravity)
+
+ self.artists += plot_latitudes([latitude[0, 0]], axes=[self.axes[0][idx]])
+ self.artists += plot_vector_fields([up[0]], axes=[self.axes[0][idx]])
+
+ self.text_objects.append(add_text(idx, text))
+
+ # Update the plot
+ self.fig.canvas.draw()
+
+ def clear(self):
+ # Remove the button
+ self.button.disconnect_events()
+ self.ax_button.remove()
+
+ for x in self.artists:
+ x.remove()
+ for text in self.text_objects:
+ text.remove()
+
+ self.artists = []
+ self.text_objects = []
+
+
+__plot_dict__ = {
+ obj.plot_name: obj
+ for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass)
+ if hasattr(obj, "plot_name")
+}
diff --git a/siclib/visualization/two_view_frame.py b/siclib/visualization/two_view_frame.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff49861d536d0455f9bbf44b318ec37e18987e7
--- /dev/null
+++ b/siclib/visualization/two_view_frame.py
@@ -0,0 +1,139 @@
+import pprint
+
+import numpy as np
+
+from . import viz2d
+from .tools import RadioHideTool, ToggleTool, __plot_dict__
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+class FormatPrinter(pprint.PrettyPrinter):
+ def __init__(self, formats):
+ super(FormatPrinter, self).__init__()
+ self.formats = formats
+
+ def format(self, obj, ctx, maxlvl, lvl):
+ if type(obj) in self.formats:
+ return self.formats[type(obj)] % obj, 1, 0
+ return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl)
+
+
+class TwoViewFrame:
+ default_conf = {
+ "default": "image",
+ "summary_visible": False,
+ }
+
+ plot_dict = __plot_dict__
+
+ childs = []
+
+ event_to_image = [None, "image", "horizon_line", "lat_pred", "lat_gt"]
+
+ def __init__(self, conf, data, preds, title=None, event=1, summaries=None):
+ self.conf = conf
+ self.data = data
+ self.preds = preds
+ self.names = list(preds.keys())
+ self.plot = self.event_to_image[event]
+ self.summaries = summaries
+ self.fig, self.axes, self.summary_arts = self.init_frame()
+ if title is not None:
+ self.fig.canvas.manager.set_window_title(title)
+
+ keys = None
+ for _, pred in preds.items():
+ keys = set(pred.keys()) if keys is None else keys.intersection(pred.keys())
+
+ keys = keys.union(data.keys())
+
+ self.options = [k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)]
+ self.handle = None
+ self.radios = self.fig.canvas.manager.toolmanager.add_tool(
+ "switch plot",
+ RadioHideTool,
+ options=self.options,
+ callback_fn=self.draw,
+ active=conf.default,
+ keymap="R",
+ )
+
+ self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool(
+ "toggle summary",
+ ToggleTool,
+ toggled=self.conf.summary_visible,
+ callback_fn=self.set_summary_visible,
+ keymap="t",
+ )
+
+ if self.fig.canvas.manager.toolbar is not None:
+ self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation")
+ self.draw(conf.default)
+
+ def init_frame(self):
+ """initialize frame"""
+ imgs = [[self.data["image"][0].permute(1, 2, 0) for _ in self.names]]
+ # imgs = [imgs for _ in self.names] # repeat for each model
+
+ fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5)
+ [viz2d.add_text(i, n, axes=axes[0]) for i, n in enumerate(self.names)]
+
+ fig.canvas.mpl_connect("pick_event", self.click_artist)
+ if self.summaries is not None:
+ font_size = 7
+ formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"})
+ toggle_artists = [
+ viz2d.add_text(
+ i,
+ formatter.pformat(self.summaries[n]),
+ axes=axes[0],
+ pos=(0.01, 0.01),
+ va="bottom",
+ backgroundcolor=(0, 0, 0, 0.5),
+ visible=self.conf.summary_visible,
+ fs=font_size,
+ )
+ for i, n in enumerate(self.names)
+ ]
+ else:
+ toggle_artists = []
+ return fig, axes, toggle_artists
+
+ def draw(self, value):
+ """redraw content in frame"""
+ self.clear()
+ self.conf.default = value
+ self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds)
+ return self.handle
+
+ def clear(self):
+ if self.handle is not None:
+ try:
+ self.handle.clear()
+ except AttributeError:
+ pass
+ self.handle = None
+ for row in self.axes:
+ for ax in row:
+ [li.remove() for li in ax.lines]
+ [c.remove() for c in ax.collections]
+ self.fig.artists.clear()
+ self.fig.canvas.draw_idle()
+ self.handle = None
+
+ def click_artist(self, event):
+ art = event.artist
+ select = art.get_arrowstyle().arrow == "-"
+ art.set_arrowstyle("<|-|>" if select else "-")
+ if select:
+ art.set_zorder(1)
+ if hasattr(self.handle, "click_artist"):
+ self.handle.click_artist(event)
+ self.fig.canvas.draw_idle()
+
+ def set_summary_visible(self, visible):
+ self.conf.summary_visible = visible
+ [s.set_visible(visible) for s in self.summary_arts]
+ self.fig.canvas.draw_idle()
diff --git a/siclib/visualization/visualize_batch.py b/siclib/visualization/visualize_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..47d59d2bde94d4af4821165351714f1d2ab6be21
--- /dev/null
+++ b/siclib/visualization/visualize_batch.py
@@ -0,0 +1,191 @@
+"""Visualization of predicted and ground truth for a single batch."""
+
+from typing import Any, Dict
+
+import numpy as np
+import torch
+
+from siclib.geometry.perspective_fields import get_latitude_field
+from siclib.models.utils.metrics import latitude_error, up_error
+from siclib.utils.conversions import rad2deg
+from siclib.utils.tensor import batch_to_device
+from siclib.visualization.viz2d import (
+ plot_confidences,
+ plot_heatmaps,
+ plot_image_grid,
+ plot_latitudes,
+ plot_vector_fields,
+)
+
+
+def make_up_figure(
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
+) -> Dict[str, Any]:
+ """Get predicted and ground truth up fields and errors.
+
+ Args:
+ pred (Dict[str, torch.Tensor]): Predicted up field.
+ data (Dict[str, torch.Tensor]): Ground truth up field.
+ n_pairs (int): Number of pairs to visualize.
+
+ Returns:
+ Dict[str, Any]: Dictionary with figure.
+ """
+ pred = batch_to_device(pred, "cpu", detach=True)
+ data = batch_to_device(data, "cpu", detach=True)
+
+ n_pairs = min(n_pairs, len(data["image"]))
+
+ if "up_field" not in pred.keys():
+ return {}
+
+ errors = up_error(pred["up_field"], data["up_field"])
+
+ up_fields = []
+ for i in range(n_pairs):
+ row = [data["up_field"][i], pred["up_field"][i], errors[i]]
+ titles = ["Up GT", "Up Pred", "Up Error"]
+
+ if "up_confidence" in pred.keys():
+ row += [pred["up_confidence"][i]]
+ titles += ["Up Confidence"]
+
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
+ up_fields.append(row)
+
+ # create figure
+ N, M = len(up_fields), len(up_fields[0]) + 1
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
+ fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
+ ax = np.array(ax)
+
+ for i in range(n_pairs):
+ plot_vector_fields(up_fields[i][:2], axes=ax[i, [1, 2]])
+ plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
+
+ if "up_confidence" in pred.keys():
+ plot_confidences([up_fields[i][3]], axes=ax[i, [4]])
+
+ return {"up": fig}
+
+
+def make_latitude_figure(
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
+) -> Dict[str, Any]:
+ """Get predicted and ground truth latitude fields and errors.
+
+ Args:
+ pred (Dict[str, torch.Tensor]): Predicted latitude field.
+ data (Dict[str, torch.Tensor]): Ground truth latitude field.
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
+
+ Returns:
+ Dict[str, Any]: Dictionary with figure.
+ """
+ pred = batch_to_device(pred, "cpu", detach=True)
+ data = batch_to_device(data, "cpu", detach=True)
+
+ n_pairs = min(n_pairs, len(data["image"]))
+ latitude_fields = []
+
+ if "latitude_field" not in pred.keys():
+ return {}
+
+ errors = latitude_error(pred["latitude_field"], data["latitude_field"])
+ for i in range(n_pairs):
+ row = [
+ rad2deg(data["latitude_field"][i][0]),
+ rad2deg(pred["latitude_field"][i][0]),
+ errors[i],
+ ]
+ titles = ["Latitude GT", "Latitude Pred", "Latitude Error"]
+
+ if "latitude_confidence" in pred.keys():
+ row += [pred["latitude_confidence"][i]]
+ titles += ["Latitude Confidence"]
+
+ row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row]
+ latitude_fields.append(row)
+
+ # create figure
+ N, M = len(latitude_fields), len(latitude_fields[0]) + 1
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
+ fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
+ ax = np.array(ax)
+
+ for i in range(n_pairs):
+ plot_latitudes(latitude_fields[i][:2], is_radians=False, axes=ax[i, [1, 2]])
+ plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]])
+
+ if "latitude_confidence" in pred.keys():
+ plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]])
+
+ return {"latitude": fig}
+
+
+def make_camera_figure(
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
+) -> Dict[str, Any]:
+ """Get predicted and ground truth camera parameters.
+
+ Args:
+ pred (Dict[str, torch.Tensor]): Predicted camera parameters.
+ data (Dict[str, torch.Tensor]): Ground truth camera parameters.
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
+
+ Returns:
+ Dict[str, Any]: Dictionary with figure.
+ """
+ pred = batch_to_device(pred, "cpu", detach=True)
+ data = batch_to_device(data, "cpu", detach=True)
+
+ n_pairs = min(n_pairs, len(data["image"]))
+
+ if "camera" not in pred.keys():
+ return {}
+
+ latitudes = []
+ for i in range(n_pairs):
+ titles = ["Cameras GT"]
+ row = [get_latitude_field(data["camera"][i], data["gravity"][i])]
+
+ if "camera" in pred.keys() and "gravity" in pred.keys():
+ row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])]
+ titles += ["Cameras Pred"]
+
+ row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row]
+ latitudes.append(row)
+
+ # create figure
+ N, M = len(latitudes), len(latitudes[0]) + 1
+ imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)]
+ fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True)
+ ax = np.array(ax)
+
+ for i in range(n_pairs):
+ plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:])
+
+ return {"camera": fig}
+
+
+def make_perspective_figures(
+ pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2
+) -> Dict[str, Any]:
+ """Get predicted and ground truth perspective fields.
+
+ Args:
+ pred (Dict[str, torch.Tensor]): Predicted perspective fields.
+ data (Dict[str, torch.Tensor]): Ground truth perspective fields.
+ n_pairs (int, optional): Number of pairs to visualize. Defaults to 2.
+
+ Returns:
+ Dict[str, Any]: Dictionary with figure.
+ """
+ n_pairs = min(n_pairs, len(data["image"]))
+ figures = make_up_figure(pred, data, n_pairs)
+ figures |= make_latitude_figure(pred, data, n_pairs)
+ figures |= make_camera_figure(pred, data, n_pairs)
+
+ {f.tight_layout() for f in figures.values()}
+
+ return figures
diff --git a/siclib/visualization/viz2d.py b/siclib/visualization/viz2d.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a7d1773a7416e59d91c6be9bc0adbf58e29093b
--- /dev/null
+++ b/siclib/visualization/viz2d.py
@@ -0,0 +1,520 @@
+"""
+2D visualization primitives based on Matplotlib.
+1) Plot images with `plot_images`.
+2) Call TODO: add functions
+3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
+"""
+
+import matplotlib.patheffects as path_effects
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+
+from siclib.geometry.perspective_fields import get_perspective_field
+from siclib.utils.conversions import rad2deg
+
+# flake8: noqa
+# mypy: ignore-errors
+
+
+def cm_ranking(sc, ths=None):
+ if ths is None:
+ ths = [512, 1024, 2048, 4096]
+
+ ls = sc.shape[0]
+ colors = ["red", "yellow", "lime", "cyan", "blue"]
+ out = ["gray"] * ls
+ for i in range(ls):
+ for c, th in zip(colors[: len(ths) + 1], ths + [ls]):
+ if i < th:
+ out[i] = c
+ break
+ sid = np.argsort(sc, axis=0).flip(0)
+ return np.array(out)[sid]
+
+
+def cm_RdBl(x):
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+ x = np.clip(x, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+ return np.clip(c, 0, 1)
+
+
+def cm_RdGn(x):
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
+ x = np.clip(x, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
+ return np.clip(c, 0, 1)
+
+
+def cm_BlRdGn(x_):
+ """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
+ x = np.clip(x_, 0, 1)[..., None] * 2
+ c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
+
+ xn = -np.clip(x_, -1, 0)[..., None] * 2
+ cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
+ return np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
+
+
+def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True):
+ """Plot a list of images.
+
+ Args:
+ imgs (List[np.ndarray]): List of images to plot.
+ titles (List[str], optional): Titles. Defaults to None.
+ cmaps (str, optional): Colormaps. Defaults to "gray".
+ dpi (int, optional): Dots per inch. Defaults to 200.
+ pad (float, optional): Padding. Defaults to 0.5.
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
+
+ Returns:
+ plt.Figure: Figure of the images.
+ """
+ n = len(imgs)
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n
+ figsize = [sum(ratios) * 4.5, 4.5]
+ fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios})
+ if n == 1:
+ axs = [axs]
+ for i, (img, ax) in enumerate(zip(imgs, axs)):
+ ax.imshow(img, cmap=plt.get_cmap(cmaps[i]))
+ ax.set_axis_off()
+ if titles:
+ ax.set_title(titles[i])
+ fig.tight_layout(pad=pad)
+
+ return fig
+
+
+def plot_image_grid(
+ imgs,
+ titles=None,
+ cmaps="gray",
+ dpi=100,
+ pad=0.5,
+ fig=None,
+ adaptive=True,
+ figs=3.0,
+ return_fig=False,
+ set_lim=False,
+) -> plt.Figure:
+ """Plot a grid of images.
+
+ Args:
+ imgs (List[np.ndarray]): List of images to plot.
+ titles (List[str], optional): Titles. Defaults to None.
+ cmaps (str, optional): Colormaps. Defaults to "gray".
+ dpi (int, optional): Dots per inch. Defaults to 100.
+ pad (float, optional): Padding. Defaults to 0.5.
+ fig (_type_, optional): Figure to plot on. Defaults to None.
+ adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True.
+ figs (float, optional): Figure size. Defaults to 3.0.
+ return_fig (bool, optional): Whether to return the figure. Defaults to False.
+ set_lim (bool, optional): Whether to set the limits. Defaults to False.
+
+ Returns:
+ plt.Figure: Figure and axes or just axes.
+ """
+ nr, n = len(imgs), len(imgs[0])
+ if not isinstance(cmaps, (list, tuple)):
+ cmaps = [cmaps] * n
+
+ if adaptive:
+ ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H
+ else:
+ ratios = [4 / 3] * n
+
+ figsize = [sum(ratios) * figs, nr * figs]
+ if fig is None:
+ fig, axs = plt.subplots(
+ nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
+ )
+ else:
+ axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios})
+ fig.figure.set_size_inches(figsize)
+
+ if nr == 1 and n == 1:
+ axs = [[axs]]
+ elif n == 1:
+ axs = axs[:, None]
+ elif nr == 1:
+ axs = [axs]
+
+ for j in range(nr):
+ for i in range(n):
+ ax = axs[j][i]
+ ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i]))
+ ax.set_axis_off()
+ if set_lim:
+ ax.set_xlim([0, imgs[j][i].shape[1]])
+ ax.set_ylim([imgs[j][i].shape[0], 0])
+ if titles:
+ ax.set_title(titles[j][i])
+ if isinstance(fig, plt.Figure):
+ fig.tight_layout(pad=pad)
+ return (fig, axs) if return_fig else axs
+
+
+def add_text(
+ idx,
+ text,
+ pos=(0.01, 0.99),
+ fs=15,
+ color="w",
+ lcolor="k",
+ lwidth=4,
+ ha="left",
+ va="top",
+ axes=None,
+ **kwargs,
+):
+ """Add text to a plot.
+
+ Args:
+ idx (int): Index of the axes.
+ text (str): Text to add.
+ pos (tuple, optional): Text position. Defaults to (0.01, 0.99).
+ fs (int, optional): Font size. Defaults to 15.
+ color (str, optional): Text color. Defaults to "w".
+ lcolor (str, optional): Line color. Defaults to "k".
+ lwidth (int, optional): Line width. Defaults to 4.
+ ha (str, optional): Horizontal alignment. Defaults to "left".
+ va (str, optional): Vertical alignment. Defaults to "top".
+ axes (List[plt.Axes], optional): Axes to put text on. Defaults to None.
+
+ Returns:
+ plt.Text: Text object.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ ax = axes[idx]
+
+ t = ax.text(
+ *pos,
+ text,
+ fontsize=fs,
+ ha=ha,
+ va=va,
+ color=color,
+ transform=ax.transAxes,
+ zorder=5,
+ **kwargs,
+ )
+ if lcolor is not None:
+ t.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
+ path_effects.Normal(),
+ ]
+ )
+ return t
+
+
+def plot_heatmaps(
+ heatmaps,
+ vmin=-1e-6, # include negative zero
+ vmax=None,
+ cmap="Spectral",
+ a=0.5,
+ axes=None,
+ contours_every=None,
+ contour_style="solid",
+ colorbar=False,
+):
+ """Plot heatmaps with optional contours.
+
+ To plot latitude field, set vmin=-90, vmax=90 and contours_every=15.
+
+ Args:
+ heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps.
+ vmin (float, optional): Min Value. Defaults to -1e-6.
+ vmax (float, optional): Max Value. Defaults to None.
+ cmap (str, optional): Colormap. Defaults to "Spectral".
+ a (float, optional): Alpha value. Defaults to 0.5.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+ contours_every (int, optional): If not none, will draw contours. Defaults to None.
+ contour_style (str, optional): Style of the contours. Defaults to "solid".
+ colorbar (bool, optional): Whether to show colorbar. Defaults to False.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+ artists = []
+
+ for i in range(len(axes)):
+ a_ = a if isinstance(a, float) else a[i]
+
+ if isinstance(heatmaps[i], torch.Tensor):
+ heatmaps[i] = heatmaps[i].detach().cpu().numpy()
+
+ alpha = a_
+ # Plot the heatmap
+ art = axes[i].imshow(
+ heatmaps[i],
+ alpha=alpha,
+ vmin=vmin,
+ vmax=vmax,
+ cmap=cmap,
+ )
+ if colorbar:
+ cmax = vmax or np.percentile(heatmaps[i], 99)
+ art.set_clim(vmin, cmax)
+ cbar = plt.colorbar(art, ax=axes[i])
+ artists.append(cbar)
+
+ artists.append(art)
+
+ if contours_every is not None:
+ # Add contour lines to the heatmap
+ contour_data = np.arange(vmin, vmax + contours_every, contours_every)
+
+ # Get the colormap colors for contour lines
+ contour_colors = [
+ plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level))
+ for level in contour_data
+ ]
+ contours = axes[i].contour(
+ heatmaps[i],
+ levels=contour_data,
+ linewidths=2,
+ colors=contour_colors,
+ linestyles=contour_style,
+ )
+
+ contours.set_clim(vmin, vmax)
+
+ fmt = {
+ level: f"{label}°"
+ for level, label in zip(contour_data, contour_data.astype(int).astype(str))
+ }
+ t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white")
+
+ for label in t:
+ label.set_path_effects(
+ [
+ path_effects.Stroke(linewidth=1, foreground="k"),
+ path_effects.Normal(),
+ ]
+ )
+ artists.append(contours)
+
+ return artists
+
+
+def plot_horizon_lines(
+ cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None
+):
+ """Plot horizon lines on the perspective field.
+
+ Args:
+ cameras (List[Camera]): List of cameras.
+ gravities (List[Gravity]): Gravities.
+ line_colors (str, optional): Line Colors. Defaults to "orange".
+ lw (int, optional): Line width. Defaults to 2.
+ styles (str, optional): Line styles. Defaults to "solid".
+ alpha (float, optional): Alphas. Defaults to 1.0.
+ ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None.
+ """
+ if not isinstance(line_colors, list):
+ line_colors = [line_colors] * len(cameras)
+
+ if not isinstance(styles, list):
+ styles = [styles] * len(cameras)
+
+ fig = plt.gcf()
+ ax = fig.gca() if ax is None else ax
+
+ if isinstance(ax, plt.Axes):
+ ax = [ax] * len(cameras)
+
+ assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}"
+
+ for i in range(len(cameras)):
+ _, lat = get_perspective_field(cameras[i], gravities[i])
+ # horizon line is zero level of the latitude field
+ lat = lat[0, 0].cpu().numpy()
+ contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i])
+ for contour_line in contours.collections:
+ contour_line.set_linestyle(styles[i])
+
+
+def plot_vector_fields(
+ vector_fields,
+ cmap="lime",
+ subsample=15,
+ scale=None,
+ lw=None,
+ alphas=0.8,
+ axes=None,
+):
+ """Plot vector fields.
+
+ Args:
+ vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W).
+ cmap (str, optional): Color of the vectors. Defaults to "lime".
+ subsample (int, optional): Subsample the vector field. Defaults to 15.
+ scale (float, optional): Scale of the vectors. Defaults to None.
+ lw (float, optional): Line width of the vectors. Defaults to None.
+ alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8.
+ axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields]
+
+ artists = []
+
+ H, W = vector_fields[0].shape[-2:]
+ if scale is None:
+ scale = subsample / min(H, W)
+
+ if lw is None:
+ lw = 0.1 / subsample
+
+ if alphas is None:
+ alphas = np.ones_like(vector_fields[0][0])
+ alphas = np.stack([alphas] * len(vector_fields), 0)
+ elif isinstance(alphas, float):
+ alphas = np.ones_like(vector_fields[0][0]) * alphas
+ alphas = np.stack([alphas] * len(vector_fields), 0)
+ else:
+ alphas = np.array(alphas)
+
+ subsample = min(W, H) // subsample
+ offset_x = ((W % subsample) + subsample) // 2
+
+ samples_x = np.arange(offset_x, W, subsample)
+ samples_y = np.arange(int(subsample * 0.9), H, subsample)
+
+ x_grid, y_grid = np.meshgrid(samples_x, samples_y)
+
+ for i in range(len(axes)):
+ # vector field of shape (2, H, W) with vectors of norm == 1
+ vector_field = vector_fields[i]
+
+ a = alphas[i][samples_y][:, samples_x]
+ x, y = vector_field[:, samples_y][:, :, samples_x]
+
+ c = cmap
+ if not isinstance(cmap, str):
+ c = cmap[i][samples_y][:, samples_x].reshape(-1, 3)
+
+ s = scale * min(H, W)
+ arrows = axes[i].quiver(
+ x_grid,
+ y_grid,
+ x,
+ y,
+ scale=s,
+ scale_units="width" if H > W else "height",
+ units="width" if H > W else "height",
+ alpha=a,
+ color=c,
+ angles="xy",
+ antialiased=True,
+ width=lw,
+ headaxislength=3.5,
+ zorder=5,
+ )
+
+ artists.append(arrows)
+
+ return artists
+
+
+def plot_latitudes(
+ latitude,
+ is_radians=True,
+ vmin=-90,
+ vmax=90,
+ cmap="seismic",
+ contours_every=15,
+ alpha=0.4,
+ axes=None,
+ **kwargs,
+):
+ """Plot latitudes.
+
+ Args:
+ latitude (List[torch.Tensor]): List of latitudes.
+ is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True.
+ vmin (int, optional): Min value to clip to. Defaults to -90.
+ vmax (int, optional): Max value to clip to. Defaults to 90.
+ cmap (str, optional): Colormap. Defaults to "seismic".
+ contours_every (int, optional): Contours every. Defaults to 15.
+ alpha (float, optional): Alpha value. Defaults to 0.4.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}"
+ lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude
+ return plot_heatmaps(
+ lat,
+ vmin=vmin,
+ vmax=vmax,
+ cmap=cmap,
+ a=alpha,
+ axes=axes,
+ contours_every=contours_every,
+ **kwargs,
+ )
+
+
+def plot_confidences(
+ confidence,
+ as_log=True,
+ vmin=-4,
+ vmax=0,
+ cmap="turbo",
+ alpha=0.4,
+ axes=None,
+ **kwargs,
+):
+ """Plot confidences.
+
+ Args:
+ confidence (List[torch.Tensor]): Confidence maps.
+ as_log (bool, optional): Whether to plot in log scale. Defaults to True.
+ vmin (int, optional): Min value to clip to. Defaults to -4.
+ vmax (int, optional): Max value to clip to. Defaults to 0.
+ cmap (str, optional): Colormap. Defaults to "turbo".
+ alpha (float, optional): Alpha value. Defaults to 0.4.
+ axes (List[plt.Axes], optional): Axes to plot on. Defaults to None.
+
+ Returns:
+ List[plt.Artist]: List of artists.
+ """
+ if axes is None:
+ axes = plt.gcf().axes
+
+ confidence = [c.cpu() if isinstance(c, torch.Tensor) else torch.tensor(c) for c in confidence]
+
+ assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}"
+
+ if as_log:
+ confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence]
+
+ # normalize to [0, 1]
+ confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence]
+ return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs)
+
+
+def save_plot(path, **kw):
+ """Save the current figure without any white margin."""
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)