diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..41d770a277b3c57c15453c4d17d11c015a5de85a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +assets/fisheye-skyline.jpg filter=lfs diff=lfs merge=lfs -text +assets/teaser.gif filter=lfs diff=lfs merge=lfs -text +demo.ipynb filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..7a404431fe7e7e9e10bb8fb485f39331c86553aa --- /dev/null +++ b/.gitignore @@ -0,0 +1,138 @@ +# folders +data/** +outputs/** +weights/** +**.DS_Store +.vscode/** +wandb/** +third_party/** + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d369a9455e7874673bb022ccd0023c3739dc9e5b --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,35 @@ +default_stages: [commit] +default_language_version: + python: python3.10 +repos: + - repo: https://github.com/psf/black + rev: 23.9.1 + hooks: + - id: black + args: [--line-length=100] + exclude: ^(venv/|docs/) + types: [python] + - repo: https://github.com/PyCQA/flake8 + rev: 6.1.0 + hooks: + - id: flake8 + additional_dependencies: [flake8-docstrings] + args: + [ + --max-line-length=100, + --docstring-convention=google, + --ignore=E203 W503 E402 E731, + ] + exclude: ^(venv/|docs/|.*__init__.py) + types: [python] + + - repo: https://github.com/pycqa/isort + rev: 5.12.0 + hooks: + - id: isort + args: [--line-length=100, --profile=black, --atomic] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.1.1 + hooks: + - id: mypy diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..797795d81ba8a5d06fefa772bc5b4d0b4bb94dc4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,190 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2024 ETH Zurich + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 631be781f03494b75818fa6ea30ebc99ba6e3cca..b68d36ad523c16e239d8ca94735a8821f8f721e2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,587 @@ --- title: GeoCalib -emoji: 👀 -colorFrom: pink -colorTo: yellow +app_file: gradio_app.py sdk: gradio -sdk_version: 4.43.0 -app_file: app.py -pinned: false +sdk_version: 4.38.1 --- +

+

GeoCalib 📸
Single-image Calibration with Geometric Optimization

+

+ Alexander Veicht + · + Paul-Edouard Sarlin + · + Philipp Lindenberger + · + Marc Pollefeys +

+

+

ECCV 2024

+ Paper | + Colab | + Demo 🤗 +

+ +

+

+ example +
+ + GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image +
+ by combining geometric optimization with deep learning. +
+

+ +## + +GeoCalib is a an algoritm for single-image calibration: it estimates the camera intrinsics and gravity direction from a single image only. By combining geometric optimization with deep learning, GeoCalib provides a more flexible and accurate calibration compared to previous approaches. This repository hosts the [inference](#setup-and-demo), [evaluation](#evaluation), and [training](#training) code for GeoCalib and instructions to download our training set [OpenPano](#openpano-dataset). + + +## Setup and demo + +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1oMzgPGppAPAIQxe-s7SRd_q8r7dVfnqo#scrollTo=etdzQZQzoo-K) +[![Hugging Face](https://img.shields.io/badge/Gradio-Demo-blue)](https://huggingface.co/spaces/veichta/GeoCalib) + +We provide a small inference package [`geocalib`](geocalib) that requires only minimal dependencies and Python >= 3.9. First clone the repository and install the dependencies: + +```bash +git clone https://github.com/cvg/GeoCalib.git && cd GeoCalib +python -m pip install -e . +# OR +python -m pip install -e "git+https://github.com/cvg/GeoCalib#egg=geocalib" +``` + +Here is a minimal usage example: + +```python +from geocalib import GeoCalib + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = GeoCalib().to(device) + +# load image as tensor in range [0, 1] with shape [C, H, W] +img = model.load_image("path/to/image.jpg").to(device) +result = model.calibrate(img) + +print("camera:", result["camera"]) +print("gravity:", result["gravity"]) +``` + +When either the intrinsics or the gravity are already know, they can be provided: + +```python +# known intrinsics: +result = model.calibrate(img, priors={"focal": focal_length_tensor}) + +# known gravity: +result = model.calibrate(img, priors={"gravity": gravity_direction_tensor}) +``` + +The default model is optimized for pinhole images. To handle lens distortion, use the following: + +```python +model = GeoCalib(weights="distorted") # default is "pinhole" +result = model.calibrate(img, camera_model="simple_radial") # or pinhole, simple_divisional +``` + +Check out our [demo notebook](demo.ipynb) for a full working example. + +
+[Interactive demo for your webcam - click to expand] +Run the following command: + +```bash +python -m geocalib.interactive_demo --camera_id 0 +``` + +The demo will open a window showing the camera feed and the calibration results. If `--camera_id` is not provided, the demo will ask for the IP address of a [droidcam](https://droidcam.app) camera. + +Controls: + +>Toggle the different features using the following keys: +> +>- ```h```: Show the estimated horizon line +>- ```u```: Show the estimated up-vectors +>- ```l```: Show the estimated latitude heatmap +>- ```c```: Show the confidence heatmap for the up-vectors and latitudes +>- ```d```: Show undistorted image, will overwrite the other features +>- ```g```: Shows a virtual grid of points +>- ```b```: Shows a virtual box object +> +>Change the camera model using the following keys: +> +>- ```1```: Pinhole -> Simple and fast +>- ```2```: Simple Radial -> For small distortions +>- ```3```: Simple Divisional -> For large distortions +> +>Press ```q``` to quit the demo. + +
+ + +
+[Load GeoCalib with torch hub - click to expand] + +```python +model = torch.hub.load("cvg/GeoCalib", "GeoCalib", trust_repo=True) +``` + +
+ +## Evaluation + +The full evaluation and training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed as: +```bash +python -m pip install -e siclib +``` + +Running the evaluation commands will write the results to `outputs/results/`. + +### LaMAR + +Running the evaluation commands will download the dataset to ```data/lamar2k``` which will take around 400 MB of disk space. + +
+[Evaluate GeoCalib] + +To evaluate GeoCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.lamar2k --conf geocalib-pinhole --tag geocalib --overwrite +``` + +
+ +
+[Evaluate DeepCalib] + +To evaluate DeepCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.lamar2k --conf deepcalib --tag deepcalib --overwrite +``` + +
+ +
+[Evaluate Perspective Fields] + +Coming soon! + +
+ +
+[Evaluate UVP] + +To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run: + +```bash +python -m siclib.eval.lamar2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null +``` + +
+ +
+[Evaluate your own model] + +If you have trained your own model, you can evaluate it by running: + +```bash +python -m siclib.eval.lamar2k --checkpoint --tag --overwrite +``` + +
+ + +
+[Results] + +Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods: + +| Approach | Roll | Pitch | FoV | +| --------- | ------------------ | ------------------ | ------------------ | +| DeepCalib | 44.1 / 73.9 / 84.8 | 10.8 / 28.3 / 49.8 | 0.7 / 13.0 / 24.0 | +| ParamNet | 51.7 / 77.0 / 86.0 | 27.0 / 52.7 / 70.2 | 02.8 / 06.8 / 14.3 | +| UVP | 72.7 / 81.8 / 85.7 | 42.3 / 59.9 / 69.4 | 15.6 / 30.6 / 43.5 | +| GeoCalib | 86.4 / 92.5 / 95.0 | 55.0 / 76.9 / 86.2 | 19.1 / 41.5 / 60.0 | +
+ +### MegaDepth + +Running the evaluation commands will download the dataset to ```data/megadepth2k``` or ```data/memegadepth2k-radial``` which will take around 2.1 GB and 1.47 GB of disk space respectively. + +
+[Evaluate GeoCalib] + +To evaluate GeoCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.megadepth2k --conf geocalib-pinhole --tag geocalib --overwrite +``` + +To run the eval on the radial distorted images, run: + +```bash +python -m siclib.eval.megadepth2k_radial --conf geocalib-pinhole --tag geocalib --overwrite model.camera_model=simple_radial +``` + +
+ +
+[Evaluate DeepCalib] + +To evaluate DeepCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.megadepth2k --conf deepcalib --tag deepcalib --overwrite +``` + +
+ +
+[Evaluate Perspective Fields] + +Coming soon! + +
+ +
+[Evaluate UVP] + +To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run: + +```bash +python -m siclib.eval.megadepth2k --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null +``` + +
+ +
+[Evaluate your own model] + +If you have trained your own model, you can evaluate it by running: + +```bash +python -m siclib.eval.megadepth2k --checkpoint --tag --overwrite +``` + +
+ +
+[Results] + +Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods: + +| Approach | Roll | Pitch | FoV | +| --------- | ------------------ | ------------------ | ------------------ | +| DeepCalib | 34.6 / 65.4 / 79.4 | 11.9 / 27.8 / 44.8 | 5.6 / 12.1 / 22.9 | +| ParamNet | 43.4 / 70.7 / 82.2 | 15.4 / 34.5 / 53.3 | 3.2 / 10.1 / 21.3 | +| UVP | 69.2 / 81.6 / 86.9 | 21.6 / 36.2 / 47.4 | 8.2 / 18.7 / 29.8 | +| GeoCalib | 82.6 / 90.6 / 94.0 | 32.4 / 53.3 / 67.5 | 13.6 / 31.7 / 48.2 | +
+ +### TartanAir + +Running the evaluation commands will download the dataset to ```data/tartanair``` which will take around 1.85 GB of disk space. + +
+[Evaluate GeoCalib] + +To evaluate GeoCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.tartanair --conf geocalib-pinhole --tag geocalib --overwrite +``` + +
+ +
+[Evaluate DeepCalib] + +To evaluate DeepCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.tartanair --conf deepcalib --tag deepcalib --overwrite +``` + +
+ +
+[Evaluate Perspective Fields] + +Coming soon! + +
+ +
+[Evaluate UVP] + +To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run: + +```bash +python -m siclib.eval.tartanair --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null +``` + +
+ +
+[Evaluate your own model] + +If you have trained your own model, you can evaluate it by running: + +```bash +python -m siclib.eval.tartanair --checkpoint --tag --overwrite +``` + +
+ +
+[Results] + +Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods: + +| Approach | Roll | Pitch | FoV | +| --------- | ------------------ | ------------------ | ------------------ | +| DeepCalib | 24.7 / 55.4 / 71.5 | 16.3 / 38.8 / 58.5 | 1.5 / 8.8 / 27.2 | +| ParamNet | 34.5 / 59.2 / 73.9 | 19.4 / 42.0 / 60.3 | 6.0 / 16.8 / 31.6 | +| UVP | 52.1 / 64.8 / 71.9 | 36.2 / 48.8 / 58.6 | 15.8 / 25.8 / 35.7 | +| GeoCalib | 71.3 / 83.8 / 89.8 | 38.2 / 62.9 / 76.6 | 14.1 / 30.4 / 47.6 | +
+ +### Stanford2D3D + +Before downloading and running the evaluation, you will need to agree to the [terms of use](https://docs.google.com/forms/d/e/1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1) for the Stanford2D3D dataset. +Running the evaluation commands will download the dataset to ```data/stanford2d3d``` which will take around 885 MB of disk space. + +
+[Evaluate GeoCalib] + +To evaluate GeoCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.stanford2d3d --conf geocalib-pinhole --tag geocalib --overwrite +``` + +
+ +
+[Evaluate DeepCalib] + +To evaluate DeepCalib trained on the OpenPano dataset, run: + +```bash +python -m siclib.eval.stanford2d3d --conf deepcalib --tag deepcalib --overwrite +``` + +
+ +
+[Evaluate Perspective Fields] + +Coming soon! + +
+ +
+[Evaluate UVP] + +To evaluate UVP, install the [VP-Estimation-with-Prior-Gravity](https://github.com/cvg/VP-Estimation-with-Prior-Gravity) under ```third_party/VP-Estimation-with-Prior-Gravity```. Then run: + +```bash +python -m siclib.eval.stanford2d3d --conf uvp --tag uvp --overwrite data.preprocessing.edge_divisible_by=null +``` + +
+ +
+[Evaluate your own model] + +If you have trained your own model, you can evaluate it by running: + +```bash +python -m siclib.eval.stanford2d3d --checkpoint --tag --overwrite +``` + +
+ +
+[Results] + +Here are the results for the Area Under the Curve (AUC) for the roll, pitch and field of view (FoV) errors at 1/5/10 degrees for the different methods: + +| Approach | Roll | Pitch | FoV | +| --------- | ------------------ | ------------------ | ------------------ | +| DeepCalib | 33.8 / 63.9 / 79.2 | 21.6 / 46.9 / 65.7 | 8.1 / 20.6 / 37.6 | +| ParamNet | 44.6 / 73.9 / 84.8 | 29.2 / 56.7 / 73.1 | 5.8 / 14.3 / 27.8 | +| UVP | 65.3 / 74.6 / 79.1 | 51.2 / 63.0 / 69.2 | 22.2 / 39.5 / 51.3 | +| GeoCalib | 83.1 / 91.8 / 94.8 | 52.3 / 74.8 / 84.6 | 17.4 / 40.0 / 59.4 | + +
+ +### Evaluation options + +If you want to provide priors during the evaluation, you can add one or multiple of the following flags: + +```bash +python -m siclib.eval. --conf \ + --tag \ + data.use_prior_focal=true \ + data.use_prior_gravity=true \ + data.use_prior_k1=true +``` + +
+[Visual inspection] + +To visually inspect the results of the evaluation, you can run the following command: + +```bash +python -m siclib.eval.inspect + +``` +For example, to inspect the results of the evaluation of the GeoCalib model on the LaMAR dataset, you can run: +```bash +python -m siclib.eval.inspect lamar2k geocalib +``` +
+ +## OpenPano Dataset + +The OpenPano dataset is a new dataset for single-image calibration which contains about 2.8k panoramas from various sources, namely [HDRMAPS](https://hdrmaps.com/hdris/), [PolyHaven](https://polyhaven.com/hdris), and the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation). While this dataset is smaller than previous ones, it is publicly available and it provides a better balance between indoor and outdoor scenes. + +
+[Downloading and preparing the dataset] + +In order to assemble the training set, first download the Laval dataset following the instructions on [the corresponding project page](http://hdrdb.com/indoor/#presentation) and place the panoramas in ```data/indoorDatasetCalibrated```. Then, tonemap the HDR images using the following command: + +```bash +python -m siclib.datasets.utils.tonemapping --hdr_dir data/indoorDatasetCalibrated --out_dir data/laval-tonemap +``` + +We provide a script to download the PolyHaven and HDRMAPS panos. The script will create folders ```data/openpano/panoramas/{split}``` containing the panoramas specified by the ```{split}_panos.txt``` files. To run the script, execute the following commands: + +```bash +python -m siclib.datasets.utils.download_openpano --name openpano --laval_dir data/laval-tonemap +``` +Alternatively, you can download the PolyHaven and HDRMAPS panos from [here](https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/). + + +After downloading the panoramas, you can create the training set by running the following command: + +```bash +python -m siclib.datasets.create_dataset_from_pano --config-name openpano +``` + +The dataset creation can be sped up by using multiple workers and a GPU. To do so, add the following arguments to the command: + +```bash +python -m siclib.datasets.create_dataset_from_pano --config-name openpano n_workers=10 device=cuda +``` + +This will create the training set in ```data/openpano/openpano``` with about 37k images for training, 2.1k for validation, and 2.1k for testing. + +
+[Distorted OpenPano] + +To create the OpenPano dataset with radial distortion, run the following command: + +```bash +python -m siclib.datasets.create_dataset_from_pano --config-name openpano_radial +``` + +
+ +
+ +## Training + +As for the evaluation, the training code is provided in the single-image calibration library [`siclib`](siclib), which can be installed by: + +```bash +python -m pip install -e siclib +``` + +Once the [OpenPano Dataset](#openpano-dataset) has been downloaded and prepared, we can train GeoCalib with it: + +First download the pre-trained weights for the [MSCAN-B](https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/) backbone: + +```bash +mkdir weights +wget "https://cloud.tsinghua.edu.cn/d/c15b25a6745946618462/files/?p=%2Fmscan_b.pth&dl=1" -O weights/mscan_b.pth +``` + +Then, start the training with the following command: + +```bash +python -m siclib.train geocalib-pinhole-openpano --conf geocalib --distributed +``` + +Feel free to use any other experiment name. By default, the checkpoints will be written to ```outputs/training/```. The default batch size is 24 which requires 2x 4090 GPUs with 24GB of VRAM each. Configurations are managed by [Hydra](https://hydra.cc/) and can be overwritten from the command line. +For example, to train GeoCalib on a single GPU with a batch size of 5, run: + +```bash +python -m siclib.train geocalib-pinhole-openpano \ + --conf geocalib \ + data.train_batch_size=5 # for 1x 2080 GPU +``` + +Be aware that this can impact the overall performance. You might need to adjust the learning rate and number of training steps accordingly. + +If you want to log the training progress to [tensorboard](https://www.tensorflow.org/tensorboard) or [wandb](https://wandb.ai/), you can set the ```train.writer``` option: + +```bash +python -m siclib.train geocalib-pinhole-openpano \ + --conf geocalib \ + --distributed \ + train.writer=tensorboard +``` + +The model can then be evaluated using its experiment name: + +```bash +python -m siclib.eval. --checkpoint geocalib-pinhole-openpano \ + --tag geocalib-retrained +``` + +
+[Training DeepCalib] + +To train DeepCalib on the OpenPano dataset, run: + +```bash +python -m siclib.train deepcalib-openpano --conf deepcalib --distributed +``` + +Make sure that you have generated the [OpenPano Dataset](#openpano-dataset) with radial distortion or add +the flag ```data=openpano``` to the command to train on the pinhole images. + +
+ +
+[Training Perspective Fields] + +Coming soon! + +
+ +## BibTeX citation + +If you use any ideas from the paper or code from this repo, please consider citing: + +```bibtex +@inproceedings{veicht2024geocalib, + author = {Alexander Veicht and + Paul-Edouard Sarlin and + Philipp Lindenberger and + Marc Pollefeys}, + title = {{GeoCalib: Single-image Calibration with Geometric Optimization}}, + booktitle = {ECCV}, + year = {2024} +} +``` + +## License + +The code is provided under the [Apache-2.0 License](LICENSE) while the weights of the trained model are provided under the [Creative Commons Attribution 4.0 International Public License](https://creativecommons.org/licenses/by/4.0/legalcode). Thanks to the authors of the [Laval Indoor HDR dataset](http://hdrdb.com/indoor/#presentation) for allowing this. -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/assets/fisheye-dog-pool.jpg b/assets/fisheye-dog-pool.jpg new file mode 100644 index 0000000000000000000000000000000000000000..02278a47628312cbb922e5c60d17e1f0ea6f1876 Binary files /dev/null and b/assets/fisheye-dog-pool.jpg differ diff --git a/assets/fisheye-skyline.jpg b/assets/fisheye-skyline.jpg new file mode 100644 index 0000000000000000000000000000000000000000..00b83edbf3168fd1a46e615323c558489e870b71 --- /dev/null +++ b/assets/fisheye-skyline.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1f32ee048e0e9f931f9d42b19269419201834322f40a172367ebbca4752826a2 +size 1353673 diff --git a/assets/pinhole-church.jpg b/assets/pinhole-church.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ae9ad103798098af7f7e26ea13096ab8aabd4d5d Binary files /dev/null and b/assets/pinhole-church.jpg differ diff --git a/assets/pinhole-garden.jpg b/assets/pinhole-garden.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ecdcd0eee7c5ec1486bf93485c20d382f5c056f5 Binary files /dev/null and b/assets/pinhole-garden.jpg differ diff --git a/assets/teaser.gif b/assets/teaser.gif new file mode 100644 index 0000000000000000000000000000000000000000..a3e6e19e9c64f17138d09004c56fd9c353d857c3 --- /dev/null +++ b/assets/teaser.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:2944bad746c92415438f65adb4362bfd4b150db50729f118266d86f638ac3d21 +size 11997917 diff --git a/demo.ipynb b/demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..980f3eed9b70945fe68758485cd19ef621ba39e1 --- /dev/null +++ b/demo.ipynb @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:115886d06182eca375251eab9e301180e86465fd0ed152e917d43d7eb4cbd722 +size 13275966 diff --git a/geocalib/__init__.py b/geocalib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8a034045d7e3c9933383b6d5fcb1a7ce2dd89037 --- /dev/null +++ b/geocalib/__init__.py @@ -0,0 +1,17 @@ +import logging + +from geocalib.extractor import GeoCalib # noqa + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S" +) +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +__module_name__ = __name__ diff --git a/geocalib/camera.py b/geocalib/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..1988ebf017d6a6d048cc0c9ca688b5dae35f50c4 --- /dev/null +++ b/geocalib/camera.py @@ -0,0 +1,774 @@ +"""Implementation of the pinhole, simple radial, and simple divisional camera models.""" + +from abc import abstractmethod +from typing import Dict, Optional, Tuple, Union + +import torch +from torch.func import jacfwd, vmap +from torch.nn import functional as F + +from geocalib.gravity import Gravity +from geocalib.misc import TensorWrapper, autocast +from geocalib.utils import deg2rad, focal2fov, fov2focal, rad2rotmat + +# flake8: noqa: E741 +# mypy: ignore-errors + + +class BaseCamera(TensorWrapper): + """Camera tensor class.""" + + eps = 1e-3 + + @autocast + def __init__(self, data: torch.Tensor): + """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}). + + Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where + - w, h: image size in pixels + - fx, fy: focal lengths in pixels + - cx, cy: principal points in normalized image coordinates + - dist: distortion parameters + + Args: + data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}). + """ + # w, h, fx, fy, cx, cy, dist + assert data.shape[-1] in {6, 7, 8}, data.shape + + pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],)) + data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data + super().__init__(data) + + @classmethod + def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera": + """Create a Camera object from a dictionary of parameters. + + Args: + param_dict (Dict[str, torch.Tensor]): Dictionary of parameters. + + Returns: + Camera: Camera object. + """ + for key, value in param_dict.items(): + if not isinstance(value, torch.Tensor): + param_dict[key] = torch.tensor(value) + + h, w = param_dict["height"], param_dict["width"] + cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2) + + if "f" in param_dict: + f = param_dict["f"] + elif "vfov" in param_dict: + vfov = param_dict["vfov"] + f = fov2focal(vfov, h) + else: + raise ValueError("Focal length or vertical field of view must be provided.") + + if "dist" in param_dict: + k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1] + elif "k1_hat" in param_dict: + k1 = param_dict["k1_hat"] * (f / h) ** 2 + + k2 = param_dict.get("k2", torch.zeros_like(k1)) + else: + k1 = param_dict.get("k1", torch.zeros_like(f)) + k2 = param_dict.get("k2", torch.zeros_like(f)) + + fx, fy = f, f + if "scales" in param_dict: + fx = fx * param_dict["scales"][..., 0] / param_dict["scales"][..., 1] + + params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1) + return cls(params) + + def pinhole(self): + """Return the pinhole camera model.""" + return self.__class__(self._data[..., :6]) + + @property + def size(self) -> torch.Tensor: + """Size (width height) of the images, with shape (..., 2).""" + return self._data[..., :2] + + @property + def f(self) -> torch.Tensor: + """Focal lengths (fx, fy) with shape (..., 2).""" + return self._data[..., 2:4] + + @property + def vfov(self) -> torch.Tensor: + """Vertical field of view in radians.""" + return focal2fov(self.f[..., 1], self.size[..., 1]) + + @property + def hfov(self) -> torch.Tensor: + """Horizontal field of view in radians.""" + return focal2fov(self.f[..., 0], self.size[..., 0]) + + @property + def c(self) -> torch.Tensor: + """Principal points (cx, cy) with shape (..., 2).""" + return self._data[..., 4:6] + + @property + def K(self) -> torch.Tensor: + """Returns the self intrinsic matrix with shape (..., 3, 3).""" + shape = self.shape + (3, 3) + K = self._data.new_zeros(shape) + K[..., 0, 0] = self.f[..., 0] + K[..., 1, 1] = self.f[..., 1] + K[..., 0, 2] = self.c[..., 0] + K[..., 1, 2] = self.c[..., 1] + K[..., 2, 2] = 1 + return K + + def update_focal(self, delta: torch.Tensor, as_log: bool = False): + """Update the self parameters after changing the focal length.""" + f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta + + # clamp focal length to a reasonable range for stability during training + min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1]) + max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1]) + min_f = min_f.unsqueeze(-1).expand(-1, 2) + max_f = max_f.unsqueeze(-1).expand(-1, 2) + f = f.clamp(min=min_f, max=max_f) + + # make sure focal ration stays the same (avoid inplace operations) + fx = f[..., 1] * self.f[..., 0] / self.f[..., 1] + f = torch.stack([fx, f[..., 1]], -1) + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([self.size, f, self.c, dist], -1)) + + def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]): + """Update the self parameters after resizing an image.""" + scales = (scales, scales) if isinstance(scales, (int, float)) else scales + s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales) + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1)) + + def crop(self, pad: Tuple[float]): + """Update the self parameters after cropping an image.""" + pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad) + size = self.size + pad.to(self.size) + c = self.c + pad.to(self.c) / 2 + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([size, self.f, c, dist], -1)) + + @autocast + def in_image(self, p2d: torch.Tensor): + """Check if 2D points are within the image boundaries.""" + assert p2d.shape[-1] == 2 + size = self.size.unsqueeze(-2) + return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1) + + @autocast + def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Project 3D points into the self plane and check for visibility.""" + z = p3d[..., -1] + valid = z > self.eps + z = z.clamp(min=self.eps) + p2d = p3d[..., :-1] / z.unsqueeze(-1) + return p2d, valid + + def J_project(self, p3d: torch.Tensor): + """Jacobian of the projection function.""" + x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2] + zero = torch.zeros_like(z) + z = z.clamp(min=self.eps) + J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1) + J = J.reshape(p3d.shape[:-1] + (2, 3)) + return J # N x 2 x 3 + + def undo_scale_crop(self, data: Dict[str, torch.Tensor]): + """Undo transforms done during scaling and cropping.""" + camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self + return camera.scale(1.0 / data["scales"]) + + @abstractmethod + def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + raise NotImplementedError("distort() must be implemented.") + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the distortion function.""" + if wrt == "scale2pts": # (..., 2) + J = [ + vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0).squeeze(-3, -2) + + elif wrt == "scale2dist": # (..., 1) + J = [] + for idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[idx, :6], x[None]], -1) + return self.__class__(params).distort(p2d[idx], return_scale=True)[0] + + J.append(vmap(jacfwd(func))(self[idx].dist)) + + return torch.cat(J, dim=0) + + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @abstractmethod + def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + raise NotImplementedError("undistort() must be implemented.") + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + if wrt == "pts": # (..., 2, 2) + J = [ + vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0).squeeze(-3) + + elif wrt == "dist": # (..., 1) + J = [] + for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[batch_idx, :6], x[None]], -1) + return self.__class__(params).undistort(p2d[batch_idx])[0] + + J.append(vmap(jacfwd(func))(self[batch_idx].dist)) + + return torch.cat(J, dim=0) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @autocast + def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor: + """Compute the offset for the up-projection.""" + return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the distortion offset for up-projection.""" + if wrt == "uv": # (B, N, 2, 2) + J = [ + vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0) + + elif wrt == "dist": # (B, N, 2) + J = [] + for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None] + return self.__class__(params).up_projection_offset(p2d[batch_idx][None]) + + J.append(vmap(jacfwd(func))(self[batch_idx].dist)) + + return torch.cat(J, dim=0).squeeze(1) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @autocast + def denormalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert normalized 2D coordinates into pixel coordinates.""" + return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2) + + def J_denormalize(self): + """Jacobian of the denormalization function.""" + return torch.diag_embed(self.f) # ..., 2 x 2 + + @autocast + def normalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert pixel coordinates into normalized 2D coordinates.""" + return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2)) + + def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"): + """Jacobian of the normalization function.""" + # ... x N x 2 x 2 + if wrt == "f": + J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2) + return torch.diag_embed(J_f) + elif wrt == "pts": + J_pts = 1 / self.f + return torch.diag_embed(J_pts) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + def pixel_coordinates(self) -> torch.Tensor: + """Pixel coordinates in self frame. + + Returns: + torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2). + """ + w, h = self.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # create grid + x = torch.arange(0, w, dtype=self.dtype, device=self.device) + y = torch.arange(0, h, dtype=self.dtype, device=self.device) + x, y = torch.meshgrid(x, y, indexing="xy") + xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2) + + # add batch dimension (normalize() would broadcast but we make it explicit) + B = self.shape[0] + xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy + + return xy.to(self.device).to(self.dtype) + + @autocast + def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor: + """Get the bearing vectors of pixel coordinates by normalizing them.""" + return F.normalize(p3d, dim=-1) + + @autocast + def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Transform 3D points into 2D pixel coordinates.""" + p2d, visible = self.project(p3d) + p2d, mask = self.distort(p2d) + p2d = self.denormalize(p2d) + valid = visible & mask & self.in_image(p2d) + return p2d, valid + + @autocast + def J_world2image(self, p3d: torch.Tensor): + """Jacobian of the world2image function.""" + p2d_proj, valid = self.project(p3d) + + J_dnorm = self.J_denormalize() + J_dist = self.J_distort(p2d_proj) + J_proj = self.J_project(p3d) + + J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj) + return J, valid + + @autocast + def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Transform point in the image plane to 3D world coordinates.""" + p2d = self.normalize(p2d) + p2d, valid = self.undistort(p2d) + ones = p2d.new_ones(p2d.shape[:-1] + (1,)) + p3d = torch.cat([p2d, ones], -1) + return p3d, valid + + @autocast + def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]: + """Jacobian of the image2world function.""" + if wrt == "dist": + p2d_norm = self.normalize(p2d) + return self.J_undistort(p2d_norm, wrt) + elif wrt == "f": + J_norm2f = self.J_normalize(p2d, wrt) + p2d_norm = self.normalize(p2d) + J_dist2norm = self.J_undistort(p2d_norm, "pts") + + return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f) + else: + raise ValueError(f"Unknown wrt: {wrt}") + + @autocast + def undistort_image(self, img: torch.Tensor) -> torch.Tensor: + """Undistort an image using the distortion model.""" + assert self.shape[0] == 1, "Batch size must be 1." + W, H = self.size.unbind(-1) + H, W = H.int().item(), W.int().item() + + x, y = torch.meshgrid(torch.arange(0, W), torch.arange(0, H), indexing="xy") + coords = torch.stack((x, y), dim=-1).reshape(-1, 2) + + p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype)) + p2d, _ = self.world2image(p3d) + + mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W)) + grid = torch.stack((mapx, mapy), dim=-1) + grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1 + return F.grid_sample(img, grid, align_corners=True) + + def get_img_from_pano( + self, + pano_img: torch.Tensor, + gravity: Gravity, + yaws: torch.Tensor = 0.0, + resize_factor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Render an image from a panorama. + + Args: + pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1]. + gravity (Gravity): Gravity direction of the camera. + yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0. + resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the + field of view. Defaults to 1. + + Returns: + torch.Tensor: Image rendered from the panorama. + """ + B = self.shape[0] + if B > 0: + assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width." + assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height." + + w, h = self.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if isinstance(yaws, (int, float)): + yaws = [yaws] + if isinstance(resize_factor, (int, float)): + resize_factor = [resize_factor] + + yaws = ( + yaws.to(self.dtype).to(self.device) + if isinstance(yaws, torch.Tensor) + else self.new_tensor(yaws) + ) + + if isinstance(resize_factor, torch.Tensor): + resize_factor = resize_factor.to(self.dtype).to(self.device) + elif resize_factor is not None: + resize_factor = self.new_tensor(resize_factor) + + assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor." + pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3 + + pano_imgs = [] + for i, yaw in enumerate(yaws): + if resize_factor is not None: + # resize the panorama such that the fov of the panorama has the same height as the + # image + vfov = self.vfov[i] if B != 0 else self.vfov + scale = torch.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i] + pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale)) + + mode = "bicubic" if scale >= 1 else "area" + resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode) + else: + # make sure to copy: resized_pano = pano_img + resized_pano = pano_img + pano_shape = pano_img.shape[-2:][::-1] + + pano_imgs.append((resized_pano, pano_shape)) + + xy = self.pixel_coordinates() + uv1, _ = self.image2world(xy) + bearings = self.pixel_bearing_many(uv1) + + # rotate bearings + R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws) + rotated_bearings = bearings @ gravity.R @ R_yaw + + # spherical coordinates + lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2]) + lat = torch.atan2( + rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1) + ) + + images = [] + for idx, (resized_pano, pano_shape) in enumerate(pano_imgs): + min_lon, max_lon = -torch.pi, torch.pi + min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0 + min_x, max_x = 0, pano_shape[0] - 1.0 + min_y, max_y = 0, pano_shape[1] - 1.0 + + # map Spherical Coordinates to Panoramic Coordinates + nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x + ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y + + # reshape and cast to numpy for remap + mapx, mapy = nx.reshape((1, h, w)), ny.reshape((1, h, w)) + + grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension + # Normalize to [-1, 1] + grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1 + # Apply grid sample + image = F.grid_sample(resized_pano, grid, align_corners=True) + images.append(image) + + return torch.concatenate(images, 0) if B > 0 else images[0] + + def __repr__(self): + """Print the Camera object.""" + return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}" + + +class Pinhole(BaseCamera): + """Implementation of the pinhole camera model. + + Use this model for undistorted images. + """ + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates.""" + if return_scale: + return p2d.new_ones(p2d.shape[:-1] + (1,)) + + return p2d, p2d.new_ones((p2d.shape[0], 1)).bool() + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the distortion function.""" + if wrt == "pts": + return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2)) + + raise ValueError(f"Unknown wrt: {wrt}") + + def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates.""" + return pts, pts.new_ones((pts.shape[0], 1)).bool() + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + if wrt == "pts": + return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2)) + + raise ValueError(f"Unknown wrt: {wrt}") + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the up-projection offset.""" + if wrt == "uv": + return torch.zeros(p2d.shape[:-1] + (2, 2), device=p2d.device, dtype=p2d.dtype) + + raise ValueError(f"Unknown wrt: {wrt}") + + +class SimpleRadial(BaseCamera): + """Implementation of the simple radial camera model. + + Use this model for weakly distorted images. + + The distortion model is 1 + k1 * r^2 where r^2 = x^2 + y^2. + The undistortion model is 1 - k1 * r^2 estimated as in + "An Exact Formula for Calculating Inverse Radial Lens Distortions" by Pierre Drap. + """ + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., 1).""" + return self._data[..., 6:] + + @property + def k1(self) -> torch.Tensor: + """Distortion parameters, with shape (...).""" + return self._data[..., 6] + + def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)): + """Update the self parameters after changing the k1 distortion parameter.""" + delta_dist = self.new_ones(self.dist.shape) * delta + dist = (self.dist + delta_dist).clamp(*dist_range) + data = torch.cat([self.size, self.f, self.c, dist], -1) + return self.__class__(data) + + @autocast + def check_valid(self, p2d: torch.Tensor) -> torch.Tensor: + """Check if the distorted points are valid.""" + return p2d.new_ones(p2d.shape[:-1]).bool() + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 + self.k1[..., None, None] * r2 + + if return_scale: + return radial, None + + return p2d * radial, self.check_valid(p2d) + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"): + """Jacobian of the distortion function.""" + if wrt == "scale2dist": # (..., 1) + return torch.sum(p2d**2, -1, keepdim=True) + elif wrt == "scale2pts": # (..., 2) + return 2 * self.k1[..., None, None] * p2d + else: + return super().J_distort(p2d, wrt) + + @autocast + def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + b1 = -self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 + b1 * r2 + return p2d * radial, self.check_valid(p2d) + + @autocast + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + b1 = -self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + if wrt == "dist": + return -r2 * p2d + elif wrt == "pts": + radial = 1 + b1 * r2 + radial_diag = torch.diag_embed(radial.expand(radial.shape[:-1] + (2,))) + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + return (2 * b1[..., None] * ppT) + radial_diag + else: + return super().J_undistort(p2d, wrt) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the up-projection offset.""" + if wrt == "uv": # (..., 2, 2) + return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,))) + elif wrt == "dist": + return 2 * p2d # (..., 2) + else: + return super().J_up_projection_offset(p2d, wrt) + + +class SimpleDivisional(BaseCamera): + """Implementation of the simple divisional camera model. + + Use this model for strongly distorted images. + + The distortion model is (1 - sqrt(1 - 4 * k1 * r^2)) / (2 * k1 * r^2) where r^2 = x^2 + y^2. + The undistortion model is 1 / (1 + k1 * r^2). + """ + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., 1).""" + return self._data[..., 6:] + + @property + def k1(self) -> torch.Tensor: + """Distortion parameters, with shape (...).""" + return self._data[..., 6] + + def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)): + """Update the self parameters after changing the k1 distortion parameter.""" + delta_dist = self.new_ones(self.dist.shape) * delta + dist = (self.dist + delta_dist).clamp(*dist_range) + data = torch.cat([self.size, self.f, self.c, dist], -1) + return self.__class__(data) + + @autocast + def check_valid(self, p2d: torch.Tensor) -> torch.Tensor: + """Check if the distorted points are valid.""" + return p2d.new_ones(p2d.shape[:-1]).bool() + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0)) + denom = 2 * self.k1[..., None, None] * r2 + + ones = radial.new_ones(radial.shape) + radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6)) + + if return_scale: + return radial, None + + return p2d * radial, self.check_valid(p2d) + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"): + """Jacobian of the distortion function.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6)) + if wrt == "scale2pts": # (B, N, 2) + d1 = t0 * 2 * r2 + d2 = self.k1[..., None, None] * r2**2 + denom = d1 * d2 + return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6) + + elif wrt == "scale2dist": + d1 = 2 * self.k1[..., None, None] * t0 + d2 = 2 * r2 * self.k1[..., None, None] ** 2 + denom = d1 * d2 + return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6) + + else: + return super().J_distort(p2d, wrt) + + @autocast + def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + denom = 1 + self.k1[..., None, None] * r2 + radial = 1 / denom.masked_fill(denom == 0, 1e6) + return p2d * radial, self.check_valid(p2d) + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + # return super().J_undistort(p2d, wrt) + r2 = torch.sum(p2d**2, -1, keepdim=True) + k1 = self.k1[..., None, None] + if wrt == "dist": + denom = (1 + k1 * r2) ** 2 + return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d + elif wrt == "pts": + t0 = 1 + k1 * r2 + t0 = t0.masked_fill(t0 == 0, 1e6) + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,))) + return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2) + + else: + return super().J_undistort(p2d, wrt) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the up-projection offset. + + func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv + - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv + """ + k1 = self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + t0 = (1 - 4 * k1 * r2).clamp(min=1e-6) + t1 = torch.sqrt(t0) + if wrt == "dist": + denom = 4 * t0 ** (3 / 2) + denom = denom.masked_fill(denom == 0, 1e6) + J = 16 / denom + + denom = r2 * t1 * k1 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 2 / denom + + denom = (r2 * k1) ** 2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (1 - t1) / denom + + return J * p2d + elif wrt == "uv": + # ! unstable (gradient checker might fail), rewrite to use single division (by denom) + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + + denom = 2 * r2 * t1 + denom = denom.masked_fill(denom == 0, 1e6) + J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,))) + + denom = 4 * t1 * r2**2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 16 / denom[..., None] * ppT + + denom = 4 * r2 * t0 ** (3 / 2) + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (32 * k1[..., None]) / denom[..., None] * ppT + + denom = r2**2 * t1 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 4 / denom[..., None] * ppT + + denom = k1 * r2**3 + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (4 * (1 - t1) / denom)[..., None] * ppT + + denom = k1 * r2**2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,))) + + return J + else: + return super().J_up_projection_offset(p2d, wrt) + + +camera_models = { + "pinhole": Pinhole, + "simple_radial": SimpleRadial, + "simple_divisional": SimpleDivisional, +} diff --git a/geocalib/extractor.py b/geocalib/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..9fdda6ad30136385bcfe5fd0cf606cb1f860c4f9 --- /dev/null +++ b/geocalib/extractor.py @@ -0,0 +1,126 @@ +"""Simple interface for GeoCalib model.""" + +from pathlib import Path +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch.nn.functional import interpolate + +from geocalib.camera import BaseCamera +from geocalib.geocalib import GeoCalib as Model +from geocalib.utils import ImagePreprocessor, load_image + + +class GeoCalib(nn.Module): + """Simple interface for GeoCalib model.""" + + def __init__(self, weights: str = "pinhole"): + """Initialize the model with optional config overrides. + + Args: + weights (str): trained variant, "pinhole" (default) or "distorted". + """ + super().__init__() + if weights == "pinhole": + url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar" + elif weights == "distorted": + url = ( + "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar" + ) + else: + raise ValueError(f"Unknown weights: {weights}") + + # load checkpoint + model_dir = f"{torch.hub.get_dir()}/geocalib" + state_dict = torch.hub.load_state_dict_from_url( + url, model_dir, map_location="cpu", file_name=f"{weights}.tar" + ) + + self.model = Model() + self.model.flexible_load(state_dict["model"]) + self.model.eval() + + self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32}) + + def load_image(self, path: Path) -> torch.Tensor: + """Load image from path.""" + return load_image(path) + + def _post_process( + self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor] + ) -> tuple[BaseCamera, dict[str, torch.Tensor]]: + """Post-process model output by undoing scaling and cropping.""" + camera = camera.undo_scale_crop(img_data) + + w, h = camera.size.unbind(-1) + h = h[0].round().int().item() + w = w[0].round().int().item() + + for k in ["latitude_field", "up_field"]: + out[k] = interpolate(out[k], size=(h, w), mode="bilinear") + for k in ["up_confidence", "latitude_confidence"]: + out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0] + + inverse_scales = 1.0 / img_data["scales"] + zero = camera.new_zeros(camera.f.shape[0]) + out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1] + return camera, out + + @torch.no_grad() + def calibrate( + self, + img: torch.Tensor, + camera_model: str = "pinhole", + priors: Optional[Dict[str, torch.Tensor]] = None, + shared_intrinsics: bool = False, + ) -> Dict[str, torch.Tensor]: + """Perform calibration with online resizing. + + Assumes input image is in range [0, 1] and in RGB format. + + Args: + img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W) + camera_model (str, optional): Camera model. Defaults to "pinhole". + priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}. + shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False. + + Returns: + Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties. + """ + if len(img.shape) == 3: + img = img[None] # add batch dim + if not shared_intrinsics: + assert len(img.shape) == 4 and img.shape[0] == 1 + + img_data = self.image_processor(img) + + if priors is None: + priors = {} + + prior_values = {} + if prior_focal := priors.get("focal"): + prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal + prior_values["prior_focal"] = prior_focal * img_data["scales"][1] + + if "gravity" in priors: + prior_gravity = priors["gravity"] + prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity + prior_values["prior_gravity"] = prior_gravity + + self.model.optimizer.set_camera_model(camera_model) + self.model.optimizer.shared_intrinsics = shared_intrinsics + + out = self.model(img_data | prior_values) + + camera, gravity = out["camera"], out["gravity"] + camera, out = self._post_process(camera, img_data, out) + + return { + "camera": camera, + "gravity": gravity, + "covariance": out["covariance"], + **{k: out[k] for k in out.keys() if "field" in k}, + **{k: out[k] for k in out.keys() if "confidence" in k}, + **{k: out[k] for k in out.keys() if "uncertainty" in k}, + } diff --git a/geocalib/geocalib.py b/geocalib/geocalib.py new file mode 100644 index 0000000000000000000000000000000000000000..c2dc70fdb0ed56774e920fbeeb64cb150e23e0f1 --- /dev/null +++ b/geocalib/geocalib.py @@ -0,0 +1,150 @@ +"""GeoCalib model definition.""" + +import logging +from typing import Dict + +import torch +from torch import nn +from torch.nn import functional as F + +from geocalib.lm_optimizer import LMOptimizer +from geocalib.modules import MSCAN, ConvModule, LightHamHead + +# mypy: ignore-errors + +logger = logging.getLogger(__name__) + + +class LowLevelEncoder(nn.Module): + """Very simple low-level encoder.""" + + def __init__(self): + """Simple low-level encoder.""" + super().__init__() + self.in_channel = 3 + self.feat_dim = 64 + + self.conv1 = ConvModule(self.in_channel, self.feat_dim, kernel_size=3, padding=1) + self.conv2 = ConvModule(self.feat_dim, self.feat_dim, kernel_size=3, padding=1) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward pass.""" + x = data["image"] + + assert ( + x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0 + ), "Image size must be multiple of 32 if not using single image input." + + c1 = self.conv1(x) + c2 = self.conv2(c1) + + return {"features": c2} + + +class UpDecoder(nn.Module): + """Minimal implementation of UpDecoder.""" + + def __init__(self): + """Up decoder.""" + super().__init__() + self.decoder = LightHamHead() + self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, 2, kernel_size=1) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward pass.""" + x, log_confidence = self.decoder(data["features"]) + up = self.linear_pred_up(x) + return {"up_field": F.normalize(up, dim=1), "up_confidence": torch.sigmoid(log_confidence)} + + +class LatitudeDecoder(nn.Module): + """Minimal implementation of LatitudeDecoder.""" + + def __init__(self): + """Latitude decoder.""" + super().__init__() + self.decoder = LightHamHead() + self.linear_pred_latitude = nn.Conv2d(self.decoder.out_channels, 1, kernel_size=1) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward pass.""" + x, log_confidence = self.decoder(data["features"]) + eps = 1e-5 # avoid nan in backward of asin + lat = torch.tanh(self.linear_pred_latitude(x)) + lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps)) + return {"latitude_field": lat, "latitude_confidence": torch.sigmoid(log_confidence)} + + +class PerspectiveDecoder(nn.Module): + """Minimal implementation of PerspectiveDecoder.""" + + def __init__(self): + """Perspective decoder wrapping up and latitude decoders.""" + super().__init__() + self.up_head = UpDecoder() + self.latitude_head = LatitudeDecoder() + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward pass.""" + return self.up_head(data) | self.latitude_head(data) + + +class GeoCalib(nn.Module): + """GeoCalib inference model.""" + + def __init__(self, **optimizer_options): + """Initialize the GeoCalib inference model. + + Args: + optimizer_options: Options for the lm optimizer. + """ + super().__init__() + self.backbone = MSCAN() + self.ll_enc = LowLevelEncoder() + self.perspective_decoder = PerspectiveDecoder() + + self.optimizer = LMOptimizer({**optimizer_options}) + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Forward pass.""" + features = {"hl": self.backbone(data)["features"], "ll": self.ll_enc(data)["features"]} + out = self.perspective_decoder({"features": features}) + + out |= { + k: data[k] + for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"] + if k in data + } + + out |= self.optimizer(out) + + return out + + def flexible_load(self, state_dict: Dict[str, torch.Tensor]) -> None: + """Load a checkpoint with flexible key names.""" + dict_params = set(state_dict.keys()) + model_params = set(map(lambda n: n[0], self.named_parameters())) + + if dict_params == model_params: # perfect fit + logger.info("Loading all parameters of the checkpoint.") + self.load_state_dict(state_dict, strict=True) + return + elif len(dict_params & model_params) == 0: # perfect mismatch + strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:]) + state_dict = {strip_prefix(n): p for n, p in state_dict.items()} + dict_params = set(state_dict.keys()) + if len(dict_params & model_params) == 0: + raise ValueError( + "Could not manage to load the checkpoint with" + "parameters:" + "\n\t".join(sorted(dict_params)) + ) + common_params = dict_params & model_params + left_params = dict_params - model_params + left_params = [ + p for p in left_params if "running" not in p and "num_batches_tracked" not in p + ] + logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params))) + if left_params: + # ignore running stats of batchnorm + logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params))) + self.load_state_dict(state_dict, strict=False) diff --git a/geocalib/gravity.py b/geocalib/gravity.py new file mode 100644 index 0000000000000000000000000000000000000000..31f0e91e720ac5cf0a23e8db538e1f76a18f7fee --- /dev/null +++ b/geocalib/gravity.py @@ -0,0 +1,131 @@ +"""Tensor class for gravity vector in camera frame.""" + +import torch +from torch.nn import functional as F + +from geocalib.misc import EuclideanManifold, SphericalManifold, TensorWrapper, autocast +from geocalib.utils import rad2rotmat + +# mypy: ignore-errors + + +class Gravity(TensorWrapper): + """Gravity vector in camera frame.""" + + eps = 1e-4 + + @autocast + def __init__(self, data: torch.Tensor) -> None: + """Create gravity vector from data. + + Args: + data (torch.Tensor): gravity vector as 3D vector in camera frame. + """ + assert data.shape[-1] == 3, data.shape + + data = F.normalize(data, dim=-1) + + super().__init__(data) + + @classmethod + def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity": + """Create gravity vector from roll and pitch angles.""" + if not isinstance(roll, torch.Tensor): + roll = torch.tensor(roll) + if not isinstance(pitch, torch.Tensor): + pitch = torch.tensor(pitch) + + sr, cr = torch.sin(roll), torch.cos(roll) + sp, cp = torch.sin(pitch), torch.cos(pitch) + return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1)) + + @property + def vec3d(self) -> torch.Tensor: + """Return the gravity vector in the representation.""" + return self._data + + @property + def x(self) -> torch.Tensor: + """Return first component of the gravity vector.""" + return self._data[..., 0] + + @property + def y(self) -> torch.Tensor: + """Return second component of the gravity vector.""" + return self._data[..., 1] + + @property + def z(self) -> torch.Tensor: + """Return third component of the gravity vector.""" + return self._data[..., 2] + + @property + def roll(self) -> torch.Tensor: + """Return the roll angle of the gravity vector.""" + roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps)) + offset = -torch.pi * torch.sign(self.x) + return torch.where(self.y < 0, roll, -roll + offset) + + def J_roll(self) -> torch.Tensor: + """Return the Jacobian of the roll angle of the gravity vector.""" + cp, _ = torch.cos(self.pitch), torch.sin(self.pitch) + cr, sr = torch.cos(self.roll), torch.sin(self.roll) + Jr = self.new_zeros(self.shape + (3,)) + Jr[..., 0] = -cr * cp + Jr[..., 1] = sr * cp + return Jr + + @property + def pitch(self) -> torch.Tensor: + """Return the pitch angle of the gravity vector.""" + return torch.asin(self.z) + + def J_pitch(self) -> torch.Tensor: + """Return the Jacobian of the pitch angle of the gravity vector.""" + cp, sp = torch.cos(self.pitch), torch.sin(self.pitch) + cr, sr = torch.cos(self.roll), torch.sin(self.roll) + + Jp = self.new_zeros(self.shape + (3,)) + Jp[..., 0] = sr * sp + Jp[..., 1] = cr * sp + Jp[..., 2] = cp + return Jp + + @property + def rp(self) -> torch.Tensor: + """Return the roll and pitch angles of the gravity vector.""" + return torch.stack([self.roll, self.pitch], dim=-1) + + def J_rp(self) -> torch.Tensor: + """Return the Jacobian of the roll and pitch angles of the gravity vector.""" + return torch.stack([self.J_roll(), self.J_pitch()], dim=-1) + + @property + def R(self) -> torch.Tensor: + """Return the rotation matrix from the gravity vector.""" + return rad2rotmat(roll=self.roll, pitch=self.pitch) + + def J_R(self) -> torch.Tensor: + """Return the Jacobian of the rotation matrix from the gravity vector.""" + raise NotImplementedError + + def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity": + """Update the gravity vector by adding a delta.""" + if spherical: + data = SphericalManifold.plus(self.vec3d, delta) + return self.__class__(data) + + data = EuclideanManifold.plus(self.rp, delta) + return self.from_rp(data[..., 0], data[..., 1]) + + def J_update(self, spherical: bool = False) -> torch.Tensor: + """Return the Jacobian of the update.""" + return ( + SphericalManifold.J_plus(self.vec3d) + if spherical + else EuclideanManifold.J_plus(self.vec3d) + ) + + def __repr__(self): + """Print the Camera object.""" + return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}" diff --git a/geocalib/interactive_demo.py b/geocalib/interactive_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f0c4f9a0c0fb0982f700f9c1d34a523fc85c66 --- /dev/null +++ b/geocalib/interactive_demo.py @@ -0,0 +1,450 @@ +import argparse +import logging +import queue +import threading +import time +from time import time + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch + +from geocalib.extractor import GeoCalib +from geocalib.perspective_fields import get_perspective_field +from geocalib.utils import get_device, rad2deg + +# flake8: noqa +# mypy: ignore-errors + + +description = """ +------------------------- +GeoCalib Interactive Demo +------------------------- + +This script is an interactive demo for GeoCalib. It will open a window showing the camera feed and +the calibration results. + +Arguments: +- '--camera_id': Camera ID to use. If none, will ask for ip of droidcam (https://droidcam.app) + +You can toggle different features using the following keys: + +- 'h': Toggle horizon line +- 'u': Toggle up vector field +- 'l': Toggle latitude heatmap +- 'c': Toggle confidence heatmap +- 'd': Toggle undistorted image +- 'g': Toggle grid of points +- 'b': Toggle box object + +You can also change the camera model using the following keys: + +- '1': Pinhole +- '2': Simple Radial +- '3': Simple Divisional + +Press 'q' to quit the demo. +""" + + +# Custom VideoCapture class to get the most recent frame instead FIFO +class VideoCapture: + def __init__(self, name): + self.cap = cv2.VideoCapture(name) + self.q = queue.Queue() + t = threading.Thread(target=self._reader) + t.daemon = True + t.start() + + # read frames as soon as they are available, keeping only most recent one + def _reader(self): + while True: + ret, frame = self.cap.read() + if not ret: + break + if not self.q.empty(): + try: + self.q.get_nowait() # discard previous (unprocessed) frame + except queue.Empty: + pass + self.q.put(frame) + + def read(self): + return 1, self.q.get() + + def isOpened(self): + return self.cap.isOpened() + + +def add_text(frame, text, align_left=True, align_top=True): + """Add text to a plot.""" + h, w = frame.shape[:2] + sc = min(h / 640.0, 2.0) + Ht = int(40 * sc) # text height + + for i, l in enumerate(text.split("\n")): + max_line = len(max([l for l in text.split("\n")], key=len)) + x = int(8 * sc if align_left else w - (max_line) * sc * 18) + y = Ht * (i + 1) if align_top else h - Ht * (len(text.split("\n")) - i - 1) - int(8 * sc) + + c_back, c_front = (0, 0, 0), (255, 255, 255) + font, style = cv2.FONT_HERSHEY_DUPLEX, cv2.LINE_AA + cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_back, int(6 * sc), style) + cv2.putText(frame, l, (x, y), font, 1.0 * sc, c_front, int(1 * sc), style) + return frame + + +def is_corner(p, h, w): + """Check if a point is a corner.""" + return p in [(0, 0), (0, h - 1), (w - 1, 0), (w - 1, h - 1)] + + +def plot_latitude(frame, latitude): + """Plot latitude heatmap.""" + if not isinstance(latitude, np.ndarray): + latitude = latitude.cpu().numpy() + + cmap = plt.get_cmap("seismic") + h, w = frame.shape[0], frame.shape[1] + sc = min(h / 640.0, 2.0) + + vmin, vmax = -90, 90 + latitude = (latitude - vmin) / (vmax - vmin) + + colors = (cmap(latitude)[..., :3] * 255).astype(np.uint8)[..., ::-1] + frame = cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0) + + for contour_line in np.linspace(vmin, vmax, 15): + contour_line = (contour_line - vmin) / (vmax - vmin) + + mask = (latitude > contour_line).astype(np.uint8) + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + for contour in contours: + color = (np.array(cmap(contour_line))[:3] * 255).astype(np.uint8)[::-1] + + # remove corners + contour = [p for p in contour if not is_corner(tuple(p[0]), h, w)] + for index, item in enumerate(contour[:-1]): + cv2.line(frame, item[0], contour[index + 1][0], color.tolist(), int(5 * sc)) + + return frame + + +def draw_horizon_line(frame, heatmap): + """Draw a horizon line.""" + if not isinstance(heatmap, np.ndarray): + heatmap = heatmap.cpu().numpy() + + h, w = frame.shape[0], frame.shape[1] + sc = min(h / 640.0, 2.0) + + color = (0, 255, 255) + vmin, vmax = -90, 90 + heatmap = (heatmap - vmin) / (vmax - vmin) + + contours, _ = cv2.findContours( + (heatmap > 0.5).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE + ) + if contours: + contour = [p for p in contours[0] if not is_corner(tuple(p[0]), h, w)] + for index, item in enumerate(contour[:-1]): + cv2.line(frame, item[0], contour[index + 1][0], color, int(5 * sc)) + return frame + + +def plot_confidence(frame, confidence): + """Plot confidence heatmap.""" + if not isinstance(confidence, np.ndarray): + confidence = confidence.cpu().numpy() + + confidence = np.log10(confidence.clip(1e-6)).clip(-4) + confidence = (confidence - confidence.min()) / (confidence.max() - confidence.min()) + + cmap = plt.get_cmap("turbo") + colors = (cmap(confidence)[..., :3] * 255).astype(np.uint8)[..., ::-1] + return cv2.addWeighted(frame, 1 - 0.4, colors, 0.4, 0) + + +def plot_vector_field(frame, vector_field): + """Plot a vector field.""" + if not isinstance(vector_field, np.ndarray): + vector_field = vector_field.cpu().numpy() + + H, W = frame.shape[:2] + sc = min(H / 640.0, 2.0) + + subsample = min(W, H) // 10 + offset_x = ((W % subsample) + subsample) // 2 + samples_x = np.arange(offset_x, W, subsample) + samples_y = np.arange(int(subsample * 0.9), H, subsample) + + vec_len = 40 * sc + x_grid, y_grid = np.meshgrid(samples_x, samples_y) + x, y = vector_field[:, samples_y][:, :, samples_x] + for xi, yi, xi_dir, yi_dir in zip(x_grid.ravel(), y_grid.ravel(), x.ravel(), y.ravel()): + start = (xi, yi) + end = (int(xi + xi_dir * vec_len), int(yi + yi_dir * vec_len)) + cv2.arrowedLine( + frame, start, end, (0, 255, 0), int(5 * sc), line_type=cv2.LINE_AA, tipLength=0.3 + ) + + return frame + + +def plot_box(frame, gravity, camera): + """Plot a box object.""" + pts = np.array( + [[0, 0, 0], [0, 0, 1], [0, 1, 0], [0, 1, 1], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]] + ) + pts = pts - np.array([0.5, 1, 0.5]) + rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0] + t = np.array([0, 0, 1], dtype=float) + K = camera.K[0].cpu().numpy().astype(float) + dist = np.zeros(4, dtype=float) + axis_points, _ = cv2.projectPoints( + 0.1 * pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist + ) + + h = frame.shape[0] + sc = min(h / 640.0, 2.0) + + color = (85, 108, 228) + for p in axis_points: + center = tuple((int(p[0][0]), int(p[0][1]))) + frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA) + + for i in range(0, 4): + p1 = axis_points[i].astype(int) + p2 = axis_points[i + 4].astype(int) + frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA) + + p1 = axis_points[i].astype(int) + p2 = axis_points[(i + 1) % 4].astype(int) + frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA) + + p1 = axis_points[i + 4].astype(int) + p2 = axis_points[(i + 1) % 4 + 4].astype(int) + frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA) + + return frame + + +def plot_grid(frame, gravity, camera, grid_size=0.2, num_points=5): + """Plot a grid of points.""" + h = frame.shape[0] + sc = min(h / 640.0, 2.0) + + samples = np.linspace(-grid_size, grid_size, num_points) + xz = np.meshgrid(samples, samples) + pts = np.stack((xz[0].ravel(), np.zeros_like(xz[0].ravel()), xz[1].ravel()), axis=-1) + + # project points + rotation_vec = cv2.Rodrigues(gravity.R.numpy()[0])[0] + t = np.array([0, 0, 1], dtype=float) + K = camera.K[0].cpu().numpy().astype(float) + dist = np.zeros(4, dtype=float) + axis_points, _ = cv2.projectPoints(pts.reshape(-1, 3).astype(float), rotation_vec, t, K, dist) + + color = (192, 77, 58) + # draw points + for p in axis_points: + center = tuple((int(p[0][0]), int(p[0][1]))) + frame = cv2.circle(frame, center, 10, color, -1, cv2.LINE_AA) + + # draw lines + for i in range(num_points): + for j in range(num_points - 1): + p1 = axis_points[i * num_points + j].astype(int) + p2 = axis_points[i * num_points + j + 1].astype(int) + frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA) + + p1 = axis_points[j * num_points + i].astype(int) + p2 = axis_points[(j + 1) * num_points + i].astype(int) + frame = cv2.line(frame, tuple(p1[0]), tuple(p2[0]), color, int(5 * sc), cv2.LINE_AA) + + return frame + + +def undistort_image(img, camera, padding=0.3): + """Undistort an image.""" + W, H = camera.size.unbind(-1) + H, W = H.int().item(), W.int().item() + + pad_h, pad_w = int(H * padding), int(W * padding) + x, y = torch.meshgrid(torch.arange(0, W + pad_w), torch.arange(0, H + pad_h), indexing="xy") + coords = torch.stack((x, y), dim=-1).reshape(-1, 2) - torch.tensor([pad_w / 2, pad_h / 2]) + + p3d, _ = camera.pinhole().image2world(coords.to(camera.device).to(camera.dtype)) + p2d, _ = camera.world2image(p3d) + + p2d = p2d.float().numpy().reshape(H + pad_h, W + pad_w, 2) + img = cv2.remap(img, p2d[..., 0], p2d[..., 1], cv2.INTER_LINEAR, borderValue=(254, 254, 254)) + return cv2.resize(img, (W, H)) + + +class InteractiveDemo: + def __init__(self, capture: VideoCapture, device: str) -> None: + self.cap = capture + + self.device = torch.device(device) + self.model = GeoCalib().to(device) + + self.up_toggle = False + self.lat_toggle = False + self.conf_toggle = False + + self.hl_toggle = False + self.grid_toggle = False + self.box_toggle = False + + self.undist_toggle = False + + self.camera_model = "pinhole" + + def render_frame(self, frame, calibration): + """Render the frame with the calibration results.""" + camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu() + + if self.undist_toggle: + return undistort_image(frame, camera) + + up, lat = get_perspective_field(camera, gravity) + + if gravity.pitch[0] > 0: + frame = plot_box(frame, gravity, camera) if self.box_toggle else frame + frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame + else: + frame = plot_grid(frame, gravity, camera) if self.grid_toggle else frame + frame = plot_box(frame, gravity, camera) if self.box_toggle else frame + + frame = draw_horizon_line(frame, lat[0, 0]) if self.hl_toggle else frame + + if self.conf_toggle and self.up_toggle: + frame = plot_confidence(frame, calibration["up_confidence"][0]) + frame = plot_vector_field(frame, up[0]) if self.up_toggle else frame + + if self.conf_toggle and self.lat_toggle: + frame = plot_confidence(frame, calibration["latitude_confidence"][0]) + frame = plot_latitude(frame, rad2deg(lat)[0, 0]) if self.lat_toggle else frame + + return frame + + def format_results(self, calibration): + """Format the calibration results.""" + camera, gravity = calibration["camera"].cpu(), calibration["gravity"].cpu() + + vfov, focal = camera.vfov[0].item(), camera.f[0, 0].item() + fov_unc = rad2deg(calibration["vfov_uncertainty"].item()) + f_unc = calibration["focal_uncertainty"].item() + + roll, pitch = gravity.rp[0].unbind(-1) + roll, pitch, vfov = rad2deg(roll), rad2deg(pitch), rad2deg(vfov) + roll_unc = rad2deg(calibration["roll_uncertainty"].item()) + pitch_unc = rad2deg(calibration["pitch_uncertainty"].item()) + + text = f"{self.camera_model.replace('_', ' ').title()}\n" + text += f"Roll: {roll:.2f} (+- {roll_unc:.2f})\n" + text += f"Pitch: {pitch:.2f} (+- {pitch_unc:.2f})\n" + text += f"vFoV: {vfov:.2f} (+- {fov_unc:.2f})\n" + text += f"Focal: {focal:.2f} (+- {f_unc:.2f})" + + if hasattr(camera, "k1"): + text += f"\nK1: {camera.k1[0].item():.2f}" + + return text + + def update_toggles(self): + """Update the toggles.""" + key = cv2.waitKey(100) & 0xFF + if key == ord("h"): + self.hl_toggle = not self.hl_toggle + elif key == ord("u"): + self.up_toggle = not self.up_toggle + elif key == ord("l"): + self.lat_toggle = not self.lat_toggle + elif key == ord("c"): + self.conf_toggle = not self.conf_toggle + elif key == ord("d"): + self.undist_toggle = not self.undist_toggle + elif key == ord("g"): + self.grid_toggle = not self.grid_toggle + elif key == ord("b"): + self.box_toggle = not self.box_toggle + + elif key == ord("1"): + self.camera_model = "pinhole" + elif key == ord("2"): + self.camera_model = "simple_radial" + elif key == ord("3"): + self.camera_model = "simple_divisional" + + elif key == ord("q"): + return True + + return False + + def run(self): + """Run the interactive demo.""" + while True: + start = time() + ret, frame = self.cap.read() + + if not ret: + print("Error: Failed to retrieve frame.") + break + + # create tensor from frame + img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + img = torch.tensor(img).permute(2, 0, 1) / 255.0 + + calibration = self.model.calibrate(img.to(self.device), camera_model=self.camera_model) + + # render results to the frame + frame = self.render_frame(frame, calibration) + frame = add_text(frame, self.format_results(calibration)) + + end = time() + frame = add_text( + frame, f"FPS: {1 / (end - start):04.1f}", align_left=False, align_top=False + ) + + cv2.imshow("GeoCalib Demo", frame) + + if self.update_toggles(): + break + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--camera_id", + type=int, + default=None, + help="Camera ID to use. If none, will ask for ip of droidcam.", + ) + args = parser.parse_args() + + print(description) + + device = get_device() + print(f"Running on: {device}") + + # setup video capture + if args.camera_id is not None: + cap = VideoCapture(args.camera_id) + else: + ip = input("Enter the IP address of the camera: ") + cap = VideoCapture(f"http://{ip}:4747/video/force/1920x1080") + + if not cap.isOpened(): + raise ValueError("Error: Could not open camera.") + + demo = InteractiveDemo(cap, device) + demo.run() + + +if __name__ == "__main__": + main() diff --git a/geocalib/lm_optimizer.py b/geocalib/lm_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..45c28154aed238b35b5c48f37d04700882621450 --- /dev/null +++ b/geocalib/lm_optimizer.py @@ -0,0 +1,642 @@ +"""Implementation of the Levenberg-Marquardt optimizer for camera calibration.""" + +import logging +import time +from types import SimpleNamespace +from typing import Any, Callable, Dict, Tuple + +import torch +import torch.nn as nn + +from geocalib.camera import BaseCamera, camera_models +from geocalib.gravity import Gravity +from geocalib.misc import J_focal2fov +from geocalib.perspective_fields import J_perspective_field, get_perspective_field +from geocalib.utils import focal2fov, rad2deg + +logger = logging.getLogger(__name__) + + +def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera: + """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w). + + Args: + data (Dict[str, torch.Tensor]): Input data dictionary. + camera_model (BaseCamera): Camera model to use. + + Returns: + BaseCamera: Initial camera for optimization. + """ + """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).""" + ref = data.get("up_field", data["latitude_field"]) + ref = ref.detach() + + h, w = ref.shape[-2:] + batch_h, batch_w = ( + ref.new_ones((ref.shape[0],)) * h, + ref.new_ones((ref.shape[0],)) * w, + ) + + init_r = ref.new_zeros((ref.shape[0],)) + init_p = ref.new_zeros((ref.shape[0],)) + + focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w)) + init_vfov = focal2fov(focal, h) + + params = {"width": batch_w, "height": batch_h, "vfov": init_vfov} + params |= {"scales": data["scales"]} if "scales" in data else {} + params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {} + camera = camera_model.from_dict(params) + camera = camera.float().to(ref.device) + + gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device) + + if "prior_gravity" in data: + gravity = data["prior_gravity"].float().to(ref.device) + gravity = Gravity(gravity) if isinstance(gravity, torch.Tensor) else gravity + + return camera, gravity + + +def scaled_loss( + x: torch.Tensor, fn: Callable, a: float +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply a loss function to a tensor and pre- and post-scale it. + + Args: + x: the data tensor, should already be squared: `x = y**2`. + fn: the loss function, with signature `fn(x) -> y`. + a: the scale parameter. + + Returns: + The value of the loss, and its first and second derivatives. + """ + a2 = a**2 + loss, loss_d1, loss_d2 = fn(x / a2) + return loss * a2, loss_d1, loss_d2 / a2 + + +def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """The classical robust Huber loss, with first and second derivatives.""" + mask = x <= 1 + sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass + isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx) + loss = torch.where(mask, x, 2 * sx - 1) + loss_d1 = torch.where(mask, torch.ones_like(x), isx) + loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x)) + return loss, loss_d1, loss_d2 + + +def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool: + """Early stopping criterion based on cost convergence.""" + return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol) + + +def update_lambda( + lamb: torch.Tensor, + prev_cost: torch.Tensor, + new_cost: torch.Tensor, + lambda_min: float = 1e-6, + lambda_max: float = 1e2, +) -> torch.Tensor: + """Update damping factor for Levenberg-Marquardt optimization.""" + new_lamb = lamb.new_zeros(lamb.shape) + new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1) + lamb = torch.clamp(new_lamb, lambda_min, lambda_max) + return lamb + + +def optimizer_step( + G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """One optimization step with Gauss-Newton or Levenberg-Marquardt. + + Args: + G (torch.Tensor): Batched gradient tensor of size (..., N). + H (torch.Tensor): Batched hessian tensor of size (..., N, N). + lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,). + eps (float, optional): Epsilon for damping. Defaults to 1e-6. + + Returns: + torch.Tensor: Batched update tensor of size (..., N). + """ + diag = H.diagonal(dim1=-2, dim2=-1) + diag = diag * lambda_.unsqueeze(-1) # (B, 3) + + H = H + diag.clamp(min=eps).diag_embed() + + H_, G_ = H.cpu(), G.cpu() + try: + U = torch.linalg.cholesky(H_) + except RuntimeError: + logger.warning("Cholesky decomposition failed. Stopping.") + delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3) + else: + delta = torch.cholesky_solve(G_[..., None], U)[..., 0] + + return delta.to(H.device) + + +# mypy: ignore-errors +class LMOptimizer(nn.Module): + """Levenberg-Marquardt optimizer for camera calibration.""" + + default_conf = { + # Camera model parameters + "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"} + "shared_intrinsics": False, # share focal length across all images in batch + # LM optimizer parameters + "num_steps": 30, + "lambda_": 0.1, + "fix_lambda": False, + "early_stop": True, + "atol": 1e-8, + "rtol": 1e-8, + "use_spherical_manifold": True, # use spherical manifold for gravity optimization + "use_log_focal": True, # use log focal length for optimization + # Loss function parameters + "up_loss_fn_scale": 1e-2, + "lat_loss_fn_scale": 1e-2, + # Misc + "verbose": False, + } + + def __init__(self, conf: Dict[str, Any]): + """Initialize the LM optimizer.""" + super().__init__() + self.conf = conf = SimpleNamespace(**{**self.default_conf, **conf}) + self.num_steps = conf.num_steps + + self.set_camera_model(conf.camera_model) + self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics) + + def set_camera_model(self, camera_model: str) -> None: + """Set the camera model to use for the optimization. + + Args: + camera_model (str): Camera model to use. + """ + assert ( + camera_model in camera_models.keys() + ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}" + self.camera_model = camera_models[camera_model] + self.camera_has_distortion = hasattr(self.camera_model, "dist") + + logger.debug( + f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})" + ) + + def setup_optimization_and_priors( + self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False + ) -> None: + """Setup the optimization and priors for the LM optimizer. + + Args: + data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults + to None. + shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch. + Defaults to False. + """ + if data is None: + data = {} + self.shared_intrinsics = shared_intrinsics + + if shared_intrinsics: # si => must use pinhole + assert ( + self.camera_model == camera_models["pinhole"] + ), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}" + + self.estimate_gravity = True + if "prior_gravity" in data: + self.estimate_gravity = False + logger.debug("Using provided gravity as prior.") + + self.estimate_focal = True + if "prior_focal" in data: + self.estimate_focal = False + logger.debug("Using provided focal as prior.") + + self.estimate_k1 = True + if "prior_k1" in data: + self.estimate_k1 = False + logger.debug("Using provided k1 as prior.") + + self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,) + self.focal_delta_dims = ( + (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,) + ) + self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,) + + logger.debug(f"Camera Model: {self.camera_model}") + logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})") + logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})") + logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})") + + logger.debug(f"Shared intrinsics: {self.shared_intrinsics}") + + def calculate_residuals( + self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Calculate the residuals for the optimization. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields. + + Returns: + Dict[str, torch.Tensor]: Residuals for the optimization. + """ + perspective_up, perspective_lat = get_perspective_field(camera, gravity) + perspective_lat = torch.sin(perspective_lat) + + residuals = {} + if "up_field" in data: + up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1) + residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2) + + if "latitude_field" in data: + target_lat = torch.sin(data["latitude_field"]) + lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1) + residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1) + + return residuals + + def calculate_costs( + self, residuals: torch.Tensor, data: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """Calculate the costs and weights for the optimization. + + Args: + residuals (torch.Tensor): Residuals for the optimization. + data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence. + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the + optimization. + """ + costs, weights = {}, {} + + if "up_residual" in residuals: + up_cost = (residuals["up_residual"] ** 2).sum(dim=-1) + up_cost, up_weight, _ = scaled_loss(up_cost, huber_loss, self.conf.up_loss_fn_scale) + + if "up_confidence" in data: + up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1) + up_weight = up_weight * up_conf + up_cost = up_cost * up_conf + + costs["up_cost"] = up_cost + weights["up_weights"] = up_weight + + if "latitude_residual" in residuals: + lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1) + lat_cost, lat_weight, _ = scaled_loss(lat_cost, huber_loss, self.conf.lat_loss_fn_scale) + + if "latitude_confidence" in data: + lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1) + lat_weight = lat_weight * lat_conf + lat_cost = lat_cost * lat_conf + + costs["latitude_cost"] = lat_cost + weights["latitude_weights"] = lat_weight + + return costs, weights + + def calculate_gradient_and_hessian( + self, + J: torch.Tensor, + residuals: torch.Tensor, + weights: torch.Tensor, + shared_intrinsics: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate the gradient and Hessian for given the Jacobian, residuals, and weights. + + Args: + J (torch.Tensor): Jacobian. + residuals (torch.Tensor): Residuals. + weights (torch.Tensor): Weights. + shared_intrinsics (bool): Whether to share the intrinsics across the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian. + """ + dims = () + if self.estimate_gravity: + dims = (0, 1) + if self.estimate_focal: + dims += (2,) + if self.camera_has_distortion and self.estimate_k1: + dims += (3,) + assert dims, "No parameters to optimize" + + J = J[..., dims] + + Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals) + Grad = weights[..., None] * Grad + Grad = Grad.sum(-2) # (B, N_params) + + if shared_intrinsics: + # reshape to (1, B * (N_params-1) + 1) + Grad_g = Grad[..., :2].reshape(1, -1) + Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True) + Grad = torch.cat([Grad_g, Grad_f], dim=-1) + + Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J) + Hess = weights[..., None, None] * Hess + Hess = Hess.sum(-3) + + if shared_intrinsics: + H_g = torch.block_diag(*list(Hess[..., :2, :2])) + J_fg = Hess[..., :2, 2].flatten() + J_gf = Hess[..., 2, :2].flatten() + J_f = Hess[..., 2, 2].sum() + dims = H_g.shape[-1] + 1 + Hess = Hess.new_zeros((dims, dims), dtype=torch.float32) + Hess[:-1, :-1] = H_g + Hess[-1, :-1] = J_gf + Hess[:-1, -1] = J_fg + Hess[-1, -1] = J_f + Hess = Hess.unsqueeze(0) + + return Grad, Hess + + def setup_system( + self, + camera: BaseCamera, + gravity: Gravity, + residuals: Dict[str, torch.Tensor], + weights: Dict[str, torch.Tensor], + as_rpf: bool = False, + shared_intrinsics: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate the gradient and Hessian for the optimization. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + residuals (Dict[str, torch.Tensor]): Residuals for the optimization. + weights (Dict[str, torch.Tensor]): Weights for the optimization. + as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to + roll, pitch, and focal length. Defaults to False. + shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch. + Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization. + """ + J_up, J_lat = J_perspective_field( + camera, + gravity, + spherical=self.conf.use_spherical_manifold and not as_rpf, + log_focal=self.conf.use_log_focal and not as_rpf, + ) + + J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3) + J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3) + + n_params = ( + 2 * self.estimate_gravity + + self.estimate_focal + + (self.camera_has_distortion and self.estimate_k1) + ) + Grad = J_up.new_zeros(J_up.shape[0], n_params) + Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params) + + if shared_intrinsics: + N_params = Grad.shape[0] * (n_params - 1) + 1 + Grad = Grad.new_zeros(1, N_params) + Hess = Hess.new_zeros(1, N_params, N_params) + + if "up_residual" in residuals: + Up_Grad, Up_Hess = self.calculate_gradient_and_hessian( + J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics + ) + + if self.conf.verbose: + logger.info(f"Up J:\n{Up_Grad.mean(0)}") + + Grad = Grad + Up_Grad + Hess = Hess + Up_Hess + + if "latitude_residual" in residuals: + Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian( + J_lat, + residuals["latitude_residual"], + weights["latitude_weights"], + shared_intrinsics, + ) + + if self.conf.verbose: + logger.info(f"Lat J:\n{Lat_Grad.mean(0)}") + + Grad = Grad + Lat_Grad + Hess = Hess + Lat_Hess + + return Grad, Hess + + def estimate_uncertainty( + self, + camera_opt: BaseCamera, + gravity_opt: Gravity, + errors: Dict[str, torch.Tensor], + weights: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Estimate the uncertainty of the optimized camera and gravity at the final step. + + Args: + camera_opt (BaseCamera): Final optimized camera. + gravity_opt (Gravity): Final optimized gravity. + errors (Dict[str, torch.Tensor]): Costs for the optimization. + weights (Dict[str, torch.Tensor]): Weights for the optimization. + + Returns: + Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity. + """ + _, Hess = self.setup_system( + camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False + ) + Cov = torch.inverse(Hess) + + roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + if self.estimate_gravity: + roll_uncertainty = Cov[..., 0, 0] + pitch_uncertainty = Cov[..., 1, 1] + + try: + delta_uncertainty = Cov[..., :2, :2] + eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu()) + gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device) + except RuntimeError: + logger.warning("Could not calculate gravity uncertainty") + gravity_uncertainty = Cov.new_zeros(Cov.shape[0]) + + focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + if self.estimate_focal: + focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]] + fov_uncertainty = ( + J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty + ) + + return { + "covariance": Cov, + "roll_uncertainty": torch.sqrt(roll_uncertainty), + "pitch_uncertainty": torch.sqrt(pitch_uncertainty), + "gravity_uncertainty": torch.sqrt(gravity_uncertainty), + "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2, + "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2), + } + + def update_estimate( + self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor + ) -> Tuple[BaseCamera, Gravity]: + """Update the camera and gravity estimates with the given delta. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + delta (torch.Tensor): Delta to update the camera and gravity estimates. + + Returns: + Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates. + """ + delta_gravity = ( + delta[..., self.gravity_delta_dims] + if self.estimate_gravity + else delta.new_zeros(delta.shape[:-1] + (2,)) + ) + new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold) + + delta_f = ( + delta[..., self.focal_delta_dims] + if self.estimate_focal + else delta.new_zeros(delta.shape[:-1] + (1,)) + ) + new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal) + + delta_dist = ( + delta[..., self.k1_delta_dims] + if self.camera_has_distortion and self.estimate_k1 + else delta.new_zeros(delta.shape[:-1] + (1,)) + ) + if self.camera_has_distortion: + new_camera = new_camera.update_dist(delta_dist) + + return new_camera, new_gravity + + def optimize( + self, + data: Dict[str, torch.Tensor], + camera_opt: BaseCamera, + gravity_opt: Gravity, + ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: + """Optimize the camera and gravity estimates. + + Args: + data (Dict[str, torch.Tensor]): Input data. + camera_opt (BaseCamera): Optimized camera. + gravity_opt (Gravity): Optimized gravity. + + Returns: + Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity + estimates and optimization information. + """ + key = list(data.keys())[0] + B = data[key].shape[0] + + lamb = data[key].new_ones(B) * self.conf.lambda_ + if self.shared_intrinsics: + lamb = data[key].new_ones(1) * self.conf.lambda_ + + infos = {"stop_at": self.num_steps} + for i in range(self.num_steps): + if self.conf.verbose: + logger.info(f"Step {i+1}/{self.num_steps}") + + errors = self.calculate_residuals(camera_opt, gravity_opt, data) + costs, weights = self.calculate_costs(errors, data) + + if i == 0: + prev_cost = sum(c.mean(-1) for c in costs.values()) + for k, c in costs.items(): + infos[f"initial_{k}"] = c.mean(-1) + + infos["initial_cost"] = prev_cost + + Grad, Hess = self.setup_system( + camera_opt, + gravity_opt, + errors, + weights, + shared_intrinsics=self.shared_intrinsics, + ) + delta = optimizer_step(Grad, Hess, lamb) # (B, N_params) + + if self.shared_intrinsics: + delta_g = delta[..., :-1].reshape(B, 2) + delta_f = delta[..., -1].expand(B, 1) + delta = torch.cat([delta_g, delta_f], dim=-1) + + # calculate new cost + camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta) + new_cost, _ = self.calculate_costs( + self.calculate_residuals(camera_opt, gravity_opt, data), data + ) + new_cost = sum(c.mean(-1) for c in new_cost.values()) + + if not self.conf.fix_lambda and not self.shared_intrinsics: + lamb = update_lambda(lamb, prev_cost, new_cost) + + if self.conf.verbose: + logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}") + logger.info(f"Camera:\n{camera_opt._data}") + + if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol): + infos["stop_at"] = min(i + 1, infos["stop_at"]) + + if self.conf.early_stop: + if self.conf.verbose: + logger.info(f"Early stopping at step {i+1}") + break + + prev_cost = new_cost + + if i == self.num_steps - 1 and self.conf.early_stop: + logger.warning("Reached maximum number of steps without convergence.") + + final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3) + final_cost, weights = self.calculate_costs(final_errors, data) # (B, N) + + if not self.training: + infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights) + + infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"] + for k, c in final_cost.items(): + infos[f"final_{k}"] = c.mean(-1) + + infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values()) + + return camera_opt, gravity_opt, infos + + def forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Run the LM optimization.""" + camera_init, gravity_init = get_trivial_estimation(data, self.camera_model) + + self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics) + + start = time.time() + camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init) + + if self.conf.verbose: + logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms") + + logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}") + logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}") + + logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}") + logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}") + + return {"camera": camera_opt, "gravity": gravity_opt, **infos} diff --git a/geocalib/misc.py b/geocalib/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..e6c481fee5276a69470559d427497d53b14acf7d --- /dev/null +++ b/geocalib/misc.py @@ -0,0 +1,318 @@ +"""Miscellaneous functions and classes for the geocalib_inference package.""" + +import functools +import inspect +import logging +from typing import Callable, List + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# mypy: ignore-errors + + +def autocast(func: Callable) -> Callable: + """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays. + + Use the device and dtype of the wrapper. + + Args: + func (Callable): Method of a TensorWrapper class. + + Returns: + Callable: Wrapped method. + """ + + @functools.wraps(func) + def wrap(self, *args): + device = torch.device("cpu") + dtype = None + if isinstance(self, TensorWrapper): + if self._data is not None: + device = self.device + dtype = self.dtype + elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): + raise ValueError(self) + + cast_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.from_numpy(arg) + arg = arg.to(device=device, dtype=dtype) + cast_args.append(arg) + return func(self, *cast_args) + + return wrap + + +class TensorWrapper: + """Wrapper for PyTorch tensors.""" + + _data = None + + @autocast + def __init__(self, data: torch.Tensor): + """Wrapper for PyTorch tensors.""" + self._data = data + + @property + def shape(self) -> torch.Size: + """Shape of the underlying tensor.""" + return self._data.shape[:-1] + + @property + def device(self) -> torch.device: + """Get the device of the underlying tensor.""" + return self._data.device + + @property + def dtype(self) -> torch.dtype: + """Get the dtype of the underlying tensor.""" + return self._data.dtype + + def __getitem__(self, index) -> torch.Tensor: + """Get the underlying tensor.""" + return self.__class__(self._data[index]) + + def __setitem__(self, index, item): + """Set the underlying tensor.""" + self._data[index] = item.data + + def to(self, *args, **kwargs): + """Move the underlying tensor to a new device.""" + return self.__class__(self._data.to(*args, **kwargs)) + + def cpu(self): + """Move the underlying tensor to the CPU.""" + return self.__class__(self._data.cpu()) + + def cuda(self): + """Move the underlying tensor to the GPU.""" + return self.__class__(self._data.cuda()) + + def pin_memory(self): + """Pin the underlying tensor to memory.""" + return self.__class__(self._data.pin_memory()) + + def float(self): + """Cast the underlying tensor to float.""" + return self.__class__(self._data.float()) + + def double(self): + """Cast the underlying tensor to double.""" + return self.__class__(self._data.double()) + + def detach(self): + """Detach the underlying tensor.""" + return self.__class__(self._data.detach()) + + def numpy(self): + """Convert the underlying tensor to a numpy array.""" + return self._data.detach().cpu().numpy() + + def new_tensor(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_tensor(*args, **kwargs) + + def new_zeros(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_zeros(*args, **kwargs) + + def new_ones(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_ones(*args, **kwargs) + + def new_full(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_full(*args, **kwargs) + + def new_empty(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_empty(*args, **kwargs) + + def unsqueeze(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self.__class__(self._data.unsqueeze(*args, **kwargs)) + + def squeeze(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self.__class__(self._data.squeeze(*args, **kwargs)) + + @classmethod + def stack(cls, objects: List, dim=0, *, out=None): + """Stack a list of objects with the same type and shape.""" + data = torch.stack([obj._data for obj in objects], dim=dim, out=out) + return cls(data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Support torch functions.""" + if kwargs is None: + kwargs = {} + return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented + + +class EuclideanManifold: + """Simple euclidean manifold.""" + + @staticmethod + def J_plus(x: torch.Tensor) -> torch.Tensor: + """Plus operator Jacobian.""" + return torch.eye(x.shape[-1]).to(x) + + @staticmethod + def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: + """Plus operator.""" + return x + delta + + +class SphericalManifold: + """Implementation of the spherical manifold. + + Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State + Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25). + + Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by + Golub et al. + """ + + @staticmethod + def householder_vector(x: torch.Tensor) -> torch.Tensor: + """Return the Householder vector and beta. + + Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies + in Mathematical Sciences) but using the nth element of the input vector as pivot instead of + first. + + This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is + orthogonal and H * x = ||x||_2 * e_n. + + Args: + x (torch.Tensor): [..., n] tensor. + + Returns: + torch.Tensor: v of shape [..., n] + torch.Tensor: beta of shape [...] + """ + sigma = torch.sum(x[..., :-1] ** 2, -1) + xpiv = x[..., -1] + norm = torch.norm(x, dim=-1) + if torch.any(sigma < 1e-7): + sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma) + logger.warning("sigma < 1e-7") + + vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm)) + beta = 2 * vpiv**2 / (sigma + vpiv**2) + v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1) + return v, beta + + @staticmethod + def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + """Apply Householder transformation. + + Args: + y (torch.Tensor): Vector to transform of shape [..., n]. + v (torch.Tensor): Householder vector of shape [..., n]. + beta (torch.Tensor): Householder beta of shape [...]. + + Returns: + torch.Tensor: Transformed vector of shape [..., n]. + """ + return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None] + + @classmethod + def J_plus(cls, x: torch.Tensor) -> torch.Tensor: + """Plus operator Jacobian.""" + v, beta = cls.householder_vector(x) + H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v) + H = H + torch.eye(H.shape[-1]).to(H) + return H[..., :-1] # J + + @classmethod + def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: + """Plus operator. + + Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State + Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth + element of the input vector as pivot instead of first. + + Args: + x: point on the manifold + delta: tangent vector + """ + eps = 1e-7 + # keep norm is not equal to 1 + nx = torch.norm(x, dim=-1, keepdim=True) + nd = torch.norm(delta, dim=-1, keepdim=True) + + # make sure we don't divide by zero in backward as torch.where computes grad for both + # branches + nd_ = torch.where(nd < eps, nd + eps, nd) + sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_) + + # cos is applied to last dim instead of first + exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1) + + v, beta = cls.householder_vector(x) + return nx * cls.apply_householder(exp_delta, v, beta) + + +@torch.jit.script +def J_vecnorm(vec: torch.Tensor) -> torch.Tensor: + """Compute the jacobian of vec / norm2(vec). + + Args: + vec (torch.Tensor): [..., D] tensor. + + Returns: + torch.Tensor: [..., D, D] Jacobian. + """ + D = vec.shape[-1] + norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1) + + if (norm_x == 0).any(): + norm_x = norm_x + 1e-6 + + xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D) + identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D) + + return identity / norm_x - (xxT / norm_x**3) # (..., D, D) + + +@torch.jit.script +def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the jacobian of the focal2fov function.""" + return -4 * h / (4 * focal**2 + h**2) + + +@torch.jit.script +def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Compute the jacobian of the up-vector projection. + + Args: + uv (torch.Tensor): Normalized image coordinates of shape (..., 2). + abc (torch.Tensor): Gravity vector of shape (..., 3). + wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv". + + Raises: + ValueError: If the wrt parameter is unknown. + + Returns: + torch.Tensor: Jacobian with respect to the parameter. + """ + if wrt == "uv": + c = abc[..., 2][..., None, None, None] + return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2)) + + elif wrt == "abc": + J = uv.new_zeros(uv.shape[:-1] + (2, 3)) + J[..., 0, 0] = 1 + J[..., 1, 1] = 1 + J[..., 0, 2] = -uv[..., 0] + J[..., 1, 2] = -uv[..., 1] + return J + + else: + raise ValueError(f"Unknown wrt: {wrt}") diff --git a/geocalib/modules.py b/geocalib/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..f635e84bc9e577711bd3e508800011a3821ce026 --- /dev/null +++ b/geocalib/modules.py @@ -0,0 +1,575 @@ +"""Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic +Segmentation (NeurIPS 2022) adapted from + +https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/backbones/mscan.py + + +Light Hamburger Decoder adapted from: + +https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py +""" + +from typing import Dict, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.modules.utils import _pair as to_2tuple + +# flake8: noqa: E266 +# mypy: ignore-errors + + +class ConvModule(nn.Module): + """Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + use_norm: bool = False, + bias: bool = True, + ): + """Simple convolution block. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + kernel_size (int): Kernel size. + padding (int, optional): Padding. Defaults to 0. + use_norm (bool, optional): Whether to use normalization. Defaults to False. + bias (bool, optional): Whether to use bias. Defaults to True. + """ + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity() + self.activate = nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + x = self.conv(x) + x = self.bn(x) + return self.activate(x) + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Simple residual convolution block. + + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features: int, unit2only=False, upsample=True): + """Feature fusion block. + + Args: + features (int): Number of features. + unit2only (bool, optional): Whether to use only the second unit. Defaults to False. + upsample (bool, optional): Whether to upsample. Defaults to True. + """ + super().__init__() + self.upsample = upsample + + if not unit2only: + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + output = xs[0] + + if len(xs) == 2: + output = output + self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + if self.upsample: + output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False) + + return output + + +################################################### +########### Light Hamburger Decoder ############### +################################################### + + +class NMF2D(nn.Module): + """Non-negative Matrix Factorization (NMF) for 2D data.""" + + def __init__(self): + """Non-negative Matrix Factorization (NMF) for 2D data.""" + super().__init__() + self.S, self.D, self.R = 1, 512, 64 + self.train_steps = 6 + self.eval_steps = 7 + self.inv_t = 1 + + def _build_bases(self, B: int, S: int, D: int, R: int, device: str = "cpu") -> torch.Tensor: + bases = torch.rand((B * S, D, R)).to(device) + return F.normalize(bases, dim=1) + + def local_step( + self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Update bases and coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + return bases, coef + + def compute_coef( + self, x: torch.Tensor, bases: torch.Tensor, coef: torch.Tensor + ) -> torch.Tensor: + """Compute coefficient.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + return coef * numerator / (denominator + 1e-6) + + def local_inference( + self, x: torch.Tensor, bases: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Local inference.""" + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + + # (S, D, R) -> (B * S, D, R) + bases = self._build_bases(B, self.S, D, self.R, device=x.device) + bases, coef = self.local_inference(x, bases) + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) + # (B * H, D, R) -> (B, H, N, D) + bases = bases.view(B, self.S, D, self.R) + + return x + + +class Hamburger(nn.Module): + """Hamburger Module.""" + + def __init__(self, ham_channels: int = 512): + """Hambuger Module. + + Args: + ham_channels (int, optional): Number of channels in the hamburger module. Defaults to + 512. + """ + super().__init__() + self.ham_in = ConvModule(ham_channels, ham_channels, 1) + self.ham = NMF2D() + self.ham_out = ConvModule(ham_channels, ham_channels, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=False) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=False) + return ham + + +class LightHamHead(nn.Module): + """Is Attention Better Than Matrix Decomposition? + + This head is the implementation of `HamNet `. + """ + + def __init__(self): + """Light hamburger decoder head.""" + super().__init__() + self.in_index = [0, 1, 2, 3] + self.in_channels = [64, 128, 320, 512] + self.out_channels = 64 + self.ham_channels = 512 + self.align_corners = False + + self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1) + + self.hamburger = Hamburger(self.ham_channels) + + self.align = ConvModule(self.ham_channels, self.out_channels, 1) + + self.linear_pred_uncertainty = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1), + ) + + self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1, bias=False) + self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False) + + def forward(self, features: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass.""" + inputs = [features["hl"][i] for i in self.in_index] + + inputs = [ + F.interpolate( + level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + x = self.squeeze(inputs) + + x = self.hamburger(x) + + feats = self.align(x) + + assert "ll" in features, "Low-level features are required for this model" + feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) + feats = self.out_conv(feats) + feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) + feats = self.ll_fusion(feats, features["ll"].clone()) + + uncertainty = self.linear_pred_uncertainty(feats).squeeze(1) + + return feats, uncertainty + + +################################################### +########### MSCAN ################ +################################################### + + +class DWConv(nn.Module): + """Depthwise convolution.""" + + def __init__(self, dim: int = 768): + """Depthwise convolution. + + Args: + dim (int, optional): Number of features. Defaults to 768. + """ + super().__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + return self.dwconv(x) + + +class Mlp(nn.Module): + """MLP module.""" + + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + """Initialize the MLP.""" + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward pass.""" + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(nn.Module): + """Simple stem convolution module.""" + + def __init__(self, in_channels: int, out_channels: int): + """Simple stem convolution module. + + Args: + in_channels (int): Input channels. + out_channels (int): Output channels. + """ + super().__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ), + nn.BatchNorm2d(out_channels // 2), + nn.GELU(), + nn.Conv2d( + out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + """Forward pass.""" + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class AttentionModule(nn.Module): + """Attention module.""" + + def __init__(self, dim: int): + """Attention module. + + Args: + dim (int): Number of features. + """ + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim) + self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim) + + self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim) + self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim) + + self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim) + self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim) + self.conv3 = nn.Conv2d(dim, dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + u = x.clone() + attn = self.conv0(x) + + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + attn = attn + attn_0 + attn_1 + attn_2 + + attn = self.conv3(attn) + return attn * u + + +class SpatialAttention(nn.Module): + """Spatial attention module.""" + + def __init__(self, dim: int): + """Spatial attention module. + + Args: + dim (int): Number of features. + """ + super().__init__() + self.d_model = dim + self.proj_1 = nn.Conv2d(dim, dim, 1) + self.activation = nn.GELU() + self.spatial_gating_unit = AttentionModule(dim) + self.proj_2 = nn.Conv2d(dim, dim, 1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass.""" + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class Block(nn.Module): + """MSCAN block.""" + + def __init__( + self, dim: int, mlp_ratio: float = 4.0, drop: float = 0.0, act_layer: nn.Module = nn.GELU + ): + """MSCAN block. + + Args: + dim (int): Number of features. + mlp_ratio (float, optional): Ratio of the hidden features in the MLP. Defaults to 4.0. + drop (float, optional): Dropout rate. Defaults to 0.0. + act_layer (nn.Module, optional): Activation layer. Defaults to nn.GELU. + """ + super().__init__() + self.norm1 = nn.BatchNorm2d(dim) + self.attn = SpatialAttention(dim) + self.drop_path = nn.Identity() # only used in training + self.norm2 = nn.BatchNorm2d(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + + def forward(self, x: torch.Tensor, H: int, W: int) -> torch.Tensor: + """Forward pass.""" + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path(self.layer_scale_1[..., None, None] * self.attn(self.norm1(x))) + x = x + self.drop_path(self.layer_scale_2[..., None, None] * self.mlp(self.norm2(x))) + return x.view(B, C, N).permute(0, 2, 1) + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__( + self, patch_size: int = 7, stride: int = 4, in_chans: int = 3, embed_dim: int = 768 + ): + """Image to Patch Embedding. + + Args: + patch_size (int, optional): Image patch size. Defaults to 7. + stride (int, optional): Stride. Defaults to 4. + in_chans (int, optional): Number of input channels. Defaults to 3. + embed_dim (int, optional): Embedding dimension. Defaults to 768. + """ + super().__init__() + patch_size = to_2tuple(patch_size) + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.BatchNorm2d(embed_dim) + + def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, int, int]: + """Forward pass.""" + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class MSCAN(nn.Module): + """Multi-scale convolutional attention network.""" + + def __init__(self): + """Multi-scale convolutional attention network.""" + super().__init__() + self.in_channels = 3 + self.embed_dims = [64, 128, 320, 512] + self.mlp_ratios = [8, 8, 4, 4] + self.drop_rate = 0.0 + self.drop_path_rate = 0.1 + self.depths = [3, 3, 12, 3] + self.num_stages = 4 + + for i in range(self.num_stages): + if i == 0: + patch_embed = StemConv(3, self.embed_dims[0]) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_chans=self.in_chans if i == 0 else self.embed_dims[i - 1], + embed_dim=self.embed_dims[i], + ) + + block = nn.ModuleList( + [ + Block( + dim=self.embed_dims[i], + mlp_ratio=self.mlp_ratios[i], + drop=self.drop_rate, + ) + for _ in range(self.depths[i]) + ] + ) + norm = nn.LayerNorm(self.embed_dims[i]) + + setattr(self, f"patch_embed{i + 1}", patch_embed) + setattr(self, f"block{i + 1}", block) + setattr(self, f"norm{i + 1}", norm) + + def forward(self, data): + """Forward pass.""" + # rgb -> bgr and from [0, 1] to [0, 255] + x = data["image"][:, [2, 1, 0], :, :] * 255.0 + B = x.shape[0] + + outs = [] + for i in range(self.num_stages): + patch_embed = getattr(self, f"patch_embed{i + 1}") + block = getattr(self, f"block{i + 1}") + norm = getattr(self, f"norm{i + 1}") + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return {"features": outs} diff --git a/geocalib/perspective_fields.py b/geocalib/perspective_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..f6b099ce4e9da1a31b785969c95a4af94029cf76 --- /dev/null +++ b/geocalib/perspective_fields.py @@ -0,0 +1,366 @@ +"""Implementation of perspective fields. + +Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py +""" + +from typing import Tuple + +import torch +from torch.nn import functional as F + +from geocalib.camera import BaseCamera +from geocalib.gravity import Gravity +from geocalib.misc import J_up_projection, J_vecnorm, SphericalManifold + +# flake8: noqa: E266 + + +def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor: + """Get the horizon line from the camera parameters. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True. + + Returns: + torch.Tensor: In image frame, fraction of image left/right border intersection with + respect to image height. + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + # project horizon midpoint to image plane + horizon_midpoint = camera.new_tensor([0, 0, 1]) + horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint + midpoint = horizon_midpoint[:2] / horizon_midpoint[2] + + # compute left and right offset to borders + left_offset = midpoint[0] * torch.tan(gravity.roll) + right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll) + left, right = midpoint[1] + left_offset, midpoint[1] - right_offset + + horizon = camera.new_tensor([left, right]) + return horizon / camera.size[1] if relative else horizon + + +def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor: + """Get the up vector field from the camera parameters. + + Args: + camera (Camera): Camera parameters. + normalize (bool, optional): Whether to normalize the up vector. Defaults to True. + + Returns: + torch.Tensor: up vector field as tensor of shape (..., h, w, 2). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + uv = camera.normalize(camera.pixel_coordinates()) + + # projected up is (a, b) - c * (u, v) + abc = gravity.vec3d + projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2) + + if hasattr(camera, "dist"): + d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1) + d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2) + offset = camera.up_projection_offset(uv) # (..., N, 2) + offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2) + + # (..., N, 2) + projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d) + + if normalize: + projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2) + + return projected_up2d.reshape(camera.shape[0], h, w, 2) + + +def J_up_field( + camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False +) -> torch.Tensor: + """Get the jacobian of the up field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # Forward + xy = camera.pixel_coordinates() + uv = camera.normalize(xy) + + projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv + + # Backward + J = [] + + # (..., N, 2, 2) + J_norm2proj = J_vecnorm( + get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2) + ) + + # distortion values + if hasattr(camera, "dist"): + d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1) + d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2) + offset = camera.up_projection_offset(uv) # (..., N, 2) + offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2) + + ###################### + ## Gravity Jacobian ## + ###################### + + J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3) + + if hasattr(camera, "dist"): + # (..., N, 2, 3) + J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc) + + J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp() + J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta) + J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta) + J.append(J_up2delta) + + ###################### + ### Focal Jacobian ### + ###################### + + J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2) + + if hasattr(camera, "dist"): + J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv) + J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d) + + inner = (uv * projected_up2d).sum(-1)[..., None, None] + J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv") + J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2) + J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up + + J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2) + + if log_focal: + J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2) + + J_uv2f = J_uv2f.sum(-1) # (..., N, 2) + + J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2) + J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1) + J.append(J_up2f) + + ###################### + ##### K1 Jacobian #### + ###################### + + if hasattr(camera, "dist"): + J_duv = camera.J_distort(uv, wrt="scale2dist") + J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2) + J_offset = torch.einsum( + "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv + ) + J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d) + J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None] + J.append(J_k1) + + n_params = sum(j.shape[-1] for j in J) + return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params) + + +def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor: + """Get the latitudes of the camera pixels in radians. + + Latitudes are defined as the angle between the ray and the up vector. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + + Returns: + torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + uv1, _ = camera.image2world(camera.pixel_coordinates()) + rays = camera.pixel_bearing_many(uv1) + + lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d) + + eps = 1e-6 + lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps)) + + return lat_asin.reshape(camera.shape[0], h, w, 1) + + +def J_latitude_field( + camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False +) -> torch.Tensor: + """Get the jacobian of the latitude field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # Forward + xy = camera.pixel_coordinates() + uv1, _ = camera.image2world(xy) + uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3) + + # Backward + J = [] + J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2) + + ###################### + ## Gravity Jacobian ## + ###################### + + J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp() + J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2) + J.append(J_delta) + + ###################### + ### Focal Jacobian ### + ###################### + + J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2) + if log_focal: + J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :] + J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2) + + J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3) + J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1) + J.append(J_f) + + ###################### + ##### K1 Jacobian #### + ###################### + + if hasattr(camera, "dist"): + J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2) + # (..., N, 2) + J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1) + # (..., N, 1) + J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1) + J.append(J_k1) + + n_params = sum(j.shape[-1] for j in J) + return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params) + + +def get_perspective_field( + camera: BaseCamera, + gravity: Gravity, + use_up: bool = True, + use_latitude: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the perspective field from the camera parameters. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + use_up (bool, optional): Whether to include the up vector field. Defaults to True. + use_latitude (bool, optional): Whether to include the latitude field. Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape + (..., 2, h, w) and (..., 1, h, w). + """ + assert use_up or use_latitude, "At least one of use_up or use_latitude must be True." + + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if use_up: + permute = (0, 3, 1, 2) + # (..., 2, h, w) + up = get_up_field(camera, gravity).permute(permute) + else: + shape = (camera.shape[0], 2, h, w) + up = camera.new_zeros(shape) + + if use_latitude: + permute = (0, 3, 1, 2) + # (..., 1, h, w) + lat = get_latitude_field(camera, gravity).permute(permute) + else: + shape = (camera.shape[0], 1, h, w) + lat = camera.new_zeros(shape) + + return up, lat + + +def J_perspective_field( + camera: BaseCamera, + gravity: Gravity, + use_up: bool = True, + use_latitude: bool = True, + spherical: bool = False, + log_focal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the jacobian of the perspective field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + use_up (bool, optional): Whether to include the up vector field. Defaults to True. + use_latitude (bool, optional): Whether to include the latitude field. Defaults to True. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape + (..., h, w, 2, 4) and (..., h, w, 1, 4). + """ + assert use_up or use_latitude, "At least one of use_up or use_latitude must be True." + + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if use_up: + J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4) + else: + shape = (camera.shape[0], h, w, 2, 4) + J_up = camera.new_zeros(shape) + + if use_latitude: + J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4) + else: + shape = (camera.shape[0], h, w, 1, 4) + J_lat = camera.new_zeros(shape) + + return J_up, J_lat diff --git a/geocalib/utils.py b/geocalib/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5372e157443502c6f4e5cacf46968839e14af341 --- /dev/null +++ b/geocalib/utils.py @@ -0,0 +1,325 @@ +"""Image loading and general conversion utilities.""" + +import collections.abc as collections +from pathlib import Path +from types import SimpleNamespace +from typing import Dict, Optional, Tuple + +import cv2 +import kornia +import numpy as np +import torch +import torchvision + +# mypy: ignore-errors + + +def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False): + """Get padding to make the image size a multiple of the given number. + + Args: + x (torch.Tensor): Input tensor. + multiple (int, optional): Multiple to fit to. + crop (bool, optional): Whether to crop or pad. Defaults to False. + + Returns: + torch.Tensor: Padding. + """ + h, w = x.shape[-2:] + + if crop: + pad_w = (w // multiple) * multiple - w + pad_h = (h // multiple) * multiple - h + else: + pad_w = (multiple - w % multiple) % multiple + pad_h = (multiple - h % multiple) % multiple + + if mode == "center": + pad_l = pad_w // 2 + pad_r = pad_w - pad_l + pad_t = pad_h // 2 + pad_b = pad_h - pad_t + elif mode == "left": + pad_l, pad_r = 0, pad_w + pad_t, pad_b = 0, pad_h + else: + raise ValueError(f"Unknown mode {mode}") + + return (pad_l, pad_r, pad_t, pad_b) + + +def fit_features_to_multiple( + features: torch.Tensor, multiple: int = 32, crop: bool = False +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Pad or crop image to a multiple of the given number. + + Args: + features (torch.Tensor): Input features. + multiple (int, optional): Multiple. Defaults to 32. + crop (bool, optional): Whether to crop or pad. Defaults to False. + + Returns: + Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding. + """ + pad = fit_to_multiple(features, multiple, crop=crop) + return torch.nn.functional.pad(features, pad, mode="reflect"), pad + + +class ImagePreprocessor: + """Preprocess images for calibration.""" + + default_conf = { + "resize": 320, # target edge length, None for no resizing + "edge_divisible_by": None, + "side": "short", + "interpolation": "bilinear", + "align_corners": None, + "antialias": True, + "square_crop": False, + "add_padding_mask": False, + "resize_backend": "kornia", # torchvision, kornia + } + + def __init__(self, conf) -> None: + """Initialize the image preprocessor.""" + self.conf = {**self.default_conf, **conf} + self.conf = SimpleNamespace(**self.conf) + + def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict: + """Resize and preprocess an image, return image and resize scale.""" + h, w = img.shape[-2:] + size = h, w + + if self.conf.square_crop: + min_size = min(h, w) + offset = (h - min_size) // 2, (w - min_size) // 2 + img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size] + size = img.shape[-2:] + + if self.conf.resize is not None: + if interpolation is None: + interpolation = self.conf.interpolation + size = self.get_new_image_size(h, w) + img = self.resize(img, size, interpolation) + + scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) + T = np.diag([scale[0].cpu(), scale[1].cpu(), 1]) + + data = { + "scales": scale, + "image_size": np.array(size[::-1]), + "transform": T, + "original_image_size": np.array([w, h]), + } + + if self.conf.edge_divisible_by is not None: + # crop to make the edge divisible by a number + w_, h_ = img.shape[-1], img.shape[-2] + img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True) + crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img) + data["crop_pad"] = crop_pad + data["image_size"] = np.array([img.shape[-1], img.shape[-2]]) + + data["image"] = img + return data + + def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor: + """Resize an image using the specified backend.""" + if self.conf.resize_backend == "kornia": + return kornia.geometry.transform.resize( + img, + size, + side=self.conf.side, + antialias=self.conf.antialias, + align_corners=self.conf.align_corners, + interpolation=interpolation, + ) + elif self.conf.resize_backend == "torchvision": + return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img) + else: + raise ValueError(f"{self.conf.resize_backend} not implemented.") + + def load_image(self, image_path: Path) -> dict: + """Load an image from a path and preprocess it.""" + return self(load_image(image_path)) + + def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]: + """Get the new image size after resizing.""" + side = self.conf.side + if isinstance(self.conf.resize, collections.Iterable): + assert len(self.conf.resize) == 2 + return tuple(self.conf.resize) + side_size = self.conf.resize + aspect_ratio = w / h + if side not in ("short", "long", "vert", "horz"): + raise ValueError( + f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'" + ) + return ( + (side_size, int(side_size * aspect_ratio)) + if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0)) + else (int(side_size / aspect_ratio), side_size) + ) + + +def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: + """Normalize the image tensor and reorder the dimensions.""" + if image.ndim == 3: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f"Not an image: {image.shape}") + return torch.tensor(image / 255.0, dtype=torch.float) + + +def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray: + """Normalize and reorder the dimensions of an image tensor.""" + if image.ndim == 3: + image = image.permute((1, 2, 0)) # CxHxW to HxWxC + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f"Not an image: {image.shape}") + return (image.cpu().detach().numpy() * 255).astype(np.uint8) + + +def read_image(path: Path, grayscale: bool = False) -> np.ndarray: + """Read an image from path as RGB or grayscale.""" + if not Path(path).exists(): + raise FileNotFoundError(f"No image at path {path}.") + mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise IOError(f"Could not read image at {path}.") + if not grayscale: + image = image[..., ::-1] + return image + + +def write_image(img: torch.Tensor, path: Path): + """Write an image tensor to a file.""" + img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img + cv2.imwrite(str(path), img[..., ::-1]) + + +def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor: + """Load an image from a path and return as a tensor.""" + image = read_image(path, grayscale=grayscale) + if return_tensor: + return numpy_image_to_torch(image) + + assert image.ndim in [2, 3], f"Not an image: {image.shape}" + image = image[None] if image.ndim == 2 else image + return torch.tensor(image.copy(), dtype=torch.uint8) + + +def skew_symmetric(v: torch.Tensor) -> torch.Tensor: + """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). + + Args: + (torch.Tensor): Vector of size (..., 3). + + Returns: + (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3). + """ + z = torch.zeros_like(v[..., 0]) + return torch.stack( + [z, -v[..., 2], v[..., 1], v[..., 2], z, -v[..., 0], -v[..., 1], v[..., 0], z], dim=-1 + ).reshape(v.shape[:-1] + (3, 3)) + + +def rad2rotmat( + roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix. + + Args: + roll (torch.Tensor): Roll angle in radians. + pitch (torch.Tensor): Pitch angle in radians. + yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None. + + Returns: + torch.Tensor: Rotation matrix of shape (..., 3, 3). + """ + if yaw is None: + yaw = roll.new_zeros(roll.shape) + + Rx = pitch.new_zeros(pitch.shape + (3, 3)) + Rx[..., 0, 0] = 1 + Rx[..., 1, 1] = torch.cos(pitch) + Rx[..., 1, 2] = torch.sin(pitch) + Rx[..., 2, 1] = -torch.sin(pitch) + Rx[..., 2, 2] = torch.cos(pitch) + + Ry = yaw.new_zeros(yaw.shape + (3, 3)) + Ry[..., 0, 0] = torch.cos(yaw) + Ry[..., 0, 2] = -torch.sin(yaw) + Ry[..., 1, 1] = 1 + Ry[..., 2, 0] = torch.sin(yaw) + Ry[..., 2, 2] = torch.cos(yaw) + + Rz = roll.new_zeros(roll.shape + (3, 3)) + Rz[..., 0, 0] = torch.cos(roll) + Rz[..., 0, 1] = torch.sin(roll) + Rz[..., 1, 0] = -torch.sin(roll) + Rz[..., 1, 1] = torch.cos(roll) + Rz[..., 2, 2] = 1 + + return Rz @ Rx @ Ry + + +def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor: + """Compute focal length from (vertical/horizontal) field of view.""" + return size / 2 / torch.tan(fov / 2) + + +def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor: + """Compute (vertical/horizontal) field of view from focal length.""" + return 2 * torch.arctan(size / (2 * focal)) + + +def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the distance from principal point to the horizon.""" + return torch.tan(pitch) * f / h + + +def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the pitch angle from the distance to the horizon.""" + return torch.atan(rho * h / f) + + +def rad2deg(rad: torch.Tensor) -> torch.Tensor: + """Convert radians to degrees.""" + return rad / torch.pi * 180 + + +def deg2rad(deg: torch.Tensor) -> torch.Tensor: + """Convert degrees to radians.""" + return deg / 180 * torch.pi + + +def get_device() -> str: + """Get the device (cpu, cuda, mps) available.""" + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + return device + + +def print_calibration(results: Dict[str, torch.Tensor]) -> None: + """Print the calibration results.""" + camera, gravity = results["camera"], results["gravity"] + vfov = rad2deg(camera.vfov) + roll, pitch = rad2deg(gravity.rp).unbind(-1) + + print("\nEstimated parameters (Pred):") + print(f"Roll: {roll.item():.1f}° (± {rad2deg(results['roll_uncertainty']).item():.1f})°") + print(f"Pitch: {pitch.item():.1f}° (± {rad2deg(results['pitch_uncertainty']).item():.1f})°") + print(f"vFoV: {vfov.item():.1f}° (± {rad2deg(results['vfov_uncertainty']).item():.1f})°") + print(f"Focal: {camera.f[0, 1].item():.1f} px (± {results['focal_uncertainty'].item():.1f} px)") + + if hasattr(camera, "k1"): + print(f"K1: {camera.k1.item():.1f}") diff --git a/geocalib/viz2d.py b/geocalib/viz2d.py new file mode 100644 index 0000000000000000000000000000000000000000..2747684c826d0ab610d399d99cd4eb5b22fbfd27 --- /dev/null +++ b/geocalib/viz2d.py @@ -0,0 +1,502 @@ +"""2D visualization primitives based on Matplotlib. + +1) Plot images with `plot_images`. +2) Call functions to plot heatmaps, vector fields, and horizon lines. +3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. +""" + +import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt +import numpy as np +import torch + +from geocalib.perspective_fields import get_perspective_field +from geocalib.utils import rad2deg + +# mypy: ignore-errors + + +def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True): + """Plot a list of images. + + Args: + imgs (List[np.ndarray]): List of images to plot. + titles (List[str], optional): Titles. Defaults to None. + cmaps (str, optional): Colormaps. Defaults to "gray". + dpi (int, optional): Dots per inch. Defaults to 200. + pad (float, optional): Padding. Defaults to 0.5. + adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. + + Returns: + plt.Figure: Figure of the images. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] + fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}) + if n == 1: + axs = [axs] + for i, (img, ax) in enumerate(zip(imgs, axs)): + ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if titles: + ax.set_title(titles[i]) + fig.tight_layout(pad=pad) + + return fig + + +def plot_image_grid( + imgs, + titles=None, + cmaps="gray", + dpi=100, + pad=0.5, + fig=None, + adaptive=True, + figs=3.0, + return_fig=False, + set_lim=False, +) -> plt.Figure: + """Plot a grid of images. + + Args: + imgs (List[np.ndarray]): List of images to plot. + titles (List[str], optional): Titles. Defaults to None. + cmaps (str, optional): Colormaps. Defaults to "gray". + dpi (int, optional): Dots per inch. Defaults to 100. + pad (float, optional): Padding. Defaults to 0.5. + fig (_type_, optional): Figure to plot on. Defaults to None. + adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. + figs (float, optional): Figure size. Defaults to 3.0. + return_fig (bool, optional): Whether to return the figure. Defaults to False. + set_lim (bool, optional): Whether to set the limits. Defaults to False. + + Returns: + plt.Figure: Figure and axes or just axes. + """ + nr, n = len(imgs), len(imgs[0]) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H + else: + ratios = [4 / 3] * n + + figsize = [sum(ratios) * figs, nr * figs] + if fig is None: + fig, axs = plt.subplots( + nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + else: + axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios}) + fig.figure.set_size_inches(figsize) + + if nr == 1 and n == 1: + axs = [[axs]] + elif n == 1: + axs = axs[:, None] + elif nr == 1: + axs = [axs] + + for j in range(nr): + for i in range(n): + ax = axs[j][i] + ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if set_lim: + ax.set_xlim([0, imgs[j][i].shape[1]]) + ax.set_ylim([imgs[j][i].shape[0], 0]) + if titles: + ax.set_title(titles[j][i]) + if isinstance(fig, plt.Figure): + fig.tight_layout(pad=pad) + return (fig, axs) if return_fig else axs + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=4, + ha="left", + va="top", + axes=None, + **kwargs, +): + """Add text to a plot. + + Args: + idx (int): Index of the axes. + text (str): Text to add. + pos (tuple, optional): Text position. Defaults to (0.01, 0.99). + fs (int, optional): Font size. Defaults to 15. + color (str, optional): Text color. Defaults to "w". + lcolor (str, optional): Line color. Defaults to "k". + lwidth (int, optional): Line width. Defaults to 4. + ha (str, optional): Horizontal alignment. Defaults to "left". + va (str, optional): Vertical alignment. Defaults to "top". + axes (List[plt.Axes], optional): Axes to put text on. Defaults to None. + + Returns: + plt.Text: Text object. + """ + if axes is None: + axes = plt.gcf().axes + + ax = axes[idx] + + t = ax.text( + *pos, + text, + fontsize=fs, + ha=ha, + va=va, + color=color, + transform=ax.transAxes, + zorder=5, + **kwargs, + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + return t + + +def plot_heatmaps( + heatmaps, + vmin=-1e-6, # include negative zero + vmax=None, + cmap="Spectral", + a=0.5, + axes=None, + contours_every=None, + contour_style="solid", + colorbar=False, +): + """Plot heatmaps with optional contours. + + To plot latitude field, set vmin=-90, vmax=90 and contours_every=15. + + Args: + heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps. + vmin (float, optional): Min Value. Defaults to -1e-6. + vmax (float, optional): Max Value. Defaults to None. + cmap (str, optional): Colormap. Defaults to "Spectral". + a (float, optional): Alpha value. Defaults to 0.5. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + contours_every (int, optional): If not none, will draw contours. Defaults to None. + contour_style (str, optional): Style of the contours. Defaults to "solid". + colorbar (bool, optional): Whether to show colorbar. Defaults to False. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + artists = [] + + for i in range(len(axes)): + a_ = a if isinstance(a, float) else a[i] + + if isinstance(heatmaps[i], torch.Tensor): + heatmaps[i] = heatmaps[i].cpu().numpy() + + alpha = a_ + # Plot the heatmap + art = axes[i].imshow( + heatmaps[i], + alpha=alpha, + vmin=vmin, + vmax=vmax, + cmap=cmap, + ) + if colorbar: + cmax = vmax or np.percentile(heatmaps[i], 99) + art.set_clim(vmin, cmax) + cbar = plt.colorbar(art, ax=axes[i]) + artists.append(cbar) + + artists.append(art) + + if contours_every is not None: + # Add contour lines to the heatmap + contour_data = np.arange(vmin, vmax + contours_every, contours_every) + + # Get the colormap colors for contour lines + contour_colors = [ + plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level)) + for level in contour_data + ] + contours = axes[i].contour( + heatmaps[i], + levels=contour_data, + linewidths=2, + colors=contour_colors, + linestyles=contour_style, + ) + + contours.set_clim(vmin, vmax) + + fmt = { + level: f"{label}°" + for level, label in zip(contour_data, contour_data.astype(int).astype(str)) + } + t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white") + + for label in t: + label.set_path_effects( + [ + path_effects.Stroke(linewidth=1, foreground="k"), + path_effects.Normal(), + ] + ) + artists.append(contours) + + return artists + + +def plot_horizon_lines( + cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None +): + """Plot horizon lines on the perspective field. + + Args: + cameras (List[Camera]): List of cameras. + gravities (List[Gravity]): Gravities. + line_colors (str, optional): Line Colors. Defaults to "orange". + lw (int, optional): Line width. Defaults to 2. + styles (str, optional): Line styles. Defaults to "solid". + alpha (float, optional): Alphas. Defaults to 1.0. + ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None. + """ + if not isinstance(line_colors, list): + line_colors = [line_colors] * len(cameras) + + if not isinstance(styles, list): + styles = [styles] * len(cameras) + + fig = plt.gcf() + ax = fig.gca() if ax is None else ax + + if isinstance(ax, plt.Axes): + ax = [ax] * len(cameras) + + assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}" + + for i in range(len(cameras)): + _, lat = get_perspective_field(cameras[i], gravities[i]) + # horizon line is zero level of the latitude field + lat = lat[0, 0].cpu().numpy() + contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i]) + for contour_line in contours.collections: + contour_line.set_linestyle(styles[i]) + + +def plot_vector_fields( + vector_fields, + cmap="lime", + subsample=15, + scale=None, + lw=None, + alphas=0.8, + axes=None, +): + """Plot vector fields. + + Args: + vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W). + cmap (str, optional): Color of the vectors. Defaults to "lime". + subsample (int, optional): Subsample the vector field. Defaults to 15. + scale (float, optional): Scale of the vectors. Defaults to None. + lw (float, optional): Line width of the vectors. Defaults to None. + alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8. + axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields] + + artists = [] + + H, W = vector_fields[0].shape[-2:] + if scale is None: + scale = subsample / min(H, W) + + if lw is None: + lw = 0.1 / subsample + + if alphas is None: + alphas = np.ones_like(vector_fields[0][0]) + alphas = np.stack([alphas] * len(vector_fields), 0) + elif isinstance(alphas, float): + alphas = np.ones_like(vector_fields[0][0]) * alphas + alphas = np.stack([alphas] * len(vector_fields), 0) + else: + alphas = np.array(alphas) + + subsample = min(W, H) // subsample + offset_x = ((W % subsample) + subsample) // 2 + + samples_x = np.arange(offset_x, W, subsample) + samples_y = np.arange(int(subsample * 0.9), H, subsample) + + x_grid, y_grid = np.meshgrid(samples_x, samples_y) + + for i in range(len(axes)): + # vector field of shape (2, H, W) with vectors of norm == 1 + vector_field = vector_fields[i] + + a = alphas[i][samples_y][:, samples_x] + x, y = vector_field[:, samples_y][:, :, samples_x] + + c = cmap + if not isinstance(cmap, str): + c = cmap[i][samples_y][:, samples_x].reshape(-1, 3) + + s = scale * min(H, W) + arrows = axes[i].quiver( + x_grid, + y_grid, + x, + y, + scale=s, + scale_units="width" if H > W else "height", + units="width" if H > W else "height", + alpha=a, + color=c, + angles="xy", + antialiased=True, + width=lw, + headaxislength=3.5, + zorder=5, + ) + + artists.append(arrows) + + return artists + + +def plot_latitudes( + latitude, + is_radians=True, + vmin=-90, + vmax=90, + cmap="seismic", + contours_every=15, + alpha=0.4, + axes=None, + **kwargs, +): + """Plot latitudes. + + Args: + latitude (List[torch.Tensor]): List of latitudes. + is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True. + vmin (int, optional): Min value to clip to. Defaults to -90. + vmax (int, optional): Max value to clip to. Defaults to 90. + cmap (str, optional): Colormap. Defaults to "seismic". + contours_every (int, optional): Contours every. Defaults to 15. + alpha (float, optional): Alpha value. Defaults to 0.4. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}" + lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude + return plot_heatmaps( + lat, + vmin=vmin, + vmax=vmax, + cmap=cmap, + a=alpha, + axes=axes, + contours_every=contours_every, + **kwargs, + ) + + +def plot_perspective_fields(cameras, gravities, axes=None, **kwargs): + """Plot perspective fields. + + Args: + cameras (List[Camera]): List of cameras. + gravities (List[Gravity]): List of gravities. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + assert len(axes) == len(cameras), f"{len(axes)}, {len(cameras)}" + + artists = [] + for i in range(len(axes)): + up, lat = get_perspective_field(cameras[i], gravities[i]) + artists += plot_vector_fields([up[0]], axes=[axes[i]], **kwargs) + artists += plot_latitudes([lat[0, 0]], axes=[axes[i]], **kwargs) + + return artists + + +def plot_confidences( + confidence, + as_log=True, + vmin=-4, + vmax=0, + cmap="turbo", + alpha=0.4, + axes=None, + **kwargs, +): + """Plot confidences. + + Args: + confidence (List[torch.Tensor]): Confidence maps. + as_log (bool, optional): Whether to plot in log scale. Defaults to True. + vmin (int, optional): Min value to clip to. Defaults to -4. + vmax (int, optional): Max value to clip to. Defaults to 0. + cmap (str, optional): Colormap. Defaults to "turbo". + alpha (float, optional): Alpha value. Defaults to 0.4. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}" + + if as_log: + confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence] + + # normalize to [0, 1] + confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence] + return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw) diff --git a/gradio_app.py b/gradio_app.py new file mode 100644 index 0000000000000000000000000000000000000000..11ba5653adf5bccc464e8334634b008e0ede1074 --- /dev/null +++ b/gradio_app.py @@ -0,0 +1,228 @@ +"""Gradio app for GeoCalib inference.""" + +from copy import deepcopy +from time import time + +import gradio as gr +import numpy as np +import spaces +import torch + +from geocalib import viz2d +from geocalib.camera import camera_models +from geocalib.extractor import GeoCalib +from geocalib.perspective_fields import get_perspective_field +from geocalib.utils import rad2deg + +# flake8: noqa +# mypy: ignore-errors + +description = """ +

+

GeoCalib 📸
Single-image Calibration with Geometric Optimization

+

+ Alexander Veicht + · + Paul-Edouard Sarlin + · + Philipp Lindenberger + · + Marc Pollefeys +

+

+

ECCV 2024

+ Paper | + Code | + Colab +

+

+ +## Getting Started +GeoCalib accurately estimates the camera intrinsics and gravity direction from a single image by +combining geometric optimization with deep learning. + +To get started, upload an image or select one of the examples below. +You can choose between different camera models and visualize the calibration results. + +""" + +example_images = [ + ["assets/pinhole-church.jpg"], + ["assets/pinhole-garden.jpg"], + ["assets/fisheye-skyline.jpg"], + ["assets/fisheye-dog-pool.jpg"], +] + +device = "cuda" if torch.cuda.is_available() else "cpu" +model = GeoCalib().to(device) + + +def format_output(results): + camera, gravity = results["camera"], results["gravity"] + vfov = rad2deg(camera.vfov) + roll, pitch = rad2deg(gravity.rp).unbind(-1) + + txt = "Estimated parameters:\n" + txt += f"Roll: {roll.item():.2f}° (± {rad2deg(results['roll_uncertainty']).item():.2f})°\n" + txt += f"Pitch: {pitch.item():.2f}° (± {rad2deg(results['pitch_uncertainty']).item():.2f})°\n" + txt += f"vFoV: {vfov.item():.2f}° (± {rad2deg(results['vfov_uncertainty']).item():.2f})°\n" + txt += ( + f"Focal: {camera.f[0, 1].item():.2f} px (± {results['focal_uncertainty'].item():.2f} px)\n" + ) + if hasattr(camera, "k1"): + txt += f"K1: {camera.k1[0].item():.2f}\n" + return txt + + +@spaces.GPU(duration=10) +def inference(img, camera_model): + out = model.calibrate(img.to(device), camera_model=camera_model) + save_keys = ["camera", "gravity"] + [f"{k}_uncertainty" for k in ["roll", "pitch", "vfov"]] + res = {k: v.cpu() for k, v in out.items() if k in save_keys} + # not converting to numpy results in gpu abort + res["up_confidence"] = out["up_confidence"].cpu().numpy() + res["latitude_confidence"] = out["latitude_confidence"].cpu().numpy() + return res + + +def process_results( + image_path, + camera_model, + plot_up, + plot_up_confidence, + plot_latitude, + plot_latitude_confidence, + plot_undistort, +): + """Process the image and return the calibration results.""" + + if image_path is None: + raise gr.Error("Please upload an image first.") + + img = model.load_image(image_path) + print("Running inference...") + start = time() + inference_result = inference(img, camera_model) + print(f"Done ({time() - start:.2f}s)") + inference_result["image"] = img.cpu() + + if inference_result is None: + return ("", np.ones((128, 256, 3)), None) + + plot_img = update_plot( + inference_result, + plot_up, + plot_up_confidence, + plot_latitude, + plot_latitude_confidence, + plot_undistort, + ) + + return format_output(inference_result), plot_img, inference_result + + +def update_plot( + inference_result, + plot_up, + plot_up_confidence, + plot_latitude, + plot_latitude_confidence, + plot_undistort, +): + """Update the plot based on the selected options.""" + if inference_result is None: + gr.Error("Please calibrate an image first.") + return np.ones((128, 256, 3)) + + camera, gravity = inference_result["camera"], inference_result["gravity"] + img = inference_result["image"].permute(1, 2, 0).numpy() + + if plot_undistort: + if not hasattr(camera, "k1"): + return img + + return camera.undistort_image(inference_result["image"][None])[0].permute(1, 2, 0).numpy() + + up, lat = get_perspective_field(camera, gravity) + + fig = viz2d.plot_images([img], pad=0) + ax = fig.get_axes() + + if plot_up: + viz2d.plot_vector_fields([up[0]], axes=[ax[0]]) + + if plot_latitude: + viz2d.plot_latitudes([lat[0, 0]], axes=[ax[0]]) + + if plot_up_confidence: + viz2d.plot_confidences([inference_result["up_confidence"][0]], axes=[ax[0]]) + + if plot_latitude_confidence: + viz2d.plot_confidences([inference_result["latitude_confidence"][0]], axes=[ax[0]]) + + fig.canvas.draw() + img = np.array(fig.canvas.renderer.buffer_rgba()) + + return img + + +# Create the Gradio interface +with gr.Blocks() as demo: + gr.Markdown(description) + with gr.Row(): + with gr.Column(): + gr.Markdown("""## Input Image""") + image_path = gr.Image(label="Upload image to calibrate", type="filepath") + choice_input = gr.Dropdown( + choices=list(camera_models.keys()), label="Choose a camera model.", value="pinhole" + ) + submit_btn = gr.Button("Calibrate 📸") + gr.Examples(examples=example_images, inputs=[image_path, choice_input]) + + with gr.Column(): + gr.Markdown("""## Results""") + image_output = gr.Image(label="Calibration Results") + gr.Markdown("### Plot Options") + plot_undistort = gr.Checkbox( + label="undistort", + value=False, + info="Undistorted image " + + "(this is only available for models with distortion " + + "parameters and will overwrite other options).", + ) + + with gr.Row(): + plot_up = gr.Checkbox(label="up-vectors", value=True) + plot_up_confidence = gr.Checkbox(label="up confidence", value=False) + plot_latitude = gr.Checkbox(label="latitude", value=True) + plot_latitude_confidence = gr.Checkbox(label="latitude confidence", value=False) + + gr.Markdown("### Calibration Results") + text_output = gr.Textbox(label="Estimated parameters", type="text", lines=5) + + # Define the action when the button is clicked + inference_state = gr.State() + plot_inputs = [ + inference_state, + plot_up, + plot_up_confidence, + plot_latitude, + plot_latitude_confidence, + plot_undistort, + ] + submit_btn.click( + fn=process_results, + inputs=[image_path, choice_input] + plot_inputs[1:], + outputs=[text_output, image_output, inference_state], + ) + + # Define the action when the plot checkboxes are clicked + plot_up.change(fn=update_plot, inputs=plot_inputs, outputs=image_output) + plot_up_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output) + plot_latitude.change(fn=update_plot, inputs=plot_inputs, outputs=image_output) + plot_latitude_confidence.change(fn=update_plot, inputs=plot_inputs, outputs=image_output) + plot_undistort.change(fn=update_plot, inputs=plot_inputs, outputs=image_output) + + +# Launch the app +demo.launch() diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ad63268c53133b227bdf5cfec046bd81f7ecea --- /dev/null +++ b/hubconf.py @@ -0,0 +1,14 @@ +"""Entrypoint for torch hub.""" + +dependencies = ["torch", "torchvision", "opencv-python", "kornia", "matplotlib"] + +from geocalib import GeoCalib + + +def model(*args, **kwargs): + """Pre-trained Geocalib model. + + Args: + weights (str): trained variant, "pinhole" (default) or "distorted". + """ + return GeoCalib(*args, **kwargs) diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..10b2c1c5800f42c9067b194bc9a1228d832ac4db --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,49 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "geocalib" +version = "1.0" +description = "GeoCalib Inference Package" +authors = [ + { name = "Alexander Veicht" }, + { name = "Paul-Edouard Sarlin" }, + { name = "Philipp Lindenberger" }, +] +readme = "README.md" +requires-python = ">=3.9" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +urls = { Repository = "https://github.com/cvg/GeoCalib" } + +dynamic = ["dependencies"] + +[project.optional-dependencies] +dev = ["black==23.9.1", "flake8", "isort==5.12.0"] + +[tool.setuptools] +packages = ["geocalib"] + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + + +[tool.black] +line-length = 100 +exclude = "(venv/|docs/|third_party/)" + +[tool.isort] +profile = "black" +line_length = 100 +atomic = true + +[tool.flake8] +max-line-length = 100 +docstring-convention = "google" +ignore = ["E203", "W503", "E402"] +exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..ea45682e0ba9634fb03cbca9df33eae28a2c6842 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +torch +torchvision +opencv-python +kornia +matplotlib diff --git a/siclib/LICENSE b/siclib/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..797795d81ba8a5d06fefa772bc5b4d0b4bb94dc4 --- /dev/null +++ b/siclib/LICENSE @@ -0,0 +1,190 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2024 ETH Zurich + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/siclib/__init__.py b/siclib/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0944cfcf82a485c31ba77867ee3821ae59557ecb --- /dev/null +++ b/siclib/__init__.py @@ -0,0 +1,15 @@ +import logging + +formatter = logging.Formatter( + fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s", datefmt="%m/%d/%Y %H:%M:%S" +) +handler = logging.StreamHandler() +handler.setFormatter(formatter) +handler.setLevel(logging.INFO) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) +logger.addHandler(handler) +logger.propagate = False + +__module_name__ = __name__ diff --git a/siclib/configs/deepcalib.yaml b/siclib/configs/deepcalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..994f19c929d55dffb2cac9d9f972b480457a607a --- /dev/null +++ b/siclib/configs/deepcalib.yaml @@ -0,0 +1,12 @@ +defaults: + - data: openpano-radial + - train: deepcalib + - model: deepcalib + - _self_ + +data: + train_batch_size: 32 + val_batch_size: 32 + test_batch_size: 32 + augmentations: + name: "deepcalib" diff --git a/siclib/configs/geocalib-radial.yaml b/siclib/configs/geocalib-radial.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4b492f9aac6f19cd1bb59c2cf38eeeaa007d1156 --- /dev/null +++ b/siclib/configs/geocalib-radial.yaml @@ -0,0 +1,38 @@ +defaults: + - data: openpano-radial + - train: geocalib + - model: geocalib + - _self_ + +data: + # smaller batch size since lm takes more memory + train_batch_size: 18 + val_batch_size: 18 + test_batch_size: 18 + +model: + optimizer: + camera_model: simple_radial + + weights: weights/geocalib.tar + +train: + lr: 1e-5 # smaller lr since we are fine-tuning + num_steps: 200_000 # adapt to see same number of samples as previous training + + lr_schedule: + type: SequentialLR + on_epoch: false + options: + # adapt to see same number of samples as previous training + milestones: [5_000] + schedulers: + - type: LinearLR + options: + start_factor: 1e-3 + total_iters: 5_000 + - type: MultiStepLR + options: + gamma: 0.1 + # adapt to see same number of samples as previous training + milestones: [110_000, 170_000] diff --git a/siclib/configs/geocalib.yaml b/siclib/configs/geocalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f965a269673bbe907ca60387e7458c7df3c5b4fc --- /dev/null +++ b/siclib/configs/geocalib.yaml @@ -0,0 +1,10 @@ +defaults: + - data: openpano + - train: geocalib + - model: geocalib + - _self_ + +data: + train_batch_size: 24 + val_batch_size: 24 + test_batch_size: 24 diff --git a/siclib/configs/model/deepcalib.yaml b/siclib/configs/model/deepcalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2d4d0670cfb5d1742029ba3dc32b81eb56b1b57e --- /dev/null +++ b/siclib/configs/model/deepcalib.yaml @@ -0,0 +1,7 @@ +name: networks.deepcalib +bounds: + roll: [-45, 45] + # rho = torch.tan(pitch) / torch.tan(vfov / 2) / 2 -> rho in [-1/0.3526, 1/0.0872] + rho: [-2.83607487, 2.83607487] + vfov: [20, 105] + k1_hat: [-0.7, 0.7] diff --git a/siclib/configs/model/geocalib.yaml b/siclib/configs/model/geocalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..42c06f44b522c47ce3127a4a4bd4ac85d5686768 --- /dev/null +++ b/siclib/configs/model/geocalib.yaml @@ -0,0 +1,31 @@ +name: networks.geocalib + +ll_enc: + name: encoders.low_level_encoder + +backbone: + name: encoders.mscan + weights: weights/mscan_b.pth + +perspective_decoder: + name: decoders.perspective_decoder + + up_decoder: + name: decoders.up_decoder + loss_type: l1 + use_uncertainty_loss: true + decoder: + name: decoders.light_hamburger + predict_uncertainty: true + + latitude_decoder: + name: decoders.latitude_decoder + loss_type: l1 + use_uncertainty_loss: true + decoder: + name: decoders.light_hamburger + predict_uncertainty: true + +optimizer: + name: optimization.lm_optimizer + camera_model: pinhole diff --git a/siclib/configs/train/deepcalib.yaml b/siclib/configs/train/deepcalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..f4bd79685287ab0d926c11a709e98a466705b86c --- /dev/null +++ b/siclib/configs/train/deepcalib.yaml @@ -0,0 +1,22 @@ +seed: 0 +num_steps: 20_000 +log_every_iter: 500 +eval_every_iter: 3000 +test_every_epoch: 1 +writer: null +lr: 1.0e-4 +clip_grad: 1.0 +lr_schedule: + type: null +optimizer: adam +submodules: [] +median_metrics: + - roll_error + - pitch_error + - vfov_error +recall_metrics: + roll_error: [1, 5, 10] + pitch_error: [1, 5, 10] + vfov_error: [1, 5, 10] + +plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"] diff --git a/siclib/configs/train/geocalib.yaml b/siclib/configs/train/geocalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..966f227f685246f7c932f40046c423b0fc6c7bed --- /dev/null +++ b/siclib/configs/train/geocalib.yaml @@ -0,0 +1,50 @@ +seed: 0 +num_steps: 150_000 + +writer: null +log_every_iter: 500 +eval_every_iter: 1000 + +lr: 1e-4 +optimizer: adamw +clip_grad: 1.0 +best_key: loss/param_total + +lr_schedule: + type: SequentialLR + on_epoch: false + options: + milestones: [4_000] + schedulers: + - type: LinearLR + options: + start_factor: 1e-3 + total_iters: 4_000 + - type: MultiStepLR + options: + gamma: 0.1 + milestones: [80_000, 130_000] + +submodules: [] + +median_metrics: + - roll_error + - pitch_error + - gravity_error + - vfov_error + - up_angle_error + - latitude_angle_error + - up_angle_recall@1 + - up_angle_recall@5 + - up_angle_recall@10 + - latitude_angle_recall@1 + - latitude_angle_recall@5 + - latitude_angle_recall@10 + +recall_metrics: + roll_error: [1, 3, 5, 10] + pitch_error: [1, 3, 5, 10] + gravity_error: [1, 3, 5, 10] + vfov_error: [1, 3, 5, 10] + +plot: [3, "siclib.visualization.visualize_batch.make_perspective_figures"] diff --git a/siclib/datasets/__init__.py b/siclib/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f660a54e7a1174e83d999cd90c25ce85ac7294b --- /dev/null +++ b/siclib/datasets/__init__.py @@ -0,0 +1,25 @@ +import importlib.util + +from siclib.datasets.base_dataset import BaseDataset +from siclib.utils.tools import get_class + + +def get_dataset(name): + import_paths = [name, f"{__name__}.{name}"] + for path in import_paths: + try: + spec = importlib.util.find_spec(path) + except ModuleNotFoundError: + spec = None + if spec is not None: + try: + return get_class(path, BaseDataset) + except AssertionError: + mod = __import__(path, fromlist=[""]) + try: + return mod.__main_dataset__ + except AttributeError as exc: + print(exc) + continue + + raise RuntimeError(f'Dataset {name} not found in any of [{" ".join(import_paths)}]') diff --git a/siclib/datasets/augmentations.py b/siclib/datasets/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..6f203f301d5760763a397e9666c0af80ef39a776 --- /dev/null +++ b/siclib/datasets/augmentations.py @@ -0,0 +1,359 @@ +from typing import Union + +import albumentations as A +import cv2 +import numpy as np +import torch +from albumentations.pytorch.transforms import ToTensorV2 +from omegaconf import OmegaConf + + +# flake8: noqa +# mypy: ignore-errors +class IdentityTransform(A.ImageOnlyTransform): + def apply(self, img, **params): + return img + + def get_transform_init_args_names(self): + return () + + +class RandomAdditiveShade(A.ImageOnlyTransform): + def __init__( + self, + nb_ellipses=10, + transparency_limit=[-0.5, 0.8], + kernel_size_limit=[150, 350], + always_apply=False, + p=0.5, + ): + super().__init__(always_apply, p) + self.nb_ellipses = nb_ellipses + self.transparency_limit = transparency_limit + self.kernel_size_limit = kernel_size_limit + + def apply(self, img, **params): + if img.dtype == np.float32: + shaded = self._py_additive_shade(img * 255.0) + shaded /= 255.0 + elif img.dtype == np.uint8: + shaded = self._py_additive_shade(img.astype(np.float32)) + shaded = shaded.astype(np.uint8) + else: + raise NotImplementedError(f"Data augmentation not available for type: {img.dtype}") + return shaded + + def _py_additive_shade(self, img): + grayscale = len(img.shape) == 2 + if grayscale: + img = img[None] + min_dim = min(img.shape[:2]) / 4 + mask = np.zeros(img.shape[:2], img.dtype) + for i in range(self.nb_ellipses): + ax = int(max(np.random.rand() * min_dim, min_dim / 5)) + ay = int(max(np.random.rand() * min_dim, min_dim / 5)) + max_rad = max(ax, ay) + x = np.random.randint(max_rad, img.shape[1] - max_rad) # center + y = np.random.randint(max_rad, img.shape[0] - max_rad) + angle = np.random.rand() * 90 + cv2.ellipse(mask, (x, y), (ax, ay), angle, 0, 360, 255, -1) + + transparency = np.random.uniform(*self.transparency_limit) + ks = np.random.randint(*self.kernel_size_limit) + if (ks % 2) == 0: # kernel_size has to be odd + ks += 1 + mask = cv2.GaussianBlur(mask.astype(np.float32), (ks, ks), 0) + shaded = img * (1 - transparency * mask[..., np.newaxis] / 255.0) + out = np.clip(shaded, 0, 255) + if grayscale: + out = out.squeeze(0) + return out + + def get_transform_init_args_names(self): + return "transparency_limit", "kernel_size_limit", "nb_ellipses" + + +def kw(entry: Union[float, dict], n=None, **default): + if not isinstance(entry, dict): + entry = {"p": entry} + entry = OmegaConf.create(entry) + if n is not None: + entry = default.get(n, entry) + return OmegaConf.merge(default, entry) + + +def kwi(entry: Union[float, dict], n=None, **default): + conf = kw(entry, n=n, **default) + return {k: conf[k] for k in set(default.keys()).union(set(["p"]))} + + +def replay_str(transforms, s="Replay:\n", log_inactive=True): + for t in transforms: + if "transforms" in t.keys(): + s = replay_str(t["transforms"], s=s) + elif t["applied"] or log_inactive: + s += t["__class_fullname__"] + " " + str(t["applied"]) + "\n" + return s + + +class BaseAugmentation(object): + base_default_conf = { + "name": "???", + "shuffle": False, + "p": 1.0, + "verbose": False, + "dtype": "uint8", # (byte, float) + } + + default_conf = {} + + def __init__(self, conf={}): + """Perform some logic and call the _init method of the child model.""" + default_conf = OmegaConf.merge( + OmegaConf.create(self.base_default_conf), + OmegaConf.create(self.default_conf), + ) + OmegaConf.set_struct(default_conf, True) + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(self.conf, True) + self._init(self.conf) + + self.conf = OmegaConf.merge(self.conf, conf) + if self.conf.verbose: + self.compose = A.ReplayCompose + else: + self.compose = A.Compose + if self.conf.dtype == "uint8": + self.dtype = np.uint8 + self.preprocess = A.FromFloat(always_apply=True, dtype="uint8") + self.postprocess = A.ToFloat(always_apply=True) + elif self.conf.dtype == "float32": + self.dtype = np.float32 + self.preprocess = A.ToFloat(always_apply=True) + self.postprocess = IdentityTransform() + else: + raise ValueError(f"Unsupported dtype {self.conf.dtype}") + self.to_tensor = ToTensorV2() + + def _init(self, conf): + """Child class overwrites this, setting up a list of transforms""" + self.transforms = [] + + def __call__(self, image, return_tensor=False): + """image as HW or HWC""" + if isinstance(image, torch.Tensor): + image = image.cpu().numpy() + data = {"image": image} + if image.dtype != self.dtype: + data = self.preprocess(**data) + transforms = self.transforms + if self.conf.shuffle: + order = [i for i, _ in enumerate(transforms)] + np.random.shuffle(order) + transforms = [transforms[i] for i in order] + transformed = self.compose(transforms, p=self.conf.p)(**data) + if self.conf.verbose: + print(replay_str(transformed["replay"]["transforms"])) + transformed = self.postprocess(**transformed) + if return_tensor: + return self.to_tensor(**transformed)["image"] + else: + return transformed["image"] + + +class IdentityAugmentation(BaseAugmentation): + default_conf = {} + + def _init(self, conf): + self.transforms = [IdentityTransform(p=1.0)] + + +class DarkAugmentation(BaseAugmentation): + default_conf = {"p": 0.75} + + def _init(self, conf): + bright_contr = 0.5 + blur = 0.1 + random_gamma = 0.1 + hue = 0.1 + self.transforms = [ + A.RandomRain(p=0.2), + A.RandomBrightnessContrast( + **kw( + bright_contr, + brightness_limit=(-0.4, 0.0), + contrast_limit=(-0.3, 0.0), + ) + ), + A.OneOf( + [ + A.Blur(**kwi(blur, p=0.1, blur_limit=(3, 9), n="blur")), + A.MotionBlur(**kwi(blur, p=0.2, blur_limit=(3, 25), n="motion_blur")), + A.ISONoise(), + A.ImageCompression(), + ], + **kwi(blur, p=0.1), + ), + A.RandomGamma(**kw(random_gamma, gamma_limit=(15, 65))), + A.OneOf( + [ + A.Equalize(), + A.CLAHE(p=0.2), + A.ToGray(), + A.ToSepia(p=0.1), + A.HueSaturationValue(**kw(hue, val_shift_limit=(-100, -40))), + ], + p=0.5, + ), + ] + + +class DefaultAugmentation(BaseAugmentation): + default_conf = {"p": 1.0} + + def _init(self, conf): + self.transforms = [ + A.RandomBrightnessContrast(p=0.2), + A.HueSaturationValue(p=0.2), + A.ToGray(p=0.2), + A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5), + A.OneOf( + [ + A.MotionBlur(p=0.2), + A.MedianBlur(blur_limit=3, p=0.1), + A.Blur(blur_limit=3, p=0.1), + ], + p=0.2, + ), + ] + + +class PerspectiveAugmentation(BaseAugmentation): + default_conf = {"p": 1.0} + + def _init(self, conf): + self.transforms = [ + A.RandomBrightnessContrast(p=0.2), + A.HueSaturationValue(p=0.2), + A.ToGray(p=0.2), + A.ImageCompression(quality_lower=30, quality_upper=100, p=0.5), + A.OneOf( + [ + A.MotionBlur(p=0.2), + A.MedianBlur(blur_limit=3, p=0.1), + A.Blur(blur_limit=3, p=0.1), + ], + p=0.2, + ), + ] + + +class DeepCalibAugmentations(BaseAugmentation): + default_conf = {"p": 1.0} + + def _init(self, conf): + self.transforms = [ + A.RandomBrightnessContrast(p=0.5), + A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75), + A.Downscale( + scale_min=0.5, + scale_max=0.95, + interpolation=dict(downscale=cv2.INTER_AREA, upscale=cv2.INTER_LINEAR), + p=0.5, + ), + A.Downscale(scale_min=0.5, scale_max=0.95, interpolation=cv2.INTER_LINEAR, p=0.5), + A.ImageCompression(quality_lower=20, quality_upper=85, p=1, always_apply=True), + A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4), + A.Sharpen(alpha=(0.2, 0.5), lightness=(0.5, 1.0), p=0.5), + A.ToGray(always_apply=False, p=0.2), + A.GaussianBlur(blur_limit=(3, 5), sigma_limit=0, p=0.25), + A.MotionBlur(blur_limit=5, allow_shifted=True, p=0.25), + A.MultiplicativeNoise(multiplier=[0.85, 1.15], elementwise=True, p=0.5), + ] + + +class GeoCalibAugmentations(BaseAugmentation): + default_conf = {"p": 1.0} + + def _init(self, conf): + self.color_transforms = [ + A.RandomGamma(gamma_limit=(80, 180), p=0.8), + A.RandomToneCurve(scale=0.1, p=0.5), + A.RandomBrightnessContrast(p=0.5), + A.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.4), + A.OneOf([A.ToGray(p=0.1), A.ToSepia(p=0.1), IdentityTransform(p=0.8)], p=1), + ] + + self.noise_transforms = [ + A.GaussNoise(var_limit=(5.0, 112.0), mean=0, per_channel=True, p=0.75), + A.ImageCompression(quality_lower=20, quality_upper=100, p=1), + A.ISONoise(color_shift=(0.01, 0.05), intensity=(0.1, 0.5), p=0.5), + A.OneOrOther( + first=A.Compose( + [ + A.AdvancedBlur( + p=1, + blur_limit=(3, 7), + sigmaX_limit=(0.2, 1.0), + sigmaY_limit=(0.2, 1.0), + rotate_limit=(-90, 90), + beta_limit=(0.5, 8.0), + noise_limit=(0.9, 1.1), + ), + A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)), + ] + ), + second=A.Compose( + [ + A.Sharpen(p=0.5, alpha=(0.2, 0.5), lightness=(0.5, 1.0)), + A.AdvancedBlur( + p=1, + blur_limit=(3, 7), + sigmaX_limit=(0.2, 1.0), + sigmaY_limit=(0.2, 1.0), + rotate_limit=(-90, 90), + beta_limit=(0.5, 8.0), + noise_limit=(0.9, 1.1), + ), + ] + ), + ), + ] + + self.image_transforms = [ + A.OneOf( + [ + A.Downscale( + scale_min=0.5, + scale_max=0.99, + interpolation=dict(downscale=down, upscale=up), + p=1, + ) + for down, up in [ + (cv2.INTER_AREA, cv2.INTER_LINEAR), + (cv2.INTER_LINEAR, cv2.INTER_CUBIC), + (cv2.INTER_CUBIC, cv2.INTER_LINEAR), + (cv2.INTER_LINEAR, cv2.INTER_AREA), + ] + ], + p=1, + ) + ] + + self.transforms = [ + *self.color_transforms, + *self.noise_transforms, + *self.image_transforms, + ] + + +augmentations = { + "default": DefaultAugmentation, + "dark": DarkAugmentation, + "perspective": PerspectiveAugmentation, + "deepcalib": DeepCalibAugmentations, + "geocalib": GeoCalibAugmentations, + "identity": IdentityAugmentation, +} diff --git a/siclib/datasets/base_dataset.py b/siclib/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7491b20f4f77338e5800b00f2624c84bb1ca35fb --- /dev/null +++ b/siclib/datasets/base_dataset.py @@ -0,0 +1,218 @@ +"""Base class for dataset. + +See mnist.py for an example of dataset. +""" + +import collections +import logging +from abc import ABCMeta, abstractmethod + +import omegaconf +import torch +from omegaconf import OmegaConf +from torch.utils.data import DataLoader, Sampler, get_worker_info +from torch.utils.data._utils.collate import default_collate_err_msg_format, np_str_obj_array_pattern + +from siclib.utils.tensor import string_classes +from siclib.utils.tools import set_num_threads, set_seed + +logger = logging.getLogger(__name__) + +# mypy: ignore-errors + + +class LoopSampler(Sampler): + """Infinite sampler that loops over a given number of elements.""" + + def __init__(self, loop_size: int, total_size: int = None): + """Initialize the sampler. + + Args: + loop_size (int): Number of elements to loop over. + total_size (int, optional): Total number of elements. Defaults to None. + """ + self.loop_size = loop_size + self.total_size = total_size - (total_size % loop_size) + + def __iter__(self): + """Return an iterator over the elements.""" + return (i % self.loop_size for i in range(self.total_size)) + + def __len__(self): + """Return the number of elements.""" + return self.total_size + + +def worker_init_fn(i): + """Initialize the workers with a different seed.""" + info = get_worker_info() + if hasattr(info.dataset, "conf"): + conf = info.dataset.conf + set_seed(info.id + conf.seed) + set_num_threads(conf.num_threads) + else: + set_num_threads(1) + + +def collate(batch): + """Difference with PyTorch default_collate: it can stack of other objects.""" + if not isinstance(batch, list): # no batching + return batch + elem = batch[0] + elem_type = type(elem) + if isinstance(elem, torch.Tensor): + # out = None + if torch.utils.data.get_worker_info() is not None: + # If we're in a background process, concatenate directly into a + # shared memory tensor to avoid an extra copy + numel = sum([x.numel() for x in batch]) + try: + _ = elem.untyped_storage()._new_shared(numel) + except AttributeError: + _ = elem.storage()._new_shared(numel) + return torch.stack(batch, dim=0) + elif ( + elem_type.__module__ == "numpy" + and elem_type.__name__ != "str_" + and elem_type.__name__ != "string_" + ): + if elem_type.__name__ in ["ndarray", "memmap"]: + # array of string classes and object + if np_str_obj_array_pattern.search(elem.dtype.str) is not None: + raise TypeError(default_collate_err_msg_format.format(elem.dtype)) + return collate([torch.as_tensor(b) for b in batch]) + elif elem.shape == (): # scalars + return torch.as_tensor(batch) + elif isinstance(elem, float): + return torch.tensor(batch, dtype=torch.float64) + elif isinstance(elem, int): + return torch.tensor(batch) + elif isinstance(elem, string_classes): + return batch + elif isinstance(elem, collections.abc.Mapping): + return {key: collate([d[key] for d in batch]) for key in elem} + elif isinstance(elem, tuple) and hasattr(elem, "_fields"): # namedtuple + return elem_type(*(collate(samples) for samples in zip(*batch))) + elif isinstance(elem, collections.abc.Sequence): + # check to make sure that the elements in batch have consistent size + it = iter(batch) + elem_size = len(next(it)) + if any(len(elem) != elem_size for elem in it): + raise RuntimeError("each element in list of batch should be of equal size") + transposed = zip(*batch) + return [collate(samples) for samples in transposed] + elif elem is None: + return elem + else: + # try to stack anyway in case the object implements stacking. + return torch.stack(batch, 0) + + +class BaseDataset(metaclass=ABCMeta): + """Base class for dataset. + + What the dataset model is expect to declare: + default_conf: dictionary of the default configuration of the dataset. + It overwrites base_default_conf in BaseModel, and it is overwritten by + the user-provided configuration passed to __init__. + Configurations can be nested. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + get_dataset(self, split): method that returns an instance of + torch.utils.data.Dataset corresponding to the requested split string, + which can be `'train'`, `'val'`, or `'test'`. + """ + + base_default_conf = { + "name": "???", + "num_workers": "???", + "train_batch_size": "???", + "val_batch_size": "???", + "test_batch_size": "???", + "shuffle_training": True, + "batch_size": 1, + "num_threads": 1, + "seed": 0, + "prefetch_factor": 2, + } + default_conf = {} + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + default_conf = OmegaConf.merge( + OmegaConf.create(self.base_default_conf), + OmegaConf.create(self.default_conf), + ) + OmegaConf.set_struct(default_conf, True) + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(self.conf, True) + logger.info(f"Creating dataset {self.__class__.__name__}") + self._init(self.conf) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def get_dataset(self, split): + """To be implemented by the child class.""" + raise NotImplementedError + + def get_data_loader(self, split, shuffle=None, pinned=False, distributed=False): + """Return a data loader for a given split.""" + assert split in ["train", "val", "test"] + dataset = self.get_dataset(split) + try: + batch_size = self.conf[f"{split}_batch_size"] + except omegaconf.MissingMandatoryValue: + batch_size = self.conf.batch_size + num_workers = self.conf.get("num_workers", batch_size) + if distributed: + shuffle = False + sampler = torch.utils.data.distributed.DistributedSampler(dataset) + else: + sampler = None + if shuffle is None: + shuffle = split == "train" and self.conf.shuffle_training + return DataLoader( + dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + pin_memory=pinned, + collate_fn=collate, + num_workers=num_workers, + worker_init_fn=worker_init_fn, + prefetch_factor=self.conf.prefetch_factor, + ) + + def get_overfit_loader(self, split: str): + """Return an overfit data loader. + + The training set is composed of a single duplicated batch, while + the validation and test sets contain a single copy of this same batch. + This is useful to debug a model and make sure that losses and metrics + correlate well. + """ + assert split in {"train", "val", "test"} + dataset = self.get_dataset("train") + sampler = LoopSampler( + self.conf.batch_size, + len(dataset) if split == "train" else self.conf.batch_size, + ) + num_workers = self.conf.get("num_workers", self.conf.batch_size) + return DataLoader( + dataset, + batch_size=self.conf.batch_size, + pin_memory=True, + num_workers=num_workers, + sampler=sampler, + worker_init_fn=worker_init_fn, + collate_fn=collate, + ) diff --git a/siclib/datasets/configs/openpano-radial.yaml b/siclib/datasets/configs/openpano-radial.yaml new file mode 100644 index 0000000000000000000000000000000000000000..717073904732c8bd9b2076b826f65664a0b9cd8f --- /dev/null +++ b/siclib/datasets/configs/openpano-radial.yaml @@ -0,0 +1,41 @@ +name: openpano_radial +base_dir: data/openpano +pano_dir: "${.base_dir}/panoramas" +images_per_pano: 16 +resize_factor: null +n_workers: 1 +device: cpu +overwrite: true +parameter_dists: + roll: + type: uniform # uni[-45, 45] + options: + loc: -0.7853981633974483 # -45 degrees + scale: 1.5707963267948966 # 90 degrees + pitch: + type: uniform # uni[-45, 45] + options: + loc: -0.7853981633974483 # -45 degrees + scale: 1.5707963267948966 # 90 degrees + vfov: + type: uniform # uni[20, 105] + options: + loc: 0.3490658503988659 # 20 degrees + scale: 1.48352986419518 # 85 degrees + k1_hat: + type: truncnorm + options: + a: -4.285714285714286 # corresponds to -0.3 + b: 4.285714285714286 # corresponds to 0.3 + loc: 0 + scale: 0.07 + resize_factor: + type: uniform + options: + loc: 1.2 + scale: 0.5 + shape: + type: fix + value: + - 640 + - 640 diff --git a/siclib/datasets/configs/openpano.yaml b/siclib/datasets/configs/openpano.yaml new file mode 100644 index 0000000000000000000000000000000000000000..07e55d33873eb8ab342f31664bb32ca895c4fce7 --- /dev/null +++ b/siclib/datasets/configs/openpano.yaml @@ -0,0 +1,34 @@ +name: openpano +base_dir: data/openpano +pano_dir: "${.base_dir}/panoramas" +images_per_pano: 16 +resize_factor: null +n_workers: 1 +device: cpu +overwrite: true +parameter_dists: + roll: + type: uniform # uni[-45, 45] + options: + loc: -0.7853981633974483 # -45 degrees + scale: 1.5707963267948966 # 90 degrees + pitch: + type: uniform # uni[-45, 45] + options: + loc: -0.7853981633974483 # -45 degrees + scale: 1.5707963267948966 # 90 degrees + vfov: + type: uniform # uni[20, 105] + options: + loc: 0.3490658503988659 # 20 degrees + scale: 1.48352986419518 # 85 degrees + resize_factor: + type: uniform + options: + loc: 1.2 + scale: 0.5 + shape: + type: fix + value: + - 640 + - 640 diff --git a/siclib/datasets/create_dataset_from_pano.py b/siclib/datasets/create_dataset_from_pano.py new file mode 100644 index 0000000000000000000000000000000000000000..f7e672e92e005fced0bde624cc9da31abc94fb89 --- /dev/null +++ b/siclib/datasets/create_dataset_from_pano.py @@ -0,0 +1,350 @@ +"""Script to create a dataset from panorama images.""" + +import hashlib +import logging +from concurrent import futures +from pathlib import Path + +import hydra +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import scipy +import torch +from omegaconf import DictConfig, OmegaConf +from tqdm import tqdm + +from siclib.geometry.camera import camera_models +from siclib.geometry.gravity import Gravity +from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg +from siclib.utils.image import load_image, write_image + +logger = logging.getLogger(__name__) + + +# mypy: ignore-errors + + +def max_radius(a, b): + """Compute the maximum radius of a Brown distortion model.""" + discrim = a * a - 4 * b + # if torch.isfinite(discrim) and discrim >= 0.0: + # discrim = np.sqrt(discrim) - a + # if discrim > 0.0: + # return 2.0 / discrim + + valid = torch.isfinite(discrim) & (discrim >= 0.0) + discrim = torch.sqrt(discrim) - a + valid &= discrim > 0.0 + return 2.0 / torch.where(valid, discrim, 0) + + +def brown_max_radius(k1, k2): + """Compute the maximum radius of a Brown distortion model.""" + # fold the constants from the derivative into a and b + a = k1 * 3 + b = k2 * 5 + return torch.sqrt(max_radius(a, b)) + + +class ParallelProcessor: + """Generic parallel processor class.""" + + def __init__(self, max_workers): + """Init processor and pbars.""" + self.max_workers = max_workers + self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers) + self.pbars = {} + + def update_pbar(self, pbar_key): + """Update progressbar.""" + pbar = self.pbars.get(pbar_key) + pbar.update(1) + + def submit_tasks(self, task_func, task_args, pbar_key): + """Submit tasks.""" + pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80) + self.pbars[pbar_key] = pbar + + def update_pbar(future): + self.update_pbar(pbar_key) + + futures = [] + for args in task_args: + future = self.executor.submit(task_func, *args) + future.add_done_callback(update_pbar) + futures.append(future) + + return futures + + def wait_for_completion(self, futures): + """Wait for completion and return results.""" + results = [] + for f in futures: + results += f.result() + + for key in self.pbars.keys(): + self.pbars[key].close() + + return results + + def shutdown(self): + """Close the executer.""" + self.executor.shutdown() + + +class DatasetGenerator: + """Dataset generator class to create perspective datasets from panoramas.""" + + default_conf = { + "name": "???", + # paths + "base_dir": "???", + "pano_dir": "${.base_dir}/panoramas", + "pano_train": "${.pano_dir}/train", + "pano_val": "${.pano_dir}/val", + "pano_test": "${.pano_dir}/test", + "perspective_dir": "${.base_dir}/${.name}", + "perspective_train": "${.perspective_dir}/train", + "perspective_val": "${.perspective_dir}/val", + "perspective_test": "${.perspective_dir}/test", + "train_csv": "${.perspective_dir}/train.csv", + "val_csv": "${.perspective_dir}/val.csv", + "test_csv": "${.perspective_dir}/test.csv", + # data options + "camera_model": "pinhole", + "parameter_dists": { + "roll": { + "type": "uniform", + "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45] + }, + "pitch": { + "type": "uniform", + "options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45] + }, + "vfov": { + "type": "uniform", + "options": {"loc": deg2rad(20), "scale": deg2rad(85)}, # in [20, 105] + }, + "resize_factor": { + "type": "uniform", + "options": {"loc": 1.0, "scale": 1.0}, # factor in [1.0, 2.0] + }, + "shape": {"type": "fix", "value": (640, 640)}, + }, + "images_per_pano": 16, + "n_workers": 10, + "device": "cpu", + "overwrite": False, + } + + def __init__(self, conf): + """Init the class by merging and storing the config.""" + self.conf = OmegaConf.merge( + OmegaConf.create(self.default_conf), + OmegaConf.create(conf), + ) + logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}") + + self.infos = {} + self.device = self.conf.device + + self.camera_model = camera_models[self.conf.camera_model] + + def sample_value(self, parameter_name, seed=None): + """Sample a value from the specified distribution.""" + param_conf = self.conf["parameter_dists"][parameter_name] + + if param_conf.type == "fix": + return torch.tensor(param_conf.value) + + # fix seed for reproducibility + generator = None + if seed: + if not isinstance(seed, (int, float)): + seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32) + generator = np.random.default_rng(seed) + + sampler = getattr(scipy.stats, param_conf.type) + return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options)) + + def plot_distributions(self): + """Plot parameter distributions.""" + fig, ax = plt.subplots(3, 3, figsize=(15, 10)) + for i, split in enumerate(["train", "val", "test"]): + roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]] + ax[i, 0].hist(roll_vals, bins=100) + ax[i, 0].set_xlabel("Roll (°)") + ax[i, 0].set_ylabel(f"Count {split}") + + pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]] + ax[i, 1].hist(pitch_vals, bins=100) + ax[i, 1].set_xlabel("Pitch (°)") + ax[i, 1].set_ylabel(f"Count {split}") + + vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]] + ax[i, 2].hist(vfov_vals, bins=100) + ax[i, 2].set_xlabel("vFoV (°)") + ax[i, 2].set_ylabel(f"Count {split}") + + plt.tight_layout() + plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf") + + fig, ax = plt.subplots(3, 3, figsize=(15, 10)) + for i, k1 in enumerate(["roll", "pitch", "vfov"]): + for j, k2 in enumerate(["roll", "pitch", "vfov"]): + ax[i, j].scatter( + [rad2deg(row[k1]) for row in self.infos["train"]], + [rad2deg(row[k2]) for row in self.infos["train"]], + s=1, + label="train", + ) + + ax[i, j].scatter( + [rad2deg(row[k1]) for row in self.infos["val"]], + [rad2deg(row[k2]) for row in self.infos["val"]], + s=1, + label="val", + ) + + ax[i, j].scatter( + [rad2deg(row[k1]) for row in self.infos["test"]], + [rad2deg(row[k2]) for row in self.infos["test"]], + s=1, + label="test", + ) + + ax[i, j].set_xlabel(k1) + ax[i, j].set_ylabel(k2) + ax[i, j].legend() + + plt.tight_layout() + plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf") + + def generate_images_from_pano(self, pano_path: Path, out_dir: Path): + """Generate perspective images from a single panorama.""" + infos = [] + + pano = load_image(pano_path).to(self.device) + + yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False) + params = { + k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws] + for k in self.conf.parameter_dists + if k != "shape" + } + shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws] + params |= { + "height": [shape[0] for shape in shapes], + "width": [shape[1] for shape in shapes], + } + + if "k1_hat" in params: + height = torch.tensor(params["height"]) + width = torch.tensor(params["width"]) + k1_hat = torch.tensor(params["k1_hat"]) + vfov = torch.tensor(params["vfov"]) + focal = fov2focal(vfov, height) + focal = focal + rel_focal = focal / height + k1 = k1_hat * rel_focal + + # distance to image corner + # r_max_im = f_px * r_max * (1 + k1*r_max**2) + # function of r_max_im: f_px = r_max_im / (r_max * (1 + k1*r_max**2)) + min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2) + r_max = brown_max_radius(k1=k1, k2=0) + lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2)) + valid = lowest_possible_f_px <= focal + + f = torch.where(valid, focal, lowest_possible_f_px) + vfov = focal2fov(f, height) + + params["vfov"] = vfov + params |= {"k1": k1} + + cam = self.camera_model.from_dict(params).float().to(self.device) + gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device) + + if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite: + for i in range(self.conf.images_per_pano): + perspective_name = f"{pano_path.stem}_{i}.jpg" + info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} + infos.append(info) + + logger.info(f"Perspectives for {pano_path.stem} already exist.") + + return infos + + perspective_images = cam.get_img_from_pano( + pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"] + ) + + for i, perspective_image in enumerate(perspective_images): + perspective_name = f"{pano_path.stem}_{i}.jpg" + + n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1] + valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01 + if not valid: + logger.debug(f"Perspective {perspective_name} has too many black pixels.") + continue + + write_image(perspective_image, out_dir / perspective_name) + + info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} + infos.append(info) + + return infos + + def generate_split(self, split: str, parallel_processor: ParallelProcessor): + """Generate a single split of a dataset.""" + self.infos[split] = [] + panorama_paths = [ + path + for path in Path(self.conf[f"pano_{split}"]).glob("*") + if not path.name.startswith(".") + ] + + out_dir = Path(self.conf[f"perspective_{split}"]) + logger.info(f"Writing perspective images to {str(out_dir)}") + if not out_dir.exists(): + out_dir.mkdir(parents=True) + + futures = parallel_processor.submit_tasks( + self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split + ) + self.infos[split] = parallel_processor.wait_for_completion(futures) + # parallel_processor.shutdown() + + metadata = pd.DataFrame(data=self.infos[split]) + metadata.to_csv(self.conf[f"{split}_csv"]) + + def generate_dataset(self): + """Generate all splits of a dataset.""" + out_dir = Path(self.conf.perspective_dir) + if not out_dir.exists(): + out_dir.mkdir(parents=True) + + OmegaConf.save(self.conf, out_dir / "config.yaml") + + processor = ParallelProcessor(self.conf.n_workers) + for split in ["train", "val", "test"]: + self.generate_split(split=split, parallel_processor=processor) + + processor.shutdown() + + for split in ["train", "val", "test"]: + logger.info(f"Generated {len(self.infos[split])} {split} images.") + + self.plot_distributions() + + +@hydra.main(version_base=None, config_path="configs", config_name="SUN360") +def main(cfg: DictConfig) -> None: + """Run dataset generation.""" + generator = DatasetGenerator(conf=cfg) + generator.generate_dataset() + + +if __name__ == "__main__": + main() diff --git a/siclib/datasets/simple_dataset.py b/siclib/datasets/simple_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..eb2ac7d4b1116f5967ae9ee0e62bab7ce83b7c92 --- /dev/null +++ b/siclib/datasets/simple_dataset.py @@ -0,0 +1,237 @@ +"""Dataset for images created with 'create_dataset_from_pano.py'.""" + +import logging +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import pandas as pd +import torch +from omegaconf import DictConfig + +from siclib.datasets.augmentations import IdentityAugmentation, augmentations +from siclib.datasets.base_dataset import BaseDataset +from siclib.geometry.camera import SimpleRadial +from siclib.geometry.gravity import Gravity +from siclib.geometry.perspective_fields import get_perspective_field +from siclib.utils.conversions import fov2focal +from siclib.utils.image import ImagePreprocessor, load_image +from siclib.utils.tools import fork_rng + +logger = logging.getLogger(__name__) + +# mypy: ignore-errors + + +def load_csv( + csv_file: Path, img_root: Path +) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]: + """Load a CSV file containing image information. + + Args: + csv_file (str): Path to the CSV file. + img_root (str): Path to the root directory containing the images. + + Returns: + list: List of dictionaries containing the image paths and camera parameters. + """ + df = pd.read_csv(csv_file) + + infos, params, gravity = [], [], [] + for _, row in df.iterrows(): + h = row["height"] + w = row["width"] + px = row.get("px", w / 2) + py = row.get("py", h / 2) + vfov = row["vfov"] + f = fov2focal(torch.tensor(vfov), h) + k1 = row.get("k1", 0) + k2 = row.get("k2", 0) + params.append(torch.tensor([w, h, f, f, px, py, k1, k2])) + + roll = row["roll"] + pitch = row["pitch"] + gravity.append(torch.tensor([roll, pitch])) + + infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])}) + + params = torch.stack(params).float() + gravity = torch.stack(gravity).float() + return infos, params, gravity + + +class SimpleDataset(BaseDataset): + """Dataset for images created with 'create_dataset_from_pano.py'.""" + + default_conf = { + # paths + "dataset_dir": "???", + "train_img_dir": "${.dataset_dir}/train", + "val_img_dir": "${.dataset_dir}/val", + "test_img_dir": "${.dataset_dir}/test", + "train_csv": "${.dataset_dir}/train.csv", + "val_csv": "${.dataset_dir}/val.csv", + "test_csv": "${.dataset_dir}/test.csv", + # data options + "use_up": True, + "use_latitude": True, + "use_prior_focal": False, + "use_prior_gravity": False, + "use_prior_k1": False, + # image options + "grayscale": False, + "preprocessing": ImagePreprocessor.default_conf, + "augmentations": {"name": "geocalib", "verbose": False}, + "p_rotate": 0.0, # probability to rotate image by +/- 90° + "reseed": False, + "seed": 0, + # data loader options + "num_workers": 8, + "prefetch_factor": 2, + "train_batch_size": 32, + "val_batch_size": 32, + "test_batch_size": 32, + } + + def _init(self, conf): + pass + + def get_dataset(self, split: str) -> torch.utils.data.Dataset: + """Return a dataset for a given split.""" + return _SimpleDataset(self.conf, split) + + +class _SimpleDataset(torch.utils.data.Dataset): + """Dataset for dataset for images created with 'create_dataset_from_pano.py'.""" + + def __init__(self, conf: DictConfig, split: str): + """Initialize the dataset.""" + self.conf = conf + self.split = split + self.img_dir = Path(conf.get(f"{split}_img_dir")) + + self.preprocessor = ImagePreprocessor(conf.preprocessing) + + # load image information + assert f"{split}_csv" in conf, f"Missing {split}_csv in conf" + infos_path = self.conf.get(f"{split}_csv") + self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir) + + # define augmentations + aug_name = conf.augmentations.name + assert ( + aug_name in augmentations.keys() + ), f'{aug_name} not in {" ".join(augmentations.keys())}' + + if self.split == "train": + self.augmentation = augmentations[aug_name](conf.augmentations) + else: + self.augmentation = IdentityAugmentation() + + def __len__(self): + return len(self.infos) + + def __getitem__(self, idx): + if not self.conf.reseed: + return self.getitem(idx) + with fork_rng(self.conf.seed + idx, False): + return self.getitem(idx) + + def _read_image( + self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor + ) -> Dict[str, Any]: + path = Path(str(infos["file_name"])) + + # load image as uint8 and HWC for augmentation + image = load_image(path, self.conf.grayscale, return_tensor=False) + image = self.augmentation(image, return_tensor=True) + + # create radial camera -> same as pinhole if k1 = 0 + camera = SimpleRadial(parameters[None]).float() + + roll, pitch = gravity[None].unbind(-1) + gravity = Gravity.from_rp(roll, pitch) + + # preprocess + data = self.preprocessor(image) + camera = camera.scale(data["scales"]) + camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera + + priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {} + priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {} + priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {} + return { + "name": infos["name"], + "path": str(path), + "camera": camera[0], + "gravity": gravity[0], + **priors, + **data, + } + + def _get_perspective(self, data): + """Get perspective field.""" + camera = data["camera"] + gravity = data["gravity"] + + up_field, lat_field = get_perspective_field( + camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude + ) + + out = {} + if self.conf.use_up: + out["up_field"] = up_field[0] + if self.conf.use_latitude: + out["latitude_field"] = lat_field[0] + + return out + + def getitem(self, idx: int): + """Return a sample from the dataset.""" + infos = self.infos[idx] + parameters = self.parameters[idx] + gravity = self.gravity[idx] + data = self._read_image(infos, parameters, gravity) + + if self.conf.use_up or self.conf.use_latitude: + data |= self._get_perspective(data) + + return data + + +if __name__ == "__main__": + # Create a dump of the dataset + import argparse + + import matplotlib.pyplot as plt + + from siclib.visualization.visualize_batch import make_perspective_figures + + parser = argparse.ArgumentParser() + parser.add_argument("--name", type=str, required=True) + parser.add_argument("--data_dir", type=str) + parser.add_argument("--split", type=str, default="train") + parser.add_argument("--shuffle", action="store_true") + parser.add_argument("--n_rows", type=int, default=4) + parser.add_argument("--dpi", type=int, default=100) + args = parser.parse_intermixed_args() + + dconf = SimpleDataset.default_conf + dconf["name"] = args.name + dconf["num_workers"] = 0 + dconf["prefetch_factor"] = None + + dconf["dataset_dir"] = args.data_dir + dconf[f"{args.split}_batch_size"] = args.n_rows + + torch.set_grad_enabled(False) + + dataset = SimpleDataset(dconf) + loader = dataset.get_data_loader(args.split, args.shuffle) + + with fork_rng(seed=42): + for data in loader: + pred = data + break + fig = make_perspective_figures(pred, data, n_pairs=args.n_rows) + + plt.show() diff --git a/siclib/datasets/utils/__init__.py b/siclib/datasets/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/datasets/utils/align_megadepth.py b/siclib/datasets/utils/align_megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..123e456bd3603f7187785153733bd843b782d53b --- /dev/null +++ b/siclib/datasets/utils/align_megadepth.py @@ -0,0 +1,41 @@ +import argparse +import subprocess +from pathlib import Path + +# flake8: noqa +# mypy: ignore-errors + +parser = argparse.ArgumentParser(description="Aligns a COLMAP model and plots the horizon lines") +parser.add_argument( + "--base_dir", type=str, help="Path to the base directory of the MegaDepth dataset" +) +parser.add_argument("--out_dir", type=str, help="Path to the output directory") +args = parser.parse_args() + +base_dir = Path(args.base_dir) +out_dir = Path(args.out_dir) + +scenes = [d.name for d in base_dir.iterdir() if d.is_dir()] +print(scenes[:3], len(scenes)) + +# exit() + +for scene in scenes: + image_dir = base_dir / scene / "images" + sfm_dir = base_dir / scene / "sparse" / "manhattan" / "0" + + # Align model + align_dir = out_dir / scene / "sparse" / "align" + align_dir.mkdir(exist_ok=True, parents=True) + + print(f"image_dir ({image_dir.exists()}): {image_dir}") + print(f"sfm_dir ({sfm_dir.exists()}): {sfm_dir}") + print(f"align_dir ({align_dir.exists()}): {align_dir}") + + cmd = ( + "colmap model_orientation_aligner " + + f"--image_path {image_dir} " + + f"--input_path {sfm_dir} " + + f"--output_path {str(align_dir)}" + ) + subprocess.run(cmd, shell=True) diff --git a/siclib/datasets/utils/download_openpano.py b/siclib/datasets/utils/download_openpano.py new file mode 100644 index 0000000000000000000000000000000000000000..ff751da36313d74fd268f7fa057c6a0bcb6c611d --- /dev/null +++ b/siclib/datasets/utils/download_openpano.py @@ -0,0 +1,75 @@ +"""Helper script to download and extract OpenPano dataset.""" + +import argparse +import shutil +from pathlib import Path + +import torch +from tqdm import tqdm + +from siclib import logger + +PANO_URL = "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/openpano.zip" + + +def download_and_extract_dataset(name: str, url: Path, output: Path) -> None: + """Download and extract a dataset from a URL.""" + dataset_dir = output / name + if not output.exists(): + output.mkdir(parents=True) + + if dataset_dir.exists(): + logger.info(f"Dataset {name} already exists at {dataset_dir}, skipping download.") + return + + zip_file = output / f"{name}.zip" + + if not zip_file.exists(): + logger.info(f"Downloading dataset {name} to {zip_file} from {url}.") + torch.hub.download_url_to_file(url, zip_file) + + logger.info(f"Extracting dataset {name} in {output}.") + shutil.unpack_archive(zip_file, output, format="zip") + zip_file.unlink() + + +def main(): + """Prepare the OpenPano dataset.""" + parser = argparse.ArgumentParser(description="Download and extract OpenPano dataset.") + parser.add_argument("--name", type=str, default="openpano", help="Name of the dataset.") + parser.add_argument( + "--laval_dir", type=str, default="data/laval-tonemap", help="Path the Laval dataset." + ) + + args = parser.parse_args() + + out_dir = Path("data") + download_and_extract_dataset(args.name, PANO_URL, out_dir) + + pano_dir = out_dir / args.name / "panoramas" + for split in ["train", "test", "val"]: + with open(pano_dir / f"{split}_panos.txt", "r") as f: + pano_list = f.readlines() + pano_list = [fname.strip() for fname in pano_list] + + for fname in tqdm(pano_list, ncols=80, desc=f"Copying {split} panoramas"): + laval_path = Path(args.laval_dir) / fname + target_path = pano_dir / split / fname + + # pano either exists in laval or is in split + if target_path.exists(): + continue + + if laval_path.exists(): + shutil.copy(laval_path, target_path) + else: # not in laval and not in split + logger.warning(f"Panorama {fname} not found in {args.laval_dir} or {split} split.") + + n_train = len(list(pano_dir.glob("train/*.jpg"))) + n_test = len(list(pano_dir.glob("test/*.jpg"))) + n_val = len(list(pano_dir.glob("val/*.jpg"))) + logger.info(f"{args.name} contains {n_train}/{n_test}/{n_val} train/test/val panoramas.") + + +if __name__ == "__main__": + main() diff --git a/siclib/datasets/utils/tonemapping.py b/siclib/datasets/utils/tonemapping.py new file mode 100644 index 0000000000000000000000000000000000000000..3da1740724f003900771b099ef8f2de00deb8ea4 --- /dev/null +++ b/siclib/datasets/utils/tonemapping.py @@ -0,0 +1,316 @@ +import os + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" +import argparse + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib import colors +from tqdm import tqdm + +# flake8: noqa +# mypy: ignore-errors + + +class tonemap: + def __init__(self): + pass + + def process(self, img): + return img + + def inv_process(self, img): + return img + + +# Log correction +class log_tonemap(tonemap): + # Constructor + # Base of log + # Scale of tonemapped + # Offset + def __init__(self, base, scale=1, offset=1): + self.base = base + self.scale = scale + self.offset = offset + + def process(self, img): + tonemapped = (np.log(img + self.offset) / np.log(self.base)) * self.scale + return tonemapped + + def inv_process(self, img): + inverse_tonemapped = np.power(self.base, (img) / self.scale) - self.offset + return inverse_tonemapped + + +class log_tonemap_clip(tonemap): + # Constructor + # Base of log + # Scale of tonemapped + # Offset + def __init__(self, base, scale=1, offset=1): + self.base = base + self.scale = scale + self.offset = offset + + def process(self, img): + tonemapped = np.clip((np.log(img * self.scale + self.offset) / np.log(self.base)), 0, 2) - 1 + return tonemapped + + def inv_process(self, img): + inverse_tonemapped = (np.power(self.base, (img + 1)) - self.offset) / self.scale + return inverse_tonemapped + + +# Gamma Tonemap +class gamma_tonemap(tonemap): + def __init__( + self, + gamma, + ): + self.gamma = gamma + + def process(self, img): + tonemapped = np.power(img, 1 / self.gamma) + return tonemapped + + def inv_process(self, img): + inverse_tonemapped = np.power(img, self.gamma) + return inverse_tonemapped + + +class linear_clip(tonemap): + def __init__(self, scale, mean): + self.scale = scale + self.mean = mean + + def process(self, img): + tonemapped = np.clip((img - self.mean) / self.scale, -1, 1) + return tonemapped + + def inv_process(self, img): + inverse_tonemapped = img * self.scale + self.mean + return inverse_tonemapped + + +def make_tonemap_HDR(opt): + if opt.mode == "luminance": + res_tonemap = log_tonemap_clip(10, 1.0, 1.0) + else: # temperature + res_tonemap = linear_clip(5000.0, 5000.0) + return res_tonemap + + +class LDRfromHDR: + def __init__( + self, tonemap="none", orig_scale=False, clip=True, quantization=0, color_jitter=0, noise=0 + ): + self.tonemap_str, val = tonemap + if tonemap[0] == "gamma": + self.tonemap = gamma_tonemap(val) + elif tonemap[0] == "log10": + self.tonemap = log_tonemap(val) + else: + print("Warning: No tonemap specified, using linear") + + self.clip = clip + self.orig_scale = orig_scale + self.bits = quantization + self.jitter = color_jitter + self.noise = noise + + self.wbModel = None + + def process(self, HDR): + LDR, normalized_scale = self.rescale(HDR) + LDR = self.apply_clip(LDR) + LDR = self.apply_scale(LDR, normalized_scale) + LDR = self.apply_tonemap(LDR) + LDR = self.colorJitter(LDR) + LDR = self.gaussianNoise(LDR) + LDR = self.quantize(LDR) + LDR = self.apply_white_balance(LDR) + return LDR, normalized_scale + + def rescale(self, img, percentile=90, max_mapping=0.8): + r_percentile = np.percentile(img, percentile) + alpha = max_mapping / (r_percentile + 1e-10) + + img_reexposed = img * alpha + + normalized_scale = normalizeScale(1 / alpha) + + return img_reexposed, normalized_scale + + def rescaleAlpha(self, img, percentile=90, max_mapping=0.8): + r_percentile = np.percentile(img, percentile) + alpha = max_mapping / (r_percentile + 1e-10) + + return alpha + + def apply_clip(self, img): + if self.clip: + img = np.clip(img, 0, 1) + return img + + def apply_scale(self, img, scale): + if self.orig_scale: + scale = unNormalizeScale(scale) + img = img * scale + return img + + def apply_tonemap(self, img): + if self.tonemap_str == "none": + return img + gammaed = self.tonemap.process(img) + return gammaed + + def quantize(self, img): + if self.bits == 0: + return img + max_val = np.power(2, self.bits) + img = img * max_val + img = np.floor(img) + img = img / max_val + return img + + def colorJitter(self, img): + if self.jitter == 0: + return img + hsv = colors.rgb_to_hsv(img) + hue_offset = np.random.normal(0, self.jitter, 1) + hsv[:, :, 0] = (hsv[:, :, 0] + hue_offset) % 1.0 + rgb = colors.hsv_to_rgb(hsv) + return rgb + + def gaussianNoise(self, img): + if self.noise == 0: + return img + noise_amount = np.random.uniform(0, self.noise, 1) + noise_img = np.random.normal(0, noise_amount, img.shape) + img = img + noise_img + img = np.clip(img, 0, 1).astype(np.float32) + return img + + def apply_white_balance(self, img): + if self.wbModel is None: + return img + img = self.wbModel.correctImage(img) + return img.copy() + + +def make_LDRfromHDR(opt): + LDR_from_HDR = LDRfromHDR( + opt.tonemap_LDR, opt.orig_scale, opt.clip, opt.quantization, opt.color_jitter, opt.noise + ) + return LDR_from_HDR + + +def torchnormalizeEV(EV, mean=5.12, scale=6, clip=True): + # Normalize based on the computed distribution between -1 1 + EV -= mean + EV = EV / scale + + if clip: + EV = torch.clip(EV, min=-1, max=1) + + return EV + + +def torchnormalizeEV0(EV, mean=5.12, scale=6, clip=True): + # Normalize based on the computed distribution between 0 1 + EV -= mean + EV = EV / scale + + if clip: + EV = torch.clip(EV, min=-1, max=1) + + EV += 0.5 + EV = EV / 2 + + return EV + + +def normalizeScale(x, scale=4): + x = np.log10(x + 1) + + x = x / (scale / 2) + x = x - 1 + + return x + + +def unNormalizeScale(x, scale=4): + x = x + 1 + x = x * (scale / 2) + + x = np.power(10, x) - 1 + + return x + + +def normalizeIlluminance(x, scale=5): + x = np.log10(x + 1) + + x = x / (scale / 2) + x = x - 1 + + return x + + +def unNormalizeIlluminance(x, scale=5): + x = x + 1 + x = x * (scale / 2) + + x = np.power(10, x) - 1 + + return x + + +def main(args): + processor = LDRfromHDR( + # tonemap=("log10", 10), + tonemap=("gamma", args.gamma), + orig_scale=False, + clip=True, + quantization=0, + color_jitter=0, + noise=0, + ) + + img_list = list(os.listdir(args.hdr_dir)) + img_list = [f for f in img_list if f.endswith(args.extension)] + img_list = [f for f in img_list if not f.startswith("._")] + + if not os.path.exists(args.out_dir): + os.makedirs(args.out_dir) + + for fname in tqdm(img_list): + fname_out = ".".join(fname.split(".")[:-1]) + out = os.path.join(args.out_dir, f"{fname_out}.jpg") + if os.path.exists(out) and not args.overwrite: + continue + + fpath = os.path.join(args.hdr_dir, fname) + img = cv2.imread(fpath, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + ldr, scale = processor.process(img) + + ldr = (ldr * 255).astype(np.uint8) + ldr = cv2.cvtColor(ldr, cv2.COLOR_RGB2BGR) + cv2.imwrite(out, ldr) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--hdr_dir", type=str, default="hdr") + parser.add_argument("--out_dir", type=str, default="ldr") + parser.add_argument("--extension", type=str, default=".exr") + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--gamma", type=float, default=2) + args = parser.parse_args() + + main(args) diff --git a/siclib/eval/__init__.py b/siclib/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..550493c7bc54a9d827d8e12010b29a7017970821 --- /dev/null +++ b/siclib/eval/__init__.py @@ -0,0 +1,18 @@ +import torch + +from siclib.eval.eval_pipeline import EvalPipeline +from siclib.utils.tools import get_class + + +def get_benchmark(benchmark): + return get_class(f"{__name__}.{benchmark}", EvalPipeline) + + +@torch.no_grad() +def run_benchmark(benchmark, eval_conf, experiment_dir, model=None): + """This overwrites existing benchmarks""" + experiment_dir.mkdir(exist_ok=True, parents=True) + bm = get_benchmark(benchmark) + + pipeline = bm(eval_conf) + return pipeline.run(experiment_dir, model=model, overwrite=True, overwrite_eval=True) diff --git a/siclib/eval/configs/deepcalib.yaml b/siclib/eval/configs/deepcalib.yaml new file mode 100644 index 0000000000000000000000000000000000000000..2a0173265d7e8ff64110ccff94e90d6d9a31fdf4 --- /dev/null +++ b/siclib/eval/configs/deepcalib.yaml @@ -0,0 +1,3 @@ +model: + name: networks.deepcalib + weights: weights/deepcalib.tar diff --git a/siclib/eval/configs/geocalib-pinhole.yaml b/siclib/eval/configs/geocalib-pinhole.yaml new file mode 100644 index 0000000000000000000000000000000000000000..32b036201ae27f38811dc7558c8bf4449ee6630c --- /dev/null +++ b/siclib/eval/configs/geocalib-pinhole.yaml @@ -0,0 +1,2 @@ +model: + name: networks.geocalib_pretrained diff --git a/siclib/eval/configs/geocalib-simple_radial.yaml b/siclib/eval/configs/geocalib-simple_radial.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6bdeeb4e21cb9b6f3a179d829c2adc3ae7f9f0e5 --- /dev/null +++ b/siclib/eval/configs/geocalib-simple_radial.yaml @@ -0,0 +1,4 @@ +model: + name: networks.geocalib_pretrained + camera_model: simple_radial + model_weights: distorted diff --git a/siclib/eval/configs/uvp.yaml b/siclib/eval/configs/uvp.yaml new file mode 100644 index 0000000000000000000000000000000000000000..887dfde752becf3ec06be35411a68cf9af11eb74 --- /dev/null +++ b/siclib/eval/configs/uvp.yaml @@ -0,0 +1,18 @@ +model: + name: optimization.vp_from_prior + SOLVER_FLAGS: [True, True, True, True, True] + magsac_scoring: true + min_lines: 5 + verbose: false + + # RANSAC inlier threshold + th_pixels: 3 + + # 3 uses the gravity in the LS refinement, 2 does not. Here we use a prior on the gravity, so use 2 + ls_refinement: 2 + + # change to 3 to add a Ceres optimization after the non minimal solver (slower) + nms: 1 + + # deeplsd, lsd + line_type: deeplsd diff --git a/siclib/eval/eval_pipeline.py b/siclib/eval/eval_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..cb5e4ada2779bd7b42b14488037b45395131fd4e --- /dev/null +++ b/siclib/eval/eval_pipeline.py @@ -0,0 +1,106 @@ +import json + +import h5py +import numpy as np +from omegaconf import OmegaConf + +# flake8: noqa +# mypy: ignore-errors + + +def load_eval(dir): + summaries, results = {}, {} + with h5py.File(str(dir / "results.h5"), "r") as hfile: + for k in hfile.keys(): + r = np.array(hfile[k]) + if len(r.shape) < 3: + results[k] = r + for k, v in hfile.attrs.items(): + summaries[k] = v + with open(dir / "summaries.json", "r") as f: + s = json.load(f) + summaries = {k: v if v is not None else np.nan for k, v in s.items()} + return summaries, results + + +def save_eval(dir, summaries, figures, results): + with h5py.File(str(dir / "results.h5"), "w") as hfile: + for k, v in results.items(): + arr = np.array(v) + if not np.issubdtype(arr.dtype, np.number): + arr = arr.astype("object") + hfile.create_dataset(k, data=arr) + # just to be safe, not used in practice + for k, v in summaries.items(): + hfile.attrs[k] = v + s = { + k: float(v) if np.isfinite(v) else None + for k, v in summaries.items() + if not isinstance(v, list) + } + s = {**s, **{k: v for k, v in summaries.items() if isinstance(v, list)}} + with open(dir / "summaries.json", "w") as f: + json.dump(s, f, indent=4) + + for fig_name, fig in figures.items(): + fig.savefig(dir / f"{fig_name}.png") + + +def exists_eval(dir): + return (dir / "results.h5").exists() and (dir / "summaries.json").exists() + + +class EvalPipeline: + default_conf = {} + + export_keys = [] + optional_export_keys = [] + + def __init__(self, conf): + """Assumes""" + self.default_conf = OmegaConf.create(self.default_conf) + self.conf = OmegaConf.merge(self.default_conf, conf) + self._init(self.conf) + + def _init(self, conf): + pass + + @classmethod + def get_dataloader(cls, data_conf=None): + """Returns a data loader with samples for each eval datapoint""" + raise NotImplementedError + + def get_predictions(self, experiment_dir, model=None, overwrite=False): + """Export a prediction file for each eval datapoint""" + raise NotImplementedError + + def run_eval(self, loader, pred_file): + """Run the eval on cached predictions""" + raise NotImplementedError + + def run(self, experiment_dir, model=None, overwrite=False, overwrite_eval=False): + """Run export+eval loop""" + self.save_conf(experiment_dir, overwrite=overwrite, overwrite_eval=overwrite_eval) + pred_file = self.get_predictions(experiment_dir, model=model, overwrite=overwrite) + + f = {} + if not exists_eval(experiment_dir) or overwrite_eval or overwrite: + s, f, r = self.run_eval(self.get_dataloader(self.conf.data, 1), pred_file) + save_eval(experiment_dir, s, f, r) + s, r = load_eval(experiment_dir) + return s, f, r + + def save_conf(self, experiment_dir, overwrite=False, overwrite_eval=False): + # store config + conf_output_path = experiment_dir / "conf.yaml" + if conf_output_path.exists(): + saved_conf = OmegaConf.load(conf_output_path) + if (saved_conf.data != self.conf.data) or (saved_conf.model != self.conf.model): + assert ( + overwrite + ), "configs changed, add --overwrite to rerun experiment with new conf" + if saved_conf.eval != self.conf.eval: + assert ( + overwrite or overwrite_eval + ), "eval configs changed, add --overwrite_eval to rerun evaluation" + OmegaConf.save(self.conf, experiment_dir / "conf.yaml") diff --git a/siclib/eval/inspect.py b/siclib/eval/inspect.py new file mode 100644 index 0000000000000000000000000000000000000000..ac00c89c736a4b9bb6a7351b103d120f3745ad43 --- /dev/null +++ b/siclib/eval/inspect.py @@ -0,0 +1,62 @@ +import argparse +from collections import defaultdict +from pathlib import Path +from pprint import pprint + +import matplotlib +import matplotlib.pyplot as plt + +from siclib.eval import get_benchmark +from siclib.eval.eval_pipeline import load_eval +from siclib.settings import EVAL_PATH +from siclib.visualization.global_frame import GlobalFrame +from siclib.visualization.two_view_frame import TwoViewFrame + +# flake8: noqa +# mypy: ignore-errors + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("benchmark", type=str) + parser.add_argument("--x", type=str, default=None) + parser.add_argument("--y", type=str, default=None) + parser.add_argument("--backend", type=str, default=None) + parser.add_argument("--default_plot", type=str, default=TwoViewFrame.default_conf["default"]) + + parser.add_argument("dotlist", nargs="*") + args = parser.parse_intermixed_args() + + output_dir = Path(EVAL_PATH, args.benchmark) + + results = {} + summaries = defaultdict(dict) + + predictions = {} + + if args.backend: + matplotlib.use(args.backend) + + bm = get_benchmark(args.benchmark) + loader = bm.get_dataloader() + + for name in args.dotlist: + experiment_dir = output_dir / name + pred_file = experiment_dir / "predictions.h5" + s, results[name] = load_eval(experiment_dir) + predictions[name] = pred_file + for k, v in s.items(): + summaries[k][name] = v + + pprint(summaries) + + plt.close("all") + + frame = GlobalFrame( + {"child": {"default": args.default_plot}, **vars(args)}, + results, + loader, + predictions, + child_frame=TwoViewFrame, + ) + frame.draw() + plt.show() diff --git a/siclib/eval/io.py b/siclib/eval/io.py new file mode 100644 index 0000000000000000000000000000000000000000..276a47c6e2fa58ba4d9980bd8ea5af624da44c10 --- /dev/null +++ b/siclib/eval/io.py @@ -0,0 +1,105 @@ +import argparse +from pathlib import Path +from pprint import pprint +from typing import Optional + +import pkg_resources +from hydra import compose, initialize +from omegaconf import OmegaConf + +from siclib.models import get_model +from siclib.settings import TRAINING_PATH +from siclib.utils.experiments import load_experiment + +# flake8: noqa +# mypy: ignore-errors + + +def parse_config_path(name_or_path: Optional[str], defaults: str) -> Path: + default_configs = {} + print(f"Looking for default config: {'siclib', str(defaults)}") + for c in pkg_resources.resource_listdir("siclib.eval", str(defaults)): + if c.endswith(".yaml"): + default_configs[Path(c).stem] = Path( + pkg_resources.resource_filename("siclib.eval", defaults + c) + ) + if name_or_path is None: + return None + if name_or_path in default_configs: + return default_configs[name_or_path] + path = Path(name_or_path) + if not path.exists(): + raise FileNotFoundError( + f"Cannot find the config file: {name_or_path}. " + f"Not in the default configs {list(default_configs.keys())} " + "and not an existing path." + ) + return Path(path) + + +def extract_benchmark_conf(conf, benchmark): + mconf = OmegaConf.create({"model": conf.get("model", {})}) + if "benchmarks" in conf.keys(): + return OmegaConf.merge(mconf, conf.benchmarks.get(benchmark, {})) + else: + return mconf + + +def parse_eval_args(benchmark, args, configs_path, default=None): + conf = {"data": {}, "model": {}, "eval": {}} + + if args.conf: + print(f"Loading config: {configs_path}") + conf_path = parse_config_path(args.conf, configs_path) + initialize(version_base=None, config_path=configs_path) + custom_conf = compose(config_name=args.conf) + conf = extract_benchmark_conf(OmegaConf.merge(conf, custom_conf), benchmark) + args.tag = args.tag if args.tag is not None else conf_path.name.replace(".yaml", "") + + cli_conf = OmegaConf.from_cli(args.dotlist) + conf = OmegaConf.merge(conf, cli_conf) + conf.checkpoint = args.checkpoint or conf.get("checkpoint") + + if conf.checkpoint and not conf.checkpoint.endswith(".tar"): + checkpoint_conf = OmegaConf.load(TRAINING_PATH / conf.checkpoint / "config.yaml") + conf = OmegaConf.merge(extract_benchmark_conf(checkpoint_conf, benchmark), conf) + + if default: + conf = OmegaConf.merge(default, conf) + + if args.tag is not None: + name = args.tag + elif args.conf and conf.checkpoint: + name = f"{args.conf}_{conf.checkpoint}" + elif args.conf: + name = args.conf + elif conf.checkpoint: + name = conf.checkpoint + if len(args.dotlist) > 0 and not args.tag: + name = f"{name}_" + ":".join(args.dotlist) + + print("Running benchmark:", benchmark) + print("Experiment tag:", name) + print("Config:") + pprint(OmegaConf.to_container(conf)) + return name, conf + + +def load_model(model_conf, checkpoint, get_last=False): + if checkpoint: + model = load_experiment(checkpoint, conf=model_conf, get_last=get_last).eval() + else: + model = get_model(model_conf.name)(model_conf).eval() + return model + + +def get_eval_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--tag", type=str, default=None) + parser.add_argument("--checkpoint", type=str, default=None) + parser.add_argument("--conf", type=str, default=None) + parser.add_argument("--overwrite", action="store_true") + parser.add_argument("--overwrite_eval", action="store_true") + parser.add_argument("--plot", action="store_true") + parser.add_argument("dotlist", nargs="*") + return parser diff --git a/siclib/eval/lamar2k.py b/siclib/eval/lamar2k.py new file mode 100644 index 0000000000000000000000000000000000000000..9d2087ff460bf2ea07ce1ca29069188a512fddd8 --- /dev/null +++ b/siclib/eval/lamar2k.py @@ -0,0 +1,89 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class Lamar2k(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/lamar2k", + "test_img_dir": "${.dataset_dir}/images", + "test_csv": "${.dataset_dir}/images.csv", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/lamar2k.zip", + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(Lamar2k.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = Lamar2k(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/megadepth2k.py b/siclib/eval/megadepth2k.py new file mode 100644 index 0000000000000000000000000000000000000000..58ee733c376871eb33c6cddd983dcac769db374d --- /dev/null +++ b/siclib/eval/megadepth2k.py @@ -0,0 +1,89 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class Megadepth2k(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/megadepth2k", + "test_img_dir": "${.dataset_dir}/images", + "test_csv": "${.dataset_dir}/images.csv", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/megadepth2k.zip", + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(Megadepth2k.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = Megadepth2k(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/megadepth2k_radial.py b/siclib/eval/megadepth2k_radial.py new file mode 100644 index 0000000000000000000000000000000000000000..092d0ff94c2034ca4c0d4fe965fac6490c31e75a --- /dev/null +++ b/siclib/eval/megadepth2k_radial.py @@ -0,0 +1,101 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.eval.utils import download_and_extract_benchmark +from siclib.geometry.camera import SimpleRadial +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class Megadepth2kRadial(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/megadepth2k-radial", + "test_img_dir": "${.dataset_dir}/images", + "test_csv": "${.dataset_dir}/images.csv", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/megadepth2k-radial.zip", + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + def _init(self, conf): + self.verbose = conf.eval.verbose + self.num_vis = self.conf.eval.num_vis + + self.CameraModel = SimpleRadial + + if conf.url is not None: + ds_dir = Path(conf.data.dataset_dir) + download_and_extract_benchmark(ds_dir.name, conf.url, ds_dir.parent) + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(Megadepth2kRadial.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = Megadepth2kRadial(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/openpano.py b/siclib/eval/openpano.py new file mode 100644 index 0000000000000000000000000000000000000000..88a7f5544bba15fb70adec90351f05e12787c4e1 --- /dev/null +++ b/siclib/eval/openpano.py @@ -0,0 +1,70 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class OpenPano(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/poly+maps+laval/poly+maps+laval", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": None, + } + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(OpenPano.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = OpenPano(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/openpano_radial.py b/siclib/eval/openpano_radial.py new file mode 100644 index 0000000000000000000000000000000000000000..a9264f045c43e2374cdca63f30cc20c91a4efe6a --- /dev/null +++ b/siclib/eval/openpano_radial.py @@ -0,0 +1,70 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class OpenPanoRadial(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/poly+maps+laval/pano_dataset_distorted", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": None, + } + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(OpenPanoRadial.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = OpenPanoRadial(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/run_perceptual.py b/siclib/eval/run_perceptual.py new file mode 100644 index 0000000000000000000000000000000000000000..a3d26cad2c05e44f925144489e936a3a5a22fa7b --- /dev/null +++ b/siclib/eval/run_perceptual.py @@ -0,0 +1,84 @@ +"""Run the TPAMI 2023 Perceptual model. + +Run the model of the paper + A Deep Perceptual Measure for Lens and Camera Calibration, TPAMI 2023 + https://lvsn.github.io/deepcalib/ +through the public Dashboard available at http://rachmaninoff.gel.ulaval.ca:8005. +""" + +import argparse +import json +import re +import time +from pathlib import Path + +from selenium import webdriver +from selenium.webdriver.common.by import By +from tqdm import tqdm + +# mypy: ignore-errors + +JS_DROP_FILES = "var k=arguments,d=k[0],g=k[1],c=k[2],m=d.ownerDocument||document;for(var e=0;;){var f=d.getBoundingClientRect(),b=f.left+(g||(f.width/2)),a=f.top+(c||(f.height/2)),h=m.elementFromPoint(b,a);if(h&&d.contains(h)){break}if(++e>1){var j=new Error('Element not interactable');j.code=15;throw j}d.scrollIntoView({behavior:'instant',block:'center',inline:'center'})}var l=m.createElement('INPUT');l.setAttribute('type','file');l.setAttribute('multiple','');l.setAttribute('style','position:fixed;z-index:2147483647;left:0;top:0;');l.onchange=function(q){l.parentElement.removeChild(l);q.stopPropagation();var r={constructor:DataTransfer,effectAllowed:'all',dropEffect:'none',types:['Files'],files:l.files,setData:function u(){},getData:function o(){},clearData:function s(){},setDragImage:function i(){}};if(window.DataTransferItemList){r.items=Object.setPrototypeOf(Array.prototype.map.call(l.files,function(x){return{constructor:DataTransferItem,kind:'file',type:x.type,getAsFile:function v(){return x},getAsString:function y(A){var z=new FileReader();z.onload=function(B){A(B.target.result)};z.readAsText(x)},webkitGetAsEntry:function w(){return{constructor:FileSystemFileEntry,name:x.name,fullPath:'/'+x.name,isFile:true,isDirectory:false,file:function z(A){A(x)}}}}}),{constructor:DataTransferItemList,add:function t(){},clear:function p(){},remove:function n(){}})}['dragenter','dragover','drop'].forEach(function(v){var w=m.createEvent('DragEvent');w.initMouseEvent(v,true,true,m.defaultView,0,0,0,b,a,false,false,false,false,0,null);Object.setPrototypeOf(w,null);w.dataTransfer=r;Object.setPrototypeOf(w,DragEvent.prototype);h.dispatchEvent(w)})};m.documentElement.appendChild(l);l.getBoundingClientRect();return l" # noqa E501 + + +def setup_driver(): + """Setup the Selenium browser.""" + options = webdriver.FirefoxOptions() + geckodriver_path = "/snap/bin/geckodriver" # specify the path to your geckodriver + driver_service = webdriver.FirefoxService(executable_path=geckodriver_path) + return webdriver.Firefox(options=options, service=driver_service) + + +def run(args): + """Run on an image folder.""" + driver = setup_driver() + driver.get("http://rachmaninoff.gel.ulaval.ca:8005/") + time.sleep(5) + result_div = driver.find_element(By.ID, "estimated-parameters-display") + + def upload_image(path): + path = Path(path).absolute().as_posix() + elem = driver.find_element(By.ID, "dash-uploader") + inp = driver.execute_script(JS_DROP_FILES, elem, 25, 25) + inp._execute("sendKeysToElement", {"value": [path], "text": path}) + + def run_image(path, prev_result, timeout_seconds=60): + # One main assumption is that subsequent images will have different results + # from each other, otherwise we cannot detect that the inference has completed. + upload_image(path) + started = time.time() + while True: + result = result_div.text + if (result != prev_result) and result: + return result + prev_result = result + if (time.time() - started) > timeout_seconds: + raise TimeoutError + + result = str(result_div.text) + number = r"(nan|-?\d*\.?\d*)" + pattern = re.compile( + f"Pitch: {number}° / Roll: {number}° / HFOV : {number}° / Distortion: {number}" + ) + + paths = sorted(args.images.iterdir()) + results = {} + for path in (pbar := tqdm(paths)): + pbar.set_description(path.name) + result = run_image(path, result) + match = pattern.match(result) + if match is None: + print("Error, cannot parse:", result, path) + continue + results[path.name] = tuple(map(float, match.groups())) + + args.results.write_text(json.dumps(results)) + driver.close() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("images", type=Path) + parser.add_argument("results", type=Path) + args = parser.parse_args() + run(args) diff --git a/siclib/eval/simple_pipeline.py b/siclib/eval/simple_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..7271e961bd93192120e41767b44ad3ebbbbda105 --- /dev/null +++ b/siclib/eval/simple_pipeline.py @@ -0,0 +1,408 @@ +import logging +import resource +from collections import defaultdict +from pathlib import Path +from pprint import pprint +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import torch +from omegaconf import OmegaConf +from tqdm import tqdm + +from siclib.datasets import get_dataset +from siclib.eval.eval_pipeline import EvalPipeline +from siclib.eval.io import get_eval_parser, load_model, parse_eval_args +from siclib.eval.utils import download_and_extract_benchmark, plot_scatter_grid +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.camera import Pinhole +from siclib.geometry.gravity import Gravity +from siclib.models.cache_loader import CacheLoader +from siclib.models.utils.metrics import ( + gravity_error, + latitude_error, + pitch_error, + roll_error, + up_error, + vfov_error, +) +from siclib.settings import EVAL_PATH +from siclib.utils.conversions import rad2deg +from siclib.utils.export_predictions import export_predictions +from siclib.utils.tensor import add_batch_dim +from siclib.utils.tools import AUCMetric, set_seed +from siclib.visualization import visualize_batch, viz2d + +# flake8: noqa +# mypy: ignore-errors + +logger = logging.getLogger(__name__) + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +def calculate_pixel_projection_error( + camera_pred: BaseCamera, camera_gt: BaseCamera, N: int = 500, distortion_only: bool = True +) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate the pixel projection error between two cameras. + + 1. Project a grid of points with the ground truth camera to the image plane. + 2. Project the same grid of points with the estimated camera to the image plane. + 3. Calculate the pixel distance between the ground truth and estimated points. + + Args: + camera_pred (Camera): Predicted camera. + camera_gt (Camera): Ground truth camera. + N (int, optional): Number of points in the grid. Defaults to 500. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Pixel distance and valid pixels. + """ + H, W = camera_gt.size.unbind(-1) + H, W = H.int(), W.int() + + assert torch.allclose( + camera_gt.size, camera_pred.size + ), f"Cameras must have the same size: {camera_gt.size} != {camera_pred.size}" + + if distortion_only: + params = camera_gt._data.clone() + params[..., -2:] = camera_pred._data[..., -2:] + CameraModel = type(camera_gt) + camera_pred = CameraModel(params) + + x_gt, y_gt = torch.meshgrid( + torch.linspace(0, H - 1, N), torch.linspace(0, W - 1, N), indexing="xy" + ) + xy = torch.stack((x_gt, y_gt), dim=-1).reshape(-1, 2) + + camera_pin_gt = camera_gt.pinhole() + uv_pin, _ = camera_pin_gt.image2world(xy) + + # gt + xy_undist_gt, valid_dist_gt = camera_gt.world2image(uv_pin) + # pred + xy_undist, valid_dist = camera_pred.world2image(uv_pin) + + valid = valid_dist_gt & valid_dist + + dist = (xy_undist - xy_undist_gt) ** 2 + dist = (dist.sum(-1)).sqrt() + + return dist[valid_dist_gt], valid[valid_dist_gt] + + +def compute_camera_metrics( + camera_pred: BaseCamera, camera_gt: BaseCamera, thresholds: List[float] +) -> Dict[str, float]: + results = defaultdict(list) + results["vfov"].append(rad2deg(camera_pred.vfov).item()) + results["vfov_error"].append(vfov_error(camera_pred, camera_gt).item()) + + results["focal"].append(camera_pred.f[..., 1].item()) + focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1]) + results["focal_error"].append(focal_error.item()) + + rel_focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1]) / camera_gt.f[..., 1] + results["rel_focal_error"].append(rel_focal_error.item()) + + if hasattr(camera_pred, "k1"): + results["k1"].append(camera_pred.k1.item()) + k1_error = torch.abs(camera_pred.k1 - camera_gt.k1) + results["k1_error"].append(k1_error.item()) + + if thresholds is None: + return results + + err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=False) + for th in thresholds: + results[f"pixel_projection_error@{th}"].append( + ((err[valid] < th).sum() / len(valid)).float().item() + ) + + err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=True) + for th in thresholds: + results[f"pixel_distortion_error@{th}"].append( + ((err[valid] < th).sum() / len(valid)).float().item() + ) + return results + + +def compute_gravity_metrics(gravity_pred: Gravity, gravity_gt: Gravity) -> Dict[str, float]: + results = defaultdict(list) + results["roll"].append(rad2deg(gravity_pred.roll).item()) + results["pitch"].append(rad2deg(gravity_pred.pitch).item()) + + results["roll_error"].append(roll_error(gravity_pred, gravity_gt).item()) + results["pitch_error"].append(pitch_error(gravity_pred, gravity_gt).item()) + results["gravity_error"].append(gravity_error(gravity_pred[None], gravity_gt[None]).item()) + return results + + +class SimplePipeline(EvalPipeline): + default_conf = { + "data": {}, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": None, # url to benchmark.zip + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + def _init(self, conf): + self.verbose = conf.eval.verbose + self.num_vis = self.conf.eval.num_vis + + self.CameraModel = Pinhole + + if conf.url is not None: + ds_dir = Path(conf.data.dataset_dir) + download_and_extract_benchmark(ds_dir.name, conf.url, ds_dir.parent) + + @classmethod + def get_dataloader(cls, data_conf=None, batch_size=None): + """Returns a data loader with samples for each eval datapoint""" + data_conf = data_conf or cls.default_conf["data"] + + if batch_size is not None: + data_conf["test_batch_size"] = batch_size + + do_shuffle = data_conf["test_batch_size"] > 1 + dataset = get_dataset(data_conf["name"])(data_conf) + return dataset.get_data_loader("test", shuffle=do_shuffle) + + def get_predictions(self, experiment_dir, model=None, overwrite=False): + """Export a prediction file for each eval datapoint""" + # set_seed(0) + pred_file = experiment_dir / "predictions.h5" + if not pred_file.exists() or overwrite: + if model is None: + model = load_model(self.conf.model, self.conf.checkpoint) + export_predictions( + self.get_dataloader(self.conf.data), + model, + pred_file, + keys=self.export_keys, + optional_keys=self.optional_export_keys, + verbose=self.verbose, + ) + return pred_file + + def get_figures(self, results): + figures = {} + + if self.num_vis == 0: + return figures + + gl = ["up", "latitude"] + rpf = ["roll", "pitch", "vfov"] + + # check if rpf in results + if all(k in results for k in rpf): + x_keys = [f"{k}_gt" for k in rpf] + + # gt vs error + y_keys = [f"{k}_error" for k in rpf] + fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False) + figures |= {"rpf_gt_error": fig} + + # gt vs pred + y_keys = [f"{k}" for k in rpf] + fig, _ = plot_scatter_grid(results, x_keys, y_keys, diag=True, show_means=False) + figures |= {"rpf_gt_pred": fig} + + if all(f"{k}_error" in results for k in gl): + x_keys = [f"{k}_gt" for k in rpf] + y_keys = [f"{k}_error" for k in gl] + fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False) + figures |= {"gl_gt_error": fig} + + return figures + + def run_eval(self, loader, pred_file): + conf = self.conf.eval + results = defaultdict(list) + + save_to = Path(pred_file).parent / "figures" + if not save_to.exists() and self.num_vis > 0: + save_to.mkdir() + + cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() + + if not self.verbose: + logger.info(f"Evaluating {pred_file}") + + for i, data in enumerate( + tqdm(loader, desc="Evaluating", total=len(loader), ncols=80, disable=not self.verbose) + ): + # NOTE: data is batched but pred is not + pred = cache_loader(data) + + results["names"].append(data["name"][0]) + + gt_cam = data["camera"][0] + gt_gravity = data["gravity"][0] + # add gt parameters + results["roll_gt"].append(rad2deg(gt_gravity.roll).item()) + results["pitch_gt"].append(rad2deg(gt_gravity.pitch).item()) + results["vfov_gt"].append(rad2deg(gt_cam.vfov).item()) + results["focal_gt"].append(gt_cam.f[1].item()) + + results["k1_gt"].append(gt_cam.k1.item()) + + if "camera" in pred: + # pred["camera"] is a tensor of the parameters + pred_cam = self.CameraModel(pred["camera"]) + + pred_camera = pred_cam[None].undo_scale_crop(data)[0] + gt_camera = gt_cam[None].undo_scale_crop(data)[0] + + camera_metrics = compute_camera_metrics( + pred_camera, gt_camera, conf.pixel_thresholds + ) + + for k, v in camera_metrics.items(): + results[k].extend(v) + + if "focal_uncertainty" in pred: + focal_uncertainty = pred["focal_uncertainty"] + results["focal_uncertainty"].append(focal_uncertainty.item()) + + if "vfov_uncertainty" in pred: + vfov_uncertainty = rad2deg(pred["vfov_uncertainty"]) + results["vfov_uncertainty"].append(vfov_uncertainty.item()) + + if "gravity" in pred: + # pred["gravity"] is a tensor of the parameters + pred_gravity = Gravity(pred["gravity"]) + + gravity_metrics = compute_gravity_metrics(pred_gravity, gt_gravity) + for k, v in gravity_metrics.items(): + results[k].extend(v) + + if "roll_uncertainty" in pred: + roll_uncertainty = rad2deg(pred["roll_uncertainty"]) + results["roll_uncertainty"].append(roll_uncertainty.item()) + + if "pitch_uncertainty" in pred: + pitch_uncertainty = rad2deg(pred["pitch_uncertainty"]) + results["pitch_uncertainty"].append(pitch_uncertainty.item()) + + if "gravity_uncertainty" in pred: + gravity_uncertainty = rad2deg(pred["gravity_uncertainty"]) + results["gravity_uncertainty"].append(gravity_uncertainty.item()) + + if "up_field" in pred: + up_err = up_error(pred["up_field"].unsqueeze(0), data["up_field"]) + results["up_error"].append(up_err.mean(axis=(1, 2)).item()) + results["up_med_error"].append(up_err.median().item()) + + if "up_confidence" in pred: + up_confidence = pred["up_confidence"].unsqueeze(0) + weighted_error = (up_err * up_confidence).sum(axis=(1, 2)) + weighted_error = weighted_error / up_confidence.sum(axis=(1, 2)) + results["up_weighted_error"].append(weighted_error.item()) + + if i < self.num_vis: + pred_batched = add_batch_dim(pred) + up_fig = visualize_batch.make_up_figure(pred=pred_batched, data=data) + up_fig = up_fig["up"] + plt.tight_layout() + viz2d.save_plot(save_to / f"up-{i}-{up_err.median().item():.3f}.jpg") + plt.close() + + if "latitude_field" in pred: + lat_err = latitude_error( + pred["latitude_field"].unsqueeze(0), data["latitude_field"] + ) + results["latitude_error"].append(lat_err.mean(axis=(1, 2)).item()) + results["latitude_med_error"].append(lat_err.median().item()) + + if "latitude_confidence" in pred: + lat_confidence = pred["latitude_confidence"].unsqueeze(0) + weighted_error = (lat_err * lat_confidence).sum(axis=(1, 2)) + weighted_error = weighted_error / lat_confidence.sum(axis=(1, 2)) + results["latitude_weighted_error"].append(weighted_error.item()) + + if i < self.num_vis: + pred_batched = add_batch_dim(pred) + lat_fig = visualize_batch.make_latitude_figure(pred=pred_batched, data=data) + lat_fig = lat_fig["latitude"] + plt.tight_layout() + viz2d.save_plot(save_to / f"latitude-{i}-{lat_err.median().item():.3f}.jpg") + plt.close() + + summaries = {} + for k, v in results.items(): + arr = np.array(v) + if not np.issubdtype(np.array(v).dtype, np.number): + continue + + if k.endswith("_error") or "recall" in k or "pixel" in k: + summaries[f"mean_{k}"] = round(np.nanmean(arr), 3) + summaries[f"median_{k}"] = round(np.nanmedian(arr), 3) + + if any(keyword in k for keyword in ["roll", "pitch", "vfov", "gravity"]): + if not conf.thresholds: + continue + + auc = AUCMetric( + elements=arr, thresholds=list(conf.thresholds), min_error=1 + ).compute() + for i, t in enumerate(conf.thresholds): + summaries[f"auc_{k}@{t}"] = round(auc[i], 3) + + return summaries, self.get_figures(results), results + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(SimplePipeline.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = SimplePipeline(conf) + s, f, r = pipeline.run( + experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/stanford2d3d.py b/siclib/eval/stanford2d3d.py new file mode 100644 index 0000000000000000000000000000000000000000..5380b1a0d68118e9acbdea64a22f4d203962df91 --- /dev/null +++ b/siclib/eval/stanford2d3d.py @@ -0,0 +1,89 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class Stanford2D3D(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/stanford2d3d", + "test_img_dir": "${.dataset_dir}/images", + "test_csv": "${.dataset_dir}/images.csv", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/stanford2d3d.zip", + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(Stanford2D3D.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = Stanford2D3D(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/tartanair.py b/siclib/eval/tartanair.py new file mode 100644 index 0000000000000000000000000000000000000000..e12d78420c7581e2234a00cb0b558954ef545646 --- /dev/null +++ b/siclib/eval/tartanair.py @@ -0,0 +1,89 @@ +import resource +from pathlib import Path +from pprint import pprint + +import matplotlib.pyplot as plt +import torch +from omegaconf import OmegaConf + +from siclib.eval.io import get_eval_parser, parse_eval_args +from siclib.eval.simple_pipeline import SimplePipeline +from siclib.settings import EVAL_PATH + +# flake8: noqa +# mypy: ignore-errors + +rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) +resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) + +torch.set_grad_enabled(False) + + +class Tartanair(SimplePipeline): + default_conf = { + "data": { + "name": "simple_dataset", + "dataset_dir": "data/tartanair", + "test_img_dir": "${.dataset_dir}/images", + "test_csv": "${.dataset_dir}/images.csv", + "augmentations": {"name": "identity"}, + "preprocessing": {"resize": 320, "edge_divisible_by": 32}, + "test_batch_size": 1, + }, + "model": {}, + "eval": { + "thresholds": [1, 5, 10], + "pixel_thresholds": [0.5, 1, 3, 5], + "num_vis": 10, + "verbose": True, + }, + "url": "https://cvg-data.inf.ethz.ch/GeoCalib_ECCV2024/tartanair.zip", + } + + export_keys = [ + "camera", + "gravity", + ] + + optional_export_keys = [ + "focal_uncertainty", + "vfov_uncertainty", + "roll_uncertainty", + "pitch_uncertainty", + "gravity_uncertainty", + "up_field", + "up_confidence", + "latitude_field", + "latitude_confidence", + ] + + +if __name__ == "__main__": + dataset_name = Path(__file__).stem + parser = get_eval_parser() + args = parser.parse_intermixed_args() + + default_conf = OmegaConf.create(Tartanair.default_conf) + + # mingle paths + output_dir = Path(EVAL_PATH, dataset_name) + output_dir.mkdir(exist_ok=True, parents=True) + + name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) + + experiment_dir = output_dir / name + experiment_dir.mkdir(exist_ok=True) + + pipeline = Tartanair(conf) + s, f, r = pipeline.run( + experiment_dir, + overwrite=args.overwrite, + overwrite_eval=args.overwrite_eval, + ) + + pprint(s) + + if args.plot: + for name, fig in f.items(): + fig.canvas.manager.set_window_title(name) + plt.show() diff --git a/siclib/eval/utils.py b/siclib/eval/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9b0ba53848603370aefea2f0fa1b06a5cd2abdd0 --- /dev/null +++ b/siclib/eval/utils.py @@ -0,0 +1,116 @@ +import logging +import shutil +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +def download_and_extract_benchmark(name: str, url: Path, output: Path) -> None: + benchmark_dir = output / name + if not output.exists(): + output.mkdir(parents=True) + + if benchmark_dir.exists(): + logger.info(f"Benchmark {name} already exists at {benchmark_dir}, skipping download.") + return + + if name == "stanford2d3d": + # prompt user to sign data sharing and usage terms + txt = "\n" + "#" * 108 + "\n\n" + txt += "To download the Stanford2D3D dataset, you must agree to the terms of use:\n\n" + txt += ( + "https://docs.google.com/forms/d/e/" + + "1FAIpQLScFR0U8WEUtb7tgjOhhnl31OrkEs73-Y8bQwPeXgebqVKNMpQ/viewform?c=0&w=1\n\n" + ) + txt += "#" * 108 + "\n\n" + txt += "Did you fill out the data sharing and usage terms? [y/n] " + choice = input(txt) + if choice.lower() != "y": + raise ValueError( + "You must agree to the terms of use to download the Stanford2D3D dataset." + ) + + zip_file = output / f"{name}.zip" + + if not zip_file.exists(): + logger.info(f"Downloading benchmark {name} to {zip_file} from {url}.") + torch.hub.download_url_to_file(url, zip_file) + + logger.info(f"Extracting benchmark {name} in {output}.") + shutil.unpack_archive(zip_file, output, format="zip") + zip_file.unlink() + + +def check_keys_recursive(d, pattern): + if isinstance(pattern, dict): + {check_keys_recursive(d[k], v) for k, v in pattern.items()} + else: + for k in pattern: + assert k in d.keys() + + +def plot_scatter_grid( + results, x_keys, y_keys, name=None, diag=False, ax=None, line_idx=0, show_means=True +): # sourcery skip: low-code-quality + if ax is None: + N, M = len(y_keys), len(x_keys) + fig, ax = plt.subplots(N, M, figsize=(M * 6, N * 5)) + + if N == 1: + ax = np.array(ax) + ax = ax.reshape(1, -1) + + if M == 1: + ax = np.array(ax) + ax = ax.reshape(-1, 1) + else: + fig = None + + for j, kx in enumerate(x_keys): + for i, ky in enumerate(y_keys): + ax[i, j].scatter( + results[kx], + results[ky], + s=1, + alpha=0.5, + label=name or None, + ) + + ax[i, j].set_xlabel(f"{' '.join(kx.split('_')).title()}") + ax[i, j].set_ylabel(f"{' '.join(ky.split('_')).title()}") + + low = min(ax[i, j].get_xlim()[0], ax[i, j].get_ylim()[0]) + high = max(ax[i, j].get_xlim()[1], ax[i, j].get_ylim()[1]) + if diag == "all" or (i == j and diag): + ax[i, j].plot([low, high], [low, high], ls="--", c="red", label="y=x") + + if name or diag == "all" or (i == j and diag): + ax[i, j].legend() + + if not show_means: + return fig, ax + + means = {"y": {}, "x": {}} + for kx in x_keys: + for ky in y_keys: + means["x"][kx] = np.mean(results[kx]) + means["y"][ky] = np.mean(results[ky]) + + for j, kx in enumerate(x_keys): + for i, ky in enumerate(y_keys): + xlim = np.min(results[kx]), np.max(results[kx]) + ylim = np.min(results[ky]), np.max(results[ky]) + means_x = [means["x"][kx]] + means_y = [means["y"][ky]] + color = plt.cm.tab10(line_idx) + ax[i, j].vlines(means_x, *ylim, colors=[color]) + ax[i, j].hlines(means_y, *xlim, colors=[color]) + + return fig, ax diff --git a/siclib/geometry/__init__.py b/siclib/geometry/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/geometry/base_camera.py b/siclib/geometry/base_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..1e07615c274f57f3cec64308b5c58dcbb96fc125 --- /dev/null +++ b/siclib/geometry/base_camera.py @@ -0,0 +1,523 @@ +# Adapted from PixLoc, Paul-Edouard Sarlin, ETH Zurich +# https://github.com/cvg/pixloc +# Released under the Apache License 2.0 + +"""Convenience classes a for camera models. + +Based on PyTorch tensors: differentiable, batched, with GPU support. +""" + +from abc import abstractmethod +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import torch +from torch.func import jacfwd, vmap +from torch.nn import functional as F + +from siclib.geometry.gravity import Gravity +from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2rotmat +from siclib.utils.tensor import TensorWrapper, autocast + +# mypy: ignore-errors + + +class BaseCamera(TensorWrapper): + """Camera tensor class.""" + + eps = 1e-3 + + @autocast + def __init__(self, data: torch.Tensor): + """Camera parameters with shape (..., {w, h, fx, fy, cx, cy, *dist}). + + Tensor convention: (..., {w, h, fx, fy, cx, cy, pitch, roll, *dist}) where + - w, h: image size in pixels + - fx, fy: focal lengths in pixels + - cx, cy: principal points in normalized image coordinates + - dist: distortion parameters + + Args: + data (torch.Tensor): Camera parameters with shape (..., {6, 7, 8}). + """ + # w, h, fx, fy, cx, cy, dist + assert data.shape[-1] in {6, 7, 8}, data.shape + + pad = data.new_zeros(data.shape[:-1] + (8 - data.shape[-1],)) + data = torch.cat([data, pad], -1) if data.shape[-1] != 8 else data + super().__init__(data) + + @classmethod + def from_dict(cls, param_dict: Dict[str, torch.Tensor]) -> "BaseCamera": + """Create a Camera object from a dictionary of parameters. + + Args: + param_dict (Dict[str, torch.Tensor]): Dictionary of parameters. + + Returns: + Camera: Camera object. + """ + for key, value in param_dict.items(): + if not isinstance(value, torch.Tensor): + param_dict[key] = torch.tensor(value) + + h, w = param_dict["height"], param_dict["width"] + cx, cy = param_dict.get("cx", w / 2), param_dict.get("cy", h / 2) + + vfov = param_dict.get("vfov") + f = param_dict.get("f", fov2focal(vfov, h)) + + if "dist" in param_dict: + k1, k2 = param_dict["dist"][..., 0], param_dict["dist"][..., 1] + elif "k1_hat" in param_dict: + k1 = param_dict["k1_hat"] * (f / h) ** 2 + + k2 = param_dict.get("k2", torch.zeros_like(k1)) + else: + k1 = param_dict.get("k1", torch.zeros_like(f)) + k2 = param_dict.get("k2", torch.zeros_like(f)) + + fx, fy = f, f + if "scales" in param_dict: + scales = param_dict["scales"] + fx = fx * scales[..., 0] / scales[..., 1] + + params = torch.stack([w, h, fx, fy, cx, cy, k1, k2], dim=-1) + return cls(params) + + def pinhole(self): + """Return the pinhole camera model.""" + return self.__class__(self._data[..., :6]) + + @property + def size(self) -> torch.Tensor: + """Size (width height) of the images, with shape (..., 2).""" + return self._data[..., :2] + + @property + def f(self) -> torch.Tensor: + """Focal lengths (fx, fy) with shape (..., 2).""" + return self._data[..., 2:4] + + @property + def vfov(self) -> torch.Tensor: + """Vertical field of view in radians.""" + return focal2fov(self.f[..., 1], self.size[..., 1]) + + @property + def hfov(self) -> torch.Tensor: + """Horizontal field of view in radians.""" + return focal2fov(self.f[..., 0], self.size[..., 0]) + + @property + def c(self) -> torch.Tensor: + """Principal points (cx, cy) with shape (..., 2).""" + return self._data[..., 4:6] + + @property + def K(self) -> torch.Tensor: + """Returns the self intrinsic matrix with shape (..., 3, 3).""" + shape = self.shape + (3, 3) + K = self._data.new_zeros(shape) + K[..., 0, 0] = self.f[..., 0] + K[..., 1, 1] = self.f[..., 1] + K[..., 0, 2] = self.c[..., 0] + K[..., 1, 2] = self.c[..., 1] + K[..., 2, 2] = 1 + return K + + def update_focal(self, delta: torch.Tensor, as_log: bool = False): + """Update the self parameters after changing the focal length.""" + f = torch.exp(torch.log(self.f) + delta) if as_log else self.f + delta + + # clamp focal length to a reasonable range for stability during training + min_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(150), self.size[..., 1]) + max_f = fov2focal(self.new_ones(self.shape[0]) * deg2rad(5), self.size[..., 1]) + min_f = min_f.unsqueeze(-1).expand(-1, 2) + max_f = max_f.unsqueeze(-1).expand(-1, 2) + f = f.clamp(min=min_f, max=max_f) + + # make sure focal ration stays the same (avoid inplace operations) + fx = f[..., 1] * self.f[..., 0] / self.f[..., 1] + f = torch.stack([fx, f[..., 1]], -1) + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([self.size, f, self.c, dist], -1)) + + def scale(self, scales: Union[float, int, Tuple[Union[float, int]]]): + """Update the self parameters after resizing an image.""" + scales = (scales, scales) if isinstance(scales, (int, float)) else scales + s = scales if isinstance(scales, torch.Tensor) else self.new_tensor(scales) + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([self.size * s, self.f * s, self.c * s, dist], -1)) + + def crop(self, pad: Tuple[float]): + """Update the self parameters after cropping an image.""" + pad = pad if isinstance(pad, torch.Tensor) else self.new_tensor(pad) + size = self.size + pad.to(self.size) + c = self.c + pad.to(self.c) / 2 + + dist = self.dist if hasattr(self, "dist") else self.new_zeros(self.f.shape) + return self.__class__(torch.cat([size, self.f, c, dist], -1)) + + def undo_scale_crop(self, data: Dict[str, torch.Tensor]): + """Undo transforms done during scaling and cropping.""" + camera = self.crop(-data["crop_pad"]) if "crop_pad" in data else self + return camera.scale(1.0 / data["scales"]) + + @autocast + def in_image(self, p2d: torch.Tensor): + """Check if 2D points are within the image boundaries.""" + assert p2d.shape[-1] == 2 + size = self.size.unsqueeze(-2) + return torch.all((p2d >= 0) & (p2d <= (size - 1)), -1) + + @autocast + def project(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Project 3D points into the self plane and check for visibility.""" + z = p3d[..., -1] + valid = z > self.eps + z = z.clamp(min=self.eps) + p2d = p3d[..., :-1] / z.unsqueeze(-1) + return p2d, valid + + def J_project(self, p3d: torch.Tensor): + """Jacobian of the projection function.""" + x, y, z = p3d[..., 0], p3d[..., 1], p3d[..., 2] + zero = torch.zeros_like(z) + z = z.clamp(min=self.eps) + J = torch.stack([1 / z, zero, -x / z**2, zero, 1 / z, -y / z**2], dim=-1) + J = J.reshape(p3d.shape[:-1] + (2, 3)) + return J # N x 2 x 3 + + @abstractmethod + def distort(self, pts: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + raise NotImplementedError("distort() must be implemented.") + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the distortion function.""" + if wrt == "scale2pts": # (..., 2) + J = [ + vmap(jacfwd(lambda x: self[idx].distort(x, return_scale=True)[0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0).squeeze(-3, -2) + + elif wrt == "scale2dist": # (..., 1) + J = [] + for idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[idx, :6], x[None]], -1) + return self.__class__(params).distort(p2d[idx], return_scale=True)[0] + + J.append(vmap(jacfwd(func))(self[idx].dist)) + + return torch.cat(J, dim=0) + + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @abstractmethod + def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + raise NotImplementedError("undistort() must be implemented.") + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + if wrt == "pts": # (..., 2, 2) + J = [ + vmap(jacfwd(lambda x: self[idx].undistort(x)[0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0).squeeze(-3) + + elif wrt == "dist": # (..., 1) + J = [] + for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[batch_idx, :6], x[None]], -1) + return self.__class__(params).undistort(p2d[batch_idx])[0] + + J.append(vmap(jacfwd(func))(self[batch_idx].dist)) + + return torch.cat(J, dim=0) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @autocast + def up_projection_offset(self, p2d: torch.Tensor) -> torch.Tensor: + """Compute the offset for the up-projection.""" + return self.J_distort(p2d, wrt="scale2pts") # (B, N, 2) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the distortion offset for up-projection.""" + if wrt == "uv": # (B, N, 2, 2) + J = [ + vmap(jacfwd(lambda x: self[idx].up_projection_offset(x)[0, 0]))(p2d[idx])[None] + for idx in range(p2d.shape[0]) + ] + + return torch.cat(J, dim=0) + + elif wrt == "dist": # (B, N, 2) + J = [] + for batch_idx in range(p2d.shape[0]): # loop to batch pts dimension + + def func(x): + params = torch.cat([self._data[batch_idx, :6], x[None]], -1)[None] + return self.__class__(params).up_projection_offset(p2d[batch_idx][None]) + + J.append(vmap(jacfwd(func))(self[batch_idx].dist)) + + return torch.cat(J, dim=0).squeeze(1) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + @autocast + def denormalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert normalized 2D coordinates into pixel coordinates.""" + return p2d * self.f.unsqueeze(-2) + self.c.unsqueeze(-2) + + def J_denormalize(self): + """Jacobian of the denormalization function.""" + return torch.diag_embed(self.f) # ..., 2 x 2 + + @autocast + def normalize(self, p2d: torch.Tensor) -> torch.Tensor: + """Convert pixel coordinates into normalized 2D coordinates.""" + return (p2d - self.c.unsqueeze(-2)) / (self.f.unsqueeze(-2)) + + def J_normalize(self, p2d: torch.Tensor, wrt: str = "f"): + """Jacobian of the normalization function.""" + # ... x N x 2 x 2 + if wrt == "f": + J_f = -(p2d - self.c.unsqueeze(-2)) / ((self.f.unsqueeze(-2)) ** 2) + return torch.diag_embed(J_f) + elif wrt == "pts": + J_pts = 1 / self.f + return torch.diag_embed(J_pts) + else: + raise NotImplementedError(f"Jacobian not implemented for wrt={wrt}") + + def pixel_coordinates(self) -> torch.Tensor: + """Pixel coordinates in self frame. + + Returns: + torch.Tensor: Pixel coordinates as a tensor of shape (B, h * w, 2). + """ + w, h = self.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # create grid + x = torch.arange(0, w, dtype=self.dtype, device=self.device) + y = torch.arange(0, h, dtype=self.dtype, device=self.device) + x, y = torch.meshgrid(x, y, indexing="xy") + xy = torch.stack((x, y), dim=-1).reshape(-1, 2) # shape (h * w, 2) + + # add batch dimension (normalize() would broadcast but we make it explicit) + B = self.shape[0] + xy = xy.unsqueeze(0).expand(B, -1, -1) # if B > 0 else xy + + return xy.to(self.device).to(self.dtype) + + def normalized_image_coordinates(self) -> torch.Tensor: + """Normalized image coordinates in self frame. + + Returns: + torch.Tensor: Normalized image coordinates as a tensor of shape (B, h * w, 3). + """ + xy = self.pixel_coordinates() + uv1, _ = self.image2world(xy) + + B = self.shape[0] + uv1 = uv1.reshape(B, -1, 3) + return uv1.to(self.device).to(self.dtype) + + @autocast + def pixel_bearing_many(self, p3d: torch.Tensor) -> torch.Tensor: + """Get the bearing vectors of pixel coordinates. + + Args: + p2d (torch.Tensor): Pixel coordinates as a tensor of shape (..., 3). + + Returns: + torch.Tensor: Bearing vectors as a tensor of shape (..., 3). + """ + return F.normalize(p3d, dim=-1) + + @autocast + def world2image(self, p3d: torch.Tensor) -> Tuple[torch.Tensor]: + """Transform 3D points into 2D pixel coordinates.""" + p2d, visible = self.project(p3d) + p2d, mask = self.distort(p2d) + p2d = self.denormalize(p2d) + valid = visible & mask & self.in_image(p2d) + return p2d, valid + + @autocast + def J_world2image(self, p3d: torch.Tensor): + """Jacobian of the world2image function.""" + p2d_proj, valid = self.project(p3d) + + J_dnorm = self.J_denormalize() + J_dist = self.J_distort(p2d_proj) + J_proj = self.J_project(p3d) + + J = torch.einsum("...ij,...jk,...kl->...il", J_dnorm, J_dist, J_proj) + return J, valid + + @autocast + def image2world(self, p2d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Transform point in the image plane to 3D world coordinates.""" + p2d = self.normalize(p2d) + p2d, valid = self.undistort(p2d) + ones = p2d.new_ones(p2d.shape[:-1] + (1,)) + p3d = torch.cat([p2d, ones], -1) + return p3d, valid + + @autocast + def J_image2world(self, p2d: torch.Tensor, wrt: str = "f") -> Tuple[torch.Tensor, torch.Tensor]: + """Jacobian of the image2world function.""" + if wrt == "dist": + p2d_norm = self.normalize(p2d) + return self.J_undistort(p2d_norm, wrt) + elif wrt == "f": + J_norm2f = self.J_normalize(p2d, wrt) + p2d_norm = self.normalize(p2d) + J_dist2norm = self.J_undistort(p2d_norm, "pts") + + return torch.einsum("...ij,...jk->...ik", J_dist2norm, J_norm2f) + else: + raise ValueError(f"Unknown wrt: {wrt}") + + @autocast + def undistort_image(self, img: torch.Tensor) -> torch.Tensor: + """Undistort an image using the distortion model.""" + assert self.shape[0] == 1, "Batch size must be 1." + W, H = self.size.unbind(-1) + H, W = H.int().item(), W.int().item() + + x, y = torch.arange(0, W), torch.arange(0, H) + x, y = torch.meshgrid(x, y, indexing="xy") + coords = torch.stack((x, y), dim=-1).reshape(-1, 2) + + p3d, _ = self.pinhole().image2world(coords.to(self.device).to(self.dtype)) + p2d, _ = self.world2image(p3d) + + mapx, mapy = p2d[..., 0].reshape((1, H, W)), p2d[..., 1].reshape((1, H, W)) + grid = torch.stack((mapx, mapy), dim=-1) + grid = 2.0 * grid / torch.tensor([W - 1, H - 1]).to(grid) - 1 + return F.grid_sample(img, grid, align_corners=True) + + def get_img_from_pano( + self, + pano_img: torch.Tensor, + gravity: Gravity, + yaws: torch.Tensor = 0.0, + resize_factor: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Render an image from a panorama. + + Args: + pano_img (torch.Tensor): Panorama image of shape (3, H, W) in [0, 1]. + gravity (Gravity): Gravity direction of the camera. + yaws (torch.Tensor | list, optional): Yaw angle in radians. Defaults to 0.0. + resize_factor (torch.Tensor, optional): Resize the panorama to be a multiple of the + field of view. Defaults to 1. + + Returns: + torch.Tensor: Image rendered from the panorama. + """ + B = self.shape[0] + if B > 0: + assert self.size[..., 0].unique().shape[0] == 1, "All images must have the same width." + assert self.size[..., 1].unique().shape[0] == 1, "All images must have the same height." + + w, h = self.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if isinstance(yaws, (int, float)): + yaws = [yaws] + if isinstance(resize_factor, (int, float)): + resize_factor = [resize_factor] + + yaws = ( + yaws.to(self.dtype).to(self.device) + if isinstance(yaws, torch.Tensor) + else self.new_tensor(yaws) + ) + + if isinstance(resize_factor, torch.Tensor): + resize_factor = resize_factor.to(self.dtype).to(self.device) + elif resize_factor is not None: + resize_factor = self.new_tensor(resize_factor) + + assert isinstance(pano_img, torch.Tensor), "Panorama image must be a torch.Tensor." + pano_img = pano_img if pano_img.dim() == 4 else pano_img.unsqueeze(0) # B x H x W x 3 + + pano_imgs = [] + for i, yaw in enumerate(yaws): + if resize_factor is not None: + # resize the panorama such that the fov of the panorama has the same height as the + # image + vfov = self.vfov[i] if B != 0 else self.vfov + scale = np.pi / float(vfov) * float(h) / pano_img.shape[0] * resize_factor[i] + pano_shape = (int(pano_img.shape[0] * scale), int(pano_img.shape[1] * scale)) + + # pano_img = pano_img.permute(2, 0, 1).unsqueeze(0) + mode = "bicubic" if scale >= 1 else "area" + resized_pano = F.interpolate(pano_img, size=pano_shape, mode=mode) + else: + # make sure to copy: resized_pano = pano_img + resized_pano = pano_img + pano_shape = pano_img.shape[-2:][::-1] + + pano_imgs.append((resized_pano, pano_shape)) + + xy = self.pixel_coordinates() + uv1, valid = self.image2world(xy) + bearings = self.pixel_bearing_many(uv1) + + # rotate bearings + R_yaw = rad2rotmat(self.new_zeros(yaw.shape), self.new_zeros(yaw.shape), yaws) + rotated_bearings = bearings @ gravity.R @ R_yaw + + # spherical coordinates + lon = torch.atan2(rotated_bearings[..., 0], rotated_bearings[..., 2]) + lat = torch.atan2( + rotated_bearings[..., 1], torch.norm(rotated_bearings[..., [0, 2]], dim=-1) + ) + + images = [] + for idx, (resized_pano, pano_shape) in enumerate(pano_imgs): + min_lon, max_lon = -torch.pi, torch.pi + min_lat, max_lat = -torch.pi / 2.0, torch.pi / 2.0 + min_x, max_x = 0, pano_shape[0] - 1.0 + min_y, max_y = 0, pano_shape[1] - 1.0 + + # map Spherical Coordinates to Panoramic Coordinates + nx = (lon[idx] - min_lon) / (max_lon - min_lon) * (max_x - min_x) + min_x + ny = (lat[idx] - min_lat) / (max_lat - min_lat) * (max_y - min_y) + min_y + + # reshape and cast to numpy for remap + mapx = nx.reshape((1, h, w)) + mapy = ny.reshape((1, h, w)) + + grid = torch.stack((mapx, mapy), dim=-1) # Add batch dimension + # Normalize to [-1, 1] + grid = 2.0 * grid / torch.tensor([pano_shape[-2] - 1, pano_shape[-1] - 1]).to(grid) - 1 + # Apply grid sample + image = F.grid_sample(resized_pano, grid, align_corners=True) + images.append(image) + + return torch.concatenate(images, 0) if B > 0 else images[0] + + def __repr__(self): + """Print the Camera object.""" + return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}" diff --git a/siclib/geometry/camera.py b/siclib/geometry/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..02e1139357036be37e3d15371676a97e67d4fdc9 --- /dev/null +++ b/siclib/geometry/camera.py @@ -0,0 +1,280 @@ +"""Implementation of the pinhole, simple radial, and simple divisional camera models.""" + +from typing import Tuple + +import torch + +from siclib.geometry.base_camera import BaseCamera +from siclib.utils.tensor import autocast + +# flake8: noqa: E741 + +# mypy: ignore-errors + + +class Pinhole(BaseCamera): + """Implementation of the pinhole camera model.""" + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates.""" + if return_scale: + return p2d.new_ones(p2d.shape[:-1] + (1,)) + + return p2d, p2d.new_ones((p2d.shape[0], 1)).bool() + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the distortion function.""" + if wrt == "pts": + return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2)) + else: + raise ValueError(f"Unknown wrt: {wrt}") + + def undistort(self, pts: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates.""" + return pts, pts.new_ones((pts.shape[0], 1)).bool() + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + if wrt == "pts": + return torch.eye(2, device=p2d.device, dtype=p2d.dtype).expand(p2d.shape[:-1] + (2, 2)) + else: + raise ValueError(f"Unknown wrt: {wrt}") + + +class SimpleRadial(BaseCamera): + """Implementation of the simple radial camera model.""" + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., 1).""" + return self._data[..., 6:] + + @property + def k1(self) -> torch.Tensor: + """Distortion parameters, with shape (...).""" + return self._data[..., 6] + + @property + def k1_hat(self) -> torch.Tensor: + """Distortion parameters, with shape (...).""" + return self.k1 / (self.f[..., 1] / self.size[..., 1]) ** 2 + + def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-0.7, 0.7)): + """Update the self parameters after changing the k1 distortion parameter.""" + delta_dist = self.new_ones(self.dist.shape) * delta + dist = (self.dist + delta_dist).clamp(*dist_range) + data = torch.cat([self.size, self.f, self.c, dist], -1) + return self.__class__(data) + + @autocast + def check_valid(self, p2d: torch.Tensor) -> torch.Tensor: + """Check if the distorted points are valid.""" + return p2d.new_ones(p2d.shape[:-1]).bool() + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 + self.k1[..., None, None] * r2 + + if return_scale: + return radial, None + + return p2d * radial, self.check_valid(p2d) + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"): + """Jacobian of the distortion function.""" + k1 = self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + if wrt == "pts": # (..., 2, 2) + radial = 1 + k1 * r2 + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + return (2 * k1 * ppT) + torch.diag_embed(radial.expand(radial.shape[:-1] + (2,))) + elif wrt == "dist": # (..., 2) + return r2 * p2d + elif wrt == "scale2dist": # (..., 1) + return r2 + elif wrt == "scale2pts": # (..., 2) + return 2 * k1 * p2d + else: + return super().J_distort(p2d, wrt) + + @autocast + def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + b1 = -self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 + b1 * r2 + return p2d * radial, self.check_valid(p2d) + + @autocast + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + b1 = -self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + if wrt == "dist": + return -r2 * p2d + elif wrt == "pts": + radial = 1 + b1 * r2 + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + return (2 * b1[..., None] * ppT) + torch.diag_embed( + radial.expand(radial.shape[:-1] + (2,)) + ) + else: + return super().J_undistort(p2d, wrt) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the up-projection offset.""" + if wrt == "uv": # (..., 2, 2) + return torch.diag_embed((2 * self.k1[..., None, None]).expand(p2d.shape[:-1] + (2,))) + elif wrt == "dist": + return 2 * p2d # (..., 2) + else: + return super().J_up_projection_offset(p2d, wrt) + + +class SimpleDivisional(BaseCamera): + """Implementation of the simple divisional camera model.""" + + @property + def dist(self) -> torch.Tensor: + """Distortion parameters, with shape (..., 1).""" + return self._data[..., 6:] + + @property + def k1(self) -> torch.Tensor: + """Distortion parameters, with shape (...).""" + return self._data[..., 6] + + def update_dist(self, delta: torch.Tensor, dist_range: Tuple[float, float] = (-3.0, 3.0)): + """Update the self parameters after changing the k1 distortion parameter.""" + delta_dist = self.new_ones(self.dist.shape) * delta + dist = (self.dist + delta_dist).clamp(*dist_range) + data = torch.cat([self.size, self.f, self.c, dist], -1) + return self.__class__(data) + + @autocast + def check_valid(self, p2d: torch.Tensor) -> torch.Tensor: + """Check if the distorted points are valid.""" + return p2d.new_ones(p2d.shape[:-1]).bool() + + def distort(self, p2d: torch.Tensor, return_scale: bool = False) -> Tuple[torch.Tensor]: + """Distort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + radial = 1 - torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=0)) + denom = 2 * self.k1[..., None, None] * r2 + + ones = radial.new_ones(radial.shape) + radial = torch.where(denom == 0, ones, radial / denom.masked_fill(denom == 0, 1e6)) + + if return_scale: + return radial, None + + return p2d * radial, self.check_valid(p2d) + + def J_distort(self, p2d: torch.Tensor, wrt: str = "pts"): + """Jacobian of the distortion function.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + t0 = torch.sqrt((1 - 4 * self.k1[..., None, None] * r2).clamp(min=1e-6)) + if wrt == "scale2pts": # (B, N, 2) + d1 = t0 * 2 * r2 + d2 = self.k1[..., None, None] * r2**2 + denom = d1 * d2 + return p2d * (4 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6) + + elif wrt == "scale2dist": + d1 = 2 * self.k1[..., None, None] * t0 + d2 = 2 * r2 * self.k1[..., None, None] ** 2 + denom = d1 * d2 + return (2 * d2 - (1 - t0) * d1) / denom.masked_fill(denom == 0, 1e6) + + else: + return super().J_distort(p2d, wrt) + + @autocast + def undistort(self, p2d: torch.Tensor) -> Tuple[torch.Tensor]: + """Undistort normalized 2D coordinates and check for validity of the distortion model.""" + r2 = torch.sum(p2d**2, -1, keepdim=True) + denom = 1 + self.k1[..., None, None] * r2 + radial = 1 / denom.masked_fill(denom == 0, 1e6) + return p2d * radial, self.check_valid(p2d) + + def J_undistort(self, p2d: torch.Tensor, wrt: str = "pts") -> torch.Tensor: + """Jacobian of the undistortion function.""" + # return super().J_undistort(p2d, wrt) + r2 = torch.sum(p2d**2, -1, keepdim=True) + k1 = self.k1[..., None, None] + if wrt == "dist": + denom = (1 + k1 * r2) ** 2 + return -r2 / denom.masked_fill(denom == 0, 1e6) * p2d + elif wrt == "pts": + t0 = 1 + k1 * r2 + t0 = t0.masked_fill(t0 == 0, 1e6) + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + J = torch.diag_embed((1 / t0).expand(p2d.shape[:-1] + (2,))) + return J - 2 * k1[..., None] * ppT / t0[..., None] ** 2 # (..., N, 2, 2) + + else: + return super().J_undistort(p2d, wrt) + + def J_up_projection_offset(self, p2d: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Jacobian of the up-projection offset. + + func(uv, dist) = 4 / (2 * norm2(uv)^2 * (1-4*k1*norm2(uv)^2)^0.5) * uv + - (1-(1-4*k1*norm2(uv)^2)^0.5) / (k1 * norm2(uv)^4) * uv + """ + k1 = self.k1[..., None, None] + r2 = torch.sum(p2d**2, -1, keepdim=True) + t0 = (1 - 4 * k1 * r2).clamp(min=1e-6) + t1 = torch.sqrt(t0) + if wrt == "dist": + denom = 4 * t0 ** (3 / 2) + denom = denom.masked_fill(denom == 0, 1e6) + J = 16 / denom + + denom = r2 * t1 * k1 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 2 / denom + + denom = (r2 * k1) ** 2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (1 - t1) / denom + + return J * p2d + elif wrt == "uv": + # ! unstable (gradient checker might fail), rewrite to use single division (by denom) + ppT = torch.einsum("...i,...j->...ij", p2d, p2d) # (..., 2, 2) + + denom = 2 * r2 * t1 + denom = denom.masked_fill(denom == 0, 1e6) + J = torch.diag_embed((4 / denom).expand(p2d.shape[:-1] + (2,))) + + denom = 4 * t1 * r2**2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 16 / denom[..., None] * ppT + + denom = 4 * r2 * t0 ** (3 / 2) + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (32 * k1[..., None]) / denom[..., None] * ppT + + denom = r2**2 * t1 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - 4 / denom[..., None] * ppT + + denom = k1 * r2**3 + denom = denom.masked_fill(denom == 0, 1e6) + J = J + (4 * (1 - t1) / denom)[..., None] * ppT + + denom = k1 * r2**2 + denom = denom.masked_fill(denom == 0, 1e6) + J = J - torch.diag_embed(((1 - t1) / denom).expand(p2d.shape[:-1] + (2,))) + + return J + else: + return super().J_up_projection_offset(p2d, wrt) + + +camera_models = { + "pinhole": Pinhole, + "simple_radial": SimpleRadial, + "simple_divisional": SimpleDivisional, +} diff --git a/siclib/geometry/gradient_checker.py b/siclib/geometry/gradient_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..6e55fbfa9a9d5a4a5bc33724902f0a79a5906567 --- /dev/null +++ b/siclib/geometry/gradient_checker.py @@ -0,0 +1,571 @@ +"""Test gradient implementations.""" + +import logging +import unittest + +import torch +from torch.func import jacfwd, vmap + +from siclib.geometry.camera import camera_models +from siclib.geometry.gravity import Gravity +from siclib.geometry.jacobians import J_up_projection +from siclib.geometry.manifolds import SphericalManifold +from siclib.geometry.perspective_fields import J_perspective_field, get_perspective_field +from siclib.models.optimization.lm_optimizer import LMOptimizer +from siclib.utils.conversions import deg2rad, fov2focal + +# flake8: noqa E731 +# mypy: ignore-errors + +H, W = 320, 320 + +K1 = -0.1 + +# CAMERA_MODEL = "pinhole" +CAMERA_MODEL = "simple_radial" +# CAMERA_MODEL = "simple_divisional" + +Camera = camera_models[CAMERA_MODEL] + +# detect anomaly +torch.autograd.set_detect_anomaly(True) + + +logger = logging.getLogger("geocalib.models.base_model") +logger.setLevel("ERROR") + + +def get_toy_rpf(roll=None, pitch=None, vfov=None) -> torch.Tensor: + """Return a random roll, pitch, focal length if not specified.""" + if roll is None: + roll = deg2rad((torch.rand(1) - 0.5) * 90) # -45 ~ 45 + elif not isinstance(roll, torch.Tensor): + roll = torch.tensor(deg2rad(roll)).unsqueeze(0) + + if pitch is None: + pitch = deg2rad((torch.rand(1) - 0.5) * 90) # -45 ~ 45 + elif not isinstance(pitch, torch.Tensor): + pitch = torch.tensor(deg2rad(pitch)).unsqueeze(0) + + if vfov is None: + vfov = deg2rad(5 + torch.rand(1) * 75) # 5 ~ 80 + elif not isinstance(vfov, torch.Tensor): + vfov = torch.tensor(deg2rad(vfov)).unsqueeze(0) + + return torch.stack([roll, pitch, fov2focal(vfov, H)], dim=-1).float() + + +class TestJacobianFunctions(unittest.TestCase): + """Test the jacobian functions.""" + + eps = 5e-3 + + def validate(self, J: torch.Tensor, J_auto: torch.Tensor): + """Check if the jacobians are close and finite.""" + self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical") + self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto") + + text_j = f" > {self.eps}\nJ:\n{J[0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0].numpy()}" + max_diff = torch.max(torch.abs(J - J_auto)) + text = f"Overall - max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J, J_auto, atol=self.eps), text) + + def test_spherical_plus(self): + """Test the spherical plus operator.""" + rpf = get_toy_rpf() + gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1]) + J = gravity.J_update(spherical=True) + + # auto jacobian + delta = gravity.vec3d.new_zeros(gravity.vec3d.shape)[..., :-1] + + def spherical_plus(delta: torch.Tensor) -> torch.Tensor: + """Plus operator.""" + return SphericalManifold.plus(gravity.vec3d, delta) + + J_auto = vmap(jacfwd(spherical_plus))(delta).squeeze(0) + + self.validate(J, J_auto) + + def test_up_projection_uv(self): + """Test the up projection jacobians.""" + rpf = get_toy_rpf() + + r, p, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + gravity = Gravity.from_rp(r, p) + uv = camera.normalize(camera.pixel_coordinates()) + + J = J_up_projection(uv, gravity.vec3d, "uv") + + # auto jacobian + def projection_uv(uv: torch.Tensor) -> torch.Tensor: + """Projection.""" + abc = gravity.vec3d + projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv + return projected_up2d[0, 0] + + J_auto = vmap(jacfwd(projection_uv))(uv[0])[None] + + self.validate(J, J_auto) + + def test_up_projection_abc(self): + """Test the up projection jacobians.""" + rpf = get_toy_rpf() + + r, p, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + gravity = Gravity.from_rp(r, p) + uv = camera.normalize(camera.pixel_coordinates()) + J = J_up_projection(uv, gravity.vec3d, "abc") + + # auto jacobian + def projection_abc(abc: torch.Tensor) -> torch.Tensor: + """Projection.""" + return abc[..., None, :2] - abc[..., 2, None, None] * uv + + J_auto = vmap(jacfwd(projection_abc))(gravity.vec3d)[0] + + self.validate(J, J_auto) + + def test_undistort_pts(self): + """Test the undistortion jacobians.""" + if CAMERA_MODEL == "pinhole": + return + + rpf = get_toy_rpf() + _, _, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + uv = camera.normalize(camera.pixel_coordinates()) + J = camera.J_undistort(uv, "pts") + + # auto jacobian + def func_pts(pts): + return camera.undistort(pts)[0][0] + + J_auto = vmap(jacfwd(func_pts))(uv[0])[None].squeeze(-3) + + self.validate(J, J_auto) + + def test_undistort_k1(self): + """Test the undistortion jacobians.""" + if CAMERA_MODEL == "pinhole": + return + + rpf = get_toy_rpf() + _, _, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + uv = camera.normalize(camera.pixel_coordinates()) + J = camera.J_undistort(uv, "dist") + + # auto jacobian + def func_k1(k1): + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + return camera.undistort(uv)[0][0] + + J_auto = vmap(jacfwd(func_k1))(camera.dist[..., :1]).squeeze(-1) + + self.validate(J, J_auto) + + def test_up_projection_offset(self): + """Test the up projection offset jacobians.""" + if CAMERA_MODEL == "pinhole": + return + + rpf = get_toy_rpf() + # J = up_projection_offset(rpf) + _, _, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + uv = camera.normalize(camera.pixel_coordinates()) + J = camera.up_projection_offset(uv) + + # auto jacobian + def projection_uv(uv: torch.Tensor) -> torch.Tensor: + """Projection.""" + s, _ = camera.distort(uv, return_scale=True) + return s[0, 0, 0] + + J_auto = vmap(jacfwd(projection_uv))(uv[0])[None].squeeze(-2) + + self.validate(J, J_auto) + + def test_J_up_projection_offset_uv(self): + """Test the up projection offset jacobians.""" + if CAMERA_MODEL == "pinhole": + return + + rpf = get_toy_rpf() + _, _, f = rpf.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": [K1]}) + uv = camera.normalize(camera.pixel_coordinates()) + J = camera.J_up_projection_offset(uv, "uv") + + # auto jacobian + def projection_uv(uv: torch.Tensor) -> torch.Tensor: + """Projection.""" + return camera.up_projection_offset(uv)[0, 0] + + J_auto = vmap(jacfwd(projection_uv))(uv[0])[None] + + # print(J.shape, J_auto.shape) + + self.validate(J, J_auto) + + +class TestEuclidean(unittest.TestCase): + """Test the Euclidean manifold jacobians.""" + + eps = 5e-3 + + def validate(self, J: torch.Tensor, J_auto: torch.Tensor): + """Check if the jacobians are close and finite.""" + self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical") + self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto") + + # print(f"analytical:\n{J[0, 0, 0].numpy()}\nauto:\n{J_auto[0, 0, 0].numpy()}") + + text_j = f" > {self.eps}\nJ:\n{J[0, 0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0, 0].numpy()}" + + J_up2grav = J[..., :2, :2] + J_up2grav_auto = J_auto[..., :2, :2] + max_diff = torch.max(torch.abs(J_up2grav - J_up2grav_auto)) + text = f"UP - GRAV max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_up2grav, J_up2grav_auto, atol=self.eps), text) + + J_up2focal = J[..., :2, 2] + J_up2focal_auto = J_auto[..., :2, 2] + max_diff = torch.max(torch.abs(J_up2focal - J_up2focal_auto)) + text = f"UP - FOCAL max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_up2focal, J_up2focal_auto, atol=self.eps), text) + + if CAMERA_MODEL != "pinhole": + J_up2k1 = J[..., :2, 3] + J_up2k1_auto = J_auto[..., :2, 3] + max_diff = torch.max(torch.abs(J_up2k1 - J_up2k1_auto)) + text = f"UP - K1 max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_up2k1, J_up2k1_auto, atol=self.eps), text) + + J_lat2grav = J[..., 2:, :2] + J_lat2grav_auto = J_auto[..., 2:, :2] + max_diff = torch.max(torch.abs(J_lat2grav - J_lat2grav_auto)) + text = f"LAT - GRAV max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2grav, J_lat2grav_auto, atol=self.eps), text) + + J_lat2focal = J[..., 2:, 2] + J_lat2focal_auto = J_auto[..., 2:, 2] + max_diff = torch.max(torch.abs(J_lat2focal - J_lat2focal_auto)) + text = f"LAT - FOCAL max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2focal, J_lat2focal_auto, atol=self.eps), text) + + if CAMERA_MODEL != "pinhole": + J_lat2k1 = J[..., 2:, 3] + J_lat2k1_auto = J_auto[..., 2:, 3] + max_diff = torch.max(torch.abs(J_lat2k1 - J_lat2k1_auto)) + text = f"LAT - K1 max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2k1, J_lat2k1_auto, atol=self.eps), text) + + max_diff = torch.max(torch.abs(J - J_auto[..., : J.shape[-1]])) + text = f"Overall - max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J, J_auto[..., : J.shape[-1]], atol=self.eps), text) + + def local_pf_calc(self, rpfk: torch.Tensor): + """Calculate the perspective field.""" + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + up, lat = get_perspective_field(camera, gravity) + persp = torch.cat([up, torch.sin(lat)], dim=-3) + return persp.permute(0, 2, 3, 1).reshape(1, -1, 3) + + def test_random(self): + """Random rpf.""" + rpf = get_toy_rpf() + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2) + J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_zero_roll(self): + """Roll = 0.""" + rpf = get_toy_rpf(roll=0) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2) + J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_zero_pitch(self): + """Pitch = 0.""" + rpf = get_toy_rpf(pitch=0) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2) + J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_max_roll(self): + """Roll = -45, 45.""" + for roll in [-45, 45]: + rpf = get_toy_rpf(roll=roll) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2) + J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_max_pitch(self): + """Pitch = -45, 45.""" + for pitch in [-45, 45]: + rpf = get_toy_rpf(pitch=pitch) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=False), -2) + J_auto = jacfwd(self.local_pf_calc)(rpfk).squeeze(-2, -3).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + +class TestSpherical(unittest.TestCase): + """Test the spherical manifold jacobians.""" + + eps = 5e-3 + + def validate(self, J: torch.Tensor, J_auto: torch.Tensor): + """Check if the jacobians are close and finite.""" + self.assertTrue(torch.all(torch.isfinite(J)), "found nan in numerical") + self.assertTrue(torch.all(torch.isfinite(J_auto)), "found nan in auto") + + text_j = f" > {self.eps}\nJ:\n{J[0, 0, 0].numpy()}\nJ_auto:\n{J_auto[0, 0, 0].numpy()}" + + J_up2grav = J[..., :2, :2] + J_up2grav_auto = J_auto[..., :2, :2] + max_diff = torch.max(torch.abs(J_up2grav - J_up2grav_auto)) + text = f"UP - GRAV max diff is {max_diff:.4f}" + text_j + + self.assertTrue(torch.allclose(J_up2grav, J_up2grav_auto, atol=self.eps), text) + + J_up2focal = J[..., :2, 2] + J_up2focal_auto = J_auto[..., :2, 2] + max_diff = torch.max(torch.abs(J_up2focal - J_up2focal_auto)) + text = f"UP - FOCAL max diff is {max_diff:.4f}" + text_j + + self.assertTrue(torch.allclose(J_up2focal, J_up2focal_auto, atol=self.eps), text) + + if CAMERA_MODEL != "pinhole": + J_up2k1 = J[..., :2, 3] + J_up2k1_auto = J_auto[..., :2, 3] + max_diff = torch.max(torch.abs(J_up2k1 - J_up2k1_auto)) + text = f"UP - K1 max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_up2k1, J_up2k1_auto, atol=self.eps), text) + + J_lat2grav = J[..., 2:, :2] + J_lat2grav_auto = J_auto[..., 2:, :2] + max_diff = torch.max(torch.abs(J_lat2grav - J_lat2grav_auto)) + text = f"LAT - GRAV max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2grav, J_lat2grav_auto, atol=self.eps), text) + + J_lat2focal = J[..., 2:, 2] + J_lat2focal_auto = J_auto[..., 2:, 2] + max_diff = torch.max(torch.abs(J_lat2focal - J_lat2focal_auto)) + text = f"LAT - FOCAL max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2focal, J_lat2focal_auto, atol=self.eps), text) + + if CAMERA_MODEL != "pinhole": + J_lat2k1 = J[..., 2:, 3] + J_lat2k1_auto = J_auto[..., 2:, 3] + max_diff = torch.max(torch.abs(J_lat2k1 - J_lat2k1_auto)) + text = f"LAT - K1 max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J_lat2k1, J_lat2k1_auto, atol=self.eps), text) + + max_diff = torch.max(torch.abs(J - J_auto[..., : J.shape[-1]])) + text = f"Overall - max diff is {max_diff:.4f}" + text_j + self.assertTrue(torch.allclose(J, J_auto[..., : J.shape[-1]], atol=self.eps), text) + + def local_pf_calc(self, uvfk: torch.Tensor, gravity: Gravity): + """Calculate the perspective field.""" + delta, f, k1 = uvfk[..., :2], uvfk[..., 2], uvfk[..., 3] + cam = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + up, lat = get_perspective_field(cam, gravity.update(delta, spherical=True)) + persp = torch.cat([up, torch.sin(lat)], dim=-3) + return persp.permute(0, 2, 3, 1).reshape(1, -1, 3) + + def test_random(self): + """Test random rpf.""" + rpf = get_toy_rpf() + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2) + + uvfk = torch.zeros_like(rpfk) + uvfk[..., 2] = f + uvfk[..., 3] = k1 + func = lambda uvfk: self.local_pf_calc(uvfk, gravity) + J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_zero_roll(self): + """Test roll = 0.""" + rpf = get_toy_rpf(roll=0) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2) + + uvfk = torch.zeros_like(rpfk) + uvfk[..., 2] = f + uvfk[..., 3] = k1 + func = lambda uvfk: self.local_pf_calc(uvfk, gravity) + J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_zero_pitch(self): + """Test pitch = 0.""" + rpf = get_toy_rpf(pitch=0) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2) + + uvfk = torch.zeros_like(rpfk) + uvfk[..., 2] = f + uvfk[..., 3] = k1 + func = lambda uvfk: self.local_pf_calc(uvfk, gravity) + J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_max_roll(self): + """Test roll = -45, 45.""" + for roll in [-45, 45]: + rpf = get_toy_rpf(roll=roll) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2) + + uvfk = torch.zeros_like(rpfk) + uvfk[..., 2] = f + uvfk[..., 3] = k1 + func = lambda uvfk: self.local_pf_calc(uvfk, gravity) + J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + def test_max_pitch(self): + """Test pitch = -45, 45.""" + for pitch in [-45, 45]: + rpf = get_toy_rpf(pitch=pitch) + rpfk = torch.cat([rpf, torch.tensor([[K1]])], dim=-1) + r, p, f, k1 = rpfk.unbind(dim=-1) + camera = Camera.from_dict({"height": [H], "width": [W], "f": f, "k1": k1}) + gravity = Gravity.from_rp(r, p) + + J = torch.cat(J_perspective_field(camera, gravity, spherical=True), -2) + + uvfk = torch.zeros_like(rpfk) + uvfk[..., 2] = f + uvfk[..., 3] = k1 + func = lambda uvfk: self.local_pf_calc(uvfk, gravity) + J_auto = jacfwd(func)(uvfk).squeeze(-2).reshape(1, H, W, 3, 4) + + self.validate(J, J_auto) + + +class TestLM(unittest.TestCase): + """Test the LM optimizer.""" + + eps = 1e-3 + + def test_random_spherical(self): + """Test random rpf.""" + rpf = get_toy_rpf() + gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1]) + camera = Camera.from_dict({"height": [H], "width": [W], "f": rpf[..., 2], "k1": [K1]}) + + up, lat = get_perspective_field(camera, gravity) + + lm = LMOptimizer({"use_spherical_manifold": True, "camera_model": CAMERA_MODEL}) + + out = lm({"up_field": up, "latitude_field": lat}) + + cam_opt = out["camera"] + gravity_opt = out["gravity"] + + if hasattr(cam_opt, "k1"): + text = f"cam_opt: {cam_opt.k1.numpy()} | rpf: {[K1]}" + self.assertTrue( + torch.allclose(cam_opt.k1, torch.tensor([K1]).float(), atol=self.eps), text + ) + + text = f"cam_opt: {cam_opt.f[..., 1].numpy()} | rpf: {rpf[..., 2].numpy()}" + self.assertTrue(torch.allclose(cam_opt.f[..., 1], rpf[..., 2], atol=self.eps), text) + + text = f"gravity_opt.roll: {gravity_opt.roll.numpy()} | rpf: {rpf[..., 0].numpy()}" + self.assertTrue(torch.allclose(gravity_opt.roll, rpf[..., 0], atol=self.eps), text) + + text = f"gravity_opt.pitch: {gravity_opt.pitch.numpy()} | rpf: {rpf[..., 1].numpy()}" + self.assertTrue(torch.allclose(gravity_opt.pitch, rpf[..., 1], atol=self.eps), text) + + def test_random(self): + """Test random rpf.""" + rpf = get_toy_rpf() + gravity = Gravity.from_rp(rpf[..., 0], rpf[..., 1]) + camera = Camera.from_dict({"height": [H], "width": [W], "f": rpf[..., 2], "k1": [K1]}) + + up, lat = get_perspective_field(camera, gravity) + + lm = LMOptimizer({"use_spherical_manifold": False, "camera_model": CAMERA_MODEL}) + out = lm({"up_field": up, "latitude_field": lat}) + + cam_opt = out["camera"] + gravity_opt = out["gravity"] + + if hasattr(cam_opt, "k1"): + text = f"cam_opt: {cam_opt.k1.numpy()} | rpf: {[K1]}" + self.assertTrue( + torch.allclose(cam_opt.k1, torch.tensor([K1]).float(), atol=self.eps), text + ) + + text = f"cam_opt: {cam_opt.f[..., 1].numpy()} | rpf: {rpf[..., 2].numpy()}" + self.assertTrue(torch.allclose(cam_opt.f[..., 1], rpf[..., 2], atol=self.eps), text) + + text = f"gravity_opt.roll: {gravity_opt.roll.numpy()} | rpf: {rpf[..., 0].numpy()}" + self.assertTrue(torch.allclose(gravity_opt.roll, rpf[..., 0], atol=self.eps), text) + + text = f"gravity_opt.pitch: {gravity_opt.pitch.numpy()} | rpf: {rpf[..., 1].numpy()}" + self.assertTrue(torch.allclose(gravity_opt.pitch, rpf[..., 1], atol=self.eps), text) + + +if __name__ == "__main__": + unittest.main() diff --git a/siclib/geometry/gravity.py b/siclib/geometry/gravity.py new file mode 100644 index 0000000000000000000000000000000000000000..18e592f19261375e2c44c1305b809aaef303faf9 --- /dev/null +++ b/siclib/geometry/gravity.py @@ -0,0 +1,128 @@ +"""Tensor class for gravity vector in camera frame.""" + +import torch +from torch.nn import functional as F + +from siclib.geometry.manifolds import EuclideanManifold, SphericalManifold +from siclib.utils.conversions import rad2rotmat +from siclib.utils.tensor import TensorWrapper, autocast + +# mypy: ignore-errors + + +class Gravity(TensorWrapper): + """Gravity vector in camera frame.""" + + eps = 1e-4 + + @autocast + def __init__(self, data: torch.Tensor) -> None: + """Create gravity vector from data. + + Args: + data (torch.Tensor): gravity vector as 3D vector in camera frame. + """ + assert data.shape[-1] == 3, data.shape + + data = F.normalize(data, dim=-1) + + super().__init__(data) + + @classmethod + def from_rp(cls, roll: torch.Tensor, pitch: torch.Tensor) -> "Gravity": + """Create gravity vector from roll and pitch angles.""" + if not isinstance(roll, torch.Tensor): + roll = torch.tensor(roll) + if not isinstance(pitch, torch.Tensor): + pitch = torch.tensor(pitch) + + sr, cr = torch.sin(roll), torch.cos(roll) + sp, cp = torch.sin(pitch), torch.cos(pitch) + return cls(torch.stack([-sr * cp, -cr * cp, sp], dim=-1)) + + @property + def vec3d(self) -> torch.Tensor: + """Return the gravity vector in the representation.""" + return self._data + + @property + def x(self) -> torch.Tensor: + """Return first component of the gravity vector.""" + return self._data[..., 0] + + @property + def y(self) -> torch.Tensor: + """Return second component of the gravity vector.""" + return self._data[..., 1] + + @property + def z(self) -> torch.Tensor: + """Return third component of the gravity vector.""" + return self._data[..., 2] + + @property + def roll(self) -> torch.Tensor: + """Return the roll angle of the gravity vector.""" + roll = torch.asin(-self.x / (torch.sqrt(1 - self.z**2) + self.eps)) + offset = -torch.pi * torch.sign(self.x) + return torch.where(self.y < 0, roll, -roll + offset) + + def J_roll(self) -> torch.Tensor: + """Return the Jacobian of the roll angle of the gravity vector.""" + cp, _ = torch.cos(self.pitch), torch.sin(self.pitch) + cr, sr = torch.cos(self.roll), torch.sin(self.roll) + Jr = self.new_zeros(self.shape + (3,)) + Jr[..., 0] = -cr * cp + Jr[..., 1] = sr * cp + return Jr + + @property + def pitch(self) -> torch.Tensor: + """Return the pitch angle of the gravity vector.""" + return torch.asin(self.z) + + def J_pitch(self) -> torch.Tensor: + """Return the Jacobian of the pitch angle of the gravity vector.""" + cp, sp = torch.cos(self.pitch), torch.sin(self.pitch) + cr, sr = torch.cos(self.roll), torch.sin(self.roll) + + Jp = self.new_zeros(self.shape + (3,)) + Jp[..., 0] = sr * sp + Jp[..., 1] = cr * sp + Jp[..., 2] = cp + return Jp + + @property + def rp(self) -> torch.Tensor: + """Return the roll and pitch angles of the gravity vector.""" + return torch.stack([self.roll, self.pitch], dim=-1) + + def J_rp(self) -> torch.Tensor: + """Return the Jacobian of the roll and pitch angles of the gravity vector.""" + return torch.stack([self.J_roll(), self.J_pitch()], dim=-1) + + @property + def R(self) -> torch.Tensor: + """Return the rotation matrix from the gravity vector.""" + return rad2rotmat(roll=self.roll, pitch=self.pitch) + + def J_R(self) -> torch.Tensor: + """Return the Jacobian of the rotation matrix from the gravity vector.""" + raise NotImplementedError + + def update(self, delta: torch.Tensor, spherical: bool = False) -> "Gravity": + """Update the gravity vector by adding a delta.""" + if spherical: + data = SphericalManifold.plus(self.vec3d, delta) + return self.__class__(data) + + data = EuclideanManifold.plus(self.rp, delta) + return self.from_rp(data[..., 0], data[..., 1]) + + def J_update(self, spherical: bool = False) -> torch.Tensor: + """Return the Jacobian of the update.""" + return SphericalManifold if spherical else EuclideanManifold + + def __repr__(self): + """Print the Camera object.""" + return f"{self.__class__.__name__} {self.shape} {self.dtype} {self.device}" diff --git a/siclib/geometry/jacobians.py b/siclib/geometry/jacobians.py new file mode 100644 index 0000000000000000000000000000000000000000..444acc2a759308e3f7d4d7d5aac196a3ec51fca7 --- /dev/null +++ b/siclib/geometry/jacobians.py @@ -0,0 +1,64 @@ +"""Jacobians for optimization.""" + +import torch + +# flake8: noqa: E741 + + +@torch.jit.script +def J_vecnorm(vec: torch.Tensor) -> torch.Tensor: + """Compute the jacobian of vec / norm2(vec). + + Args: + vec (torch.Tensor): [..., D] tensor. + + Returns: + torch.Tensor: [..., D, D] Jacobian. + """ + D = vec.shape[-1] + norm_x = torch.norm(vec, dim=-1, keepdim=True).unsqueeze(-1) # (..., 1, 1) + + if (norm_x == 0).any(): + norm_x = norm_x + 1e-6 + + xxT = torch.einsum("...i,...j->...ij", vec, vec) # (..., D, D) + identity = torch.eye(D, device=vec.device, dtype=vec.dtype) # (D, D) + + return identity / norm_x - (xxT / norm_x**3) # (..., D, D) + + +@torch.jit.script +def J_focal2fov(focal: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the jacobian of the focal2fov function.""" + return -4 * h / (4 * focal**2 + h**2) + + +@torch.jit.script +def J_up_projection(uv: torch.Tensor, abc: torch.Tensor, wrt: str = "uv") -> torch.Tensor: + """Compute the jacobian of the up-vector projection. + + Args: + uv (torch.Tensor): Normalized image coordinates of shape (..., 2). + abc (torch.Tensor): Gravity vector of shape (..., 3). + wrt (str, optional): Parameter to differentiate with respect to. Defaults to "uv". + + Raises: + ValueError: If the wrt parameter is unknown. + + Returns: + torch.Tensor: Jacobian with respect to the parameter. + """ + if wrt == "uv": + c = abc[..., 2][..., None, None, None] + return -c * torch.eye(2, device=uv.device, dtype=uv.dtype).expand(uv.shape[:-1] + (2, 2)) + + elif wrt == "abc": + J = uv.new_zeros(uv.shape[:-1] + (2, 3)) + J[..., 0, 0] = 1 + J[..., 1, 1] = 1 + J[..., 0, 2] = -uv[..., 0] + J[..., 1, 2] = -uv[..., 1] + return J + + else: + raise ValueError(f"Unknown wrt: {wrt}") diff --git a/siclib/geometry/manifolds.py b/siclib/geometry/manifolds.py new file mode 100644 index 0000000000000000000000000000000000000000..dea879379cba6a6105a245720af6df336fd4e49f --- /dev/null +++ b/siclib/geometry/manifolds.py @@ -0,0 +1,112 @@ +"""Implementation of manifolds.""" + +import logging + +import torch + +logger = logging.getLogger(__name__) + + +class EuclideanManifold: + """Simple euclidean manifold.""" + + @staticmethod + def J_plus(x: torch.Tensor) -> torch.Tensor: + """Plus operator Jacobian.""" + return torch.eye(x.shape[-1]).to(x) + + @staticmethod + def plus(x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: + """Plus operator.""" + return x + delta + + +class SphericalManifold: + """Implementation of the spherical manifold. + + Following the derivation from 'Integrating Generic Sensor Fusion Algorithms with Sound State + Representations through Encapsulation of Manifolds' by Hertzberg et al. (B.2, p. 25). + + Householder transformation following Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by + Golub et al. + """ + + @staticmethod + def householder_vector(x: torch.Tensor) -> torch.Tensor: + """Return the Householder vector and beta. + + Algorithm 5.1.1 (p. 210) from 'Matrix Computations' by Golub et al. (Johns Hopkins Studies + in Mathematical Sciences) but using the nth element of the input vector as pivot instead of + first. + + This computes the vector v with v(n) = 1 and beta such that H = I - beta * v * v^T is + orthogonal and H * x = ||x||_2 * e_n. + + Args: + x (torch.Tensor): [..., n] tensor. + + Returns: + torch.Tensor: v of shape [..., n] + torch.Tensor: beta of shape [...] + """ + sigma = torch.sum(x[..., :-1] ** 2, -1) + xpiv = x[..., -1] + norm = torch.norm(x, dim=-1) + if torch.any(sigma < 1e-7): + sigma = torch.where(sigma < 1e-7, sigma + 1e-7, sigma) + logger.warning("sigma < 1e-7") + + vpiv = torch.where(xpiv < 0, xpiv - norm, -sigma / (xpiv + norm)) + beta = 2 * vpiv**2 / (sigma + vpiv**2) + v = torch.cat([x[..., :-1] / vpiv[..., None], torch.ones_like(vpiv)[..., None]], -1) + return v, beta + + @staticmethod + def apply_householder(y: torch.Tensor, v: torch.Tensor, beta: torch.Tensor) -> torch.Tensor: + """Apply Householder transformation. + + Args: + y (torch.Tensor): Vector to transform of shape [..., n]. + v (torch.Tensor): Householder vector of shape [..., n]. + beta (torch.Tensor): Householder beta of shape [...]. + + Returns: + torch.Tensor: Transformed vector of shape [..., n]. + """ + return y - v * (beta * torch.einsum("...i,...i->...", v, y))[..., None] + + @classmethod + def J_plus(cls, x: torch.Tensor) -> torch.Tensor: + """Plus operator Jacobian.""" + v, beta = cls.householder_vector(x) + H = -torch.einsum("..., ...k, ...l->...kl", beta, v, v) + H = H + torch.eye(H.shape[-1]).to(H) + return H[..., :-1] # J + + @classmethod + def plus(cls, x: torch.Tensor, delta: torch.Tensor) -> torch.Tensor: + """Plus operator. + + Equation 109 (p. 25) from 'Integrating Generic Sensor Fusion Algorithms with Sound State + Representations through Encapsulation of Manifolds' by Hertzberg et al. but using the nth + element of the input vector as pivot instead of first. + + Args: + x: point on the manifold + delta: tangent vector + """ + eps = 1e-7 + # keep norm is not equal to 1 + nx = torch.norm(x, dim=-1, keepdim=True) + nd = torch.norm(delta, dim=-1, keepdim=True) + + # make sure we don't divide by zero in backward as torch.where computes grad for both + # branches + nd_ = torch.where(nd < eps, nd + eps, nd) + sinc = torch.where(nd < eps, nd.new_ones(nd.shape), torch.sin(nd_) / nd_) + + # cos is applied to last dim instead of first + exp_delta = torch.cat([sinc * delta, torch.cos(nd)], -1) + + v, beta = cls.householder_vector(x) + return nx * cls.apply_householder(exp_delta, v, beta) diff --git a/siclib/geometry/perspective_fields.py b/siclib/geometry/perspective_fields.py new file mode 100644 index 0000000000000000000000000000000000000000..44395ba001e084d3bb391ecc11bca7254b5c5296 --- /dev/null +++ b/siclib/geometry/perspective_fields.py @@ -0,0 +1,367 @@ +"""Implementation of perspective fields. + +Adapted from https://github.com/jinlinyi/PerspectiveFields/blob/main/perspective2d/utils/panocam.py +""" + +from typing import Tuple + +import torch +from torch.nn import functional as F + +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.gravity import Gravity +from siclib.geometry.jacobians import J_up_projection, J_vecnorm +from siclib.geometry.manifolds import SphericalManifold + +# flake8: noqa: E266 + + +def get_horizon_line(camera: BaseCamera, gravity: Gravity, relative: bool = True) -> torch.Tensor: + """Get the horizon line from the camera parameters. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + relative (bool, optional): Whether to normalize horizon line by img_h. Defaults to True. + + Returns: + torch.Tensor: In image frame, fraction of image left/right border intersection with + respect to image height. + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + # project horizon midpoint to image plane + horizon_midpoint = camera.new_tensor([0, 0, 1]) + horizon_midpoint = camera.K @ gravity.R @ horizon_midpoint + midpoint = horizon_midpoint[:2] / horizon_midpoint[2] + + # compute left and right offset to borders + left_offset = midpoint[0] * torch.tan(gravity.roll) + right_offset = (camera.size[0] - midpoint[0]) * torch.tan(gravity.roll) + left, right = midpoint[1] + left_offset, midpoint[1] - right_offset + + horizon = camera.new_tensor([left, right]) + return horizon / camera.size[1] if relative else horizon + + +def get_up_field(camera: BaseCamera, gravity: Gravity, normalize: bool = True) -> torch.Tensor: + """Get the up vector field from the camera parameters. + + Args: + camera (Camera): Camera parameters. + normalize (bool, optional): Whether to normalize the up vector. Defaults to True. + + Returns: + torch.Tensor: up vector field as tensor of shape (..., h, w, 2). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + uv = camera.normalize(camera.pixel_coordinates()) + + # projected up is (a, b) - c * (u, v) + abc = gravity.vec3d + projected_up2d = abc[..., None, :2] - abc[..., 2, None, None] * uv # (..., N, 2) + + if hasattr(camera, "dist"): + d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1) + d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2) + offset = camera.up_projection_offset(uv) # (..., N, 2) + offset = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2) + + # (..., N, 2) + projected_up2d = torch.einsum("...Nij,...Nj->...Ni", d_uv + offset, projected_up2d) + + if normalize: + projected_up2d = F.normalize(projected_up2d, dim=-1) # (..., N, 2) + + return projected_up2d.reshape(camera.shape[0], h, w, 2) + + +def J_up_field( + camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False +) -> torch.Tensor: + """Get the jacobian of the up field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + torch.Tensor: Jacobian of the up field as a tensor of shape (..., h, w, 2, 2, 3). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # Forward + xy = camera.pixel_coordinates() + uv = camera.normalize(xy) + + projected_up2d = gravity.vec3d[..., None, :2] - gravity.vec3d[..., 2, None, None] * uv + + # Backward + J = [] + + # (..., N, 2, 2) + J_norm2proj = J_vecnorm( + get_up_field(camera, gravity, normalize=False).reshape(camera.shape[0], -1, 2) + ) + + # distortion values + if hasattr(camera, "dist"): + d_uv = camera.distort(uv, return_scale=True)[0] # (..., N, 1) + d_uv = torch.diag_embed(d_uv.expand(d_uv.shape[:-1] + (2,))) # (..., N, 2, 2) + offset = camera.up_projection_offset(uv) # (..., N, 2) + offset_uv = torch.einsum("...i,...j->...ij", offset, uv) # (..., N, 2, 2) + + ###################### + ## Gravity Jacobian ## + ###################### + + J_proj2abc = J_up_projection(uv, gravity.vec3d, wrt="abc") # (..., N, 2, 3) + + if hasattr(camera, "dist"): + # (..., N, 2, 3) + J_proj2abc = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2abc) + + J_abc2delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp() + J_proj2delta = torch.einsum("...Nij,...jk->...Nik", J_proj2abc, J_abc2delta) + J_up2delta = torch.einsum("...Nij,...Njk->...Nik", J_norm2proj, J_proj2delta) + J.append(J_up2delta) + + ###################### + ### Focal Jacobian ### + ###################### + + J_proj2uv = J_up_projection(uv, gravity.vec3d, wrt="uv") # (..., N, 2, 2) + + if hasattr(camera, "dist"): + J_proj2up = torch.einsum("...Nij,...Njk->...Nik", d_uv + offset_uv, J_proj2uv) + J_proj2duv = torch.einsum("...i,...j->...ji", offset, projected_up2d) + + inner = (uv * projected_up2d).sum(-1)[..., None, None] + J_proj2offset1 = inner * camera.J_up_projection_offset(uv, wrt="uv") + J_proj2offset2 = torch.einsum("...i,...j->...ij", offset, projected_up2d) # (..., N, 2, 2) + J_proj2uv = (J_proj2duv + J_proj2offset1 + J_proj2offset2) + J_proj2up + + J_uv2f = camera.J_normalize(xy) # (..., N, 2, 2) + + if log_focal: + J_uv2f = J_uv2f * camera.f[..., None, None, :] # (..., N, 2, 2) + + J_uv2f = J_uv2f.sum(-1) # (..., N, 2) + + J_proj2f = torch.einsum("...ij,...j->...i", J_proj2uv, J_uv2f) # (..., N, 2) + J_up2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2f)[..., None] # (..., N, 2, 1) + J.append(J_up2f) + + ###################### + ##### K1 Jacobian #### + ###################### + + if hasattr(camera, "dist"): + J_duv = camera.J_distort(uv, wrt="scale2dist") + J_duv = torch.diag_embed(J_duv.expand(J_duv.shape[:-1] + (2,))) # (..., N, 2, 2) + J_offset = torch.einsum( + "...i,...j->...ij", camera.J_up_projection_offset(uv, wrt="dist"), uv + ) + J_proj2k1 = torch.einsum("...Nij,...Nj->...Ni", J_duv + J_offset, projected_up2d) + J_k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2proj, J_proj2k1)[..., None] + J.append(J_k1) + + n_params = sum(j.shape[-1] for j in J) + return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 2, n_params) + + +def get_latitude_field(camera: BaseCamera, gravity: Gravity) -> torch.Tensor: + """Get the latitudes of the camera pixels in radians. + + Latitudes are defined as the angle between the ray and the up vector. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + + Returns: + torch.Tensor: Latitudes in radians as a tensor of shape (..., h, w, 1). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + uv1, _ = camera.image2world(camera.pixel_coordinates()) + rays = camera.pixel_bearing_many(uv1) + + lat = torch.einsum("...Nj,...j->...N", rays, gravity.vec3d) + + eps = 1e-6 + lat_asin = torch.asin(lat.clamp(min=-1 + eps, max=1 - eps)) + + return lat_asin.reshape(camera.shape[0], h, w, 1) + + +def J_latitude_field( + camera: BaseCamera, gravity: Gravity, spherical: bool = False, log_focal: bool = False +) -> torch.Tensor: + """Get the jacobian of the latitude field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + torch.Tensor: Jacobian of the latitude field as a tensor of shape (..., h, w, 1, 3). + """ + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + # Forward + xy = camera.pixel_coordinates() + uv1, _ = camera.image2world(xy) + uv1_norm = camera.pixel_bearing_many(uv1) # (..., N, 3) + + # Backward + J = [] + J_norm2w_to_img = J_vecnorm(uv1)[..., :2] # (..., N, 2) + + ###################### + ## Gravity Jacobian ## + ###################### + + J_delta = SphericalManifold.J_plus(gravity.vec3d) if spherical else gravity.J_rp() + J_delta = torch.einsum("...Ni,...ij->...Nj", uv1_norm, J_delta) # (..., N, 2) + J.append(J_delta) + + ###################### + ### Focal Jacobian ### + ###################### + + J_w_to_img2f = camera.J_image2world(xy, "f") # (..., N, 2, 2) + if log_focal: + J_w_to_img2f = J_w_to_img2f * camera.f[..., None, None, :] + J_w_to_img2f = J_w_to_img2f.sum(-1) # (..., N, 2) + + J_norm2f = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2f) # (..., N, 3) + J_f = torch.einsum("...Ni,...i->...N", J_norm2f, gravity.vec3d).unsqueeze(-1) # (..., N, 1) + J.append(J_f) + + ###################### + ##### K1 Jacobian #### + ###################### + + if hasattr(camera, "dist"): + J_w_to_img2k1 = camera.J_image2world(xy, "dist") # (..., N, 2) + # (..., N, 2) + J_norm2k1 = torch.einsum("...Nij,...Nj->...Ni", J_norm2w_to_img, J_w_to_img2k1) + # (..., N, 1) + J_k1 = torch.einsum("...Ni,...i->...N", J_norm2k1, gravity.vec3d).unsqueeze(-1) + J.append(J_k1) + + n_params = sum(j.shape[-1] for j in J) + return torch.cat(J, axis=-1).reshape(camera.shape[0], h, w, 1, n_params) + + +def get_perspective_field( + camera: BaseCamera, + gravity: Gravity, + use_up: bool = True, + use_latitude: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the perspective field from the camera parameters. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + use_up (bool, optional): Whether to include the up vector field. Defaults to True. + use_latitude (bool, optional): Whether to include the latitude field. Defaults to True. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Up and latitude fields as tensors of shape + (..., 2, h, w) and (..., 1, h, w). + """ + assert use_up or use_latitude, "At least one of use_up or use_latitude must be True." + + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if use_up: + permute = (0, 3, 1, 2) + # (..., 2, h, w) + up = get_up_field(camera, gravity).permute(permute) + else: + shape = (camera.shape[0], 2, h, w) + up = camera.new_zeros(shape) + + if use_latitude: + permute = (0, 3, 1, 2) + # (..., 1, h, w) + lat = get_latitude_field(camera, gravity).permute(permute) + else: + shape = (camera.shape[0], 1, h, w) + lat = camera.new_zeros(shape) + + return up, lat + + +def J_perspective_field( + camera: BaseCamera, + gravity: Gravity, + use_up: bool = True, + use_latitude: bool = True, + spherical: bool = False, + log_focal: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Get the jacobian of the perspective field. + + Args: + camera (Camera): Camera parameters. + gravity (Gravity): Gravity vector. + use_up (bool, optional): Whether to include the up vector field. Defaults to True. + use_latitude (bool, optional): Whether to include the latitude field. Defaults to True. + spherical (bool, optional): Whether to use spherical coordinates. Defaults to False. + log_focal (bool, optional): Whether to use log-focal length. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Up and latitude jacobians as tensors of shape + (..., h, w, 2, 4) and (..., h, w, 1, 4). + """ + assert use_up or use_latitude, "At least one of use_up or use_latitude must be True." + + camera = camera.unsqueeze(0) if len(camera.shape) == 0 else camera + gravity = gravity.unsqueeze(0) if len(gravity.shape) == 0 else gravity + + w, h = camera.size[0].unbind(-1) + h, w = h.round().to(int), w.round().to(int) + + if use_up: + J_up = J_up_field(camera, gravity, spherical, log_focal) # (..., h, w, 2, 4) + else: + shape = (camera.shape[0], h, w, 2, 4) + J_up = camera.new_zeros(shape) + + if use_latitude: + J_lat = J_latitude_field(camera, gravity, spherical, log_focal) # (..., h, w, 1, 4) + else: + shape = (camera.shape[0], h, w, 1, 4) + J_lat = camera.new_zeros(shape) + + return J_up, J_lat diff --git a/siclib/models/__init__.py b/siclib/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2eee392a5a1f927763a532527a48090564e191fb --- /dev/null +++ b/siclib/models/__init__.py @@ -0,0 +1,28 @@ +import importlib.util + +from siclib.models.base_model import BaseModel +from siclib.utils.tools import get_class + + +def get_model(name): + import_paths = [ + name, + f"{__name__}.{name}", + ] + for path in import_paths: + try: + spec = importlib.util.find_spec(path) + except ModuleNotFoundError: + spec = None + if spec is not None: + try: + return get_class(path, BaseModel) + except AssertionError: + mod = __import__(path, fromlist=[""]) + try: + return mod.__main_model__ + except AttributeError as exc: + print(exc) + continue + + raise RuntimeError(f'Model {name} not found in any of [{" ".join(import_paths)}]') diff --git a/siclib/models/base_model.py b/siclib/models/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb793f3ffac32ad353e0653f5a6218f1d6d3fbe --- /dev/null +++ b/siclib/models/base_model.py @@ -0,0 +1,205 @@ +"""Base class for trainable models.""" + +import logging +import re +from abc import ABCMeta, abstractmethod +from copy import copy + +import omegaconf +import torch +from omegaconf import OmegaConf +from torch import nn + +logger = logging.getLogger(__name__) + +try: + import wandb +except ImportError: + logger.debug("Could not import wandb.") + wandb = None + +# flake8: noqa +# mypy: ignore-errors + + +class MetaModel(ABCMeta): + def __prepare__(name, bases, **kwds): + total_conf = OmegaConf.create() + for base in bases: + for key in ("base_default_conf", "default_conf"): + update = getattr(base, key, {}) + if isinstance(update, dict): + update = OmegaConf.create(update) + total_conf = OmegaConf.merge(total_conf, update) + return dict(base_default_conf=total_conf) + + +class BaseModel(nn.Module, metaclass=MetaModel): + """ + What the child model is expect to declare: + default_conf: dictionary of the default configuration of the model. + It recursively updates the default_conf of all parent classes, and + it is updated by the user-provided configuration passed to __init__. + Configurations can be nested. + + required_data_keys: list of expected keys in the input data dictionary. + + strict_conf (optional): boolean. If false, BaseModel does not raise + an error when the user provides an unknown configuration entry. + + _init(self, conf): initialization method, where conf is the final + configuration object (also accessible with `self.conf`). Accessing + unknown configuration entries will raise an error. + + _forward(self, data): method that returns a dictionary of batched + prediction tensors based on a dictionary of batched input data tensors. + + loss(self, pred, data): method that returns a dictionary of losses, + computed from model predictions and input data. Each loss is a batch + of scalars, i.e. a torch.Tensor of shape (B,). + The total loss to be optimized has the key `'total'`. + + metrics(self, pred, data): method that returns a dictionary of metrics, + each as a batch of scalars. + """ + + default_conf = { + "name": None, + "trainable": True, # if false: do not optimize this model parameters + "freeze_batch_normalization": False, # use test-time statistics + "timeit": False, # time forward pass + "watch": False, # log weights and gradients to wandb + } + required_data_keys = [] + strict_conf = False + + def __init__(self, conf): + """Perform some logic and call the _init method of the child model.""" + super().__init__() + default_conf = OmegaConf.merge(self.base_default_conf, OmegaConf.create(self.default_conf)) + if self.strict_conf: + OmegaConf.set_struct(default_conf, True) + + # fixme: backward compatibility + if "pad" in conf and "pad" not in default_conf: # backward compat. + with omegaconf.read_write(conf): + with omegaconf.open_dict(conf): + conf["interpolation"] = {"pad": conf.pop("pad")} + + if isinstance(conf, dict): + conf = OmegaConf.create(conf) + self.conf = conf = OmegaConf.merge(default_conf, conf) + OmegaConf.set_readonly(conf, True) + OmegaConf.set_struct(conf, True) + self.required_data_keys = copy(self.required_data_keys) + self._init(conf) + + # load pretrained weights + if "weights" in conf and conf.weights is not None: + logger.info(f"Loading checkpoint {conf.weights}") + ckpt = torch.load(str(conf.weights), map_location="cpu", weights_only=False) + weights_key = "model" if "model" in ckpt else "state_dict" + self.flexible_load(ckpt[weights_key]) + + if not conf.trainable: + for p in self.parameters(): + p.requires_grad = False + + if conf.watch: + try: + wandb.watch(self, log="all", log_graph=True, log_freq=10) + logger.info(f"Watching {self.__class__.__name__}.") + except ValueError: + logger.warning(f"Could not watch {self.__class__.__name__}.") + + n_trainable = sum(p.numel() for p in self.parameters() if p.requires_grad) + logger.info(f"Creating model {self.__class__.__name__} ({n_trainable/1e6:.2f} Mio)") + + def flexible_load(self, state_dict): + """TODO: fix a probable nasty bug, and move to BaseModel.""" + # replace *gravity* with *up* + for key in list(state_dict.keys()): + if "gravity" in key: + new_key = key.replace("gravity", "up") + state_dict[new_key] = state_dict.pop(key) + # print(f"Renaming {key} to {new_key}") + + # replace *_head.* with *_head.decoder.* for original paramnet checkpoints + for key in list(state_dict.keys()): + if "linear_pred_latitude" in key or "linear_pred_up" in key: + continue + + if "_head" in key and "_head.decoder" not in key: + # check if _head.{num} in key + pattern = r"_head\.\d+" + if re.search(pattern, key): + continue + new_key = key.replace("_head.", "_head.decoder.") + state_dict[new_key] = state_dict.pop(key) + # print(f"Renaming {key} to {new_key}") + + dict_params = set(state_dict.keys()) + model_params = set(map(lambda n: n[0], self.named_parameters())) + + if dict_params == model_params: # perfect fit + logger.info("Loading all parameters of the checkpoint.") + self.load_state_dict(state_dict, strict=True) + return + elif len(dict_params & model_params) == 0: # perfect mismatch + strip_prefix = lambda x: ".".join(x.split(".")[:1] + x.split(".")[2:]) + state_dict = {strip_prefix(n): p for n, p in state_dict.items()} + dict_params = set(state_dict.keys()) + if len(dict_params & model_params) == 0: + raise ValueError( + "Could not manage to load the checkpoint with" + "parameters:" + "\n\t".join(sorted(dict_params)) + ) + common_params = dict_params & model_params + left_params = dict_params - model_params + left_params = [ + p for p in left_params if "running" not in p and "num_batches_tracked" not in p + ] + logger.debug("Loading parameters:\n\t" + "\n\t".join(sorted(common_params))) + if left_params: + # ignore running stats of batchnorm + logger.warning("Could not load parameters:\n\t" + "\n\t".join(sorted(left_params))) + self.load_state_dict(state_dict, strict=False) + + def train(self, mode=True): + super().train(mode) + + def freeze_bn(module): + if isinstance(module, nn.modules.batchnorm._BatchNorm): + module.eval() + + if self.conf.freeze_batch_normalization: + self.apply(freeze_bn) + + return self + + def forward(self, data): + """Check the data and call the _forward method of the child model.""" + + def recursive_key_check(expected, given): + for key in expected: + assert key in given, f"Missing key {key} in data: {list(given.keys())}" + if isinstance(expected, dict): + recursive_key_check(expected[key], given[key]) + + recursive_key_check(self.required_data_keys, data) + return self._forward(data) + + @abstractmethod + def _init(self, conf): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def _forward(self, data): + """To be implemented by the child class.""" + raise NotImplementedError + + @abstractmethod + def loss(self, pred, data): + """To be implemented by the child class.""" + raise NotImplementedError diff --git a/siclib/models/cache_loader.py b/siclib/models/cache_loader.py new file mode 100644 index 0000000000000000000000000000000000000000..23c72b1f557f7e39a1240e5f748ef5c7550ac2a6 --- /dev/null +++ b/siclib/models/cache_loader.py @@ -0,0 +1,109 @@ +import string + +import h5py +import torch + +from siclib.datasets.base_dataset import collate +from siclib.models.base_model import BaseModel +from siclib.settings import DATA_PATH +from siclib.utils.tensor import batch_to_device + +# flake8: noqa +# mypy: ignore-errors + + +def pad_line_features(pred, seq_l: int = None): + raise NotImplementedError + + +def recursive_load(grp, pkeys): + return { + k: ( + torch.from_numpy(grp[k].__array__()) + if isinstance(grp[k], h5py.Dataset) + else recursive_load(grp[k], list(grp.keys())) + ) + for k in pkeys + } + + +class CacheLoader(BaseModel): + default_conf = { + "path": "???", # can be a format string like exports/{scene}/ + "data_keys": None, # load all keys + "device": None, # load to same device as data + "trainable": False, + "add_data_path": True, + "collate": True, + "scale": ["keypoints"], + "padding_fn": None, + "padding_length": None, # required for batching! + "numeric_type": "float32", # [None, "float16", "float32", "float64"] + } + + required_data_keys = ["name"] # we need an identifier + + def _init(self, conf): + self.hfiles = {} + self.padding_fn = conf.padding_fn + if self.padding_fn is not None: + self.padding_fn = eval(self.padding_fn) + self.numeric_dtype = { + None: None, + "float16": torch.float16, + "float32": torch.float32, + "float64": torch.float64, + }[conf.numeric_type] + + def _forward(self, data): # sourcery skip: low-code-quality + preds = [] + device = self.conf.device + if not device: + if devices := {v.device for v in data.values() if isinstance(v, torch.Tensor)}: + assert len(devices) == 1 + device = devices.pop() + + else: + device = "cpu" + + var_names = [x[1] for x in string.Formatter().parse(self.conf.path) if x[1]] + for i, name in enumerate(data["name"]): + fpath = self.conf.path.format(**{k: data[k][i] for k in var_names}) + if self.conf.add_data_path: + fpath = DATA_PATH / fpath + hfile = h5py.File(str(fpath), "r") + grp = hfile[name] + pkeys = self.conf.data_keys if self.conf.data_keys is not None else grp.keys() + pred = recursive_load(grp, pkeys) + if self.numeric_dtype is not None: + pred = { + k: ( + v + if not isinstance(v, torch.Tensor) or not torch.is_floating_point(v) + else v.to(dtype=self.numeric_dtype) + ) + for k, v in pred.items() + } + pred = batch_to_device(pred, device) + for k, v in pred.items(): + for pattern in self.conf.scale: + if k.startswith(pattern): + view_idx = k.replace(pattern, "") + scales = ( + data["scales"] + if len(view_idx) == 0 + else data[f"view{view_idx}"]["scales"] + ) + pred[k] = pred[k] * scales[i] + # use this function to fix number of keypoints etc. + if self.padding_fn is not None: + pred = self.padding_fn(pred, self.conf.padding_length) + preds.append(pred) + hfile.close() + if self.conf.collate: + return batch_to_device(collate(preds), device) + assert len(preds) == 1 + return batch_to_device(preds[0], device) + + def loss(self, pred, data): + raise NotImplementedError diff --git a/siclib/models/decoders/__init__.py b/siclib/models/decoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/models/decoders/fpn.py b/siclib/models/decoders/fpn.py new file mode 100644 index 0000000000000000000000000000000000000000..78d0085b645ac64c0c6ede144474a9d6446214fd --- /dev/null +++ b/siclib/models/decoders/fpn.py @@ -0,0 +1,198 @@ +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from siclib.models import BaseModel +from siclib.models.utils.modules import ConvModule, FeatureFusionBlock + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class DecoderBlock(nn.Module): + def __init__( + self, + previous, + out, + ksize=3, + num_convs=1, + norm_str="BatchNorm2d", + padding="zeros", + fusion="sum", + ): + super().__init__() + + self.fusion = fusion + + if self.fusion == "sum": + self.fusion_layers = nn.Identity() + elif self.fusion == "glu": + self.fusion_layers = nn.Sequential( + nn.Conv2d(2 * out, 2 * out, 1, padding=0, bias=True), + nn.GLU(dim=1), + ) + elif self.fusion == "ff": + self.fusion_layers = FeatureFusionBlock(out, upsample=False) + else: + raise ValueError(f"Unknown fusion: {self.fusion}") + + if norm_str is not None: + norm = getattr(nn, norm_str, None) + + layers = [] + for i in range(num_convs): + conv = nn.Conv2d( + previous if i == 0 else out, + out, + kernel_size=ksize, + padding=ksize // 2, + bias=norm_str is None, + padding_mode=padding, + ) + layers.append(conv) + if norm_str is not None: + layers.append(norm(out)) + layers.append(nn.ReLU(inplace=True)) + self.layers = nn.Sequential(*layers) + + def forward(self, previous, skip): + _, _, hp, wp = previous.shape + _, _, hs, ws = skip.shape + scale = 2 ** np.round(np.log2(np.array([hs / hp, ws / wp]))) + + upsampled = nn.functional.interpolate( + previous, scale_factor=scale.tolist(), mode="bilinear", align_corners=False + ) + # If the shape of the input map `skip` is not a multiple of 2, + # it will not match the shape of the upsampled map `upsampled`. + # If the downsampling uses ceil_mode=False, we need to crop `skip`. + # If it uses ceil_mode=True (not supported here), we should pad it. + _, _, hu, wu = upsampled.shape + _, _, hs, ws = skip.shape + if (hu <= hs) and (wu <= ws): + skip = skip[:, :, :hu, :wu] + elif (hu >= hs) and (wu >= ws): + skip = nn.functional.pad(skip, [0, wu - ws, 0, hu - hs]) + else: + raise ValueError(f"Inconsistent skip vs upsampled shapes: {(hs, ws)}, {(hu, wu)}") + + skip = skip.clone() + feats_skip = self.layers(skip) + if self.fusion == "sum": + return self.fusion_layers(feats_skip + upsampled) + elif self.fusion == "glu": + x = torch.cat([feats_skip, upsampled], dim=1) + return self.fusion_layers(x) + elif self.fusion == "ff": + return self.fusion_layers(feats_skip, upsampled) + else: + raise ValueError(f"Unknown fusion: {self.fusion}") + + +class FPN(BaseModel): + default_conf = { + "predict_uncertainty": True, + "in_channels_list": [64, 128, 256, 512], + "out_channels": 64, + "num_convs": 1, + "norm": None, + "padding": "zeros", + "fusion": "sum", + "with_low_level": True, + } + + required_data_keys = ["hl"] + + def _init(self, conf): + self.in_channels_list = conf.in_channels_list + self.out_channels = conf.out_channels + + self.num_convs = conf.num_convs + self.norm = conf.norm + self.padding = conf.padding + + self.fusion = conf.fusion + + self.first = nn.Conv2d( + self.in_channels_list[-1], self.out_channels, 1, padding=0, bias=True + ) + self.blocks = nn.ModuleList( + [ + DecoderBlock( + c, + self.out_channels, + ksize=1, + num_convs=self.num_convs, + norm_str=self.norm, + padding=self.padding, + fusion=self.fusion, + ) + for c in self.in_channels_list[::-1][1:] + ] + ) + self.out = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + ) + + self.predict_uncertainty = conf.predict_uncertainty + if self.predict_uncertainty: + self.linear_pred_uncertainty = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1), + ) + + self.with_ll = conf.with_low_level + if self.with_ll: + self.out_conv = ConvModule(self.out_channels, self.out_channels, 3, padding=1) + self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False) + + def _forward(self, features): + layers = features["hl"] + feats = None + + for idx, x in enumerate(reversed(layers)): + feats = self.first(x) if feats is None else self.blocks[idx - 1](feats, x) + + feats = self.out(feats) + feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) + feats = self.out_conv(feats) + + if self.with_ll: + assert "ll" in features, "Low-level features are required for this model" + feats_ll = features["ll"].clone() + feats = self.ll_fusion(feats, feats_ll) + + uncertainty = ( + self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None + ) + return feats, uncertainty + + def loss(self, pred, data): + raise NotImplementedError + + def metrics(self, pred, data): + raise NotImplementedError diff --git a/siclib/models/decoders/latitude_decoder.py b/siclib/models/decoders/latitude_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..b61269f30336a20aef048c3a50011313c2a4f4e4 --- /dev/null +++ b/siclib/models/decoders/latitude_decoder.py @@ -0,0 +1,133 @@ +"""Latitude decoder head. + +Adapted from https://github.com/jinlinyi/PerspectiveFields +""" + +import logging + +import torch +from torch import nn + +from siclib.models import get_model +from siclib.models.base_model import BaseModel +from siclib.models.utils.metrics import latitude_error +from siclib.models.utils.perspective_encoding import decode_bin_latitude +from siclib.utils.conversions import deg2rad + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class LatitudeDecoder(BaseModel): + default_conf = { + "loss_type": "l1", + "use_loss": True, + "use_uncertainty_loss": True, + "loss_weight": 1.0, + "recall_thresholds": [1, 3, 5, 10], + "use_tanh": True, # backward compatibility to original perspective weights + "decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True}, + } + + required_data_keys = ["features"] + + def _init(self, conf): + self.loss_type = conf.loss_type + self.loss_weight = conf.loss_weight + + self.use_uncertainty_loss = conf.use_uncertainty_loss + self.predict_uncertainty = conf.decoder.predict_uncertainty + + self.num_classes = 1 + self.is_classification = self.conf.loss_type == "classification" + if self.is_classification: + self.num_classes = 180 + + self.decoder = get_model(conf.decoder.name)(conf.decoder) + self.linear_pred_latitude = nn.Conv2d( + self.decoder.out_channels, self.num_classes, kernel_size=1 + ) + + def calculate_losses(self, predictions, targets, confidence=None): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + + residuals = predictions - targets + if self.loss_type == "l2": + loss = (residuals**2).sum(axis=1) + elif self.loss_type == "l1": + loss = residuals.abs().sum(axis=1) + elif self.loss_type == "cauchy": + c = 0.007 # -> corresponds to about 5 degrees + residuals = (residuals**2).sum(axis=1) + loss = c**2 / 2 * torch.log(1 + residuals / c**2) + elif self.loss_type == "huber": + c = deg2rad(1) + loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1) + else: + raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}") + + if confidence is not None and self.use_uncertainty_loss: + conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True) + conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2)) + loss = loss * conf_weight.detach() + + losses = {f"latitude-{self.loss_type}-loss": loss.mean(axis=(1, 2))} + losses = {k: v * self.loss_weight for k, v in losses.items()} + + return losses + + def _forward(self, data): + out = {} + x, log_confidence = self.decoder(data["features"]) + lat = self.linear_pred_latitude(x) + + if self.predict_uncertainty: + out["latitude_confidence"] = torch.sigmoid(log_confidence) + + if self.is_classification: + out["latitude_field_logits"] = lat + out["latitude_field"] = decode_bin_latitude( + lat.argmax(dim=1), self.num_classes + ).unsqueeze(1) + return out + + eps = 1e-5 # avoid nan in backward of asin + lat = torch.tanh(lat) if self.conf.use_tanh else lat + lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps)) + + out["latitude_field"] = lat + return out + + def loss(self, pred, data): + if not self.conf.use_loss or self.is_classification: + return {}, self.metrics(pred, data) + + predictions = pred["latitude_field"] + targets = data["latitude_field"] + + losses = self.calculate_losses(predictions, targets, pred.get("latitude_confidence")) + + total = 0 + losses[f"latitude-{self.loss_type}-loss"] + losses |= {"latitude_total": total} + return losses, self.metrics(pred, data) + + def metrics(self, pred, data): + predictions = pred["latitude_field"] + targets = data["latitude_field"] + + error = latitude_error(predictions, targets) + out = {"latitude_angle_error": error.mean(axis=(1, 2))} + + if "latitude_confidence" in pred: + weighted_error = (error * pred["latitude_confidence"]).sum(axis=(1, 2)) + out["latitude_angle_error_weighted"] = weighted_error / pred["latitude_confidence"].sum( + axis=(1, 2) + ) + + for th in self.conf.recall_thresholds: + rec = (error < th).float().mean(axis=(1, 2)) + out[f"latitude_angle_recall@{th}"] = rec + + return out diff --git a/siclib/models/decoders/light_hamburger.py b/siclib/models/decoders/light_hamburger.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7973084b206b32321e68e6d7298f2d7a709d23 --- /dev/null +++ b/siclib/models/decoders/light_hamburger.py @@ -0,0 +1,243 @@ +"""Light HamHead Decoder. + +Adapted from: +https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from siclib.models import BaseModel +from siclib.models.utils.modules import ConvModule, FeatureFusionBlock + +# flake8: noqa +# mypy: ignore-errors + + +class _MatrixDecomposition2DBase(nn.Module): + def __init__(self): + super().__init__() + + self.spatial = True + + self.S = 1 + self.D = 512 + self.R = 64 + + self.train_steps = 6 + self.eval_steps = 7 + + self.inv_t = 100 + self.eta = 0.9 + + self.rand_init = True + + def _build_bases(self, B, S, D, R, device="cpu"): + raise NotImplementedError + + def local_step(self, x, bases, coef): + raise NotImplementedError + + # @torch.no_grad() + def local_inference(self, x, bases): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + coef = torch.bmm(x.transpose(1, 2), bases) + coef = F.softmax(self.inv_t * coef, dim=-1) + + steps = self.train_steps if self.training else self.eval_steps + for _ in range(steps): + bases, coef = self.local_step(x, bases, coef) + + return bases, coef + + def compute_coef(self, x, bases, coef): + raise NotImplementedError + + def forward(self, x, return_bases=False): + B, C, H, W = x.shape + + # (B, C, H, W) -> (B * S, D, N) + if self.spatial: + D = C // self.S + N = H * W + x = x.view(B * self.S, D, N) + else: + D = H * W + N = C // self.S + x = x.view(B * self.S, N, D).transpose(1, 2) + + if not self.rand_init and not hasattr(self, "bases"): + bases = self._build_bases(1, self.S, D, self.R, device=x.device) + self.register_buffer("bases", bases) + + # (S, D, R) -> (B * S, D, R) + if self.rand_init: + bases = self._build_bases(B, self.S, D, self.R, device=x.device) + else: + bases = self.bases.repeat(B, 1, 1) + + bases, coef = self.local_inference(x, bases) + + # (B * S, N, R) + coef = self.compute_coef(x, bases, coef) + + # (B * S, D, R) @ (B * S, N, R)^T -> (B * S, D, N) + x = torch.bmm(bases, coef.transpose(1, 2)) + + # (B * S, D, N) -> (B, C, H, W) + x = x.view(B, C, H, W) if self.spatial else x.transpose(1, 2).view(B, C, H, W) + # (B * H, D, R) -> (B, H, N, D) + bases = bases.view(B, self.S, D, self.R) + + return x + + +class NMF2D(_MatrixDecomposition2DBase): + def __init__(self): + super().__init__() + + self.inv_t = 1 + + def _build_bases(self, B, S, D, R, device="cpu"): + bases = torch.rand((B * S, D, R)).to(device) + bases = F.normalize(bases, dim=1) + + return bases + + # @torch.no_grad() + def local_step(self, x, bases, coef): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ [(B * S, D, R)^T @ (B * S, D, R)] -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # Multiplicative Update + coef = coef * numerator / (denominator + 1e-6) + + # (B * S, D, N) @ (B * S, N, R) -> (B * S, D, R) + numerator = torch.bmm(x, coef) + # (B * S, D, R) @ [(B * S, N, R)^T @ (B * S, N, R)] -> (B * S, D, R) + denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) + # Multiplicative Update + bases = bases * numerator / (denominator + 1e-6) + + return bases, coef + + def compute_coef(self, x, bases, coef): + # (B * S, D, N)^T @ (B * S, D, R) -> (B * S, N, R) + numerator = torch.bmm(x.transpose(1, 2), bases) + # (B * S, N, R) @ (B * S, D, R)^T @ (B * S, D, R) -> (B * S, N, R) + denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) + # multiplication update + coef = coef * numerator / (denominator + 1e-6) + + return coef + + +class Hamburger(nn.Module): + def __init__(self, ham_channels=512, norm_cfg=None, **kwargs): + super().__init__() + + self.ham_in = ConvModule(ham_channels, ham_channels, 1) + + self.ham = NMF2D() + + self.ham_out = ConvModule(ham_channels, ham_channels, 1) + + def forward(self, x): + enjoy = self.ham_in(x) + enjoy = F.relu(enjoy, inplace=False) + enjoy = self.ham(enjoy) + enjoy = self.ham_out(enjoy) + ham = F.relu(x + enjoy, inplace=False) + + return ham + + +class LightHamHead(BaseModel): + """Is Attention Better Than Matrix Decomposition? + This head is the implementation of `HamNet + `_. + + Args: + ham_channels (int): input channels for Hamburger. + ham_kwargs (int): kwagrs for Ham. + """ + + default_conf = { + "predict_uncertainty": True, + "out_channels": 64, + "in_channels": [64, 128, 320, 512], + "in_index": [0, 1, 2, 3], + "ham_channels": 512, + "with_low_level": True, + } + + def _init(self, conf): + self.in_index = conf.in_index + self.in_channels = conf.in_channels + self.out_channels = conf.out_channels + self.ham_channels = conf.ham_channels + self.align_corners = False + self.predict_uncertainty = conf.predict_uncertainty + + self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1) + + self.hamburger = Hamburger(self.ham_channels) + + self.align = ConvModule(self.ham_channels, self.out_channels, 1) + + if self.predict_uncertainty: + self.linear_pred_uncertainty = nn.Sequential( + ConvModule( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + padding=1, + bias=False, + ), + nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1), + ) + + self.with_ll = conf.with_low_level + if self.with_ll: + self.out_conv = ConvModule( + self.out_channels, self.out_channels, 3, padding=1, bias=False + ) + self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False) + + def _forward(self, features): + """Forward function.""" + # inputs = self._transform_inputs(inputs) + inputs = [features["hl"][i] for i in self.in_index] + + inputs = [ + F.interpolate( + level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners + ) + for level in inputs + ] + + inputs = torch.cat(inputs, dim=1) + x = self.squeeze(inputs) + + x = self.hamburger(x) + + feats = self.align(x) + + if self.with_ll: + assert "ll" in features, "Low-level features are required for this model" + feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) + feats = self.out_conv(feats) + feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) + feats_ll = features["ll"].clone() + feats = self.ll_fusion(feats, feats_ll) + + uncertainty = ( + self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None + ) + + return feats, uncertainty + + def loss(self, pred, data): + raise NotImplementedError diff --git a/siclib/models/decoders/perspective_decoder.py b/siclib/models/decoders/perspective_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..f4537586cec3e6d3a411eead431c4633395ef93a --- /dev/null +++ b/siclib/models/decoders/perspective_decoder.py @@ -0,0 +1,59 @@ +"""Perspective fields decoder heads. + +Adapted from https://github.com/jinlinyi/PerspectiveFields +""" + +import logging + +from siclib.models import get_model +from siclib.models.base_model import BaseModel + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class PerspectiveDecoder(BaseModel): + default_conf = { + "up_decoder": {"name": "decoders.up_decoder"}, + "latitude_decoder": {"name": "decoders.latitude_decoder"}, + } + + required_data_keys = ["features"] + + def _init(self, conf): + logger.debug(f"Initializing PerspectiveDecoder with config: {conf}") + self.use_up = conf.up_decoder is not None + self.use_latitude = conf.latitude_decoder is not None + + if self.use_up: + self.up_head = get_model(conf.up_decoder.name)(conf.up_decoder) + + if self.use_latitude: + self.latitude_head = get_model(conf.latitude_decoder.name)(conf.latitude_decoder) + + def _forward(self, data): + out_up = self.up_head(data) if self.use_up else {} + out_lat = self.latitude_head(data) if self.use_latitude else {} + return out_up | out_lat + + def loss(self, pred, data): + ref = data["up_field"] if self.use_up else data["latitude_field"] + + total = ref.new_zeros(ref.shape[0]) + losses, metrics = {}, {} + if self.use_up: + up_losses, up_metrics = self.up_head.loss(pred, data) + losses |= up_losses + metrics |= up_metrics + total = total + losses.get("up_total", 0) + + if self.use_latitude: + latitude_losses, latitude_metrics = self.latitude_head.loss(pred, data) + losses |= latitude_losses + metrics |= latitude_metrics + total = total + losses.get("latitude_total", 0) + + losses["perspective_total"] = total + return losses, metrics diff --git a/siclib/models/decoders/up_decoder.py b/siclib/models/decoders/up_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..30c72713cd2df8b92f156adc84628d18a934be33 --- /dev/null +++ b/siclib/models/decoders/up_decoder.py @@ -0,0 +1,128 @@ +"""up decoder head. + +Adapted from https://github.com/jinlinyi/PerspectiveFields +""" + +import logging + +import torch +from torch import nn +from torch.nn import functional as F + +from siclib.models import get_model +from siclib.models.base_model import BaseModel +from siclib.models.utils.metrics import up_error +from siclib.models.utils.perspective_encoding import decode_up_bin +from siclib.utils.conversions import deg2rad + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class UpDecoder(BaseModel): + default_conf = { + "loss_type": "l1", + "use_loss": True, + "use_uncertainty_loss": True, + "loss_weight": 1.0, + "recall_thresholds": [1, 3, 5, 10], + "decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True}, + } + + required_data_keys = ["features"] + + def _init(self, conf): + self.loss_type = conf.loss_type + self.loss_weight = conf.loss_weight + + self.use_uncertainty_loss = conf.use_uncertainty_loss + self.predict_uncertainty = conf.decoder.predict_uncertainty + + self.num_classes = 2 + self.is_classification = self.conf.loss_type == "classification" + if self.is_classification: + self.num_classes = 73 + + self.decoder = get_model(conf.decoder.name)(conf.decoder) + self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1) + + def calculate_losses(self, predictions, targets, confidence=None): + predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163 + + residuals = predictions - targets + if self.loss_type == "l2": + loss = (residuals**2).sum(axis=1) + elif self.loss_type == "l1": + loss = residuals.abs().sum(axis=1) + elif self.loss_type == "dot": + loss = 1 - (residuals * targets).sum(axis=1) + elif self.loss_type == "cauchy": + c = 0.007 # -> corresponds to about 5 degrees + residuals = (residuals**2).sum(axis=1) + loss = c**2 / 2 * torch.log(1 + residuals / c**2) + elif self.loss_type == "huber": + c = deg2rad(1) + loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1) + else: + raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}") + + if confidence is not None and self.use_uncertainty_loss: + conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True) + conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2)) + loss = loss * conf_weight.detach() + + losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))} + losses = {k: v * self.loss_weight for k, v in losses.items()} + + return losses + + def _forward(self, data): + out = {} + x, log_confidence = self.decoder(data["features"]) + up = self.linear_pred_up(x) + + if self.predict_uncertainty: + out["up_confidence"] = torch.sigmoid(log_confidence) + + if self.is_classification: + out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes) + return out + + up = F.normalize(up, dim=1) + + out["up_field"] = up + return out + + def loss(self, pred, data): + if not self.conf.use_loss or self.is_classification: + return {}, self.metrics(pred, data) + + predictions = pred["up_field"] + targets = data["up_field"] + + losses = self.calculate_losses(predictions, targets, pred.get("up_confidence")) + + total = 0 + losses[f"up-{self.loss_type}-loss"] + losses |= {"up_total": total} + return losses, self.metrics(pred, data) + + def metrics(self, pred, data): + predictions = pred["up_field"] + targets = data["up_field"] + + mask = predictions.sum(axis=1) != 0 + + error = up_error(predictions, targets) * mask + out = {"up_angle_error": error.mean(axis=(1, 2))} + + if "up_confidence" in pred: + weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2)) + out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2)) + + for th in self.conf.recall_thresholds: + rec = (error < th).float().mean(axis=(1, 2)) + out[f"up_angle_recall@{th}"] = rec + + return out diff --git a/siclib/models/encoders/__init__.py b/siclib/models/encoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/models/encoders/low_level_encoder.py b/siclib/models/encoders/low_level_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..99d4086131dea93d908052813621268f967c18f9 --- /dev/null +++ b/siclib/models/encoders/low_level_encoder.py @@ -0,0 +1,54 @@ +import logging + +import torch.nn as nn + +from siclib.models.base_model import BaseModel +from siclib.models.utils.modules import ConvModule + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class LowLevelEncoder(BaseModel): + default_conf = { + "feat_dim": 64, + "in_channel": 3, + "keep_resolution": True, + } + + required_data_keys = ["image"] + + def _init(self, conf): + logger.debug(f"Initializing LowLevelEncoder with {conf}") + + if self.conf.keep_resolution: + self.conv1 = ConvModule(conf.in_channel, conf.feat_dim, kernel_size=3, padding=1) + self.conv2 = ConvModule(conf.feat_dim, conf.feat_dim, kernel_size=3, padding=1) + else: + self.conv1 = nn.Conv2d( + conf.in_channel, conf.feat_dim, kernel_size=7, stride=2, padding=3, bias=False + ) + self.bn1 = nn.BatchNorm2d(conf.feat_dim) + self.relu = nn.ReLU(inplace=True) + + def _forward(self, data): + x = data["image"] + + assert ( + x.shape[-1] % 32 == 0 and x.shape[-2] % 32 == 0 + ), "Image size must be multiple of 32 if not using single image input." + + if self.conf.keep_resolution: + c1 = self.conv1(x) + c2 = self.conv2(c1) + else: + x = self.conv1(x) + x = self.bn1(x) + c2 = self.relu(x) + + return {"features": c2} + + def loss(self, pred, data): + raise NotImplementedError diff --git a/siclib/models/encoders/mscan.py b/siclib/models/encoders/mscan.py new file mode 100644 index 0000000000000000000000000000000000000000..81131d09082e9634df195f72a3fd47012162b209 --- /dev/null +++ b/siclib/models/encoders/mscan.py @@ -0,0 +1,258 @@ +"""Implementation of MSCAN from SegNeXt: Rethinking Convolutional Attention Design for Semantic +Segmentation (NeurIPS 2022) + +based on: https://github.com/Visual-Attention-Network/SegNeXt +""" + +import torch +import torch.nn as nn +from torch.nn.modules.utils import _pair as to_2tuple + +from siclib.models import BaseModel +from siclib.models.utils.modules import DropPath, DWConv + +# flake8: noqa +# mypy: ignore-errors + + +class Mlp(nn.Module): + def __init__( + self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0 + ): + """Initialize the MLP.""" + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Conv2d(in_features, hidden_features, 1) + self.dwconv = DWConv(hidden_features) + self.act = act_layer() + self.fc2 = nn.Conv2d(hidden_features, out_features, 1) + self.drop = nn.Dropout(drop) + + def forward(self, x): + """Forward pass.""" + x = self.fc1(x) + + x = self.dwconv(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + + return x + + +class StemConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(StemConv, self).__init__() + + self.proj = nn.Sequential( + nn.Conv2d( + in_channels, out_channels // 2, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ), + nn.BatchNorm2d(out_channels // 2), + nn.GELU(), + nn.Conv2d( + out_channels // 2, out_channels, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1) + ), + nn.BatchNorm2d(out_channels), + ) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.size() + x = x.flatten(2).transpose(1, 2) + return x, H, W + + +class AttentionModule(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim) + self.conv0_1 = nn.Conv2d(dim, dim, (1, 7), padding=(0, 3), groups=dim) + self.conv0_2 = nn.Conv2d(dim, dim, (7, 1), padding=(3, 0), groups=dim) + + self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim) + self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim) + + self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim) + self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim) + self.conv3 = nn.Conv2d(dim, dim, 1) + + def forward(self, x): + u = x.clone() + attn = self.conv0(x) + + attn_0 = self.conv0_1(attn) + attn_0 = self.conv0_2(attn_0) + + attn_1 = self.conv1_1(attn) + attn_1 = self.conv1_2(attn_1) + + attn_2 = self.conv2_1(attn) + attn_2 = self.conv2_2(attn_2) + attn = attn + attn_0 + attn_1 + attn_2 + + attn = self.conv3(attn) + + return attn * u + + +class SpatialAttention(nn.Module): + def __init__(self, d_model): + super().__init__() + self.d_model = d_model + self.proj_1 = nn.Conv2d(d_model, d_model, 1) + self.activation = nn.GELU() + self.spatial_gating_unit = AttentionModule(d_model) + self.proj_2 = nn.Conv2d(d_model, d_model, 1) + + def forward(self, x): + shorcut = x.clone() + x = self.proj_1(x) + x = self.activation(x) + x = self.spatial_gating_unit(x) + x = self.proj_2(x) + x = x + shorcut + return x + + +class Block(nn.Module): + def __init__( + self, + dim, + mlp_ratio=4.0, + drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + ): + super().__init__() + self.norm1 = nn.BatchNorm2d(dim) + self.attn = SpatialAttention(dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = nn.BatchNorm2d(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop + ) + layer_scale_init_value = 1e-2 + self.layer_scale_1 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + self.layer_scale_2 = nn.Parameter( + layer_scale_init_value * torch.ones((dim)), requires_grad=True + ) + + def forward(self, x, H, W): + B, N, C = x.shape + x = x.permute(0, 2, 1).view(B, C, H, W) + x = x + self.drop_path( + self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x)) + ) + x = x + self.drop_path( + self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x)) + ) + x = x.view(B, C, N).permute(0, 2, 1) + return x + + +class OverlapPatchEmbed(nn.Module): + """Image to Patch Embedding""" + + def __init__(self, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + patch_size = to_2tuple(patch_size) + + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.BatchNorm2d(embed_dim) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = self.norm(x) + + x = x.flatten(2).transpose(1, 2) + + return x, H, W + + +class MSCAN(BaseModel): + default_conf = { + "in_channels": 3, + "embed_dims": [64, 128, 320, 512], + "mlp_ratios": [8, 8, 4, 4], + "drop_rate": 0.0, + "drop_path_rate": 0.1, + "depths": [3, 3, 12, 3], + "num_stages": 4, + } + + required_data_keys = ["image"] + + def _init(self, conf): + self.depths = conf.depths + self.num_stages = conf.num_stages + + # stochastic depth decay rule + dpr = [x.item() for x in torch.linspace(0, conf.drop_path_rate, sum(conf.depths))] + cur = 0 + + for i in range(conf.num_stages): + if i == 0: + patch_embed = StemConv(3, conf.embed_dims[0]) + else: + patch_embed = OverlapPatchEmbed( + patch_size=7 if i == 0 else 3, + stride=4 if i == 0 else 2, + in_chans=conf.in_chans if i == 0 else conf.embed_dims[i - 1], + embed_dim=conf.embed_dims[i], + ) + + block = nn.ModuleList( + [ + Block( + dim=conf.embed_dims[i], + mlp_ratio=conf.mlp_ratios[i], + drop=conf.drop_rate, + drop_path=dpr[cur + j], + ) + for j in range(conf.depths[i]) + ] + ) + norm = nn.LayerNorm(conf.embed_dims[i]) + cur += conf.depths[i] + + setattr(self, f"patch_embed{i + 1}", patch_embed) + setattr(self, f"block{i + 1}", block) + setattr(self, f"norm{i + 1}", norm) + + def _forward(self, data): + img = data["image"] + # rgb -> bgr and from [0, 1] to [0, 255] + x = img[:, [2, 1, 0], :, :] * 255.0 + + B = x.shape[0] + outs = [] + + for i in range(self.num_stages): + patch_embed = getattr(self, f"patch_embed{i + 1}") + block = getattr(self, f"block{i + 1}") + norm = getattr(self, f"norm{i + 1}") + x, H, W = patch_embed(x) + for blk in block: + x = blk(x, H, W) + x = norm(x) + x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() + outs.append(x) + + return {"features": outs} + + def loss(self, pred, data): + """Compute the loss.""" + raise NotImplementedError diff --git a/siclib/models/encoders/resnet.py b/siclib/models/encoders/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8a6521417f09e536788eadf3baf037046a2d44 --- /dev/null +++ b/siclib/models/encoders/resnet.py @@ -0,0 +1,83 @@ +"""Basic ResNet encoder for image feature extraction. + +https://pytorch.org/hub/pytorch_vision_resnet/ +""" + +import torch +import torch.nn as nn +import torchvision +from torchvision.models.feature_extraction import create_feature_extractor + +from siclib.models.base_model import BaseModel + +# mypy: ignore-errors + + +def remove_conv_stride(conv): + """Remove the stride from a convolutional layer.""" + conv_new = nn.Conv2d( + conv.in_channels, + conv.out_channels, + conv.kernel_size, + bias=conv.bias is not None, + stride=1, + padding=conv.padding, + ) + conv_new.weight = conv.weight + conv_new.bias = conv.bias + return conv_new + + +class ResNet(BaseModel): + """ResNet encoder for image features extraction.""" + + default_conf = { + "encoder": "resnet18", + "pretrained": True, + "input_dim": 3, + "remove_stride_from_first_conv": True, + "num_downsample": None, # how many downsample bloc + "pixel_mean": [0.485, 0.456, 0.406], + "pixel_std": [0.229, 0.224, 0.225], + } + + required_data_keys = ["image"] + + def build_encoder(self, conf): + """Build the encoder from the configuration.""" + if conf.pretrained: + assert conf.input_dim == 3 + + Encoder = getattr(torchvision.models, conf.encoder) + + layers = ["layer1", "layer2", "layer3", "layer4"] + kw = {"replace_stride_with_dilation": [False, False, False]} + + if conf.num_downsample is not None: + layers = layers[: conf.num_downsample] + + encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw) + encoder = create_feature_extractor(encoder, return_nodes=layers) + + if conf.remove_stride_from_first_conv: + encoder.conv1 = remove_conv_stride(encoder.conv1) + + return encoder, layers + + def _init(self, conf): + self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False) + + self.encoder, self.layers = self.build_encoder(conf) + + def _forward(self, data): + image = data["image"] + image = (image - self.pixel_mean) / self.pixel_std + skip_features = list(self.encoder(image).values()) + + # print(f"skip_features: {[f.shape for f in skip_features]}") + return {"features": skip_features} + + def loss(self, pred, data): + """Compute the loss.""" + raise NotImplementedError diff --git a/siclib/models/encoders/vgg.py b/siclib/models/encoders/vgg.py new file mode 100644 index 0000000000000000000000000000000000000000..e8c62323dca054896aaeecd798afa335b32df3fa --- /dev/null +++ b/siclib/models/encoders/vgg.py @@ -0,0 +1,75 @@ +"""Simple VGG encoder for image features extraction.""" + +import torch +import torchvision +from torchvision.models.feature_extraction import create_feature_extractor + +from siclib.models.base_model import BaseModel + +# mypy: ignore-errors + + +class VGG(BaseModel): + """VGG encoder for image features extraction.""" + + default_conf = { + "encoder": "vgg13", + "pretrained": True, + "input_dim": 3, + "num_downsample": None, # how many downsample blocs to use + "pixel_mean": [0.485, 0.456, 0.406], + "pixel_std": [0.229, 0.224, 0.225], + } + + required_data_keys = ["image"] + + def build_encoder(self, conf): + """Build the encoder from the configuration.""" + if conf.pretrained: + assert conf.input_dim == 3 + + Encoder = getattr(torchvision.models, conf.encoder) + + kw = {} + if conf.encoder == "vgg13": + layers = [ + "features.3", + "features.8", + "features.13", + "features.18", + "features.23", + ] + elif conf.encoder == "vgg16": + layers = [ + "features.3", + "features.8", + "features.15", + "features.22", + "features.29", + ] + else: + raise NotImplementedError(f"Encoder not implemented: {conf.encoder}") + + if conf.num_downsample is not None: + layers = layers[: conf.num_downsample] + + encoder = Encoder(weights="DEFAULT" if conf.pretrained else None, **kw) + encoder = create_feature_extractor(encoder, return_nodes=layers) + + return encoder, layers + + def _init(self, conf): + self.register_buffer("pixel_mean", torch.tensor(conf.pixel_mean).view(-1, 1, 1), False) + self.register_buffer("pixel_std", torch.tensor(conf.pixel_std).view(-1, 1, 1), False) + + self.encoder, self.layers = self.build_encoder(conf) + + def _forward(self, data): + image = data["image"] + image = (image - self.pixel_mean) / self.pixel_std + skip_features = self.encoder(image).values() + return {"features": skip_features} + + def loss(self, pred, data): + """Compute the loss.""" + raise NotImplementedError diff --git a/siclib/models/extractor.py b/siclib/models/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..4096666163a097504f78df505972c048ac523192 --- /dev/null +++ b/siclib/models/extractor.py @@ -0,0 +1,126 @@ +"""Simple interface for GeoCalib model.""" + +from pathlib import Path +from typing import Dict, Optional + +import torch +import torch.nn as nn +from torch.nn.functional import interpolate + +from siclib.geometry.base_camera import BaseCamera +from siclib.models.networks.geocalib import GeoCalib as Model +from siclib.utils.image import ImagePreprocessor, load_image + + +class GeoCalib(nn.Module): + """Simple interface for GeoCalib model.""" + + def __init__(self, weights: str = "pinhole"): + """Initialize the model with optional config overrides. + + Args: + weights (str, optional): Weights to load. Defaults to "pinhole". + """ + super().__init__() + if weights == "pinhole": + url = "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-pinhole.tar" + elif weights == "distorted": + url = ( + "https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-simple_radial.tar" + ) + else: + raise ValueError(f"Unknown weights: {weights}") + + # load checkpoint + model_dir = f"{torch.hub.get_dir()}/geocalib" + state_dict = torch.hub.load_state_dict_from_url( + url, model_dir, map_location="cpu", file_name=f"{weights}.tar" + ) + + self.model = Model({}) + self.model.flexible_load(state_dict["model"]) + self.model.eval() + + self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32}) + + def load_image(self, path: Path) -> torch.Tensor: + """Load image from path.""" + return load_image(path) + + def _post_process( + self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor] + ) -> tuple[BaseCamera, dict[str, torch.Tensor]]: + """Post-process model output by undoing scaling and cropping.""" + camera = camera.undo_scale_crop(img_data) + + w, h = camera.size.unbind(-1) + h = h[0].round().int().item() + w = w[0].round().int().item() + + for k in ["latitude_field", "up_field"]: + out[k] = interpolate(out[k], size=(h, w), mode="bilinear") + for k in ["up_confidence", "latitude_confidence"]: + out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0] + + inverse_scales = 1.0 / img_data["scales"] + zero = camera.new_zeros(camera.f.shape[0]) + out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1] + return camera, out + + @torch.no_grad() + def calibrate( + self, + img: torch.Tensor, + camera_model: str = "pinhole", + priors: Optional[Dict[str, torch.Tensor]] = None, + shared_intrinsics: bool = False, + ) -> Dict[str, torch.Tensor]: + """Perform calibration with online resizing. + + Assumes input image is in range [0, 1] and in RGB format. + + Args: + img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W) + camera_model (str, optional): Camera model. Defaults to "pinhole". + priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}. + shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False. + + Returns: + Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties. + """ + if len(img.shape) == 3: + img = img[None] # add batch dim + if not shared_intrinsics: + assert len(img.shape) == 4 and img.shape[0] == 1 + + img_data = self.image_processor(img) + + if priors is None: + priors = {} + + prior_values = {} + if prior_focal := priors.get("focal"): + prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal + prior_values["prior_focal"] = prior_focal * img_data["scales"][1] + + if "gravity" in priors: + prior_gravity = priors["gravity"] + prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity + prior_values["prior_gravity"] = prior_gravity + + self.model.optimizer.set_camera_model(camera_model) + self.model.optimizer.shared_intrinsics = shared_intrinsics + + out = self.model(img_data | prior_values) + + camera, gravity = out["camera"], out["gravity"] + camera, out = self._post_process(camera, img_data, out) + + return { + "camera": camera, + "gravity": gravity, + "covariance": out["covariance"], + **{k: out[k] for k in out.keys() if "field" in k}, + **{k: out[k] for k in out.keys() if "confidence" in k}, + **{k: out[k] for k in out.keys() if "uncertainty" in k}, + } diff --git a/siclib/models/networks/__init__.py b/siclib/models/networks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/models/networks/deepcalib.py b/siclib/models/networks/deepcalib.py new file mode 100644 index 0000000000000000000000000000000000000000..a05d00e09bbcf8f0dbeec5ac0fcace781232eea3 --- /dev/null +++ b/siclib/models/networks/deepcalib.py @@ -0,0 +1,299 @@ +import logging +from copy import deepcopy +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +from torch.nn import Identity + +from siclib.geometry.camera import SimpleRadial +from siclib.geometry.gravity import Gravity +from siclib.models.base_model import BaseModel +from siclib.models.utils.metrics import dist_error, pitch_error, roll_error, vfov_error +from siclib.models.utils.modules import _DenseBlock, _Transition +from siclib.utils.conversions import deg2rad, pitch2rho, rho2pitch + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +def get_centers_and_edges(min: float, max: float, num_bins: int) -> Tuple[np.ndarray, torch.Tensor]: + centers = torch.linspace(min, max + ((max - min) / (num_bins - 1)), num_bins + 1).float() + edges = centers.detach() - ((centers.detach()[1] - centers[0]) / 2.0) + return centers, edges + + +class DeepCalib(BaseModel): + default_conf = { + "name": "densenet", + "model": "densenet161", + "loss": "NLL", + "num_bins": 256, + "freeze_batch_normalization": False, + "model": "densenet161", + "pretrained": True, # whether to use ImageNet weights + "heads": ["roll", "rho", "vfov", "k1_hat"], + "flip": [], # keys of predictions to flip the sign of + "rpf_scales": [1, 1, 1], + "bounds": { + "roll": [-45, 45], + "rho": [-1, 1], + "vfov": [20, 105], + "k1_hat": [-0.7, 0.7], + }, + "use_softamax": False, + } + + mean = [0.485, 0.456, 0.406] + std = [0.229, 0.224, 0.225] + + strict_conf = False + + required_data_keys = ["image", "image_size"] + + def _init(self, conf): + self.is_classification = True if self.conf.loss in ["NLL"] else False + + self.num_bins = conf.num_bins + + self.roll_centers, self.roll_edges = get_centers_and_edges( + deg2rad(conf.bounds.roll[0]), deg2rad(conf.bounds.roll[1]), self.num_bins + ) + + self.rho_centers, self.rho_edges = get_centers_and_edges( + conf.bounds.rho[0], conf.bounds.rho[1], self.num_bins + ) + + self.fov_centers, self.fov_edges = get_centers_and_edges( + deg2rad(conf.bounds.vfov[0]), deg2rad(conf.bounds.vfov[1]), self.num_bins + ) + + self.k1_hat_centers, self.k1_hat_edges = get_centers_and_edges( + conf.bounds.k1_hat[0], conf.bounds.k1_hat[1], self.num_bins + ) + + Model = getattr(torchvision.models, conf.model) + weights = "DEFAULT" if self.conf.pretrained else None + self.model = Model(weights=weights) + + layers = [] + + # 2208 for 161 layers. 1024 for 121 + num_features = self.model.classifier.in_features + head_layers = 3 + layers.append(_Transition(num_features, num_features // 2)) + num_features = num_features // 2 + growth_rate = 32 + layers.append( + _DenseBlock( + num_layers=head_layers, + num_input_features=num_features, + growth_rate=growth_rate, + bn_size=4, + drop_rate=0, + ) + ) + layers.append(nn.BatchNorm2d(num_features + head_layers * growth_rate)) + layers.append(nn.ReLU()) + layers.append(nn.AdaptiveAvgPool2d((1, 1))) + layers.append(nn.Flatten()) + layers.append(nn.Linear(num_features + head_layers * growth_rate, 512)) + layers.append(nn.ReLU()) + self.model.classifier = Identity() + self.model.features.norm5 = Identity() + + if self.is_classification: + layers.append(nn.Linear(512, self.num_bins)) + layers.append(nn.LogSoftmax(dim=1)) + else: + layers.append(nn.Linear(512, 1)) + layers.append(nn.Tanh()) + + self.roll_head = nn.Sequential(*deepcopy(layers)) + self.rho_head = nn.Sequential(*deepcopy(layers)) + self.vfov_head = nn.Sequential(*deepcopy(layers)) + self.k1_hat_head = nn.Sequential(*deepcopy(layers)) + + def bins_to_val(self, centers, pred): + if centers.device != pred.device: + centers = centers.to(pred.device) + + if not self.conf.use_softamax: + return centers[pred.argmax(1)] + + beta = 1e-0 + pred_softmax = F.softmax(pred / beta, dim=1) + weighted_centers = centers[:-1].unsqueeze(0) * pred_softmax + val = weighted_centers.sum(dim=1) + return val + + def _forward(self, data): + image = data["image"] + mean, std = image.new_tensor(self.mean), image.new_tensor(self.std) + image = (image - mean[:, None, None]) / std[:, None, None] + shared_features = self.model.features(image) + pred = {} + + if "roll" in self.conf.heads: + pred["roll"] = self.roll_head(shared_features) + if "rho" in self.conf.heads: + pred["rho"] = self.rho_head(shared_features) + if "vfov" in self.conf.heads: + pred["vfov"] = self.vfov_head(shared_features) + if "vfov" in self.conf.flip: + pred["vfov"] = pred["vfov"] * -1 + if "k1_hat" in self.conf.heads: + pred["k1_hat"] = self.k1_hat_head(shared_features) + + size = data["image_size"] + w, h = size[:, 0], size[:, 1] + + if self.is_classification: + parameters = { + "roll": self.bins_to_val(self.roll_centers, pred["roll"]), + "rho": self.bins_to_val(self.rho_centers, pred["rho"]), + "vfov": self.bins_to_val(self.fov_centers, pred["vfov"]), + "k1_hat": self.bins_to_val(self.k1_hat_centers, pred["k1_hat"]), + "width": w, + "height": h, + } + + for k in self.conf.flip: + parameters[k] = parameters[k] * -1 + + for i, k in enumerate(["roll", "rho", "vfov"]): + parameters[k] = parameters[k] * self.conf.rpf_scales[i] + + camera = SimpleRadial.from_dict(parameters) + + roll, pitch = parameters["roll"], rho2pitch(parameters["rho"], camera.f[..., 1], h) + gravity = Gravity.from_rp(roll, pitch) + + else: # regression + if "roll" in self.conf.heads: + pred["roll"] = pred["roll"] * deg2rad(45) + if "vfov" in self.conf.heads: + pred["vfov"] = (pred["vfov"] + 1) * deg2rad((105 - 20) / 2 + 20) + + camera = SimpleRadial.from_dict(pred | {"width": w, "height": h}) + gravity = Gravity.from_rp(pred["roll"], pred["pitch"]) + + return pred | {"camera": camera, "gravity": gravity} + + def loss(self, pred, data): + loss = {"total": 0} + if self.conf.loss == "Huber": + loss_fn = nn.HuberLoss(reduction="none") + elif self.conf.loss == "L1": + loss_fn = nn.L1Loss(reduction="none") + elif self.conf.loss == "L2": + loss_fn = nn.MSELoss(reduction="none") + elif self.conf.loss == "NLL": + loss_fn = nn.NLLLoss(reduction="none") + + gt_cam = data["camera"] + + if "roll" in self.conf.heads: + # nbins softmax values if classification, else scalar value + gt_roll = data["gravity"].roll.float() + pred_roll = pred["roll"].float() + + if gt_roll.device != self.roll_edges.device: + self.roll_edges = self.roll_edges.to(gt_roll.device) + self.roll_centers = self.roll_centers.to(gt_roll.device) + + if self.is_classification: + gt_roll = ( + torch.bucketize(gt_roll.contiguous(), self.roll_edges) - 1 + ) # converted to class + + assert (gt_roll >= 0).all(), gt_roll + assert (gt_roll < self.num_bins).all(), gt_roll + else: + assert pred_roll.dim() == gt_roll.dim() + + loss_roll = loss_fn(pred_roll, gt_roll) + loss["roll"] = loss_roll + loss["total"] += loss_roll + + if "rho" in self.conf.heads: + gt_rho = pitch2rho(data["gravity"].pitch, gt_cam.f[..., 1], gt_cam.size[..., 1]).float() + pred_rho = pred["rho"].float() + + if gt_rho.device != self.rho_edges.device: + self.rho_edges = self.rho_edges.to(gt_rho.device) + self.rho_centers = self.rho_centers.to(gt_rho.device) + + if self.is_classification: + gt_rho = torch.bucketize(gt_rho.contiguous(), self.rho_edges) - 1 + + assert (gt_rho >= 0).all(), gt_rho + assert (gt_rho < self.num_bins).all(), gt_rho + else: + assert pred_rho.dim() == gt_rho.dim() + + # print(f"Rho: {gt_rho.shape}, {pred_rho.shape}") + loss_rho = loss_fn(pred_rho, gt_rho) + loss["rho"] = loss_rho + loss["total"] += loss_rho + + if "vfov" in self.conf.heads: + gt_vfov = gt_cam.vfov.float() + pred_vfov = pred["vfov"].float() + + if gt_vfov.device != self.fov_edges.device: + self.fov_edges = self.fov_edges.to(gt_vfov.device) + self.fov_centers = self.fov_centers.to(gt_vfov.device) + + if self.is_classification: + gt_vfov = torch.bucketize(gt_vfov.contiguous(), self.fov_edges) - 1 + + assert (gt_vfov >= 0).all(), gt_vfov + assert (gt_vfov < self.num_bins).all(), gt_vfov + else: + min_vfov = deg2rad(self.conf.bounds.vfov[0]) + max_vfov = deg2rad(self.conf.bounds.vfov[1]) + gt_vfov = (2 * (gt_vfov - min_vfov) / (max_vfov - min_vfov)) - 1 + assert pred_vfov.dim() == gt_vfov.dim() + + loss_vfov = loss_fn(pred_vfov, gt_vfov) + loss["vfov"] = loss_vfov + loss["total"] += loss_vfov + + if "k1_hat" in self.conf.heads: + gt_k1_hat = data["camera"].k1_hat.float() + pred_k1_hat = pred["k1_hat"].float() + + if gt_k1_hat.device != self.k1_hat_edges.device: + self.k1_hat_edges = self.k1_hat_edges.to(gt_k1_hat.device) + self.k1_hat_centers = self.k1_hat_centers.to(gt_k1_hat.device) + + if self.is_classification: + gt_k1_hat = torch.bucketize(gt_k1_hat.contiguous(), self.k1_hat_edges) - 1 + + assert (gt_k1_hat >= 0).all(), gt_k1_hat + assert (gt_k1_hat < self.num_bins).all(), gt_k1_hat + else: + assert pred_k1_hat.dim() == gt_k1_hat.dim() + + loss_k1_hat = loss_fn(pred_k1_hat, gt_k1_hat) + loss["k1_hat"] = loss_k1_hat + loss["total"] += loss_k1_hat + + return loss, self.metrics(pred, data) + + def metrics(self, pred, data): + pred_cam, gt_cam = pred["camera"], data["camera"] + pred_gravity, gt_gravity = pred["gravity"], data["gravity"] + + return { + "roll_error": roll_error(pred_gravity, gt_gravity), + "pitch_error": pitch_error(pred_gravity, gt_gravity), + "vfov_error": vfov_error(pred_cam, gt_cam), + "k1_error": dist_error(pred_cam, gt_cam), + } diff --git a/siclib/models/networks/dust3r.py b/siclib/models/networks/dust3r.py new file mode 100644 index 0000000000000000000000000000000000000000..2d00ac2f898eec450ddcb20f1fd34e51171dd422 --- /dev/null +++ b/siclib/models/networks/dust3r.py @@ -0,0 +1,81 @@ +"""Wrapper for DUSt3R model to estimate focal length. + +DUSt3R: Geometric 3D Vision Made Easy, https://arxiv.org/abs/2312.14132 +""" + +import sys + +sys.path.append("third_party/dust3r") + +import torch +from dust3r.cloud_opt import GlobalAlignerMode, global_aligner +from dust3r.image_pairs import make_pairs +from dust3r.inference import inference, load_model +from dust3r.utils.image import load_images + +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.gravity import Gravity +from siclib.models import BaseModel + +# mypy: ignore-errors + + +class Dust3R(BaseModel): + """DUSt3R model for focal length estimation.""" + + default_conf = { + "model_path": "weights/DUSt3R_ViTLarge_BaseDecoder_512_dpt.pth", + "device": "cuda", + "batch_size": 1, + "schedule": "cosine", + "lr": 0.01, + "niter": 300, + "show_scene": False, + } + + required_data_keys = ["path"] + + def _init(self, conf): + """Initialize the DUSt3R model.""" + self.model = load_model(conf["model_path"], conf["device"]) + + def _forward(self, data): + """Forward pass of the DUSt.""" + assert len(data["path"]) == 1, f"Only batch size of 1 is supported (bs={len(data['path'])}" + + path = data["path"][0] + images = [path] * 2 + + with torch.enable_grad(): + images = load_images(images, size=512) + pairs = make_pairs(images, scene_graph="complete", prefilter=None, symmetrize=True) + output = inference( + pairs, self.model, self.conf["device"], batch_size=self.conf["batch_size"] + ) + scene = global_aligner( + output, device=self.conf["device"], mode=GlobalAlignerMode.PointCloudOptimizer + ) + _ = scene.compute_global_alignment( + init="mst", + niter=self.conf["niter"], + schedule=self.conf["schedule"], + lr=self.conf["lr"], + ) + + # retrieve useful values from scene: + focals = scene.get_focals().mean(dim=0) + + h, w = images[0]["true_shape"][:, 0], images[0]["true_shape"][:, 1] + h, w = focals.new_tensor(h), focals.new_tensor(w) + + camera = BaseCamera.from_dict({"height": h, "width": w, "f": focals}) + gravity = Gravity.from_rp([0.0], [0.0]) + + if self.conf["show_scene"]: + scene.show() + + return {"camera": camera, "gravity": gravity} + + def loss(self, pred, data): + """Loss function for DUSt3R model.""" + return {}, {} diff --git a/siclib/models/networks/geocalib.py b/siclib/models/networks/geocalib.py new file mode 100644 index 0000000000000000000000000000000000000000..23af7f288c91eeea6716fb32e14609716cc3467e --- /dev/null +++ b/siclib/models/networks/geocalib.py @@ -0,0 +1,66 @@ +import logging + +from siclib.models import get_model +from siclib.models.base_model import BaseModel + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class GeoCalib(BaseModel): + default_conf = { + "backbone": {"name": "encoders.mscan"}, + "ll_enc": {"name": "encoders.low_level_encoder"}, + "perspective_decoder": {"name": "decoders.perspective_decoder"}, + "optimizer": {"name": "optimization.lm_optimizer"}, + } + + required_data_keys = ["image"] + + def _init(self, conf): + logger.debug(f"Initializing GeoCalib with {conf}") + self.backbone = get_model(conf.backbone["name"])(conf.backbone) + self.ll_enc = get_model(conf.ll_enc["name"])(conf.ll_enc) if conf.ll_enc else None + + self.perspective_decoder = get_model(conf.perspective_decoder["name"])( + conf.perspective_decoder + ) + + self.optimizer = ( + get_model(conf.optimizer["name"])(conf.optimizer) if conf.optimizer else None + ) + + def _forward(self, data): + backbone_out = self.backbone(data) + features = {"hl": backbone_out["features"], "padding": backbone_out.get("padding", None)} + + if self.ll_enc is not None: + features["ll"] = self.ll_enc(data)["features"] # low level features + + out = self.perspective_decoder({"features": features}) + + out |= { + k: data[k] + for k in ["image", "scales", "prior_gravity", "prior_focal", "prior_k1"] + if k in data + } + + if self.optimizer is not None: + out |= self.optimizer(out) + + return out + + def loss(self, pred, data): + losses, metrics = self.perspective_decoder.loss(pred, data) + total = losses["perspective_total"] + + if self.optimizer is not None: + opt_losses, param_metrics = self.optimizer.loss(pred, data) + losses |= opt_losses + metrics |= param_metrics + total = total + opt_losses["param_total"] + + losses["total"] = total + return losses, metrics diff --git a/siclib/models/networks/geocalib_pretrained.py b/siclib/models/networks/geocalib_pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..4dde770f18bb5bb1a7c4a4ede06fbb9c907e7e50 --- /dev/null +++ b/siclib/models/networks/geocalib_pretrained.py @@ -0,0 +1,41 @@ +"""Interface for GeoCalib inference package.""" + +from geocalib import GeoCalib +from siclib.models.base_model import BaseModel + + +# mypy: ignore-errors +class GeoCalibPretrained(BaseModel): + """GeoCalib pretrained model.""" + + default_conf = { + "camera_model": "pinhole", + "model_weights": "pinhole", + } + + def _init(self, conf): + """Initialize pretrained GeoCalib model.""" + self.model = GeoCalib(weights=conf.model_weights) + + def _forward(self, data): + """Forward pass.""" + priors = {} + if "prior_gravity" in data: + priors["gravity"] = data["prior_gravity"] + + if "prior_focal" in data: + priors["focal"] = data["prior_focal"] + + results = self.model.calibrate( + data["image"], camera_model=self.conf.camera_model, priors=priors + ) + + return results + + def metrics(self, pred, data): + """Compute metrics.""" + raise NotImplementedError("GeoCalibPretrained does not support metrics computation.") + + def loss(self, pred, data): + """Compute loss.""" + raise NotImplementedError("GeoCalibPretrained does not support loss computation.") diff --git a/siclib/models/optimization/__init__.py b/siclib/models/optimization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/models/optimization/inference_optimizer.py b/siclib/models/optimization/inference_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..b5666fc9d74bb7c6890bfd20ef36f13c4a47580b --- /dev/null +++ b/siclib/models/optimization/inference_optimizer.py @@ -0,0 +1,36 @@ +import logging + +from siclib.models.optimization.lm_optimizer import LMOptimizer + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class InferenceOptimizer(LMOptimizer): + default_conf = { + # Camera model parameters + "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"} + "shared_intrinsics": False, # share focal length across all images in batch + "estimate_gravity": True, + "estimate_focal": True, + "estimate_k1": True, # will be ignored if camera_model is pinhole + # LM optimizer parameters + "num_steps": 30, + "lambda_": 0.1, + "fix_lambda": False, + "early_stop": True, + "atol": 1e-8, + "rtol": 1e-8, + "use_spherical_manifold": True, # use spherical manifold for gravity optimization + "use_log_focal": True, # use log focal length for optimization + # Loss function parameters + "loss_fn": "huber_loss", # {"squared_loss", "huber_loss"} + "up_loss_fn_scale": 1e-2, + "lat_loss_fn_scale": 1e-2, + "init_conf": {"name": "trivial"}, # pass config of other models to use as initializer + # Misc + "loss_weight": 1, + "verbose": False, + } diff --git a/siclib/models/optimization/lm_optimizer.py b/siclib/models/optimization/lm_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..50b899aa89c18c5d149987dad038a21a1db05748 --- /dev/null +++ b/siclib/models/optimization/lm_optimizer.py @@ -0,0 +1,603 @@ +import logging +import time +from typing import Dict, Tuple + +import torch +from torch import nn + +import siclib.models.optimization.losses as losses +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.camera import camera_models +from siclib.geometry.gravity import Gravity +from siclib.geometry.jacobians import J_focal2fov +from siclib.geometry.perspective_fields import J_perspective_field, get_perspective_field +from siclib.models import get_model +from siclib.models.base_model import BaseModel +from siclib.models.optimization.utils import ( + early_stop, + get_initial_estimation, + optimizer_step, + update_lambda, +) +from siclib.models.utils.metrics import ( + dist_error, + gravity_error, + pitch_error, + roll_error, + vfov_error, +) +from siclib.utils.conversions import rad2deg + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class LMOptimizer(BaseModel): + default_conf = { + # Camera model parameters + "camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"} + "shared_intrinsics": False, # share focal length across all images in batch + # LM optimizer parameters + "num_steps": 10, + "lambda_": 0.1, + "fix_lambda": False, + "early_stop": False, + "atol": 1e-8, + "rtol": 1e-8, + "use_spherical_manifold": True, # use spherical manifold for gravity optimization + "use_log_focal": True, # use log focal length for optimization + # Loss function parameters + "loss_fn": "squared_loss", # {"squared_loss", "huber_loss"} + "up_loss_fn_scale": 1e-2, + "lat_loss_fn_scale": 1e-2, + "init_conf": {"name": "trivial"}, # pass config of other models to use as initializer + # Misc + "loss_weight": 1, + "verbose": False, + } + + def _init(self, conf): + self.loss_fn = getattr(losses, conf.loss_fn) + self.num_steps = conf.num_steps + + self.set_camera_model(conf.camera_model) + + self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics) + + self.initializer = None + if self.conf.init_conf.name not in ["trivial", "heuristic"]: + self.initializer = get_model(conf.init_conf.name)(conf.init_conf) + + def set_camera_model(self, camera_model: str) -> None: + """Set the camera model to use for the optimization. + + Args: + camera_model (str): Camera model to use. + """ + assert ( + camera_model in camera_models.keys() + ), f"Unknown camera model: {camera_model} not in {camera_models.keys()}" + self.camera_model = camera_models[camera_model] + self.camera_has_distortion = hasattr(self.camera_model, "dist") + + logger.debug( + f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})" + ) + + def setup_optimization_and_priors( + self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False + ) -> None: + """Setup the optimization and priors for the LM optimizer. + + Args: + data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults + to None. + shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch. + Defaults to False. + """ + if data is None: + data = {} + self.shared_intrinsics = shared_intrinsics + + if shared_intrinsics: # si => must use pinhole + assert ( + self.camera_model == camera_models["pinhole"] + ), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}" + + self.estimate_gravity = True + if "prior_gravity" in data: + self.estimate_gravity = False + logger.debug("Using provided gravity as prior.") + + self.estimate_focal = True + if "prior_focal" in data: + self.estimate_focal = False + logger.debug("Using provided focal as prior.") + + self.estimate_k1 = True + if "prior_k1" in data: + self.estimate_k1 = False + logger.debug("Using provided k1 as prior.") + + self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,) + self.focal_delta_dims = ( + (max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,) + ) + self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,) + + logger.debug(f"Camera Model: {self.camera_model}") + logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})") + logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})") + logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})") + + logger.debug(f"Shared intrinsics: {self.shared_intrinsics}") + + def calculate_residuals( + self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Calculate the residuals for the optimization. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields. + + Returns: + Dict[str, torch.Tensor]: Residuals for the optimization. + """ + perspective_up, perspective_lat = get_perspective_field(camera, gravity) + perspective_lat = torch.sin(perspective_lat) + + residuals = {} + if "up_field" in data: + up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1) + residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2) + + if "latitude_field" in data: + target_lat = torch.sin(data["latitude_field"]) + lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1) + residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1) + + return residuals + + def calculate_costs( + self, residuals: torch.Tensor, data: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """Calculate the costs and weights for the optimization. + + Args: + residuals (torch.Tensor): Residuals for the optimization. + data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence. + + Returns: + Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the + optimization. + """ + costs, weights = {}, {} + + if "up_residual" in residuals: + up_cost = (residuals["up_residual"] ** 2).sum(dim=-1) + up_cost, up_weight, _ = losses.scaled_loss( + up_cost, self.loss_fn, self.conf.up_loss_fn_scale + ) + + if "up_confidence" in data: + up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1) + up_weight = up_weight * up_conf + up_cost = up_cost * up_conf + + costs["up_cost"] = up_cost + weights["up_weights"] = up_weight + + if "latitude_residual" in residuals: + lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1) + lat_cost, lat_weight, _ = losses.scaled_loss( + lat_cost, self.loss_fn, self.conf.lat_loss_fn_scale + ) + + if "latitude_confidence" in data: + lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1) + lat_weight = lat_weight * lat_conf + lat_cost = lat_cost * lat_conf + + costs["latitude_cost"] = lat_cost + weights["latitude_weights"] = lat_weight + + return costs, weights + + def calculate_gradient_and_hessian( + self, + J: torch.Tensor, + residuals: torch.Tensor, + weights: torch.Tensor, + shared_intrinsics: bool, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate the gradient and Hessian for given the Jacobian, residuals, and weights. + + Args: + J (torch.Tensor): Jacobian. + residuals (torch.Tensor): Residuals. + weights (torch.Tensor): Weights. + shared_intrinsics (bool): Whether to share the intrinsics across the batch. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian. + """ + dims = () + if self.estimate_gravity: + dims = (0, 1) + if self.estimate_focal: + dims += (2,) + if self.camera_has_distortion and self.estimate_k1: + dims += (3,) + assert dims, "No parameters to optimize" + + J = J[..., dims] + + Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals) + Grad = weights[..., None] * Grad + Grad = Grad.sum(-2) # (B, N_params) + + if shared_intrinsics: + # reshape to (1, B * (N_params-1) + 1) + Grad_g = Grad[..., :2].reshape(1, -1) + Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True) + Grad = torch.cat([Grad_g, Grad_f], dim=-1) + + Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J) + Hess = weights[..., None, None] * Hess + Hess = Hess.sum(-3) + + if shared_intrinsics: + H_g = torch.block_diag(*list(Hess[..., :2, :2])) + J_fg = Hess[..., :2, 2].flatten() + J_gf = Hess[..., 2, :2].flatten() + J_f = Hess[..., 2, 2].sum() + dims = H_g.shape[-1] + 1 + Hess = Hess.new_zeros((dims, dims), dtype=torch.float32) + Hess[:-1, :-1] = H_g + Hess[-1, :-1] = J_gf + Hess[:-1, -1] = J_fg + Hess[-1, -1] = J_f + Hess = Hess.unsqueeze(0) + + return Grad, Hess + + def setup_system( + self, + camera: BaseCamera, + gravity: Gravity, + residuals: Dict[str, torch.Tensor], + weights: Dict[str, torch.Tensor], + as_rpf: bool = False, + shared_intrinsics: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Calculate the gradient and Hessian for the optimization. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + residuals (Dict[str, torch.Tensor]): Residuals for the optimization. + weights (Dict[str, torch.Tensor]): Weights for the optimization. + as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to + roll, pitch, and focal length. Defaults to False. + shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch. + Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization. + """ + J_up, J_lat = J_perspective_field( + camera, + gravity, + spherical=self.conf.use_spherical_manifold and not as_rpf, + log_focal=self.conf.use_log_focal and not as_rpf, + ) + + J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3) + J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3) + + n_params = ( + 2 * self.estimate_gravity + + self.estimate_focal + + (self.camera_has_distortion and self.estimate_k1) + ) + Grad = J_up.new_zeros(J_up.shape[0], n_params) + Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params) + + if shared_intrinsics: + N_params = Grad.shape[0] * (n_params - 1) + 1 + Grad = Grad.new_zeros(1, N_params) + Hess = Hess.new_zeros(1, N_params, N_params) + + if "up_residual" in residuals: + Up_Grad, Up_Hess = self.calculate_gradient_and_hessian( + J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics + ) + + if self.conf.verbose: + logger.info(f"Up J:\n{Up_Grad.mean(0)}") + + Grad = Grad + Up_Grad + Hess = Hess + Up_Hess + + if "latitude_residual" in residuals: + Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian( + J_lat, + residuals["latitude_residual"], + weights["latitude_weights"], + shared_intrinsics, + ) + + if self.conf.verbose: + logger.info(f"Lat J:\n{Lat_Grad.mean(0)}") + + Grad = Grad + Lat_Grad + Hess = Hess + Lat_Hess + + return Grad, Hess + + def estimate_uncertainty( + self, + camera_opt: BaseCamera, + gravity_opt: Gravity, + errors: Dict[str, torch.Tensor], + weights: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """Estimate the uncertainty of the optimized camera and gravity at the final step. + + Args: + camera_opt (BaseCamera): Final optimized camera. + gravity_opt (Gravity): Final optimized gravity. + errors (Dict[str, torch.Tensor]): Costs for the optimization. + weights (Dict[str, torch.Tensor]): Weights for the optimization. + + Returns: + Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity. + """ + _, Hess = self.setup_system( + camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False + ) + Cov = torch.inverse(Hess) + + roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + if self.estimate_gravity: + roll_uncertainty = Cov[..., 0, 0] + pitch_uncertainty = Cov[..., 1, 1] + + try: + delta_uncertainty = Cov[..., :2, :2] + eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu()) + gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device) + except RuntimeError: + logger.warning("Could not calculate gravity uncertainty") + gravity_uncertainty = Cov.new_zeros(Cov.shape[0]) + + focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape) + if self.estimate_focal: + focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]] + fov_uncertainty = ( + J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty + ) + + return { + "covariance": Cov, + "roll_uncertainty": torch.sqrt(roll_uncertainty), + "pitch_uncertainty": torch.sqrt(pitch_uncertainty), + "gravity_uncertainty": torch.sqrt(gravity_uncertainty), + "focal_uncertainty": torch.sqrt(focal_uncertainty) / 2, + "vfov_uncertainty": torch.sqrt(fov_uncertainty / 2), + } + + def update_estimate( + self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor + ) -> Tuple[BaseCamera, Gravity]: + """Update the camera and gravity estimates with the given delta. + + Args: + camera (BaseCamera): Optimized camera. + gravity (Gravity): Optimized gravity. + delta (torch.Tensor): Delta to update the camera and gravity estimates. + + Returns: + Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates. + """ + delta_gravity = ( + delta[..., self.gravity_delta_dims] + if self.estimate_gravity + else delta.new_zeros(delta.shape[:-1] + (2,)) + ) + new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold) + + delta_f = ( + delta[..., self.focal_delta_dims] + if self.estimate_focal + else delta.new_zeros(delta.shape[:-1] + (1,)) + ) + new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal) + + delta_dist = ( + delta[..., self.k1_delta_dims] + if self.camera_has_distortion and self.estimate_k1 + else delta.new_zeros(delta.shape[:-1] + (1,)) + ) + if self.camera_has_distortion: + new_camera = new_camera.update_dist(delta_dist) + + return new_camera, new_gravity + + def optimize( + self, + data: Dict[str, torch.Tensor], + camera_opt: BaseCamera, + gravity_opt: Gravity, + ) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: + """Optimize the camera and gravity estimates. + + Args: + data (Dict[str, torch.Tensor]): Input data. + camera_opt (BaseCamera): Optimized camera. + gravity_opt (Gravity): Optimized gravity. + + Returns: + Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity + estimates and optimization information. + """ + key = list(data.keys())[0] + B = data[key].shape[0] + + lamb = data[key].new_ones(B) * self.conf.lambda_ + if self.shared_intrinsics: + lamb = data[key].new_ones(1) * self.conf.lambda_ + + infos = {"stop_at": self.num_steps} + for i in range(self.num_steps): + if self.conf.verbose: + logger.info(f"Step {i+1}/{self.num_steps}") + + errors = self.calculate_residuals(camera_opt, gravity_opt, data) + costs, weights = self.calculate_costs(errors, data) + + if i == 0: + prev_cost = sum(c.mean(-1) for c in costs.values()) + for k, c in costs.items(): + infos[f"initial_{k}"] = c.mean(-1) + + infos["initial_cost"] = prev_cost + + Grad, Hess = self.setup_system( + camera_opt, + gravity_opt, + errors, + weights, + shared_intrinsics=self.shared_intrinsics, + ) + delta = optimizer_step(Grad, Hess, lamb) # (B, N_params) + + if self.shared_intrinsics: + delta_g = delta[..., :-1].reshape(B, 2) + delta_f = delta[..., -1].expand(B, 1) + delta = torch.cat([delta_g, delta_f], dim=-1) + + # calculate new cost + camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta) + new_cost, _ = self.calculate_costs( + self.calculate_residuals(camera_opt, gravity_opt, data), data + ) + new_cost = sum(c.mean(-1) for c in new_cost.values()) + + if not self.conf.fix_lambda and not self.shared_intrinsics: + lamb = update_lambda(lamb, prev_cost, new_cost) + + if self.conf.verbose: + logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}") + logger.info(f"Camera:\n{camera_opt._data}") + + if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol): + infos["stop_at"] = min(i + 1, infos["stop_at"]) + + if self.conf.early_stop: + if self.conf.verbose: + logger.info(f"Early stopping at step {i+1}") + break + + prev_cost = new_cost + + if i == self.num_steps - 1 and self.conf.early_stop: + logger.warning("Reached maximum number of steps without convergence.") + + final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3) + final_cost, weights = self.calculate_costs(final_errors, data) # (B, N) + + if not self.training: + infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights) + + infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"] + for k, c in final_cost.items(): + infos[f"final_{k}"] = c.mean(-1) + + infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values()) + + return camera_opt, gravity_opt, infos + + def _forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Run the LM optimization.""" + if self.initializer is None: + camera_init, gravity_init = get_initial_estimation( + data, self.camera_model, trivial_init=self.conf.init_conf.name == "trivial" + ) + else: + out = self.initializer(data) + camera_init = out["camera"] + gravity_init = out["gravity"] + + self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics) + + start = time.time() + camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init) + + if self.conf.verbose: + logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms") + + logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}") + logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}") + + logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}") + logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}") + + return {"camera": camera_opt, "gravity": gravity_opt, **infos} + + def metrics( + self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + """Calculate the metrics for the optimization.""" + pred_cam, gt_cam = pred["camera"], data["camera"] + pred_gravity, gt_gravity = pred["gravity"], data["gravity"] + + infos = {"stop_at": pred["stop_at"]} + for k, v in pred.items(): + if "initial" in k or "final" in k: + infos[k] = v + + return { + "roll_error": roll_error(pred_gravity, gt_gravity), + "pitch_error": pitch_error(pred_gravity, gt_gravity), + "gravity_error": gravity_error(pred_gravity, gt_gravity), + "vfov_error": vfov_error(pred_cam, gt_cam), + "k1_error": dist_error(pred_cam, gt_cam), + **infos, + } + + def loss( + self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor] + ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + """Calculate the loss for the optimization.""" + pred_cam, gt_cam = pred["camera"], data["camera"] + pred_gravity, gt_gravity = pred["gravity"], data["gravity"] + + loss_fn = nn.L1Loss(reduction="none") + + # loss will be 0 if estimate is false and prior is provided during training + gravity_loss = loss_fn(pred_gravity.vec3d, gt_gravity.vec3d) + + h = data["camera"].size[0, 0] + focal_loss = loss_fn(pred_cam.f, gt_cam.f).mean(-1) / h + + dist_loss = focal_loss.new_zeros(focal_loss.shape) + if self.camera_has_distortion: + dist_loss = loss_fn(pred_cam.dist, gt_cam.dist).sum(-1) + + losses = { + "gravity": gravity_loss.sum(-1), + "focal": focal_loss, + "dist": dist_loss, + "param_total": gravity_loss.sum(-1) + focal_loss + dist_loss, + } + + losses = {k: v * self.conf.loss_weight for k, v in losses.items()} + return losses, self.metrics(pred, data) diff --git a/siclib/models/optimization/losses.py b/siclib/models/optimization/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..34cd1f393c0465f860d8bc03f0ae4b8ecacd2b5a --- /dev/null +++ b/siclib/models/optimization/losses.py @@ -0,0 +1,93 @@ +"""Generic losses and error functions for optimization or training deep networks.""" + +from typing import Callable, Tuple + +import torch + + +def scaled_loss( + x: torch.Tensor, fn: Callable, a: float +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Apply a loss function to a tensor and pre- and post-scale it. + + Args: + x: the data tensor, should already be squared: `x = y**2`. + fn: the loss function, with signature `fn(x) -> y`. + a: the scale parameter. + + Returns: + The value of the loss, and its first and second derivatives. + """ + a2 = a**2 + loss, loss_d1, loss_d2 = fn(x / a2) + return loss * a2, loss_d1, loss_d2 / a2 + + +def squared_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """A dummy squared loss.""" + return x, torch.ones_like(x), torch.zeros_like(x) + + +def huber_loss(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """The classical robust Huber loss, with first and second derivatives.""" + mask = x <= 1 + sx = torch.sqrt(x + 1e-8) # avoid nan in backward pass + isx = torch.max(sx.new_tensor(torch.finfo(torch.float).eps), 1 / sx) + loss = torch.where(mask, x, 2 * sx - 1) + loss_d1 = torch.where(mask, torch.ones_like(x), isx) + loss_d2 = torch.where(mask, torch.zeros_like(x), -isx / (2 * x)) + return loss, loss_d1, loss_d2 + + +def barron_loss( + x: torch.Tensor, alpha: torch.Tensor, derivatives: bool = True, eps: float = 1e-7 +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Parameterized & adaptive robust loss function. + + Described in: + A General and Adaptive Robust Loss Function, Barron, CVPR 2019 + + alpha = 2 -> L2 loss + alpha = 1 -> Charbonnier loss (smooth L1) + alpha = 0 -> Cauchy loss + alpha = -2 -> Geman-McClure loss + alpha = -inf -> Welsch loss + + Contrary to the original implementation, assume the the input is already + squared and scaled (basically scale=1). Computes the first derivative, but + not the second (TODO if needed). + """ + loss_two = x + loss_zero = 2 * torch.log1p(torch.clamp(0.5 * x, max=33e37)) + + # The loss when not in one of the above special cases. + # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. + beta_safe = torch.abs(alpha - 2.0).clamp(min=eps) + # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. + alpha_safe = torch.where(alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha)) + alpha_safe = alpha_safe * torch.abs(alpha).clamp(min=eps) + + loss_otherwise = ( + 2 * (beta_safe / alpha_safe) * (torch.pow(x / beta_safe + 1.0, 0.5 * alpha) - 1.0) + ) + + # Select which of the cases of the loss to return. + loss = torch.where(alpha == 0, loss_zero, torch.where(alpha == 2, loss_two, loss_otherwise)) + dummy = torch.zeros_like(x) + + if derivatives: + loss_two_d1 = torch.ones_like(x) + loss_zero_d1 = 2 / (x + 2) + loss_otherwise_d1 = torch.pow(x / beta_safe + 1.0, 0.5 * alpha - 1.0) + loss_d1 = torch.where( + alpha == 0, loss_zero_d1, torch.where(alpha == 2, loss_two_d1, loss_otherwise_d1) + ) + + return loss, loss_d1, dummy + else: + return loss, dummy, dummy + + +def scaled_barron(a, c): + """Return a scaled Barron loss function.""" + return lambda x: scaled_loss(x, lambda y: barron_loss(y, y.new_tensor(a)), c) diff --git a/siclib/models/optimization/perspective_opt.py b/siclib/models/optimization/perspective_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3c1532c9c064368ef025d92aba2b4d287c62ef --- /dev/null +++ b/siclib/models/optimization/perspective_opt.py @@ -0,0 +1,195 @@ +import logging +from typing import Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from tqdm import tqdm + +from siclib.geometry.camera import Pinhole as Camera +from siclib.geometry.gravity import Gravity +from siclib.geometry.perspective_fields import get_perspective_field +from siclib.models.base_model import BaseModel +from siclib.models.utils.metrics import pitch_error, roll_error, vfov_error +from siclib.utils.conversions import deg2rad + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +class PerspectiveParamOpt(BaseModel): + default_conf = { + "max_steps": 1000, + "lr": 0.01, + "lr_scheduler": { + "name": "ReduceLROnPlateau", + "options": {"mode": "min", "patience": 3}, + }, + "patience": 3, + "abs_tol": 1e-7, + "rel_tol": 1e-9, + "lamb": 0.5, + "verbose": False, + } + + required_data_keys = ["up_field", "latitude_field"] + + def _init(self, conf): + pass + + def cost_function(self, pred, target): + """Compute cost function for perspective parameter optimization.""" + eps = 1e-7 + + lat_loss = F.l1_loss(pred["latitude_field"], target["latitude_field"], reduction="none") + lat_loss = lat_loss.squeeze(1) + + up_loss = F.cosine_similarity(pred["up_field"], target["up_field"], dim=1) + up_loss = torch.acos(torch.clip(up_loss, -1 + eps, 1 - eps)) + + cost = (self.conf.lamb * lat_loss) + ((1 - self.conf.lamb) * up_loss) + return { + "total": torch.mean(cost), + "up": torch.mean(up_loss), + "latitude": torch.mean(lat_loss), + } + + def check_convergence(self, loss, losses_prev): + """Check if optimization has converged.""" + + if loss["total"].item() <= self.conf.abs_tol: + return True, losses_prev + + if len(losses_prev) < self.conf.patience: + losses_prev.append(loss["total"].item()) + + elif np.abs(loss["total"].item() - losses_prev[0]) < self.conf.rel_tol: + return True, losses_prev + + else: + losses_prev.append(loss["total"].item()) + losses_prev = losses_prev[-self.conf.patience :] + + return False, losses_prev + + def _update_estimate(self, camera: Camera, gravity: Gravity): + """Update camera estimate based on current parameters.""" + + camera = Camera.from_dict( + {"height": camera.size[..., 1], "width": camera.size[..., 0], "vfov": self.vfov_opt} + ) + gravity = Gravity.from_rp(self.roll_opt, self.pitch_opt) + return camera, gravity + + def optimize(self, data, camera_init, gravity_init): + """Optimize camera parameters to minimize cost function.""" + device = data["up_field"].device + self.roll_opt = nn.Parameter(gravity_init.roll, requires_grad=True).to(device) + self.pitch_opt = nn.Parameter(gravity_init.pitch, requires_grad=True).to(device) + self.vfov_opt = nn.Parameter(camera_init.vfov, requires_grad=True).to(device) + + optimizer = torch.optim.Adam( + [self.roll_opt, self.pitch_opt, self.vfov_opt], lr=self.conf.lr + ) + + lr_scheduler = None + if self.conf.lr_scheduler["name"] is not None: + lr_scheduler = getattr(torch.optim.lr_scheduler, self.conf.lr_scheduler["name"])( + optimizer, **self.conf.lr_scheduler["options"] + ) + + losses_prev = [] + + loop = range(self.conf.max_steps) + if self.conf.verbose: + pbar = tqdm(loop, desc="Optimizing", total=len(loop), ncols=100) + + with torch.set_grad_enabled(True): + self.train() + for _ in loop: + optimizer.zero_grad() + + camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init) + + up, lat = get_perspective_field(camera_opt, gravity_opt) + pred = {"up_field": up, "latitude_field": lat} + + loss = self.cost_function(pred, data) + loss["total"].backward() + optimizer.step() + + if lr_scheduler is not None: + lr_scheduler.step(loss["total"]) + + if self.conf.verbose: + pbar.set_postfix({k[:3]: v.item() for k, v in loss.items()}) + pbar.update(1) + + converged, losses_prev = self.check_convergence(loss, losses_prev) + if converged: + if self.conf.verbose: + pbar.close() + break + + camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init) + return {"camera_opt": camera_opt, "gravity_opt": gravity_opt} + + def _get_init_params(self, data) -> Tuple[Camera, Gravity]: + """Get initial camera parameters for optimization.""" + up_ref = data["up_field"] + latitude_ref = data["latitude_field"] + + h, w = latitude_ref.shape[-2:] + + # init roll is angle of the up vector at the center of the image + init_r = -torch.arctan2( + up_ref[:, 0, int(h / 2), int(w / 2)], + -up_ref[:, 1, int(h / 2), int(w / 2)], + ) + + # init pitch is the value at the center of the latitude map + init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)] + + # init vfov is the difference between the central top and bottom of the latitude map + init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)] + init_vfov = torch.abs(init_vfov) + init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120)) + + h, w = ( + latitude_ref.new_ones(latitude_ref.shape[0]) * h, + latitude_ref.new_ones(latitude_ref.shape[0]) * w, + ) + params = {"width": w, "height": h, "vfov": init_vfov} + camera = Camera.from_dict(params) + gravity = Gravity.from_rp(init_r, init_p) + return camera, gravity + + def _forward(self, data): + """Forward pass of optimization model.""" + + assert data["up_field"].shape[0] == 1, "Batch size must be 1 for optimization model." + + # detach all tensors to avoid backprop + for k, v in data.items(): + if isinstance(v, torch.Tensor): + data[k] = v.detach() + + camera_init, gravity_init = self._get_init_params(data) + return self.optimize(data, camera_init, gravity_init) + + def metrics(self, pred, data): + pred_cam, gt_cam = pred["camera_opt"], data["camera"] + pred_grav, gt_grav = pred["gravity_opt"], data["gravity"] + + return { + "roll_opt_error": roll_error(pred_grav, gt_grav), + "pitch_opt_error": pitch_error(pred_grav, gt_grav), + "vfov_opt_error": vfov_error(pred_cam, gt_cam), + } + + def loss(self, pred, data): + """No loss function for this optimization model.""" + return {"opt_param_total": 0}, self.metrics(pred, data) diff --git a/siclib/models/optimization/ransac.py b/siclib/models/optimization/ransac.py new file mode 100644 index 0000000000000000000000000000000000000000..04c9ca658059d3aa8e2c71322b6a4fa07f5cd342 --- /dev/null +++ b/siclib/models/optimization/ransac.py @@ -0,0 +1,407 @@ +from typing import Optional, Tuple + +import numpy as np +import torch +import torch.nn.functional as F + +from siclib.geometry.camera import Pinhole +from siclib.geometry.gravity import Gravity +from siclib.geometry.perspective_fields import get_latitude_field, get_up_field +from siclib.models.base_model import BaseModel +from siclib.models.utils.metrics import ( + latitude_error, + pitch_error, + roll_error, + up_error, + vfov_error, +) +from siclib.utils.conversions import skew_symmetric + +# flake8: noqa +# mypy: ignore-errors + + +def get_up_lines(up, xy): + up_lines = torch.cat([up, torch.zeros_like(up[..., :1])], dim=-1) + + xy1 = torch.cat([xy, torch.ones_like(xy[..., :1])], dim=-1) + + xy2 = xy1 + up_lines + + return torch.einsum("...ij,...j->...i", skew_symmetric(xy1), xy2) + + +def calculate_vvp(line1, line2): + return torch.einsum("...ij,...j->...i", skew_symmetric(line1), line2) + + +def calculate_vvps(xs, ys, up): + xy_grav = torch.stack([xs[..., :2], ys[..., :2]], dim=-1).float() + up_lines = get_up_lines(up, xy_grav) # (B, N, 2, D) + vvp = calculate_vvp(up_lines[..., 0, :], up_lines[..., 1, :]) # (B, N, 3) + vvp = vvp / vvp[..., (2,)] + return vvp + + +def get_up_samples(pred, xs, ys): + B, N = xs.shape[:2] + batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device) + zeros = torch.zeros_like(xs).to(xs.device) + ones = torch.ones_like(xs).to(xs.device) + sample_indices_x = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4) + sample_indices_y = torch.stack([batch_indices, ones, ys, xs], dim=-1).long() # (B, N, 3, 4) + up_x = pred["up_field"][sample_indices_x[..., (0, 1), :].unbind(-1)] # (B, N, 2) + up_y = pred["up_field"][sample_indices_y[..., (0, 1), :].unbind(-1)] # (B, N, 2) + return torch.stack([up_x, up_y], dim=-1) # (B, N, 2, D) + + +def get_latitude_samples(pred, xs, ys): + # Setup latitude + B, N = xs.shape[:2] + batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device) + zeros = torch.zeros_like(xs).to(xs.device) + sample_indices = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4) + latitude = pred["latitude_field"][sample_indices[..., 2, :].unbind(-1)] + return torch.sin(latitude) # (B, N) + + +class MinimalSolver: + def __init__(self): + pass + + @staticmethod + def solve_focal( + L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Solve for focal length. + + Args: + L (torch.Tensor): Latitude samples. + xy (torch.Tensor): xy of latitude samples of shape (..., 2). + vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). + c (torch.Tensor): Principal points of shape (..., 2). + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Positive and negative solution of focal length. + """ + c = c.unsqueeze(1) + u, v = (xy - c).unbind(-1) + + vx, vy, vz = vvp.unbind(-1) + cx, cy = c.unbind(-1) + vx = vx - cx * vz + vy = vy - cy * vz + + # Solve quadratic equation + a0 = (L**2 - 1) * vz**2 + a1 = L**2 * (vz**2 * (u**2 + v**2) + vx**2 + vy**2) - 2 * vz * (vx * u + vy * v) + a2 = L**2 * (v**2 + u**2) * (vx**2 + vy**2) - (u * vx + v * vy) ** 2 + + a0 = torch.where(a0 == 0, torch.ones_like(a0) * 1e-6, a0) + + f2_pos = -a1 / (2 * a0) + torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0) + f2_neg = -a1 / (2 * a0) - torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0) + + f_pos, f_neg = torch.sqrt(f2_pos), torch.sqrt(f2_neg) + + return f_pos, f_neg + + @staticmethod + def solve_scale( + L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor + ) -> torch.Tensor: + """Solve for scale of homogeneous vector. + + Args: + L (torch.Tensor): Latitude samples. + xy (torch.Tensor): xy of latitude samples of shape (..., 2). + vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). + c (torch.Tensor): Principal points of shape (..., 2). + f (torch.Tensor): Focal lengths. + + Returns: + torch.Tensor: Estimated scales. + """ + c = c.unsqueeze(1) + u, v = (xy - c).unbind(-1) + + vx, vy, vz = vvp.unbind(-1) + cx, cy = c.unbind(-1) + vx = vx - cx * vz + vy = vy - cy * vz + + w2 = (f**2 * L**2 * (u**2 + v**2 + f**2)) / (vx * u + vy * v + vz * f**2) ** 2 + return torch.sqrt(w2) + + @staticmethod + def solve_abc( + vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor, w: Optional[torch.Tensor] = None + ) -> torch.Tensor: + """Solve for abc vector (solution to homogeneous equation). + + Args: + vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). + c (torch.Tensor): Principal points of shape (..., 2). + f (torch.Tensor): Focal lengths. + w (torch.Tensor): Scales. + + Returns: + torch.Tensor: Estimated abc vector. + """ + vx, vy, vz = vvp.unbind(-1) + cx, cy = c.unsqueeze(1).unbind(-1) + vx = vx - cx * vz + vy = vy - cy * vz + + a = vx / f + b = vy / f + c = vz + + abc = torch.stack((a, b, c), dim=-1) + + return F.normalize(abc, dim=-1) if w is None else abc * w.unsqueeze(-1) + + @staticmethod + def solve_rp(abc: torch.Tensor) -> torch.Tensor: + """Solve for roll, pitch. + + Args: + abc (torch.Tensor): Estimated abc vector. + + Returns: + torch.Tensor: Estimated roll, pitch, focal length. + """ + a, _, c = abc.unbind(-1) + roll = torch.asin(-a / torch.sqrt(1 - c**2)) + pitch = torch.asin(c) + return roll, pitch + + +class RPFSolver(BaseModel): + default_conf = { + "n_iter": 1000, + "up_inlier_th": 1, + "latitude_inlier_th": 1, + "error_fn": "angle", # angle or mse + "up_weight": 1, + "latitude_weight": 1, + "loss_weight": 1, + "use_latitude": True, + } + + def _init(self, conf): + self.solver = MinimalSolver() + + def check_up_inliers(self, pred, est_camera, est_gravity, N=1): + pred_up = pred["up_field"] + # expand from from (B, 1, H, W) to (B * N, 1, H, W) + B = pred_up.shape[0] + pred_up = pred_up.unsqueeze(1).expand(-1, N, -1, -1, -1) + pred_up = pred_up.reshape(B * N, *pred_up.shape[2:]) + + est_up = get_up_field(est_camera, est_gravity).permute(0, 3, 1, 2) + + if self.conf.error_fn == "angle": + mse = up_error(est_up, pred_up) + elif self.conf.error_fn == "mse": + mse = F.mse_loss(est_up, pred_up, reduction="none").mean(1) + else: + raise ValueError(f"Unknown error function: {self.conf.error_fn}") + + # shape (B, H, W) + conf = pred.get("up_confidence", pred_up.new_ones(pred_up.shape[0], *pred_up.shape[-2:])) + # shape (B, N, H, W) + conf = conf.unsqueeze(1).expand(-1, N, -1, -1) + # shape (B * N, H, W) + conf = conf.reshape(B * N, *conf.shape[-2:]) + + return (mse < self.conf.up_inlier_th) * conf + + def check_latitude_inliers(self, pred, est_camera, est_gravity, N=1): + B = pred["up_field"].shape[0] + pred_latitude = pred.get("latitude_field") + + if pred_latitude is None: + shape = (B * N, *pred["up_field"].shape[-2:]) + return est_camera.new_zeros(shape) + + # expand from from (B, 1, H, W) to (B * N, 1, H, W) + pred_latitude = pred_latitude.unsqueeze(1).expand(-1, N, -1, -1, -1) + pred_latitude = pred_latitude.reshape(B * N, *pred_latitude.shape[2:]) + + est_latitude = get_latitude_field(est_camera, est_gravity).permute(0, 3, 1, 2) + + if self.conf.error_fn == "angle": + error = latitude_error(est_latitude, pred_latitude) + elif self.conf.error_fn == "mse": + error = F.mse_loss(est_latitude, pred_latitude, reduction="none").mean(1) + else: + raise ValueError(f"Unknown error function: {self.conf.error_fn}") + + conf = pred.get( + "latitude_confidence", + pred_latitude.new_ones(pred_latitude.shape[0], *pred_latitude.shape[-2:]), + ) + conf = conf.unsqueeze(1).expand(-1, N, -1, -1) + conf = conf.reshape(B * N, *conf.shape[-2:]) + return (error < self.conf.latitude_inlier_th) * conf + + def get_best_index(self, data, camera, gravity, inliers=None): + B, _, H, W = data["up_field"].shape + N = self.conf.n_iter + + up_inliers = self.check_up_inliers(data, camera, gravity, N) + latitude_inliers = self.check_latitude_inliers(data, camera, gravity, N) + + up_inliers = up_inliers.reshape(B, N, H, W) + latitude_inliers = latitude_inliers.reshape(B, N, H, W) + + if inliers is not None: + up_inliers = up_inliers * inliers.unsqueeze(1) + latitude_inliers = latitude_inliers * inliers.unsqueeze(1) + + up_inliers = up_inliers.sum((2, 3)) + latitude_inliers = latitude_inliers.sum((2, 3)) + + total_inliers = ( + self.conf.up_weight * up_inliers + self.conf.latitude_weight * latitude_inliers + ) + + best_idx = total_inliers.argmax(-1) + + return best_idx, total_inliers[torch.arange(B), best_idx] + + def solve_rpf(self, pred, xs, ys, principal_points, focal=None): + device = pred["up_field"].device + + # Get samples + up = get_up_samples(pred, xs, ys) + + # Calculate vvps + vvp = calculate_vvps(xs, ys, up).to(device) + + # Solve for focal length + xy = torch.stack([xs[..., 2], ys[..., 2]], dim=-1).float() + if focal is not None: + f = focal.new_ones(xs[..., 2].shape) * focal.unsqueeze(-1) + f_pos, f_neg = f, f + else: + L = get_latitude_samples(pred, xs, ys) + f_pos, f_neg = self.solver.solve_focal(L, xy, vvp, principal_points) + + # Solve for abc + abc_pos = self.solver.solve_abc(vvp, principal_points, f_pos) + abc_neg = self.solver.solve_abc(vvp, principal_points, f_neg) + + # Solve for roll, pitch + roll_pos, pitch_pos = self.solver.solve_rp(abc_pos) + roll_neg, pitch_neg = self.solver.solve_rp(abc_neg) + + rpf_pos = torch.stack([roll_pos, pitch_pos, f_pos], dim=-1) + rpf_neg = torch.stack([roll_neg, pitch_neg, f_neg], dim=-1) + + return rpf_pos, rpf_neg + + def get_camera_and_gravity(self, pred, rpf): + B, _, H, W = pred["up_field"].shape + N = rpf.shape[1] + + w = pred["up_field"].new_ones(B, N) * W + h = pred["up_field"].new_ones(B, N) * H + cx = w / 2.0 + cy = h / 2.0 + + roll, pitch, focal = rpf.unbind(-1) + + params = torch.stack([w, h, focal, focal, cx, cy], dim=-1) + params = params.reshape(B * N, params.shape[-1]) + cam = Pinhole(params) + + roll, pitch = roll.reshape(B * N), pitch.reshape(B * N) + gravity = Gravity.from_rp(roll, pitch) + + return cam, gravity + + def _forward(self, data): + device = data["up_field"].device + B, _, H, W = data["up_field"].shape + + principal_points = torch.tensor([H / 2.0, W / 2.0]).expand(B, 2).to(device) + + if not self.conf.use_latitude and "latitude_field" in data: + data.pop("latitude_field") + + if "inliers" in data: + indices = torch.nonzero(data["inliers"] == 1, as_tuple=False) + batch_indices = torch.unique(indices[:, 0]) + + sampled_indices = [] + for batch_index in batch_indices: + batch_mask = indices[:, 0] == batch_index + + batch_indices_sampled = np.random.choice( + batch_mask.sum(), self.conf.n_iter * 3, replace=True + ) + batch_indices_sampled = batch_indices_sampled.reshape(self.conf.n_iter, 3) + sampled_indices.append(indices[batch_mask][batch_indices_sampled][:, :, 1:]) + + ys, xs = torch.stack(sampled_indices, dim=0).unbind(-1) + + else: + xs = torch.randint(0, W, (B, self.conf.n_iter, 3)).to(device) + ys = torch.randint(0, H, (B, self.conf.n_iter, 3)).to(device) + + rpf_pos, rpf_neg = self.solve_rpf( + data, xs, ys, principal_points, focal=data.get("prior_focal") + ) + + cams_pos, gravity_pos = self.get_camera_and_gravity(data, rpf_pos) + cams_neg, gravity_neg = self.get_camera_and_gravity(data, rpf_neg) + + inliers = data.get("inliers", None) + best_pos, score_pos = self.get_best_index(data, cams_pos, gravity_pos, inliers) + best_neg, score_neg = self.get_best_index(data, cams_neg, gravity_neg, inliers) + + rpf = rpf_pos[torch.arange(B), best_pos] + rpf[score_neg > score_pos] = rpf_neg[torch.arange(B), best_neg][score_neg > score_pos] + + cam, gravity = self.get_camera_and_gravity(data, rpf.unsqueeze(1)) + + return { + "camera_opt": cam, + "gravity_opt": gravity, + "up_inliers": self.check_up_inliers(data, cam, gravity), + "latitude_inliers": self.check_latitude_inliers(data, cam, gravity), + } + + def metrics(self, pred, data): + pred_cam, gt_cam = pred["camera_opt"], data["camera"] + pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"] + + return { + "roll_opt_error": roll_error(pred_gravity, gt_gravity), + "pitch_opt_error": pitch_error(pred_gravity, gt_gravity), + "vfov_opt_error": vfov_error(pred_cam, gt_cam), + } + + def loss(self, pred, data): + pred_cam, gt_cam = pred["camera_opt"], data["camera"] + pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"] + + h = data["camera"].size[0, 0] + + gravity_loss = F.l1_loss(pred_gravity.vec3d, gt_gravity.vec3d, reduction="none") + focal_loss = F.l1_loss(pred_cam.f, gt_cam.f, reduction="none").sum(-1) / h + + total_loss = gravity_loss.sum(-1) + if self.conf.estimate_focal: + total_loss = total_loss + focal_loss + + losses = { + "opt_gravity": gravity_loss.sum(-1), + "opt_focal": focal_loss, + "opt_param_total": total_loss, + } + + losses = {k: v * self.conf.loss_weight for k, v in losses.items()} + return losses, self.metrics(pred, data) diff --git a/siclib/models/optimization/utils.py b/siclib/models/optimization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1aae429134c3c34f78ec4df439c83c2df58e0509 --- /dev/null +++ b/siclib/models/optimization/utils.py @@ -0,0 +1,172 @@ +import logging +from typing import Dict + +import torch + +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.gravity import Gravity +from siclib.utils.conversions import deg2rad, focal2fov + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +def get_initial_estimation( + data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True +) -> BaseCamera: + """Get initial camera for optimization using heuristics.""" + return ( + get_trivial_estimation(data, camera_model) + if trivial_init + else get_heuristic_estimation(data, camera_model) + ) + + +def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera: + """Get initial camera for optimization using heuristics. + + Initial camera is initialized with the following heuristics: + - roll is the angle of the up vector at the center of the image + - pitch is the value at the center of the latitude map + - vfov is the difference between the central top and bottom of the latitude map + - distortions are set to zero + + Use the prior values if available. + + Args: + data (Dict[str, torch.Tensor]): Input data dictionary. + camera_model (BaseCamera): Camera model to use. + + Returns: + BaseCamera: Initial camera for optimization. + """ + up_ref = data["up_field"].detach() + latitude_ref = data["latitude_field"].detach() + + h, w = up_ref.shape[-2:] + batch_h, batch_w = ( + up_ref.new_ones((up_ref.shape[0],)) * h, + up_ref.new_ones((up_ref.shape[0],)) * w, + ) + + # init roll is angle of the up vector at the center of the image + init_r = -torch.atan2( + up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)] + ) + init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45)) + + # init pitch is the value at the center of the latitude map + init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)] + init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45)) + + # init vfov is the difference between the central top and bottom of the latitude map + init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)] + init_vfov = torch.abs(init_vfov) + init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120)) + + focal = data.get("prior_focal") + init_vfov = init_vfov if focal is None else focal2fov(focal, h) + + params = {"width": batch_w, "height": batch_h, "vfov": init_vfov} + params |= {"scales": data["scales"]} if "scales" in data else {} + params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {} + camera = camera_model.from_dict(params) + camera = camera.float().to(data["up_field"].device) + + gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device) + if "prior_gravity" in data: + gravity = data["prior_gravity"].float().to(up_ref.device) + + return camera, gravity + + +def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera: + """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w). + + Args: + data (Dict[str, torch.Tensor]): Input data dictionary. + camera_model (BaseCamera): Camera model to use. + + Returns: + BaseCamera: Initial camera for optimization. + """ + """Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).""" + ref = data.get("up_field", data["latitude_field"]) + ref = ref.detach() + + h, w = ref.shape[-2:] + batch_h, batch_w = ( + ref.new_ones((ref.shape[0],)) * h, + ref.new_ones((ref.shape[0],)) * w, + ) + + init_r = ref.new_zeros((ref.shape[0],)) + init_p = ref.new_zeros((ref.shape[0],)) + + focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w)) + init_vfov = init_vfov if focal is None else focal2fov(focal, h) + + params = {"width": batch_w, "height": batch_h, "vfov": init_vfov} + params |= {"scales": data["scales"]} if "scales" in data else {} + params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {} + camera = camera_model.from_dict(params) + camera = camera.float().to(ref.device) + + gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device) + + if "prior_gravity" in data: + gravity = data["prior_gravity"].float().to(ref.device) + + return camera, gravity + + +def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool: + """Early stopping criterion based on cost convergence.""" + return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol) + + +def update_lambda( + lamb: torch.Tensor, + prev_cost: torch.Tensor, + new_cost: torch.Tensor, + lambda_min: float = 1e-6, + lambda_max: float = 1e2, +) -> torch.Tensor: + """Update damping factor for Levenberg-Marquardt optimization.""" + new_lamb = lamb.new_zeros(lamb.shape) + new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1) + lamb = torch.clamp(new_lamb, lambda_min, lambda_max) + return lamb + + +def optimizer_step( + G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6 +) -> torch.Tensor: + """One optimization step with Gauss-Newton or Levenberg-Marquardt. + + Args: + G (torch.Tensor): Batched gradient tensor of size (..., N). + H (torch.Tensor): Batched hessian tensor of size (..., N, N). + lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,). + eps (float, optional): Epsilon for damping. Defaults to 1e-6. + + Returns: + torch.Tensor: Batched update tensor of size (..., N). + """ + diag = H.diagonal(dim1=-2, dim2=-1) + diag = diag * lambda_.unsqueeze(-1) # (B, 3) + + H = H + diag.clamp(min=eps).diag_embed() + + H_, G_ = H.cpu(), G.cpu() + try: + U = torch.linalg.cholesky(H_) + except RuntimeError: + logger.warning("Cholesky decomposition failed. Stopping.") + delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3) + else: + delta = torch.cholesky_solve(G_[..., None], U)[..., 0] + + return delta.to(H.device) diff --git a/siclib/models/optimization/vp_from_prior.py b/siclib/models/optimization/vp_from_prior.py new file mode 100644 index 0000000000000000000000000000000000000000..b8cd333573dffc7ff2fc8b1ffbbe751a6aae3d4b --- /dev/null +++ b/siclib/models/optimization/vp_from_prior.py @@ -0,0 +1,182 @@ +"""Wrapper for VP estimation with prior gravity using the VP-Estimation-with-Prior-Gravity library. + +repo: https://github.com/cvg/VP-Estimation-with-Prior-Gravity +""" + +import sys + +sys.path.append("third_party/VP-Estimation-with-Prior-Gravity") +sys.path.append("third_party/VP-Estimation-with-Prior-Gravity/src/deeplsd") + +import logging +import random + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +import torch +from vp_estimation_with_prior_gravity.evaluation import get_labels_from_vp, project_vp_to_image +from vp_estimation_with_prior_gravity.features.line_detector import LineDetector +from vp_estimation_with_prior_gravity.solvers import run_hybrid_uncalibrated +from vp_estimation_with_prior_gravity.visualization import plot_images, plot_lines, plot_vp + +from siclib.geometry.camera import Pinhole +from siclib.geometry.gravity import Gravity +from siclib.models import BaseModel +from siclib.models.utils.metrics import gravity_error, pitch_error, roll_error, vfov_error + +# flake8: noqa +# mypy: ignore-errors + +logger = logging.getLogger(__name__) + + +class VPEstimator(BaseModel): + # Which solvers to us for our hybrid solver: + # 0 - 2lines 200g + # 1 - 2lines 110g + # 2 - 2lines 011g + # 3 - 4lines 211 + # 4 - 4lines 220 + default_conf = { + "SOLVER_FLAGS": [True, True, True, True, True], + "th_pixels": 3, # RANSAC inlier threshold + "ls_refinement": 2, # 3 uses the gravity in the LS refinement, 2 does not. + "nms": 3, # change to 3 to add a Ceres optimization after the non minimal solver (slower) + "magsac_scoring": True, + "line_type": "deeplsd", # 'lsd' or 'deeplsd' + "min_lines": 5, # only trust images with at least this many lines + "verbose": False, + } + + def _init(self, conf): + if conf.SOLVER_FLAGS in [ + [True, False, False, False, False], + [False, False, True, False, False], + ]: + self.vertical = np.array([random.random() / 1e12, 1, random.random() / 1e12]) + self.vertical /= np.linalg.norm(self.vertical) + else: + self.vertical = np.array([0.0, 1, 0.0]) + + self.line_detector = LineDetector(line_detector=conf.line_type) + + self.verbose = conf.verbose + + def visualize_lines(self, vp, lines, img, K): + vp_labels = get_labels_from_vp( + lines[:, :, [1, 0]], project_vp_to_image(vp, K), threshold=self.conf.th_pixels + )[0] + + plot_images([img, img]) + plot_lines([lines, np.empty((0, 2, 2))]) + plot_vp([np.empty((0, 2, 2)), lines], [[], vp_labels]) + + plt.show() + + def get_vvp(self, vp, K): + best_idx, best_cossim = 0, -1 + for i, point in enumerate(vp): + cossim = np.dot(self.vertical, point) / np.linalg.norm(point) + point = -point * np.dot(self.vertical, point) + try: + gravity = Gravity(np.linalg.inv(K) @ point) + except: + continue + + if ( + np.abs(cossim) > best_cossim + and gravity.pitch.abs() <= np.pi / 4 + and gravity.roll.abs() <= np.pi / 4 + ): + best_idx, best_cossim = i, np.abs(cossim) + + vvp = vp[best_idx] + return -vvp * np.sign(np.dot(self.vertical, vvp)) + + def _forward(self, data): + device = data["image"].device + images = data["image"].cpu() + + estimations = [] + for idx, img in enumerate(images.unbind(0)): + if "prior_gravity" in data: + self.vertical = -data["prior_gravity"][idx].vec3d.cpu().numpy() + else: + self.vertical = np.array([0.0, 1, 0.0]) + + img = img.numpy().transpose(1, 2, 0) * 255 + img = img.astype(np.uint8) + gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) + + lines = self.line_detector.detect_lines(gray_img)[:, :, [1, 0]] + + if len(lines) < self.conf.min_lines: + logger.warning("Not enough lines detected! Skipping...") + gravity = Gravity.from_rp(np.nan, np.nan) + camera = Pinhole.from_dict( + {"f": np.nan, "height": img.shape[0], "width": img.shape[1]} + ) + estimations.append({"camera": camera, "gravity": gravity}) + continue + + principle_point = np.array([img.shape[1] / 2.0, img.shape[0] / 2.0]) + f, vp = run_hybrid_uncalibrated( + lines - principle_point[None, None, :], + self.vertical, + th_pixels=self.conf.th_pixels, + ls_refinement=self.conf.ls_refinement, + nms=self.conf.nms, + magsac_scoring=self.conf.magsac_scoring, + sprt=True, + solver_flags=self.conf.SOLVER_FLAGS, + ) + vp[:, 1] *= -1 + + K = np.array( + [[f, 0.0, principle_point[0]], [0.0, f, principle_point[1]], [0.0, 0.0, 1.0]] + ) + + if self.verbose: + self.visualize_lines(vp, lines, img, K) + + vp_labels = get_labels_from_vp( + lines[:, :, [1, 0]], project_vp_to_image(vp, K), threshold=self.conf.th_pixels + )[0] + out = {"vp": vp, "lines": lines, "K": K, "vp_labels": vp_labels} + + vp = project_vp_to_image(vp, K) + + vvp = self.get_vvp(vp, K) + + vvp = -vvp * np.sign(np.dot(self.vertical, vvp)) + try: + K_inv = np.linalg.inv(K) + gravity = Gravity(K_inv @ vvp) + except np.linalg.LinAlgError: + gravity = Gravity.from_rp(np.nan, np.nan) + + camera = Pinhole.from_dict({"f": f, "height": img.shape[0], "width": img.shape[1]}) + estimations.append({"camera": camera, "gravity": gravity}) + + if len(estimations) == 0: + return {} + + gravity = torch.stack([Gravity(est["gravity"].vec3d) for est in estimations], dim=0) + camera = torch.stack([Pinhole(est["camera"]._data) for est in estimations], dim=0) + + return {"camera": camera.float().to(device), "gravity": gravity.float().to(device)} | out + + def metrics(self, pred, data): + pred_cam, gt_cam = pred["camera_opt"], data["camera"] + pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"] + + return { + "roll_opt_error": roll_error(pred_gravity, gt_gravity), + "pitch_opt_error": pitch_error(pred_gravity, gt_gravity), + "gravity_opt_error": gravity_error(pred_gravity, gt_gravity), + "vfov_opt_error": vfov_error(pred_cam, gt_cam), + } + + def loss(self, pred, data): + return {}, self.metrics(pred, data) diff --git a/siclib/models/utils/__init__.py b/siclib/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/models/utils/metrics.py b/siclib/models/utils/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..2a51c7fbea501185d8200b7e35045a6123d5f045 --- /dev/null +++ b/siclib/models/utils/metrics.py @@ -0,0 +1,123 @@ +"""Various metrics for evaluating predictions.""" + +import logging + +import torch +from torch.nn import functional as F + +from siclib.geometry.base_camera import BaseCamera +from siclib.geometry.gravity import Gravity +from siclib.utils.conversions import rad2deg + +logger = logging.getLogger(__name__) + + +def pitch_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: + """Computes the pitch error between two gravities. + + Args: + pred_gravity (Gravity): Predicted camera. + target_gravity (Gravity): Ground truth camera. + + Returns: + torch.Tensor: Pitch error in degrees. + """ + return rad2deg(torch.abs(pred_gravity.pitch - target_gravity.pitch)) + + +def roll_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: + """Computes the roll error between two gravities. + + Args: + pred_gravity (Gravity): Predicted Gravity. + target_gravity (Gravity): Ground truth Gravity. + + Returns: + torch.Tensor: Roll error in degrees. + """ + return rad2deg(torch.abs(pred_gravity.roll - target_gravity.roll)) + + +def gravity_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor: + """Computes the gravity error between two gravities. + + Args: + pred_gravity (Gravity): Predicted Gravity. + target_gravity (Gravity): Ground truth Gravity. + + Returns: + torch.Tensor: Gravity error in degrees. + """ + assert ( + pred_gravity.vec3d.shape == target_gravity.vec3d.shape + ), f"{pred_gravity.vec3d.shape} != {target_gravity.vec3d.shape}" + assert pred_gravity.vec3d.ndim == 2, f"{pred_gravity.vec3d.ndim} != 2" + assert pred_gravity.vec3d.shape[1] == 3, f"{pred_gravity.vec3d.shape[1]} != 3" + + cossim = F.cosine_similarity(pred_gravity.vec3d, target_gravity.vec3d, dim=-1).clamp(-1, 1) + return rad2deg(torch.acos(cossim)) + + +def vfov_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor: + """Computes the vertical field of view error between two cameras. + + Args: + pred_cam (Camera): Predicted camera. + target_cam (Camera): Ground truth camera. + + Returns: + torch.Tensor: Vertical field of view error in degrees. + """ + return rad2deg(torch.abs(pred_cam.vfov - target_cam.vfov)) + + +def dist_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor: + """Computes the distortion parameter error between two cameras. + + Returns zero if the cameras do not have distortion parameters. + + Args: + pred_cam (Camera): Predicted camera. + target_cam (Camera): Ground truth camera. + + Returns: + torch.Tensor: distortion error. + """ + if hasattr(pred_cam, "dist") and hasattr(target_cam, "dist"): + return torch.abs(pred_cam.dist[..., 0] - target_cam.dist[..., 0]) + + logger.debug( + f"Predicted / target camera doesn't have distortion parameters: {pred_cam}/{target_cam}" + ) + return pred_cam.new_zeros(pred_cam.f.shape[0]) + + +def latitude_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Computes the latitude error between two tensors. + + Args: + predictions (torch.Tensor): Predicted latitude field of shape (B, 1, H, W). + targets (torch.Tensor): Ground truth latitude field of shape (B, 1, H, W). + + Returns: + torch.Tensor: Latitude error in degrees of shape (B, H, W). + """ + return rad2deg(torch.abs(predictions - targets)).squeeze(1) + + +def up_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """Computes the up error between two tensors. + + Args: + predictions (torch.Tensor): Predicted up field of shape (B, 2, H, W). + targets (torch.Tensor): Ground truth up field of shape (B, 2, H, W). + + Returns: + torch.Tensor: Up error in degrees of shape (B, H, W). + """ + assert predictions.shape == targets.shape, f"{predictions.shape} != {targets.shape}" + assert predictions.ndim == 4, f"{predictions.ndim} != 4" + assert predictions.shape[1] == 2, f"{predictions.shape[1]} != 2" + + angle = F.cosine_similarity(predictions, targets, dim=1).clamp(-1, 1) + return rad2deg(torch.acos(angle)) diff --git a/siclib/models/utils/modules.py b/siclib/models/utils/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c79e2c121ae739ac089b56efe33b4dfa5e31ca33 --- /dev/null +++ b/siclib/models/utils/modules.py @@ -0,0 +1,264 @@ +"""Various modules used in the decoder of the model. + +Adapted from https://github.com/jinlinyi/PerspectiveFields +""" + +import logging + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0 and scale_by_keep: + random_tensor.div_(keep_prob) + return x * random_tensor + + +class DropPath(nn.Module): + """DropBlock, DropPath + + PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers. + + Papers: + DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890) + + Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382) + + Code: + DropBlock impl inspired by two Tensorflow impl: + - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74 + - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py + + Hacked together by / Copyright 2020 Ross Wightman + """ + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob,3):0.3f}" + + +class DWConv(nn.Module): + def __init__(self, dim=768): + super(DWConv, self).__init__() + self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim) + + def forward(self, x): + x = self.dwconv(x) + return x + + +class MLP(nn.Module): + """Linear Embedding.""" + + def __init__(self, input_dim=2048, embed_dim=768): + super().__init__() + self.proj = nn.Linear(input_dim, embed_dim) + + def forward(self, x): + x = x.flatten(2).transpose(1, 2) + x = self.proj(x) + return x + + +class ConvModule(nn.Module): + """Replacement for mmcv.cnn.ConvModule to avoid mmcv dependency.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int, + padding: int = 0, + use_norm: bool = False, + bias: bool = True, + ): + super().__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias) + self.bn = nn.BatchNorm2d(out_channels) if use_norm else nn.Identity() + self.activate = nn.ReLU(inplace=True) + + def forward(self, x): + x = self.conv(x) + x = self.bn(x) + return self.activate(x) + + +class ResidualConvUnit(nn.Module): + """Residual convolution module.""" + + def __init__(self, features): + """Init. + Args: + features (int): number of features + """ + super().__init__() + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True) + + self.relu = torch.nn.ReLU(inplace=True) + + def forward(self, x): + """Forward pass. + Args: + x (tensor): input + Returns: + tensor: output + """ + out = self.relu(x) + out = self.conv1(out) + out = self.relu(out) + out = self.conv2(out) + return out + x + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block.""" + + def __init__(self, features, unit2only=False, upsample=True): + """Init. + Args: + features (int): number of features + """ + super().__init__() + self.upsample = upsample + + if not unit2only: + self.resConfUnit1 = ResidualConvUnit(features) + self.resConfUnit2 = ResidualConvUnit(features) + + def forward(self, *xs): + """Forward pass.""" + output = xs[0] + + if len(xs) == 2: + output = output + self.resConfUnit1(xs[1]) + + output = self.resConfUnit2(output) + + if self.upsample: + output = F.interpolate(output, scale_factor=2, mode="bilinear", align_corners=False) + + return output + + +class _DenseLayer(nn.Module): + def __init__(self, num_input_features, growth_rate, bn_size, drop_rate, memory_efficient): + super().__init__() + self.norm1 = nn.BatchNorm2d(num_input_features) + self.relu1 = nn.ReLU(inplace=True) + self.conv1 = nn.Conv2d( + num_input_features, bn_size * growth_rate, kernel_size=1, stride=1, bias=False + ) + + self.norm2 = nn.BatchNorm2d(bn_size * growth_rate) + self.relu2 = nn.ReLU(inplace=True) + self.conv2 = nn.Conv2d( + bn_size * growth_rate, growth_rate, kernel_size=3, stride=1, padding=1, bias=False + ) + + self.drop_rate = float(drop_rate) + self.memory_efficient = memory_efficient + + def bn_function(self, inputs): + concated_features = torch.cat(inputs, 1) + return self.conv1(self.relu1(self.norm1(concated_features))) + + def any_requires_grad(self, inp): + return any(tensor.requires_grad for tensor in inp) + + @torch.jit.unused # noqa: T484 + def call_checkpoint_bottleneck(self, inp): + def closure(*inputs): + return self.bn_function(inputs) + + return cp.checkpoint(closure, *inp) + + @torch.jit._overload_method # noqa: F811 + def forward(self, inp) -> Tensor: # noqa: F811 + pass + + @torch.jit._overload_method # noqa: F811 + def forward(self, inp): # noqa: F811 + pass + + # torchscript does not yet support *args, so we overload method + # allowing it to take either a List[Tensor] or single Tensor + def forward(self, inp): # noqa: F811 + prev_features = [inp] if isinstance(inp, Tensor) else inp + if self.memory_efficient and self.any_requires_grad(prev_features): + if torch.jit.is_scripting(): + raise Exception("Memory Efficient not supported in JIT") + + bottleneck_output = self.call_checkpoint_bottleneck(prev_features) + else: + bottleneck_output = self.bn_function(prev_features) + + new_features = self.conv2(self.relu2(self.norm2(bottleneck_output))) + if self.drop_rate > 0: + new_features = F.dropout(new_features, p=self.drop_rate, training=self.training) + return new_features + + +class _DenseBlock(nn.ModuleDict): + _version = 2 + + def __init__( + self, + num_layers, + num_input_features, + bn_size, + growth_rate, + drop_rate, + memory_efficient=False, + ): + super().__init__() + for i in range(num_layers): + layer = _DenseLayer( + num_input_features + i * growth_rate, + growth_rate=growth_rate, + bn_size=bn_size, + drop_rate=drop_rate, + memory_efficient=memory_efficient, + ) + self.add_module("denselayer%d" % (i + 1), layer) + + def forward(self, init_features): + features = [init_features] + for name, layer in self.items(): + new_features = layer(features) + features.append(new_features) + return torch.cat(features, 1) + + +class _Transition(nn.Sequential): + def __init__(self, num_input_features, num_output_features): + super().__init__() + self.norm = nn.BatchNorm2d(num_input_features) + self.relu = nn.ReLU(inplace=True) + self.conv = nn.Conv2d( + num_input_features, num_output_features, kernel_size=1, stride=1, bias=False + ) + self.pool = nn.AvgPool2d(kernel_size=2, stride=2) diff --git a/siclib/models/utils/perspective_encoding.py b/siclib/models/utils/perspective_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..1890fbd4caa5508f5768cc9111841f58d61e451d --- /dev/null +++ b/siclib/models/utils/perspective_encoding.py @@ -0,0 +1,82 @@ +"""Perspective field utilities. + +Adapted from https://github.com/jinlinyi/PerspectiveFields +""" + +import torch + +from siclib.utils.conversions import deg2rad, rad2deg + + +def encode_up_bin(vector_field: torch.Tensor, num_bin: int) -> torch.Tensor: + """Encode vector field into classification bins. + + Args: + vector_field (torch.Tensor): gravity field of shape (2, h, w), with channel 0 cos(theta) and + 1 sin(theta) + num_bin (int): number of classification bins + + Returns: + torch.Tensor: encoded bin indices of shape (1, h, w) + """ + angle = ( + torch.atan2(vector_field[1, :, :], vector_field[0, :, :]) / torch.pi * 180 + 180 + ) % 360 # [0,360) + angle_bin = torch.round(torch.div(angle, (360 / (num_bin - 1)))).long() + angle_bin[angle_bin == num_bin - 1] = 0 + invalid = (vector_field == 0).sum(0) == vector_field.size(0) + angle_bin[invalid] = num_bin - 1 + return deg2rad(angle_bin.type(torch.LongTensor)) + + +def decode_up_bin(angle_bin: torch.Tensor, num_bin: int) -> torch.Tensor: + """Decode classification bins into vector field. + + Args: + angle_bin (torch.Tensor): bin indices of shape (1, h, w) + num_bin (int): number of classification bins + + Returns: + torch.Tensor: decoded vector field of shape (2, h, w) + """ + angle = (angle_bin * (360 / (num_bin - 1)) - 180) / 180 * torch.pi + cos = torch.cos(angle) + sin = torch.sin(angle) + vector_field = torch.stack((cos, sin), dim=1) + invalid = angle_bin == num_bin - 1 + invalid = invalid.unsqueeze(1).repeat(1, 2, 1, 1) + vector_field[invalid] = 0 + return vector_field + + +def encode_bin_latitude(latimap: torch.Tensor, num_classes: int) -> torch.Tensor: + """Encode latitude map into classification bins. + + Args: + latimap (torch.Tensor): latitude map of shape (h, w) with values in [-90, 90] + num_classes (int): number of classes + + Returns: + torch.Tensor: encoded latitude bin indices + """ + boundaries = torch.arange(-90, 90, 180 / num_classes)[1:] + binmap = torch.bucketize(rad2deg(latimap), boundaries) + return binmap.type(torch.LongTensor) + + +def decode_bin_latitude(binmap: torch.Tensor, num_classes: int) -> torch.Tensor: + """Decode classification bins to latitude map. + + Args: + binmap (torch.Tensor): encoded classification bins + num_classes (int): number of classes + + Returns: + torch.Tensor: latitude map of shape (h, w) + """ + bin_size = 180 / num_classes + bin_centers = torch.arange(-90, 90, bin_size) + bin_size / 2 + bin_centers = bin_centers.to(binmap.device) + latimap = bin_centers[binmap] + + return deg2rad(latimap) diff --git a/siclib/pose_estimation.py b/siclib/pose_estimation.py new file mode 100644 index 0000000000000000000000000000000000000000..13854e940712f9a6fb48f6a3965ac97e176f3137 --- /dev/null +++ b/siclib/pose_estimation.py @@ -0,0 +1,148 @@ +import pickle +from pathlib import Path + +import numpy as np +import poselib +import pycolmap + +from siclib.models.extractor import VP + +from .models.extractor import GeoCalib +from .utils.image import load_image + +# flake8: noqa +# mypy: ignore-errors + + +class AbsolutePoseEstimator: + default_opts = { + "ransac": "poselib_gravity", # pycolmap, poselib, poselib_gravity + "refinement": "pycolmap_gravity", # pycolmap, pycolmap_gravity, none + "gravity_weight": 50000, + "max_reproj_error": 48.0, + "loss_function_scale": 1.0, + "use_vp": False, + "max_uncertainty": 10.0 / 180.0 * 3.1415, # radians + "cache_path": "../../outputs/inloc/calib.pickle", + } + + def __init__(self, pose_opts=None): + pose_opts = {} if pose_opts is None else pose_opts + self.opts = {**self.default_opts, **pose_opts} + self.device = "cuda" + + if self.opts["use_vp"]: + self.calib = VP().to(self.device) + self.cache_path = str(self.opts["cache_path"]).replace(".pickle", "_vp.pickle") + else: + self.calib = GeoCalib().to(self.device) + self.cache_path = str(self.opts["cache_path"]) + + # self.read_cache() + self.cache = {} + + def read_cache(self): + print(f"Reading cache from {self.cache_path} ({Path(self.cache_path).exists()})") + if not Path(self.cache_path).exists(): + self.cache = {} + return + with open(self.cache_path, "rb") as handle: + self.cache = pickle.load(handle) + + def write_cache(self): + with open(self.cache_path, "wb") as handle: + pickle.dump(self.cache, handle, protocol=pickle.HIGHEST_PROTOCOL) + + def __call__(self, query_path, p2d, p3d, camera_dict): + focal_length = pycolmap.Camera(camera_dict).mean_focal_length() + + if query_path in self.cache: + calib = self.cache[query_path] + else: + calib = self.calib.calibrate( + load_image(query_path).to(self.device), priors={"f": focal_length} + ) + calib = {k: v[0].detach().cpu().numpy() for k, v in calib.items()} + self.cache[query_path] = calib + # self.write_cache() + + if self.opts["ransac"] == "pycolmap": + ret = pycolmap.absolute_pose_estimation( + p2d, p3d, camera_dict, self.opts["max_reproj_error"] # , do_refine=False + ) + elif self.opts["ransac"] == "poselib": + M, ret = poselib.estimate_absolute_pose( + p2d, + p3d, + camera_dict, + ransac_opt={"max_reproj_error": self.opts["max_reproj_error"]}, + ) + ret["success"] = M is not None + ret["qvec"] = M.q + ret["tvec"] = M.t + elif self.opts["ransac"] == "poselib_gravity": + g_q = calib["gravity"].vec3d + g_qu = calib.get("gravity_uncertainty", self.opts["max_uncertainty"]) + M, ret = poselib.estimate_absolute_pose_gravity( + p2d, + p3d, + camera_dict, + g_q, + g_qu * 2 * 180 / 3.1415, # convert to scalar + ransac_opt={"max_reproj_error": self.opts["max_reproj_error"]}, + ) + ret["success"] = M is not None + ret["qvec"] = M.q + ret["tvec"] = M.t + else: + raise NotImplementedError(self.opts["ransac"]) + r_opts = { + "refine_focal_length": False, + "refine_extra_params": False, + "print_summary": False, + "loss_function_scale": self.opts["loss_function_scale"], + } + if self.opts["refinement"] == "pycolmap_gravity": + g_q = calib["gravity"].vec3d + g_qu = calib.get("gravity_uncertainty", self.opts["max_uncertainty"]) + if g_qu <= self.opts["max_uncertainty"]: + g_gt = np.array([0, 0, 1]) # world frame + ret_ref = pycolmap.pose_refinement_gravity( + ret["tvec"], + ret["qvec"], + p2d, + p3d, + ret["inliers"], + camera_dict, + g_q, + g_gt, + self.opts["gravity_weight"], + r_opts, + ) + else: + ret_ref = pycolmap.pose_refinement( + ret["tvec"], + ret["qvec"], + p2d, + p3d, + ret["inliers"], + camera_dict, + r_opts, + ) + elif self.opts["refinement"] == "pycolmap": + ret_ref = pycolmap.pose_refinement( + ret["tvec"], + ret["qvec"], + p2d, + p3d, + ret["inliers"], + camera_dict, + r_opts, + ) + elif self.opts["refinement"] == "none": + ret_ref = {} + else: + raise NotImplementedError(self.opts["refinement"]) + ret = {**ret, **ret_ref} + ret["camera_dict"] = camera_dict + return ret, calib diff --git a/siclib/pyproject.toml b/siclib/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..7ce9891101e33ce630d1f25bc130f9066bc64fe3 --- /dev/null +++ b/siclib/pyproject.toml @@ -0,0 +1,47 @@ +[build-system] +requires = ["setuptools", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "siclib" +version = "1.0" +description = "Training library for GeoCalib: Learning Single-image Calibration with Geometric Optimization" +authors = [ + { name = "Alexander Veicht" }, + { name = "Paul-Edouard Sarlin" }, + { name = "Philipp Lindenberger" }, +] +requires-python = ">=3.9" +license = { file = "LICENSE" } +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", +] +urls = { Repository = "https://github.com/cvg/GeoCalib" } + +dynamic = ["dependencies"] + +[project.optional-dependencies] +dev = ["black==23.9.1", "flake8", "isort==5.12.0"] + +[tool.setuptools.packages.find] +where = ["."] + +[tool.setuptools.dynamic] +dependencies = { file = ["requirements.txt"] } + +[tool.black] +line-length = 100 +exclude = "(venv/|docs/|third_party/)" + +[tool.isort] +profile = "black" +line_length = 100 +atomic = true + +[tool.flake8] +max-line-length = 100 +docstring-convention = "google" +ignore = ["E203", "W503", "E402"] +exclude = [".git", "__pycache__", "venv", "docs", "third_party", "scripts"] diff --git a/siclib/requirements.txt b/siclib/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..3b7748c50891d5ea018d02a7a12853780f06cde6 --- /dev/null +++ b/siclib/requirements.txt @@ -0,0 +1,14 @@ +torch +torchvision +opencv-python +kornia +matplotlib + +omegaconf +albumentations +h5py +hydra-core +pandas +tqdm +tensorboard +wandb diff --git a/siclib/settings.py b/siclib/settings.py new file mode 100644 index 0000000000000000000000000000000000000000..3606b983b6ba69d96c4a1e2b14d8b4f0f8c70854 --- /dev/null +++ b/siclib/settings.py @@ -0,0 +1,12 @@ +from pathlib import Path + +# flake8: noqa +# mypy: ignore-errors +try: + from settings import DATA_PATH, EVAL_PATH, TRAINING_PATH +except ModuleNotFoundError: + # @TODO: Add a way to patch paths + root = Path(__file__).parent.parent # top-level directory + DATA_PATH = root / "data/" # datasets and pretrained weights + TRAINING_PATH = root / "outputs/training/" # training checkpoints + EVAL_PATH = root / "outputs/results/" # evaluation results diff --git a/siclib/train.py b/siclib/train.py new file mode 100644 index 0000000000000000000000000000000000000000..ce1ee134c56ee33ee9b13a452987dedc8b76a912 --- /dev/null +++ b/siclib/train.py @@ -0,0 +1,750 @@ +""" +A generic training script that works with any model and dataset. + +Author: Paul-Edouard Sarlin (skydes) +""" + +# Filter annoying warnings +import warnings + +warnings.simplefilter("ignore", UserWarning) + +import argparse +import copy +import re +import shutil +import signal +from collections import defaultdict +from pathlib import Path +from pydoc import locate + +import numpy as np +import torch +from hydra import compose, initialize +from omegaconf import OmegaConf +from torch.cuda.amp import GradScaler, autocast +from tqdm import tqdm + +from siclib import __module_name__, logger +from siclib.datasets import get_dataset +from siclib.eval import run_benchmark +from siclib.models import get_model +from siclib.settings import EVAL_PATH, TRAINING_PATH +from siclib.utils.experiments import get_best_checkpoint, get_last_checkpoint, save_experiment +from siclib.utils.stdout_capturing import capture_outputs +from siclib.utils.summary_writer import SummaryWriter +from siclib.utils.tensor import batch_to_device +from siclib.utils.tools import ( + AverageMetric, + MedianMetric, + PRMetric, + RecallMetric, + fork_rng, + get_device, + set_seed, +) + +# flake8: noqa +# mypy: ignore-errors + + +# TODO: Fix pbar pollution in logs +# TODO: add plotting during evaluation + +default_train_conf = { + "seed": "???", # training seed + "epochs": 1, # number of epochs + "num_steps": None, # number of steps, overwrites epochs + "optimizer": "adam", # name of optimizer in [adam, sgd, rmsprop] + "opt_regexp": None, # regular expression to filter parameters to optimize + "optimizer_options": {}, # optional arguments passed to the optimizer + "lr": 0.001, # learning rate + "lr_schedule": { + "type": None, + "start": 0, + "exp_div_10": 0, + "on_epoch": False, + "factor": 1.0, + }, + "lr_scaling": [(100, ["dampingnet.const"])], + "eval_every_iter": 1000, # interval for evaluation on the validation set + "save_every_iter": 5000, # interval for saving the current checkpoint + "log_every_iter": 200, # interval for logging the loss to the console + "log_grad_every_iter": None, # interval for logging gradient hists + "writer": "tensorboard", # tensorboard or wandb + "test_every_epoch": 1, # interval for evaluation on the test benchmarks + "keep_last_checkpoints": 10, # keep only the last X checkpoints + "load_experiment": None, # initialize the model from a previous experiment + "median_metrics": [], # add the median of some metrics + "recall_metrics": {}, # add the recall of some metrics + "pr_metrics": {}, # add pr curves, set labels/predictions/mask keys + "best_key": "loss/total", # key to use to select the best checkpoint + "dataset_callback_fn": None, # data func called at the start of each epoch + "dataset_callback_on_val": False, # call data func on val data? + "clip_grad": None, + "pr_curves": {}, + "plot": None, + "submodules": [], +} +default_train_conf = OmegaConf.create(default_train_conf) + + +def get_lr_scheduler(optimizer, conf): + """Get lr scheduler specified by conf.""" + # logger.info(f"Using lr scheduler with conf: {conf}") + if conf.type not in ["factor", "exp", None]: + if hasattr(conf.options, "schedulers"): + # Add option to chain multiple schedulers together + # This is useful for e.g. warmup, then cosine decay + """Example: { + "type": "SequentialLR", + "options": { + "milestones": [1_000], + "schedulers": [ + {"type": "LinearLR", "options": {"total_iters": 10, "start_factor": 0.001}}, + {"type": "MultiStepLR", "options": {"milestones": [40, 60], "gamma": 0.1}}, + ], + } + } + """ + schedulers = [] + for scheduler_conf in conf.options.schedulers: + scheduler = get_lr_scheduler(optimizer, scheduler_conf) + schedulers.append(scheduler) + + options = {k: v for k, v in conf.options.items() if k != "schedulers"} + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, schedulers, **options) + + return getattr(torch.optim.lr_scheduler, conf.type)(optimizer, **conf.options) + + # backward compatibility + def lr_fn(it): # noqa: E306 + if conf.type is None: + return 1 + if conf.type == "factor": + return 1.0 if it < conf.start else conf.factor + if conf.type == "exp": + gam = 10 ** (-1 / conf.exp_div_10) + return 1.0 if it < conf.start else gam + else: + raise ValueError(conf.type) + + return torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_fn) + + +@torch.no_grad() +def do_evaluation(model, loader, device, loss_fn, conf, pbar=True): + model.eval() + results = {} + recall_results = {} + pr_metrics = defaultdict(PRMetric) + figures = [] + if conf.plot is not None: + n, plot_fn = conf.plot + plot_ids = np.random.choice(len(loader), min(len(loader), n), replace=False) + for i, data in enumerate(tqdm(loader, desc="Evaluation", ascii=True, disable=not pbar)): + data = batch_to_device(data, device, non_blocking=True) + with torch.no_grad(): + pred = model(data) + losses, metrics = loss_fn(pred, data) + if conf.plot is not None and i in plot_ids: + figures.append(locate(plot_fn)(pred, data)) + # add PR curves + for k, v in conf.pr_curves.items(): + pr_metrics[k].update( + pred[v["labels"]], + pred[v["predictions"]], + mask=pred[v["mask"]] if "mask" in v.keys() else None, + ) + del pred, data + + numbers = {**metrics, **{f"loss/{k}": v for k, v in losses.items()}} + for k, v in numbers.items(): + if k not in results: + results[k] = AverageMetric() + if k in conf.median_metrics: + results[f"{k}_median"] = MedianMetric() + + if k not in recall_results and k in conf.recall_metrics.keys(): + ths = conf.recall_metrics[k] + recall_results[k] = RecallMetric(ths) + + results[k].update(v) + if k in conf.median_metrics: + results[f"{k}_median"].update(v) + if k in conf.recall_metrics.keys(): + recall_results[k].update(v) + + del numbers + + results = {k: results[k].compute() for k in results} + + for k, v in recall_results.items(): + for th, recall in zip(conf.recall_metrics[k], v.compute()): + results[f"{k}_recall@{th}"] = recall + + return results, {k: v.compute() for k, v in pr_metrics.items()}, figures + + +def filter_parameters(params, regexp): + """Filter trainable parameters based on regular expressions.""" + + # Examples of regexp: + # '.*(weight|bias)$' + # 'cnn\.(enc0|enc1).*bias' + def filter_fn(x): + n, p = x + match = re.search(regexp, n) + if not match: + p.requires_grad = False + return match + + params = list(filter(filter_fn, params)) + assert len(params) > 0, regexp + logger.info("Selected parameters:\n" + "\n".join(n for n, p in params)) + return params + + +def pack_lr_parameters(params, base_lr, lr_scaling): + """Pack each group of parameters with the respective scaled learning rate.""" + filters, scales = tuple(zip(*[(n, s) for s, names in lr_scaling for n in names])) + scale2params = defaultdict(list) + for n, p in params: + scale = 1 + is_match = [f in n for f in filters] + if any(is_match): + scale = scales[is_match.index(True)] + scale2params[scale].append((n, p)) + logger.info( + "Parameters with scaled learning rate:\n%s", + {s: [n for n, _ in ps] for s, ps in scale2params.items() if s != 1}, + ) + return [ + {"lr": scale * base_lr, "params": [p for _, p in ps]} for scale, ps in scale2params.items() + ] + + +def training(rank, conf, output_dir, args): + if args.restore: + logger.info(f"Restoring from previous training of {args.experiment}") + try: + init_cp = get_last_checkpoint(args.experiment, allow_interrupted=False) + except AssertionError: + init_cp = get_best_checkpoint(args.experiment) + logger.info(f"Restoring from checkpoint {init_cp.name}") + init_cp = torch.load(str(init_cp), map_location="cpu") + conf = OmegaConf.merge(OmegaConf.create(init_cp["conf"]), conf) + conf.train = OmegaConf.merge(default_train_conf, conf.train) + epoch = init_cp["epoch"] + 1 + + # get the best loss or eval metric from the previous best checkpoint + best_cp = get_best_checkpoint(args.experiment) + best_cp = torch.load(str(best_cp), map_location="cpu") + best_eval = best_cp["eval"][conf.train.best_key] + del best_cp + else: + # we start a new, fresh training + conf.train = OmegaConf.merge(default_train_conf, conf.train) + epoch = 0 + best_eval = float("inf") + if conf.train.load_experiment: + logger.info(f"Will fine-tune from weights of {conf.train.load_experiment}") + # the user has to make sure that the weights are compatible + try: + init_cp = get_last_checkpoint(conf.train.load_experiment) + except AssertionError: + init_cp = get_best_checkpoint(conf.train.load_experiment) + # init_cp = get_last_checkpoint(conf.train.load_experiment) + init_cp = torch.load(str(init_cp), map_location="cpu") + # load the model config of the old setup, and overwrite with current config + conf.model = OmegaConf.merge(OmegaConf.create(init_cp["conf"]).model, conf.model) + print(conf.model) + else: + init_cp = None + + OmegaConf.set_struct(conf, True) # prevent access to unknown entries + set_seed(conf.train.seed) + if rank == 0: + writer = SummaryWriter(conf, args, str(output_dir)) + + data_conf = copy.deepcopy(conf.data) + if args.distributed: + logger.info(f"Training in distributed mode with {args.n_gpus} GPUs") + assert torch.cuda.is_available() + device = rank + torch.distributed.init_process_group( + backend="nccl", + world_size=args.n_gpus, + rank=device, + init_method="file://" + str(args.lock_file), + ) + torch.cuda.set_device(device) + + # adjust batch size and num of workers since these are per GPU + if "batch_size" in data_conf: + data_conf.batch_size = int(data_conf.batch_size / args.n_gpus) + if "train_batch_size" in data_conf: + data_conf.train_batch_size = int(data_conf.train_batch_size / args.n_gpus) + if "num_workers" in data_conf: + data_conf.num_workers = int((data_conf.num_workers + args.n_gpus - 1) / args.n_gpus) + else: + device = get_device() + logger.info(f"Using device {device}") + + dataset = get_dataset(data_conf.name)(data_conf) + + # Optionally load a different validation dataset than the training one + val_data_conf = conf.get("data_val", None) + if val_data_conf is None: + val_dataset = dataset + else: + val_dataset = get_dataset(val_data_conf.name)(val_data_conf) + + # @TODO: add test data loader + + if args.overfit: + # we train and eval with the same single training batch + logger.info("Data in overfitting mode") + assert not args.distributed + train_loader = dataset.get_overfit_loader("train") + val_loader = val_dataset.get_overfit_loader("val") + else: + train_loader = dataset.get_data_loader("train", distributed=args.distributed) + val_loader = val_dataset.get_data_loader("val") + if rank == 0: + logger.info(f"Training loader has {len(train_loader)} batches") + logger.info(f"Validation loader has {len(val_loader)} batches") + + # interrupts are caught and delayed for graceful termination + def sigint_handler(signal, frame): + logger.info("Caught keyboard interrupt signal, will terminate") + nonlocal stop + if stop: + raise KeyboardInterrupt + stop = True + + stop = False + signal.signal(signal.SIGINT, sigint_handler) + + model = get_model(conf.model.name)(conf.model).to(device) + if args.compile: + model = torch.compile(model, mode=args.compile) + loss_fn = model.loss + if init_cp is not None: + model.load_state_dict(init_cp["model"], strict=False) + if args.distributed: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device]) + if rank == 0 and args.print_arch: + logger.info(f"Model: \n{model}") + + torch.backends.cudnn.benchmark = True + if args.detect_anomaly: + logger.info("Enabling anomaly detection") + torch.autograd.set_detect_anomaly(True) + + optimizer_fn = { + "sgd": torch.optim.SGD, + "adam": torch.optim.Adam, + "adamw": torch.optim.AdamW, + "rmsprop": torch.optim.RMSprop, + }[conf.train.optimizer] + params = [(n, p) for n, p in model.named_parameters() if p.requires_grad] + if conf.train.opt_regexp: + params = filter_parameters(params, conf.train.opt_regexp) + all_params = [p for n, p in params] + logger.info(f"Num parameters: {sum(p.numel() for p in all_params)}") + + lr_params = pack_lr_parameters(params, conf.train.lr, conf.train.lr_scaling) + optimizer = optimizer_fn(lr_params, lr=conf.train.lr, **conf.train.optimizer_options) + scaler = GradScaler(enabled=args.mixed_precision is not None) + logger.info(f"Training with mixed_precision={args.mixed_precision}") + + mp_dtype = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + None: torch.float32, # we disable it anyway + }[args.mixed_precision] + + results = None # fix bug with it saving + + lr_scheduler = get_lr_scheduler(optimizer=optimizer, conf=conf.train.lr_schedule) + logger.info(f"Using lr scheduler of type {type(lr_scheduler)}") + + if args.restore: + optimizer.load_state_dict(init_cp["optimizer"]) + if "lr_scheduler" in init_cp: + lr_scheduler.load_state_dict(init_cp["lr_scheduler"]) + + if rank == 0: + logger.info("Starting training with configuration:\n%s", OmegaConf.to_yaml(conf)) + losses_ = None + + def trace_handler(p): + # torch.profiler.tensorboard_trace_handler(str(output_dir)) + output = p.key_averages().table(sort_by="self_cuda_time_total", row_limit=10) + print(output) + p.export_chrome_trace("trace_" + str(p.step_num) + ".json") + p.export_stacks("/tmp/profiler_stacks.txt", "self_cuda_time_total") + + if args.profile: + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=1), + on_trace_ready=torch.profiler.tensorboard_trace_handler(str(output_dir)), + record_shapes=True, + profile_memory=True, + with_stack=True, + ) + prof.__enter__() + + if conf.train.log_grad_every_iter: + writer.watch(model, log_freq=conf.train.log_grad_every_iter) + + if conf.train.num_steps is not None: + conf.train.epochs = conf.train.num_steps // len(train_loader) + 1 + conf.train.epochs = conf.train.epochs // (args.n_gpus if args.distributed else 1) + logger.info(f"Setting epochs to {conf.train.epochs} to match num_steps.") + + while epoch < conf.train.epochs and not stop: + tot_it = (len(train_loader) * epoch) * (args.n_gpus if args.distributed else 1) + tot_n_samples = tot_it * train_loader.batch_size + + if conf.train.num_steps is not None and tot_it > conf.train.num_steps: + logger.info(f"Reached max number of steps {conf.train.num_steps}") + stop = True + + if rank == 0: + logger.info(f"Starting epoch {epoch}") + + # we first run the eval + if ( + rank == 0 + and epoch % conf.train.test_every_epoch == 0 + and (epoch > 0 or not args.no_test_0) + ): + for bname, eval_conf in conf.get("benchmarks", {}).items(): + logger.info(f"Running eval on {bname}") + s, f, r = run_benchmark( + bname, + eval_conf, + EVAL_PATH / bname / args.experiment / str(epoch), + model.eval(), + ) + for metric_name, value in s.items(): + writer.add_scalar(f"test/{bname}/{metric_name}", value, step=tot_n_samples) + for fig_name, fig in f.items(): + writer.add_figure(f"figures/{bname}/{fig_name}", fig, step=tot_n_samples) + + str_results = [f"{k} {v:.3E}" for k, v in s.items() if isinstance(v, float)] + if rank == 0: + logger.info(f'[Test {bname}] {{{", ".join(str_results)}}}') + + # set the seed + set_seed(conf.train.seed + epoch) + + # update learning rate + if conf.train.lr_schedule.on_epoch and epoch > 0: + old_lr = optimizer.param_groups[0]["lr"] + lr_scheduler.step(epoch) + logger.info(f'lr changed from {old_lr} to {optimizer.param_groups[0]["lr"]}') + + if args.distributed: + train_loader.sampler.set_epoch(epoch) + if epoch > 0 and conf.train.dataset_callback_fn and not args.overfit: + loaders = [train_loader] + if conf.train.dataset_callback_on_val: + loaders += [val_loader] + for loader in loaders: + if isinstance(loader.dataset, torch.utils.data.Subset): + getattr(loader.dataset.dataset, conf.train.dataset_callback_fn)( + conf.train.seed + epoch + ) + else: + getattr(loader.dataset, conf.train.dataset_callback_fn)(conf.train.seed + epoch) + for it, data in enumerate(train_loader): + # logger.info(f"Starting iteration {it} - epoch {epoch} - rank {rank}") + tot_it = (len(train_loader) * epoch + it) * (args.n_gpus if args.distributed else 1) + tot_n_samples = tot_it + if not args.log_it: + # We normalize the x-axis of tensorboard to num samples! + tot_n_samples *= train_loader.batch_size + + model.train() + optimizer.zero_grad() + + with autocast(enabled=args.mixed_precision is not None, dtype=mp_dtype): + data = batch_to_device(data, device, non_blocking=False) + pred = model(data) + losses, metrics = loss_fn(pred, data) + loss = torch.mean(losses["total"]) + + # Skip the iteration if any rank encountered a NaN + if loss_has_nan(loss, distributed=args.distributed): + logger.warning(f"Skipping iteration {it} due to NaN (rank {rank})") + del pred, data, loss, losses, metrics + torch.cuda.empty_cache() + continue + + do_backward = loss.requires_grad + if args.distributed: + do_backward = torch.tensor(do_backward).float().to(device) + torch.distributed.all_reduce(do_backward, torch.distributed.ReduceOp.PRODUCT) + do_backward = do_backward > 0 + + if do_backward: + scaler.scale(loss).backward() + if args.detect_anomaly: + # Check for params without any gradient which causes + # problems in distributed training with checkpointing + detected_anomaly = False + for name, param in model.named_parameters(): + if param.grad is None and param.requires_grad: + logger.warning(f"param {name} has no gradient.") + detected_anomaly = True + if detected_anomaly: + raise RuntimeError("Detected anomaly in training.") + + if conf.train.get("clip_grad", None): + scaler.unscale_(optimizer) + try: + torch.nn.utils.clip_grad_norm_( + all_params, + max_norm=conf.train.clip_grad, + error_if_nonfinite=True, + ) + scaler.step(optimizer) + except RuntimeError: + logger.warning("NaN detected in gradient clipping. Skipping iteration.") + scaler.update() + else: + scaler.step(optimizer) + scaler.update() + + if not conf.train.lr_schedule.on_epoch: + [lr_scheduler.step() for _ in range(args.n_gpus if args.distributed else 1)] + else: + if rank == 0: + logger.warning(f"Skip iteration {it} due to detach/nan. (rank {rank})") + + if args.profile: + prof.step() + + if it % conf.train.log_every_iter == 0: + train_results = metrics | losses + for k in sorted(train_results.keys()): + if args.distributed: + train_results[k] = train_results[k].sum(-1) + torch.distributed.reduce(train_results[k], dst=0) + train_results[k] /= train_loader.batch_size * args.n_gpus + train_results[k] = torch.mean(train_results[k], -1) + train_results[k] = train_results[k].item() + if rank == 0: + str_losses = [f"{k} {v:.3E}" for k, v in train_results.items()] + logger.info( + "[E {} | it {}] loss {{{}}}".format(epoch, it, ", ".join(str_losses)) + ) + for k, v in train_results.items(): + writer.add_scalar("training/" + k, v, tot_n_samples) + + writer.add_scalar("training/lr", optimizer.param_groups[0]["lr"], tot_n_samples) + writer.add_scalar("training/epoch", epoch, tot_n_samples) + + if ( + conf.train.log_grad_every_iter is not None + and it % conf.train.log_grad_every_iter == 0 + ): + grad_txt = "" + for name, param in model.named_parameters(): + if param.grad is not None and param.requires_grad: + if name.endswith("bias"): + continue + writer.add_histogram(f"grad/{name}", param.grad.detach(), tot_n_samples) + norm = torch.norm(param.grad.detach(), 2) + grad_txt += f"{name} {norm.item():.3f} \n" + writer.add_text(f"grad/summary", grad_txt, tot_n_samples) + del pred, data, loss, losses + + # Run validation + if ( + (it % conf.train.eval_every_iter == 0 and (it > 0 or epoch == -int(args.no_eval_0))) + or stop + or it == (len(train_loader) - 1) + ): + with fork_rng(seed=conf.train.seed): + results, pr_metrics, figures = do_evaluation( + model, + val_loader, + device, + loss_fn, + conf.train, + pbar=(rank == -1), + ) + + if rank == 0: + str_results = [ + f"{k} {v:.3E}" for k, v in results.items() if isinstance(v, float) + ] + logger.info(f'[Validation] {{{", ".join(str_results)}}}') + for k, v in results.items(): + if isinstance(v, dict): + writer.add_scalars(f"figure/val/{k}", v, tot_n_samples) + else: + writer.add_scalar("val/" + k, v, tot_n_samples) + for k, v in pr_metrics.items(): + writer.add_pr_curve("val/" + k, *v, tot_n_samples) + # @TODO: optional always save checkpoint + if results[conf.train.best_key] < best_eval: + best_eval = results[conf.train.best_key] + save_experiment( + model, + optimizer, + lr_scheduler, + conf, + losses_, + results, + best_eval, + epoch, + tot_it, + output_dir, + stop, + args.distributed, + cp_name="checkpoint_best.tar", + ) + logger.info(f"New best val: {conf.train.best_key}={best_eval}") + if len(figures) > 0: + for i, figs in enumerate(figures): + for name, fig in figs.items(): + writer.add_figure(f"figures/{i}_{name}", fig, tot_n_samples) + torch.cuda.empty_cache() # should be cleared at the first iter + + if (tot_it % conf.train.save_every_iter == 0 and tot_it > 0) and rank == 0: + if results is None: + results, _, _ = do_evaluation( + model, + val_loader, + device, + loss_fn, + conf.train, + pbar=(rank == -1), + ) + best_eval = results[conf.train.best_key] + best_eval = save_experiment( + model, + optimizer, + lr_scheduler, + conf, + losses_, + results, + best_eval, + epoch, + tot_it, + output_dir, + stop, + args.distributed, + ) + + if stop: + break + + if rank == 0: + best_eval = save_experiment( + model, + optimizer, + lr_scheduler, + conf, + losses_, + results, + best_eval, + epoch, + tot_it, + output_dir=output_dir, + stop=stop, + distributed=args.distributed, + ) + + epoch += 1 + + logger.info(f"Finished training on process {rank}.") + if rank == 0: + writer.close() + + +def loss_has_nan(loss: torch.Tensor, distributed: bool) -> bool: + """Check if any rank has encountered a NaN loss.""" + has_nan = torch.tensor([torch.isnan(loss).any().float()]).to(loss.device) + + # Synchronize the has_nan variable across all ranks + if distributed: + torch.distributed.all_reduce(has_nan, op=torch.distributed.ReduceOp.MAX) + + return has_nan.item() > 0.5 + + +def main_worker(rank, conf, output_dir, args): + if rank == 0: + with capture_outputs(output_dir / "log.txt"): + training(rank, conf, output_dir, args) + else: + training(rank, conf, output_dir, args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("experiment", type=str) + parser.add_argument("--conf", type=str) + parser.add_argument( + "--mixed_precision", + "--mp", + default=None, + type=str, + choices=["float16", "bfloat16"], + ) + parser.add_argument( + "--compile", + default=None, + type=str, + choices=["default", "reduce-overhead", "max-autotune"], + ) + parser.add_argument("--overfit", action="store_true") + parser.add_argument("--restore", action="store_true") + parser.add_argument("--distributed", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument("--print_arch", "--pa", action="store_true") + parser.add_argument("--detect_anomaly", "--da", action="store_true") + parser.add_argument("--log_it", "--log_it", action="store_true") + parser.add_argument("--no_eval_0", action="store_true") + parser.add_argument("--no_test_0", action="store_true") + parser.add_argument("dotlist", nargs="*") + args = parser.parse_intermixed_args() + + logger.info(f"Starting experiment {args.experiment}") + output_dir = Path(TRAINING_PATH, args.experiment) + output_dir.mkdir(exist_ok=True, parents=True) + + conf = OmegaConf.from_cli(args.dotlist) + + if args.conf: + initialize(version_base=None, config_path="configs") + conf = compose(config_name=args.conf, overrides=args.dotlist) + elif args.restore: + restore_conf = OmegaConf.load(output_dir / "config.yaml") + conf = OmegaConf.merge(restore_conf, conf) + + if not args.restore: + if conf.train.seed is None: + conf.train.seed = torch.initial_seed() & (2**32 - 1) + OmegaConf.save(conf, str(output_dir / "config.yaml")) + + # copy geocalib and submodule into output dir + for module in conf.train.submodules + [__module_name__]: + mod_dir = Path(__import__(str(module)).__file__).parent + shutil.copytree(mod_dir, output_dir / module, dirs_exist_ok=True) + + if args.distributed: + args.n_gpus = torch.cuda.device_count() + args.lock_file = output_dir / "distributed_lock" + if args.lock_file.exists(): + args.lock_file.unlink() + torch.multiprocessing.spawn(main_worker, nprocs=args.n_gpus, args=(conf, output_dir, args)) + else: + main_worker(0, conf, output_dir, args) diff --git a/siclib/utils/__init__.py b/siclib/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/utils/conversions.py b/siclib/utils/conversions.py new file mode 100644 index 0000000000000000000000000000000000000000..cd019d3c58fc76236543cebb985184a671f07146 --- /dev/null +++ b/siclib/utils/conversions.py @@ -0,0 +1,149 @@ +"""Utility functions for conversions between different representations.""" + +from typing import Optional + +import torch + + +def skew_symmetric(v: torch.Tensor) -> torch.Tensor: + """Create a skew-symmetric matrix from a (batched) vector of size (..., 3). + + Args: + (torch.Tensor): Vector of size (..., 3). + + Returns: + (torch.Tensor): Skew-symmetric matrix of size (..., 3, 3). + """ + z = torch.zeros_like(v[..., 0]) + return torch.stack( + [ + z, + -v[..., 2], + v[..., 1], + v[..., 2], + z, + -v[..., 0], + -v[..., 1], + v[..., 0], + z, + ], + dim=-1, + ).reshape(v.shape[:-1] + (3, 3)) + + +def rad2rotmat( + roll: torch.Tensor, pitch: torch.Tensor, yaw: Optional[torch.Tensor] = None +) -> torch.Tensor: + """Convert (batched) roll, pitch, yaw angles (in radians) to rotation matrix. + + Args: + roll (torch.Tensor): Roll angle in radians. + pitch (torch.Tensor): Pitch angle in radians. + yaw (torch.Tensor, optional): Yaw angle in radians. Defaults to None. + + Returns: + torch.Tensor: Rotation matrix of shape (..., 3, 3). + """ + if yaw is None: + yaw = roll.new_zeros(roll.shape) + + Rx = pitch.new_zeros(pitch.shape + (3, 3)) + Rx[..., 0, 0] = 1 + Rx[..., 1, 1] = torch.cos(pitch) + Rx[..., 1, 2] = torch.sin(pitch) + Rx[..., 2, 1] = -torch.sin(pitch) + Rx[..., 2, 2] = torch.cos(pitch) + + Ry = yaw.new_zeros(yaw.shape + (3, 3)) + Ry[..., 0, 0] = torch.cos(yaw) + Ry[..., 0, 2] = -torch.sin(yaw) + Ry[..., 1, 1] = 1 + Ry[..., 2, 0] = torch.sin(yaw) + Ry[..., 2, 2] = torch.cos(yaw) + + Rz = roll.new_zeros(roll.shape + (3, 3)) + Rz[..., 0, 0] = torch.cos(roll) + Rz[..., 0, 1] = torch.sin(roll) + Rz[..., 1, 0] = -torch.sin(roll) + Rz[..., 1, 1] = torch.cos(roll) + Rz[..., 2, 2] = 1 + + return Rz @ Rx @ Ry + + +def fov2focal(fov: torch.Tensor, size: torch.Tensor) -> torch.Tensor: + """Compute focal length from (vertical/horizontal) field of view. + + Args: + fov (torch.Tensor): Field of view in radians. + size (torch.Tensor): Image height / width in pixels. + + Returns: + torch.Tensor: Focal length in pixels. + """ + return size / 2 / torch.tan(fov / 2) + + +def focal2fov(focal: torch.Tensor, size: torch.Tensor) -> torch.Tensor: + """Compute (vertical/horizontal) field of view from focal length. + + Args: + focal (torch.Tensor): Focal length in pixels. + size (torch.Tensor): Image height / width in pixels. + + Returns: + torch.Tensor: Field of view in radians. + """ + return 2 * torch.arctan(size / (2 * focal)) + + +def pitch2rho(pitch: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the distance from principal point to the horizon. + + Args: + pitch (torch.Tensor): Pitch angle in radians. + f (torch.Tensor): Focal length in pixels. + h (torch.Tensor): Image height in pixels. + + Returns: + torch.Tensor: Relative distance to the horizon. + """ + return torch.tan(pitch) * f / h + + +def rho2pitch(rho: torch.Tensor, f: torch.Tensor, h: torch.Tensor) -> torch.Tensor: + """Compute the pitch angle from the distance to the horizon. + + Args: + rho (torch.Tensor): Relative distance to the horizon. + f (torch.Tensor): Focal length in pixels. + h (torch.Tensor): Image height in pixels. + + Returns: + torch.Tensor: Pitch angle in radians. + """ + return torch.atan(rho * h / f) + + +def rad2deg(rad: torch.Tensor) -> torch.Tensor: + """Convert radians to degrees. + + Args: + rad (torch.Tensor): Angle in radians. + + Returns: + torch.Tensor: Angle in degrees. + """ + return rad / torch.pi * 180 + + +def deg2rad(deg: torch.Tensor) -> torch.Tensor: + """Convert degrees to radians. + + Args: + deg (torch.Tensor): Angle in degrees. + + Returns: + torch.Tensor: Angle in radians. + """ + return deg / 180 * torch.pi diff --git a/siclib/utils/experiments.py b/siclib/utils/experiments.py new file mode 100644 index 0000000000000000000000000000000000000000..c35a99953d09710a12015173cf70b53d43591ed5 --- /dev/null +++ b/siclib/utils/experiments.py @@ -0,0 +1,135 @@ +""" +A set of utilities to manage and load checkpoints of training experiments. + +Author: Paul-Edouard Sarlin (skydes) +""" + +import logging +import os +import re +import shutil +from pathlib import Path + +import torch +from omegaconf import OmegaConf + +from siclib.models import get_model +from siclib.settings import TRAINING_PATH + +logger = logging.getLogger(__name__) + +# flake8: noqa +# mypy: ignore-errors + + +def list_checkpoints(dir_): + """List all valid checkpoints in a given directory.""" + checkpoints = [] + for p in dir_.glob("checkpoint_*.tar"): + numbers = re.findall(r"(\d+)", p.name) + assert len(numbers) <= 2 + if len(numbers) == 0: + continue + if len(numbers) == 1: + checkpoints.append((int(numbers[0]), p)) + else: + checkpoints.append((int(numbers[1]), p)) + return checkpoints + + +def get_last_checkpoint(exper, allow_interrupted=True): + """Get the last saved checkpoint for a given experiment name.""" + ckpts = list_checkpoints(Path(TRAINING_PATH, exper)) + if not allow_interrupted: + ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name] + assert len(ckpts) > 0 + return sorted(ckpts)[-1][1] + + +def get_best_checkpoint(exper): + """Get the checkpoint with the best loss, for a given experiment name.""" + return Path(TRAINING_PATH, exper, "checkpoint_best.tar") + + +def delete_old_checkpoints(dir_, num_keep): + """Delete all but the num_keep last saved checkpoints.""" + ckpts = list_checkpoints(dir_) + ckpts = sorted(ckpts)[::-1] + kept = 0 + for ckpt in ckpts: + if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep: + logger.info(f"Deleting checkpoint {ckpt[1].name}") + ckpt[1].unlink() + else: + kept += 1 + + +def load_experiment(exper, conf=None, get_last=False, ckpt=None): + """Load and return the model of a given experiment.""" + if conf is None: + conf = {} + + exper = Path(exper) + if exper.suffix != ".tar": + ckpt = get_last_checkpoint(exper) if get_last else get_best_checkpoint(exper) + else: + ckpt = exper + logger.info(f"Loading checkpoint {ckpt.name}") + ckpt = torch.load(str(ckpt), map_location="cpu") + + loaded_conf = OmegaConf.create(ckpt["conf"]) + OmegaConf.set_struct(loaded_conf, False) + conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf)) + model = get_model(conf.name)(conf).eval() + + state_dict = ckpt["model"] + + dict_params = set(state_dict.keys()) + model_params = set(map(lambda n: n[0], model.named_parameters())) + diff = model_params - dict_params + if len(diff) > 0: + subs = os.path.commonprefix(list(diff)).rstrip(".") + logger.warning(f"Missing {len(diff)} parameters in {subs}: {diff}") + model.load_state_dict(state_dict, strict=False) + return model + + +def save_experiment( + model, + optimizer, + lr_scheduler, + conf, + losses, + results, + best_eval, + epoch, + iter_i, + output_dir, + stop=False, + distributed=False, + cp_name=None, +): + """Save the current model to a checkpoint + and return the best result so far.""" + state = (model.module if distributed else model).state_dict() + checkpoint = { + "model": state, + "optimizer": optimizer.state_dict(), + "lr_scheduler": lr_scheduler.state_dict(), + "conf": OmegaConf.to_container(conf, resolve=True), + "epoch": epoch, + "losses": losses, + "eval": results, + } + if cp_name is None: + cp_name = f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar" + logger.info(f"Saving checkpoint {cp_name}") + cp_path = str(output_dir / cp_name) + torch.save(checkpoint, cp_path) + + if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval: + best_eval = results[conf.train.best_key] + logger.info(f"New best val: {conf.train.best_key}={best_eval}") + shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar")) + delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints) + return best_eval diff --git a/siclib/utils/export_predictions.py b/siclib/utils/export_predictions.py new file mode 100644 index 0000000000000000000000000000000000000000..0c66369cdd57f278eb3167f32eefd6e53cd19189 --- /dev/null +++ b/siclib/utils/export_predictions.py @@ -0,0 +1,84 @@ +""" +Export the predictions of a model for a given dataloader (e.g. ImageFolder). +Use a standalone script with `python3 -m geocalib.scipts.export_predictions dir` +or call from another script. +""" + +import logging +from pathlib import Path + +import h5py +import numpy as np +import torch +from tqdm import tqdm + +from siclib.utils.tensor import batch_to_device +from siclib.utils.tools import get_device + +# flake8: noqa +# mypy: ignore-errors + +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def export_predictions( + loader, + model, + output_file, + as_half=False, + keys="*", + callback_fn=None, + optional_keys=None, + verbose=True, +): # sourcery skip: low-code-quality + if optional_keys is None: + optional_keys = [] + + assert keys == "*" or isinstance(keys, (tuple, list)) + Path(output_file).parent.mkdir(exist_ok=True, parents=True) + hfile = h5py.File(str(output_file), "w") + device = get_device() + model = model.to(device).eval() + + if not verbose: + logger.info(f"Exporting predictions to {output_file}") + + for data_ in tqdm(loader, desc="Exporting", total=len(loader), ncols=80, disable=not verbose): + data = batch_to_device(data_, device, non_blocking=True) + pred = model(data) + if callback_fn is not None: + pred = {**callback_fn(pred, data), **pred} + if keys != "*": + if len(set(keys) - set(pred.keys())) > 0: + raise ValueError(f"Missing key {set(keys) - set(pred.keys())}") + pred = {k: v for k, v in pred.items() if k in keys + optional_keys} + + # assert len(pred) > 0, "No predictions found" + + for idx in range(len(data["name"])): + pred_ = {k: v[idx].cpu().numpy() for k, v in pred.items()} + + if as_half: + for k in pred_: + dt = pred_[k].dtype + if (dt == np.float32) and (dt != np.float16): + pred_[k] = pred_[k].astype(np.float16) + try: + name = data["name"][idx] + try: + grp = hfile.create_group(name) + except ValueError as e: + raise ValueError(f"Group already exists {name}") from e + + # grp = hfile.create_group(name) + for k, v in pred_.items(): + grp.create_dataset(k, data=v) + except RuntimeError: + print(f"Failed to export {name}") + continue + + del pred + + hfile.close() + return output_file diff --git a/siclib/utils/image.py b/siclib/utils/image.py new file mode 100644 index 0000000000000000000000000000000000000000..c13d7218b0cf28f3a16b9f43c3634c52beedfa5c --- /dev/null +++ b/siclib/utils/image.py @@ -0,0 +1,167 @@ +"""Image preprocessing utilities.""" + +import collections.abc as collections +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import kornia +import numpy as np +import torch +import torchvision +from omegaconf import OmegaConf + +from siclib.utils.tensor import fit_features_to_multiple + +# mypy: ignore-errors + + +class ImagePreprocessor: + """Preprocess images for calibration.""" + + default_conf = { + "resize": 320, # target edge length, None for no resizing + "edge_divisible_by": None, + "side": "short", + "interpolation": "bilinear", + "align_corners": None, + "antialias": True, + "square_crop": False, + "add_padding_mask": False, + "resize_backend": "kornia", # torchvision, kornia + } + + def __init__(self, conf) -> None: + """Initialize the image preprocessor.""" + super().__init__() + default_conf = OmegaConf.create(self.default_conf) + OmegaConf.set_struct(default_conf, True) + self.conf = OmegaConf.merge(default_conf, conf) + + def __call__(self, img: torch.Tensor, interpolation: Optional[str] = None) -> dict: + """Resize and preprocess an image, return image and resize scale.""" + h, w = img.shape[-2:] + size = h, w + + if self.conf.square_crop: + min_size = min(h, w) + offset = (h - min_size) // 2, (w - min_size) // 2 + img = img[:, offset[0] : offset[0] + min_size, offset[1] : offset[1] + min_size] + size = img.shape[-2:] + + if self.conf.resize is not None: + if interpolation is None: + interpolation = self.conf.interpolation + size = self.get_new_image_size(h, w) + img = self.resize(img, size, interpolation) + + scale = torch.Tensor([img.shape[-1] / w, img.shape[-2] / h]).to(img) + T = np.diag([scale[0].cpu(), scale[1].cpu(), 1]) + + data = { + "scales": scale, + "image_size": np.array(size[::-1]), + "transform": T, + "original_image_size": np.array([w, h]), + } + + if self.conf.edge_divisible_by is not None: + # crop to make the edge divisible by a number + w_, h_ = img.shape[-1], img.shape[-2] + img, _ = fit_features_to_multiple(img, self.conf.edge_divisible_by, crop=True) + crop_pad = torch.Tensor([img.shape[-1] - w_, img.shape[-2] - h_]).to(img) + data["crop_pad"] = crop_pad + data["image_size"] = np.array([img.shape[-1], img.shape[-2]]) + + data["image"] = img + return data + + def resize(self, img: torch.Tensor, size: Tuple[int, int], interpolation: str) -> torch.Tensor: + """Resize an image using the specified backend.""" + if self.conf.resize_backend == "kornia": + return kornia.geometry.transform.resize( + img, + size, + side=self.conf.side, + antialias=self.conf.antialias, + align_corners=self.conf.align_corners, + interpolation=interpolation, + ) + elif self.conf.resize_backend == "torchvision": + return torchvision.transforms.Resize(size, antialias=self.conf.antialias)(img) + else: + raise ValueError(f"{self.conf.resize_backend} not implemented.") + + def load_image(self, image_path: Path) -> dict: + """Load an image from a path and preprocess it.""" + return self(load_image(image_path)) + + def get_new_image_size(self, h: int, w: int) -> Tuple[int, int]: + """Get the new image size after resizing.""" + side = self.conf.side + if isinstance(self.conf.resize, collections.Iterable): + assert len(self.conf.resize) == 2 + return tuple(self.conf.resize) + side_size = self.conf.resize + aspect_ratio = w / h + if side not in ("short", "long", "vert", "horz"): + raise ValueError( + f"side can be one of 'short', 'long', 'vert', and 'horz'. Got '{side}'" + ) + return ( + (side_size, int(side_size * aspect_ratio)) + if side == "vert" or (side != "horz" and (side == "short") ^ (aspect_ratio < 1.0)) + else (int(side_size / aspect_ratio), side_size) + ) + + +def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor: + """Normalize the image tensor and reorder the dimensions.""" + if image.ndim == 3: + image = image.transpose((2, 0, 1)) # HxWxC to CxHxW + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f"Not an image: {image.shape}") + return torch.tensor(image / 255.0, dtype=torch.float) + + +def torch_image_to_numpy(image: torch.Tensor) -> np.ndarray: + """Normalize and reorder the dimensions of an image tensor.""" + if image.ndim == 3: + image = image.permute((1, 2, 0)) # CxHxW to HxWxC + elif image.ndim == 2: + image = image[None] # add channel axis + else: + raise ValueError(f"Not an image: {image.shape}") + return (image.cpu().detach().numpy() * 255).astype(np.uint8) + + +def read_image(path: Path, grayscale: bool = False) -> np.ndarray: + """Read an image from path as RGB or grayscale.""" + if not Path(path).exists(): + raise FileNotFoundError(f"No image at path {path}.") + mode = cv2.IMREAD_GRAYSCALE if grayscale else cv2.IMREAD_COLOR + image = cv2.imread(str(path), mode) + if image is None: + raise IOError(f"Could not read image at {path}.") + if not grayscale: + image = image[..., ::-1] + return image + + +def write_image(img: torch.Tensor, path: Path): + """Write an image tensor to a file.""" + img = torch_image_to_numpy(img) if isinstance(img, torch.Tensor) else img + cv2.imwrite(str(path), img[..., ::-1]) + + +def load_image(path: Path, grayscale: bool = False, return_tensor: bool = True) -> torch.Tensor: + """Load an image from a path and return as a tensor.""" + image = read_image(path, grayscale=grayscale) + if return_tensor: + return numpy_image_to_torch(image) + + assert image.ndim in [2, 3], f"Not an image: {image.shape}" + image = image[None] if image.ndim == 2 else image + return torch.tensor(image.copy(), dtype=torch.uint8) diff --git a/siclib/utils/stdout_capturing.py b/siclib/utils/stdout_capturing.py new file mode 100644 index 0000000000000000000000000000000000000000..19d6701acfeca592a7eecf36edcbc6f614eb6ad0 --- /dev/null +++ b/siclib/utils/stdout_capturing.py @@ -0,0 +1,132 @@ +""" +Based on sacred/stdout_capturing.py in project Sacred +https://github.com/IDSIA/sacred + +Author: Paul-Edouard Sarlin (skydes) +""" + +from __future__ import division, print_function, unicode_literals + +import os +import subprocess +import sys +from contextlib import contextmanager +from threading import Timer + +# flake8: noqa +# mypy: ignore-errors + + +def apply_backspaces_and_linefeeds(text): + """ + Interpret backspaces and linefeeds in text like a terminal would. + Interpret text like a terminal by removing backspace and linefeed + characters and applying them line by line. + If final line ends with a carriage it keeps it to be concatenable with next + output chunk. + """ + orig_lines = text.split("\n") + orig_lines_len = len(orig_lines) + new_lines = [] + for orig_line_idx, orig_line in enumerate(orig_lines): + chars, cursor = [], 0 + orig_line_len = len(orig_line) + for orig_char_idx, orig_char in enumerate(orig_line): + if orig_char == "\r" and ( + orig_char_idx != orig_line_len - 1 or orig_line_idx != orig_lines_len - 1 + ): + cursor = 0 + elif orig_char == "\b": + cursor = max(0, cursor - 1) + else: + if orig_char == "\r": + cursor = len(chars) + if cursor == len(chars): + chars.append(orig_char) + else: + chars[cursor] = orig_char + cursor += 1 + new_lines.append("".join(chars)) + return "\n".join(new_lines) + + +def flush(): + """Try to flush all stdio buffers, both from python and from C.""" + try: + sys.stdout.flush() + sys.stderr.flush() + except (AttributeError, ValueError, IOError): + pass # unsupported + + +# Duplicate stdout and stderr to a file. Inspired by: +# http://eli.thegreenplace.net/2015/redirecting-all-kinds-of-stdout-in-python/ +# http://stackoverflow.com/a/651718/1388435 +# http://stackoverflow.com/a/22434262/1388435 +@contextmanager +def capture_outputs(filename): + """Duplicate stdout and stderr to a file on the file descriptor level.""" + with open(str(filename), "a+") as target: + original_stdout_fd = 1 + original_stderr_fd = 2 + target_fd = target.fileno() + + # Save a copy of the original stdout and stderr file descriptors + saved_stdout_fd = os.dup(original_stdout_fd) + saved_stderr_fd = os.dup(original_stderr_fd) + + tee_stdout = subprocess.Popen( + ["tee", "-a", "-i", "/dev/stderr"], + start_new_session=True, + stdin=subprocess.PIPE, + stderr=target_fd, + stdout=1, + ) + tee_stderr = subprocess.Popen( + ["tee", "-a", "-i", "/dev/stderr"], + start_new_session=True, + stdin=subprocess.PIPE, + stderr=target_fd, + stdout=2, + ) + + flush() + os.dup2(tee_stdout.stdin.fileno(), original_stdout_fd) + os.dup2(tee_stderr.stdin.fileno(), original_stderr_fd) + + try: + yield + finally: + flush() + + # then redirect stdout back to the saved fd + tee_stdout.stdin.close() + tee_stderr.stdin.close() + + # restore original fds + os.dup2(saved_stdout_fd, original_stdout_fd) + os.dup2(saved_stderr_fd, original_stderr_fd) + + # wait for completion of the tee processes with timeout + # implemented using a timer because timeout support is py3 only + def kill_tees(): + tee_stdout.kill() + tee_stderr.kill() + + tee_timer = Timer(1, kill_tees) + try: + tee_timer.start() + tee_stdout.wait() + tee_stderr.wait() + finally: + tee_timer.cancel() + + os.close(saved_stdout_fd) + os.close(saved_stderr_fd) + + # Cleanup log file + with open(str(filename), "r") as target: + text = target.read() + text = apply_backspaces_and_linefeeds(text) + with open(str(filename), "w") as target: + target.write(text) diff --git a/siclib/utils/summary_writer.py b/siclib/utils/summary_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..2542a4df5d27e777f99de6e1d6721b5a990104bc --- /dev/null +++ b/siclib/utils/summary_writer.py @@ -0,0 +1,116 @@ +"""This module implements the writer class for logging to tensorboard or wandb.""" + +import logging +import os +from typing import Any, Dict, Optional + +from omegaconf import DictConfig +from torch import nn +from torch.utils.tensorboard import SummaryWriter as TFSummaryWriter + +from siclib import __module_name__ + +logger = logging.getLogger(__name__) + +try: + import wandb +except ImportError: + logger.debug("Could not import wandb.") + wandb = None + +# mypy: ignore-errors + + +def dot_conf(conf: DictConfig) -> Dict[str, Any]: + """Recursively convert a DictConfig to a flat dict with keys joined by dots.""" + d = {} + for k, v in conf.items(): + if isinstance(v, DictConfig): + d |= {f"{k}.{k2}": v2 for k2, v2 in dot_conf(v).items()} + else: + d[k] = v + return d + + +class SummaryWriter: + """Writer class for logging to tensorboard or wandb.""" + + def __init__(self, conf: DictConfig, args: DictConfig, log_dir: str): + """Initialize the writer.""" + self.conf = conf + + if not conf.train.writer: + self.use_wandb = False + self.use_tensorboard = False + return + + self.use_wandb = "wandb" in conf.train.writer + self.use_tensorboard = "tensorboard" in conf.train.writer + + if self.use_wandb and not wandb: + raise ImportError("wandb not installed.") + + if self.use_tensorboard: + self.writer = TFSummaryWriter(log_dir=log_dir) + + if self.use_wandb: + os.environ["WANDB__SERVICE_WAIT"] = "300" + wandb.init(project=__module_name__, name=args.experiment, config=dot_conf(conf)) + + if conf.train.writer and not self.use_wandb and not self.use_tensorboard: + raise NotImplementedError(f"Writer {conf.train.writer} not implemented") + + def add_scalar(self, tag: str, value: float, step: Optional[int] = None): + """Log a scalar value to tensorboard or wandb.""" + if self.use_wandb: + step = 1 if step == 0 else step + wandb.log({tag: value}, step=step) + + if self.use_tensorboard: + self.writer.add_scalar(tag, value, step) + + def add_figure(self, tag: str, figure, step: Optional[int] = None): + """Log a figure to tensorboard or wandb.""" + if self.use_wandb: + step = 1 if step == 0 else step + wandb.log({tag: figure}, step=step) + if self.use_tensorboard: + self.writer.add_figure(tag, figure, step) + + def add_histogram(self, tag: str, values, step: Optional[int] = None): + """Log a histogram to tensorboard or wandb.""" + if self.use_tensorboard: + self.writer.add_histogram(tag, values, step) + + def add_text(self, tag: str, text: str, step: Optional[int] = None): + """Log text to tensorboard or wandb.""" + if self.use_tensorboard: + self.writer.add_text(tag, text, step) + + def add_pr_curve(self, tag: str, values, step: Optional[int] = None): + """Log a precision-recall curve to tensorboard or wandb.""" + if self.use_wandb: + step = 1 if step == 0 else step + # @TODO: check if this works + # wandb.log({"pr": wandb.plots.precision_recall(y_test, y_probas, labels)}) + wandb.log({tag: wandb.plots.precision_recall(values)}, step=step) + + if self.use_tensorboard: + self.writer.add_pr_curve(tag, values, step) + + def watch(self, model: nn.Module, log_freq: int = 1000): + """Watch a model for gradient updates.""" + if self.use_wandb: + wandb.watch( + model, + log="gradients", + log_freq=log_freq, + ) + + def close(self): + """Close the writer.""" + if self.use_wandb: + wandb.finish() + + if self.use_tensorboard: + self.writer.close() diff --git a/siclib/utils/tensor.py b/siclib/utils/tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..f8dd14e9cdd746225523f01e997dd57aed3673e2 --- /dev/null +++ b/siclib/utils/tensor.py @@ -0,0 +1,251 @@ +""" +Author: Paul-Edouard Sarlin (skydes) +""" + +import collections.abc as collections +import functools +import inspect +from typing import Callable, List, Tuple + +import numpy as np +import torch + +# flake8: noqa +# mypy: ignore-errors + + +string_classes = (str, bytes) + + +def autocast(func: Callable) -> Callable: + """Cast the inputs of a TensorWrapper method to PyTorch tensors if they are numpy arrays. + + Use the device and dtype of the wrapper. + + Args: + func (Callable): Method of a TensorWrapper class. + + Returns: + Callable: Wrapped method. + """ + + @functools.wraps(func) + def wrap(self, *args): + device = torch.device("cpu") + dtype = None + if isinstance(self, TensorWrapper): + if self._data is not None: + device = self.device + dtype = self.dtype + elif not inspect.isclass(self) or not issubclass(self, TensorWrapper): + raise ValueError(self) + + cast_args = [] + for arg in args: + if isinstance(arg, np.ndarray): + arg = torch.from_numpy(arg) + arg = arg.to(device=device, dtype=dtype) + cast_args.append(arg) + return func(self, *cast_args) + + return wrap + + +class TensorWrapper: + """Wrapper for PyTorch tensors.""" + + _data = None + + @autocast + def __init__(self, data: torch.Tensor): + """Wrapper for PyTorch tensors.""" + self._data = data + + @property + def shape(self) -> torch.Size: + """Shape of the underlying tensor.""" + return self._data.shape[:-1] + + @property + def device(self) -> torch.device: + """Get the device of the underlying tensor.""" + return self._data.device + + @property + def dtype(self) -> torch.dtype: + """Get the dtype of the underlying tensor.""" + return self._data.dtype + + def __getitem__(self, index) -> torch.Tensor: + """Get the underlying tensor.""" + return self.__class__(self._data[index]) + + def __setitem__(self, index, item): + """Set the underlying tensor.""" + self._data[index] = item.data + + def to(self, *args, **kwargs): + """Move the underlying tensor to a new device.""" + return self.__class__(self._data.to(*args, **kwargs)) + + def cpu(self): + """Move the underlying tensor to the CPU.""" + return self.__class__(self._data.cpu()) + + def cuda(self): + """Move the underlying tensor to the GPU.""" + return self.__class__(self._data.cuda()) + + def pin_memory(self): + """Pin the underlying tensor to memory.""" + return self.__class__(self._data.pin_memory()) + + def float(self): + """Cast the underlying tensor to float.""" + return self.__class__(self._data.float()) + + def double(self): + """Cast the underlying tensor to double.""" + return self.__class__(self._data.double()) + + def detach(self): + """Detach the underlying tensor.""" + return self.__class__(self._data.detach()) + + def numpy(self): + """Convert the underlying tensor to a numpy array.""" + return self._data.detach().cpu().numpy() + + def new_tensor(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_tensor(*args, **kwargs) + + def new_zeros(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_zeros(*args, **kwargs) + + def new_ones(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_ones(*args, **kwargs) + + def new_full(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_full(*args, **kwargs) + + def new_empty(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self._data.new_empty(*args, **kwargs) + + def unsqueeze(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self.__class__(self._data.unsqueeze(*args, **kwargs)) + + def squeeze(self, *args, **kwargs): + """Create a new tensor of the same type and device.""" + return self.__class__(self._data.squeeze(*args, **kwargs)) + + @classmethod + def stack(cls, objects: List, dim=0, *, out=None): + """Stack a list of objects with the same type and shape.""" + data = torch.stack([obj._data for obj in objects], dim=dim, out=out) + return cls(data) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + """Support torch functions.""" + if kwargs is None: + kwargs = {} + return cls.stack(*args, **kwargs) if func is torch.stack else NotImplemented + + +def map_tensor(input_, func): + if isinstance(input_, string_classes): + return input_ + elif isinstance(input_, collections.Mapping): + return {k: map_tensor(sample, func) for k, sample in input_.items()} + elif isinstance(input_, collections.Sequence): + return [map_tensor(sample, func) for sample in input_] + elif input_ is None: + return None + else: + return func(input_) + + +def batch_to_numpy(batch): + return map_tensor(batch, lambda tensor: tensor.cpu().numpy()) + + +def batch_to_device(batch, device, non_blocking=True, detach=False): + def _func(tensor): + t = tensor.to(device=device, non_blocking=non_blocking, dtype=torch.float32) + return t.detach() if detach else t + + return map_tensor(batch, _func) + + +def remove_batch_dim(data: dict) -> dict: + """Remove batch dimension from elements in data""" + return { + k: v[0] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v for k, v in data.items() + } + + +def add_batch_dim(data: dict) -> dict: + """Add batch dimension to elements in data""" + return { + k: v[None] if isinstance(v, (torch.Tensor, np.ndarray, list)) else v + for k, v in data.items() + } + + +def fit_to_multiple(x: torch.Tensor, multiple: int, mode: str = "center", crop: bool = False): + """Get padding to make the image size a multiple of the given number. + + Args: + x (torch.Tensor): Input tensor. + multiple (int, optional): Multiple. + crop (bool, optional): Whether to crop or pad. Defaults to False. + + Returns: + torch.Tensor: Padding. + """ + h, w = x.shape[-2:] + + if crop: + pad_w = (w // multiple) * multiple - w + pad_h = (h // multiple) * multiple - h + else: + pad_w = (multiple - w % multiple) % multiple + pad_h = (multiple - h % multiple) % multiple + + if mode == "center": + pad_l = pad_w // 2 + pad_r = pad_w - pad_l + pad_t = pad_h // 2 + pad_b = pad_h - pad_t + elif mode == "left": + pad_l = 0 + pad_r = pad_w + pad_t = 0 + pad_b = pad_h + else: + raise ValueError(f"Unknown mode {mode}") + + return (pad_l, pad_r, pad_t, pad_b) + + +def fit_features_to_multiple( + features: torch.Tensor, multiple: int = 32, crop: bool = False +) -> Tuple[torch.Tensor, Tuple[int, int]]: + """Pad image to a multiple of the given number. + + Args: + features (torch.Tensor): Input features. + multiple (int, optional): Multiple. Defaults to 32. + crop (bool, optional): Whether to crop or pad. Defaults to False. + + Returns: + Tuple[torch.Tensor, Tuple[int, int]]: Padded features and padding. + """ + pad = fit_to_multiple(features, multiple, crop=crop) + return torch.nn.functional.pad(features, pad, mode="reflect"), pad diff --git a/siclib/utils/tools.py b/siclib/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..b3869ecba347264be2112af1d1bc5f80359c716b --- /dev/null +++ b/siclib/utils/tools.py @@ -0,0 +1,309 @@ +""" +Various handy Python and PyTorch utils. + +Author: Paul-Edouard Sarlin (skydes) +""" + +import os +import random +import time +from collections.abc import Iterable +from contextlib import contextmanager +from typing import Optional + +import numpy as np +import torch + +# flake8: noqa +# mypy: ignore-errors + + +class AverageMetric: + def __init__(self, elements=None): + if elements is None: + elements = [] + self._sum = 0 + self._num_examples = 0 + else: + mask = ~np.isnan(elements) + self._sum = sum(elements[mask]) + self._num_examples = len(elements[mask]) + + def update(self, tensor): + assert tensor.dim() == 1, tensor.shape + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + return np.nan if self._num_examples == 0 else self._sum / self._num_examples + + +# same as AverageMetric, but tracks all elements +class FAverageMetric: + def __init__(self): + self._sum = 0 + self._num_examples = 0 + self._elements = [] + + def update(self, tensor): + self._elements += tensor.cpu().numpy().tolist() + assert tensor.dim() == 1, tensor.shape + tensor = tensor[~torch.isnan(tensor)] + self._sum += tensor.sum().item() + self._num_examples += len(tensor) + + def compute(self): + return np.nan if self._num_examples == 0 else self._sum / self._num_examples + + +class MedianMetric: + def __init__(self, elements=None): + if elements is None: + elements = [] + + self._elements = elements + + def update(self, tensor): + assert tensor.dim() == 1, tensor.shape + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + + # set nan to inf to avoid error + self._elements = np.array(self._elements) + self._elements[np.isnan(self._elements)] = np.inf + return np.nanmedian(self._elements) + + +class PRMetric: + def __init__(self): + self.labels = [] + self.predictions = [] + + @torch.no_grad() + def update(self, labels, predictions, mask=None): + assert labels.shape == predictions.shape + self.labels += (labels[mask] if mask is not None else labels).cpu().numpy().tolist() + self.predictions += ( + (predictions[mask] if mask is not None else predictions).cpu().numpy().tolist() + ) + + @torch.no_grad() + def compute(self): + return np.array(self.labels), np.array(self.predictions) + + def reset(self): + self.labels = [] + self.predictions = [] + + +class QuantileMetric: + def __init__(self, q=0.05): + self._elements = [] + self.q = q + + def update(self, tensor): + assert tensor.dim() == 1 + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + else: + return np.nanquantile(self._elements, self.q) + + +class RecallMetric: + def __init__(self, ths, elements=None): + if elements is None: + elements = [] + + self._elements = elements + self.ths = ths + + def update(self, tensor): + assert tensor.dim() == 1, tensor.shape + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + # set nan to inf to avoid error + self._elements = np.array(self._elements) + self._elements[np.isnan(self._elements)] = np.inf + + if isinstance(self.ths, Iterable): + return [self.compute_(th) for th in self.ths] + else: + return self.compute_(self.ths[0]) + + def compute_(self, th): + if len(self._elements) == 0: + return np.nan + + s = (np.array(self._elements) < th).sum() + return s / len(self._elements) + + +def compute_recall(errors): + num_elements = len(errors) + sort_idx = np.argsort(errors) + errors = np.array(errors.copy())[sort_idx] + recall = (np.arange(num_elements) + 1) / num_elements + return errors, recall + + +def compute_auc(errors, thresholds, min_error: Optional[float] = None): + errors, recall = compute_recall(errors) + + if min_error is not None: + min_index = np.searchsorted(errors, min_error, side="right") + min_score = min_index / len(errors) + recall = np.r_[min_score, min_score, recall[min_index:]] + errors = np.r_[0, min_error, errors[min_index:]] + else: + recall = np.r_[0, recall] + errors = np.r_[0, errors] + + aucs = [] + for t in thresholds: + last_index = np.searchsorted(errors, t, side="right") + r = np.r_[recall[:last_index], recall[last_index - 1]] + e = np.r_[errors[:last_index], t] + auc = np.trapz(r, x=e) / t + aucs.append(np.round(auc, 4)) + return aucs + + +class AUCMetric: + def __init__(self, thresholds, elements=None, min_error: Optional[float] = None): + self._elements = elements + self.thresholds = thresholds + self.min_error = min_error + if not isinstance(thresholds, list): + self.thresholds = [thresholds] + + def update(self, tensor): + assert tensor.dim() == 1, tensor.shape + self._elements += tensor.cpu().numpy().tolist() + + def compute(self): + if len(self._elements) == 0: + return np.nan + + # set nan to inf to avoid error + self._elements = np.array(self._elements) + self._elements[np.isnan(self._elements)] = np.inf + return compute_auc(self._elements, self.thresholds, self.min_error) + + +class Timer(object): + """A simpler timer context object. + Usage: + ``` + > with Timer('mytimer'): + > # some computations + [mytimer] Elapsed: X + ``` + """ + + def __init__(self, name=None): + self.name = name + + def __enter__(self): + self.tstart = time.time() + return self + + def __exit__(self, type, value, traceback): + self.duration = time.time() - self.tstart + if self.name is not None: + print(f"[{self.name}] Elapsed: {self.duration}") + + +def get_class(mod_path, BaseClass): + """Get the class object which inherits from BaseClass and is defined in + the module named mod_name, child of base_path. + """ + import inspect + + mod = __import__(mod_path, fromlist=[""]) + classes = inspect.getmembers(mod, inspect.isclass) + # Filter classes defined in the module + classes = [c for c in classes if c[1].__module__ == mod_path] + # Filter classes inherited from BaseModel + classes = [c for c in classes if issubclass(c[1], BaseClass)] + assert len(classes) == 1, classes + return classes[0][1] + + +def set_num_threads(nt): + """Force numpy and other libraries to use a limited number of threads.""" + try: + import mkl # type: ignore + except ImportError: + pass + else: + mkl.set_num_threads(nt) + torch.set_num_threads(1) + os.environ["IPC_ENABLE"] = "1" + for o in [ + "OPENBLAS_NUM_THREADS", + "NUMEXPR_NUM_THREADS", + "OMP_NUM_THREADS", + "MKL_NUM_THREADS", + ]: + os.environ[o] = str(nt) + + +def set_seed(seed): + random.seed(seed) + torch.manual_seed(seed) + np.random.seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_random_state(with_cuda): + pth_state = torch.get_rng_state() + np_state = np.random.get_state() + py_state = random.getstate() + if torch.cuda.is_available() and with_cuda: + cuda_state = torch.cuda.get_rng_state_all() + else: + cuda_state = None + return pth_state, np_state, py_state, cuda_state + + +def set_random_state(state): + pth_state, np_state, py_state, cuda_state = state + torch.set_rng_state(pth_state) + np.random.set_state(np_state) + random.setstate(py_state) + if ( + cuda_state is not None + and torch.cuda.is_available() + and len(cuda_state) == torch.cuda.device_count() + ): + torch.cuda.set_rng_state_all(cuda_state) + + +@contextmanager +def fork_rng(seed=None, with_cuda=True): + state = get_random_state(with_cuda) + if seed is not None: + set_seed(seed) + try: + yield + finally: + set_random_state(state) + + +def get_device() -> str: + device = "cpu" + if torch.cuda.is_available(): + device = "cuda" + elif torch.backends.mps.is_available(): + device = "mps" + return device diff --git a/siclib/visualization/__init__.py b/siclib/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/siclib/visualization/global_frame.py b/siclib/visualization/global_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..f89ce91406d442ef388a573effd4f9f423250c0b --- /dev/null +++ b/siclib/visualization/global_frame.py @@ -0,0 +1,282 @@ +import functools +import traceback +from copy import deepcopy + +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.widgets import Button +from omegaconf import OmegaConf + +from ..datasets.base_dataset import collate +from ..models.cache_loader import CacheLoader +from .tools import RadioHideTool + +# flake8: noqa +# mypy: ignore-errors + + +class GlobalFrame: + default_conf = { + "x": "???", + "y": "???", + "diff": False, + "child": {}, + "remove_outliers": False, + } + + child_frame = None # MatchFrame + + childs = [] + + lines = [] + + scatters = {} + + def __init__(self, conf, results, loader, predictions, title=None, child_frame=None): + self.child_frame = child_frame + if self.child_frame is not None: + # We do NOT merge inside the child frame to keep settings across figs + self.default_conf["child"] = self.child_frame.default_conf + + self.conf = OmegaConf.merge(self.default_conf, conf) + self.results = results + self.loader = loader + self.predictions = predictions + self.metrics = set() + for k, v in results.items(): + self.metrics.update(v.keys()) + self.metrics = sorted(list(self.metrics)) + + self.conf.x = conf["x"] or self.metrics[0] + self.conf.y = conf["y"] or self.metrics[1] + + assert self.conf.x in self.metrics + assert self.conf.y in self.metrics + + self.names = list(results) + self.fig, self.axes = self.init_frame() + if title is not None: + self.fig.canvas.manager.set_window_title(title) + + self.xradios = self.fig.canvas.manager.toolmanager.add_tool( + "x", + RadioHideTool, + options=self.metrics, + callback_fn=self.update_x, + active=self.conf.x, + keymap="x", + ) + + self.yradios = self.fig.canvas.manager.toolmanager.add_tool( + "y", + RadioHideTool, + options=self.metrics, + callback_fn=self.update_y, + active=self.conf.y, + keymap="y", + ) + if self.fig.canvas.manager.toolbar is not None: + self.fig.canvas.manager.toolbar.add_tool("x", "navigation") + self.fig.canvas.manager.toolbar.add_tool("y", "navigation") + + def init_frame(self): + """initialize frame""" + fig, ax = plt.subplots() + ax.set_title("click on points") + diffb_ax = fig.add_axes([0.01, 0.02, 0.12, 0.06]) + self.diffb = Button(diffb_ax, label="diff_only") + self.diffb.on_clicked(self.diff_clicked) + fig.canvas.mpl_connect("pick_event", self.on_scatter_pick) + fig.canvas.mpl_connect("motion_notify_event", self.hover) + return fig, ax + + def draw(self): + """redraw content in frame""" + self.scatters = {} + self.axes.clear() + self.axes.set_xlabel(self.conf.x) + self.axes.set_ylabel(self.conf.y) + + refx = 0.0 + refy = 0.0 + x_cat = isinstance(self.results[self.names[0]][self.conf.x][0], (bytes, str)) + y_cat = isinstance(self.results[self.names[0]][self.conf.y][0], (bytes, str)) + + if self.conf.diff: + if not x_cat: + refx = np.array(self.results[self.names[0]][self.conf.x]) + if not y_cat: + refy = np.array(self.results[self.names[0]][self.conf.y]) + for name in list(self.results.keys()): + x = np.array(self.results[name][self.conf.x]) + y = np.array(self.results[name][self.conf.y]) + + if x_cat and np.char.isdigit(x.astype(str)).all(): + x = x.astype(int) + if y_cat and np.char.isdigit(y.astype(str)).all(): + y = y.astype(int) + + x = x if x_cat else x - refx + y = y if y_cat else y - refy + + (s,) = self.axes.plot(x, y, "o", markersize=3, label=name, picker=True, pickradius=5) + self.scatters[name] = s + + if x_cat and not y_cat: + xunique, ind, xinv, xbin = np.unique( + x, return_inverse=True, return_counts=True, return_index=True + ) + ybin = np.bincount(xinv, weights=y) + sort_ax = np.argsort(ind) + self.axes.step( + xunique[sort_ax], + (ybin / xbin)[sort_ax], + where="mid", + color=s.get_color(), + ) + + if not x_cat: + xavg = np.nan_to_num(x).mean() + self.axes.axvline(xavg, c=s.get_color(), zorder=1, alpha=1.0) + xmed = np.median(x - refx) + self.axes.axvline( + xmed, + c=s.get_color(), + zorder=0, + alpha=0.5, + linestyle="dashed", + visible=False, + ) + + if not y_cat: + yavg = np.nan_to_num(y).mean() + self.axes.axhline(yavg, c=s.get_color(), zorder=1, alpha=0.5) + ymed = np.median(y - refy) + self.axes.axhline( + ymed, + c=s.get_color(), + zorder=0, + alpha=0.5, + linestyle="dashed", + visible=False, + ) + if x_cat and x.dtype == object and xunique.shape[0] > 5: + self.axes.set_xticklabels(xunique[sort_ax], rotation=90) + self.axes.legend() + + def on_scatter_pick(self, handle): + try: + art = handle.artist + try: + event = handle.mouseevent.button.value + except AttributeError: + return + name = art.get_label() + ind = handle.ind[0] + # draw lines + self.spawn_child(name, ind, event=event) + except Exception: + traceback.print_exc() + exit(0) + + def spawn_child(self, model_name, ind, event=None): + [line.remove() for line in self.lines] + self.lines = [] + + x_source = self.scatters[model_name].get_xdata()[ind] + y_source = self.scatters[model_name].get_ydata()[ind] + for oname in self.names: + xn = self.scatters[oname].get_xdata()[ind] + yn = self.scatters[oname].get_ydata()[ind] + + (ln,) = self.axes.plot([x_source, xn], [y_source, yn], "r") + self.lines.append(ln) + + self.fig.canvas.draw_idle() + + if self.child_frame is None: + return + + data = collate([self.loader.dataset[ind]]) + + preds = { + name: CacheLoader({"path": str(pfile), "add_data_path": False})(data) + for name, pfile in self.predictions.items() + } + summaries_i = { + name: {k: v[ind] for k, v in res.items() if k != "names"} + for name, res in self.results.items() + } + frame = self.child_frame( + self.conf.child, + deepcopy(data), + preds, + title=str(data["name"][0]), + event=event, + summaries=summaries_i, + ) + + frame.fig.canvas.mpl_connect( + "key_press_event", + functools.partial(self.on_childframe_key_event, frame=frame, ind=ind, event=event), + ) + self.childs.append(frame) + self.childs[-1].fig.show() + + def hover(self, event): + if event.inaxes != self.axes: + return + + for _, s in self.scatters.items(): + cont, ind = s.contains(event) + if cont: + ind = ind["ind"][0] + xdata, ydata = s.get_data() + [line.remove() for line in self.lines] + self.lines = [] + + for oname in self.names: + xn = self.scatters[oname].get_xdata()[ind] + yn = self.scatters[oname].get_ydata()[ind] + + (ln,) = self.axes.plot( + [xdata[ind], xn], + [ydata[ind], yn], + "black", + zorder=0, + alpha=0.5, + ) + self.lines.append(ln) + self.fig.canvas.draw_idle() + break + + def diff_clicked(self, args): + self.conf.diff = not self.conf.diff + self.draw() + self.fig.canvas.draw_idle() + + def update_x(self, x): + self.conf.x = x + self.draw() + + def update_y(self, y): + self.conf.y = y + self.draw() + + def on_childframe_key_event(self, key_event, frame, ind, event): + if key_event.key == "delete": + plt.close(frame.fig) + self.childs.remove(frame) + elif key_event.key in ["left", "right", "shift+left", "shift+right"]: + key = key_event.key + if key.startswith("shift+"): + key = key.replace("shift+", "") + else: + plt.close(frame.fig) + self.childs.remove(frame) + new_ind = ind + 1 if key_event.key == "right" else ind - 1 + self.spawn_child( + self.names[0], + new_ind % len(self.loader), + event=event, + ) diff --git a/siclib/visualization/tools.py b/siclib/visualization/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3f28b9cb47140fb63a56b10baaec9fc7e02438ef --- /dev/null +++ b/siclib/visualization/tools.py @@ -0,0 +1,472 @@ +import inspect +import sys +import warnings + +import matplotlib.pyplot as plt +import numpy as np +import torch +from matplotlib.backend_tools import ToolToggleBase +from matplotlib.widgets import Button, RadioButtons + +from siclib.geometry.camera import SimpleRadial as Camera +from siclib.geometry.gravity import Gravity +from siclib.geometry.perspective_fields import ( + get_latitude_field, + get_perspective_field, + get_up_field, +) +from siclib.models.utils.metrics import latitude_error, up_error +from siclib.utils.conversions import rad2deg +from siclib.visualization.viz2d import ( + add_text, + plot_confidences, + plot_heatmaps, + plot_horizon_lines, + plot_latitudes, + plot_vector_fields, +) + +# flake8: noqa +# mypy: ignore-errors + +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + plt.rcParams["toolbar"] = "toolmanager" + + +class RadioHideTool(ToolToggleBase): + """Show lines with a given gid.""" + + default_keymap = "R" + description = "Show by gid" + default_toggled = False + radio_group = "default" + + def __init__(self, *args, options=[], active=None, callback_fn=None, keymap="R", **kwargs): + super().__init__(*args, **kwargs) + self.f = 1.0 + self.options = options + self.callback_fn = callback_fn + self.active = self.options.index(active) if active else 0 + self.default_keymap = keymap + + self.enabled = self.default_toggled + + def build_radios(self): + w = 0.2 + self.radios_ax = self.figure.add_axes([1.0 - w, 0.4, w, 0.5], zorder=1) + # self.radios_ax = self.figure.add_axes([0.5-w/2, 1.0-0.2, w, 0.2], zorder=1) + self.radios = RadioButtons(self.radios_ax, self.options, active=self.active) + self.radios.on_clicked(self.on_radio_clicked) + + def enable(self, *args): + size = self.figure.get_size_inches() + size[0] *= self.f + self.build_radios() + self.figure.canvas.draw_idle() + self.enabled = True + + def disable(self, *args): + size = self.figure.get_size_inches() + size[0] /= self.f + self.radios_ax.remove() + self.radios = None + self.figure.canvas.draw_idle() + self.enabled = False + + def on_radio_clicked(self, value): + self.active = self.options.index(value) + enabled = self.enabled + if enabled: + self.disable() + if self.callback_fn is not None: + self.callback_fn(value) + if enabled: + self.enable() + + +class ToggleTool(ToolToggleBase): + """Show lines with a given gid.""" + + default_keymap = "t" + description = "Show by gid" + + def __init__(self, *args, callback_fn=None, keymap="t", **kwargs): + super().__init__(*args, **kwargs) + self.f = 1.0 + self.callback_fn = callback_fn + self.default_keymap = keymap + self.enabled = self.default_toggled + + def enable(self, *args): + self.callback_fn(True) + + def disable(self, *args): + self.callback_fn(False) + + +def add_whitespace_left(fig, factor): + w, h = fig.get_size_inches() + left = fig.subplotpars.left + fig.set_size_inches([w * (1 + factor), h]) + fig.subplots_adjust(left=(factor + left) / (1 + factor)) + + +def add_whitespace_bottom(fig, factor): + w, h = fig.get_size_inches() + b = fig.subplotpars.bottom + fig.set_size_inches([w, h * (1 + factor)]) + fig.subplots_adjust(bottom=(factor + b) / (1 + factor)) + fig.canvas.draw_idle() + + +class ImagePlot: + plot_name = "image" + required_keys = ["image"] + + def __init__(self, fig, axes, data, preds): + pass + + +class HorizonLinePlot: + plot_name = "horizon_line" + required_keys = ["camera", "gravity"] + + def __init__(self, fig, axes, data, preds): + for idx, name in enumerate(preds): + pred = preds[name] + gt_cam = data["camera"][0].detach().cpu() + gt_gravity = data["gravity"][0].detach().cpu() + plot_horizon_lines([gt_cam], [gt_gravity], line_colors="r", ax=[axes[0][idx]]) + + if "camera" in pred and "gravity" in pred: + pred_cam = Camera(pred["camera"][0].detach().cpu()) + gravity = Gravity(pred["gravity"][0].detach().cpu()) + plot_horizon_lines([pred_cam], [gravity], line_colors="yellow", ax=[axes[0][idx]]) + + +class LatitudePlot: + plot_name = "latitude" + required_keys = ["latitude_field"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + self.gt_mode = False # Flag to track whether to display ground truth or predicted + self.text_objects = [] # To store text objects + + self.fig = fig + self.axes = axes + self.data = data + self.preds = preds + + # Create a toggle button on the lower left corner of the first axis + self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) + self.button = Button(self.ax_button, "Toggle GT") + self.button.on_clicked(self.toggle_display) + + self.update_plot() + + def toggle_display(self, event): + # Toggle between ground truth and predicted latitudes + self.gt_mode = not self.gt_mode + self.update_plot() + + def update_plot(self): + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + for idx, name in enumerate(self.preds): + pred = self.preds[name] + + if self.gt_mode: + latitude = self.data["latitude_field"][0][0] + text = "\nGT" + else: + if "latitude_field" not in pred: + continue + latitude = pred["latitude_field"][0][0] + text = "\nPrediction" + + self.artists += plot_latitudes([latitude], axes=[self.axes[0][idx]]) + + self.text_objects.append(add_text(idx, text)) + + # Update the plot + self.fig.canvas.draw() + + def clear(self): + # Remove the button + self.button.disconnect_events() + self.ax_button.remove() + + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + +class LatitudeErrorPlot: + plot_name = "latitude_error" + required_keys = ["latitude_field"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + for idx, name in enumerate(preds): + pred = preds[name] + gt = data["latitude_field"].detach().cpu() + + if "latitude_field" in pred: + lat = pred["latitude_field"].detach().cpu() + error = latitude_error(lat, gt)[0].numpy() + + if "latitude_confidence" in pred: + confidence = pred["latitude_confidence"].detach().cpu().numpy() + confidence = np.log10(confidence).clip(-5) + confidence = (confidence + 5) / (confidence.max() + 5) + arts = plot_heatmaps( + [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence + ) + else: + arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True) + self.artists += arts + + def clear(self): + for x in self.artists: + x.remove() + x.colorbar.remove() + + self.artists = [] + + +class LatitudeConfidencePlot: + plot_name = "latitude_confidence" + required_keys = [] + # required_keys = ["latitude_confidence"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + for idx, name in enumerate(preds): + pred = preds[name] + + if "latitude_confidence" in pred: + arts = plot_confidences([pred["latitude_confidence"][0]], axes=[axes[0][idx]]) + self.artists += arts + + def clear(self): + for x in self.artists: + x.remove() + x.colorbar.remove() + + self.artists = [] + + +class UpPlot: + plot_name = "up" + required_keys = ["up_field"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + self.gt_mode = False # Flag to track whether to display ground truth or predicted + self.text_objects = [] # To store text objects + + self.fig = fig + self.axes = axes + self.data = data + self.preds = preds + + # Create a toggle button on the lower left corner of the first axis + self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) + self.button = Button(self.ax_button, "Toggle GT") + self.button.on_clicked(self.toggle_display) + + self.update_plot() + + def toggle_display(self, event): + # Toggle between ground truth and predicted latitudes + self.gt_mode = not self.gt_mode + self.update_plot() + + def update_plot(self): + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + for idx, name in enumerate(self.preds): + pred = self.preds[name] + + if self.gt_mode: + up = self.data["up_field"][0] + text = "\nGT" + else: + if "up_field" not in pred: + continue + up = pred["up_field"][0] + text = "\nPrediction" + + # Plot up + self.artists += plot_vector_fields([up], axes=[self.axes[0][idx]]) + + self.text_objects.append(add_text(idx, text)) + + # Update the plot + self.fig.canvas.draw() + + def clear(self): + # Remove the button + self.button.disconnect_events() + self.ax_button.remove() + + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + +class UpErrorPlot: + plot_name = "up_error" + required_keys = ["up_field"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + for idx, name in enumerate(preds): + pred = preds[name] + gt = data["up_field"].detach().cpu() + + if "up_field" in pred: + up = pred["up_field"].detach().cpu() + error = up_error(up, gt)[0].numpy() + + if "up_confidence" in pred: + confidence = pred["up_confidence"].detach().cpu().numpy() + confidence = np.log10(confidence).clip(-5) + confidence = (confidence + 5) / (confidence.max() + 5) + arts = plot_heatmaps( + [error], cmap="turbo", axes=[axes[0][idx]], colorbar=True, a=confidence + ) + else: + arts = plot_heatmaps([error], cmap="turbo", axes=[axes[0][idx]], colorbar=True) + self.artists += arts + + def clear(self): + for x in self.artists: + x.remove() + x.colorbar.remove() + + self.artists = [] + + +class UpConfidencePlot: + plot_name = "up_confidence" + required_keys = [] + # required_keys = ["up_confidence"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + for idx, name in enumerate(preds): + pred = preds[name] + + if "up_confidence" in pred: + arts = plot_confidences([pred["up_confidence"][0]], axes=[axes[0][idx]]) + self.artists += arts + + def clear(self): + for x in self.artists: + x.remove() + x.colorbar.remove() + + self.artists = [] + + +class PerspectiveField: + plot_name = "perspective_field" + required_keys = ["camera", "gravity"] + + def __init__(self, fig, axes, data, preds): + self.artists = [] + self.gt_mode = False # Flag to track whether to display ground truth or predicted + self.text_objects = [] # To store text objects + + self.fig = fig + self.axes = axes + self.data = data + self.preds = preds + + # Create a toggle button on the lower left corner of the first axis + self.ax_button = self.fig.add_axes([0.01, 0.02, 0.2, 0.06]) + self.button = Button(self.ax_button, "Toggle GT") + self.button.on_clicked(self.toggle_display) + + self.update_plot() + + def toggle_display(self, event): + # Toggle between ground truth and predicted latitudes + self.gt_mode = not self.gt_mode + self.update_plot() + + def update_plot(self): + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + for idx, name in enumerate(self.preds): + pred = self.preds[name] + + if self.gt_mode: + camera = self.data["camera"] + gravity = self.data["gravity"] + text = "\nGT" + else: + camera = pred["camera"] + gravity = pred["gravity"] + text = "\nPrediction" + camera = Camera(camera) + gravity = Gravity(gravity) + + up, latitude = get_perspective_field(camera, gravity) + + self.artists += plot_latitudes([latitude[0, 0]], axes=[self.axes[0][idx]]) + self.artists += plot_vector_fields([up[0]], axes=[self.axes[0][idx]]) + + self.text_objects.append(add_text(idx, text)) + + # Update the plot + self.fig.canvas.draw() + + def clear(self): + # Remove the button + self.button.disconnect_events() + self.ax_button.remove() + + for x in self.artists: + x.remove() + for text in self.text_objects: + text.remove() + + self.artists = [] + self.text_objects = [] + + +__plot_dict__ = { + obj.plot_name: obj + for _, obj in inspect.getmembers(sys.modules[__name__], predicate=inspect.isclass) + if hasattr(obj, "plot_name") +} diff --git a/siclib/visualization/two_view_frame.py b/siclib/visualization/two_view_frame.py new file mode 100644 index 0000000000000000000000000000000000000000..bff49861d536d0455f9bbf44b318ec37e18987e7 --- /dev/null +++ b/siclib/visualization/two_view_frame.py @@ -0,0 +1,139 @@ +import pprint + +import numpy as np + +from . import viz2d +from .tools import RadioHideTool, ToggleTool, __plot_dict__ + +# flake8: noqa +# mypy: ignore-errors + + +class FormatPrinter(pprint.PrettyPrinter): + def __init__(self, formats): + super(FormatPrinter, self).__init__() + self.formats = formats + + def format(self, obj, ctx, maxlvl, lvl): + if type(obj) in self.formats: + return self.formats[type(obj)] % obj, 1, 0 + return pprint.PrettyPrinter.format(self, obj, ctx, maxlvl, lvl) + + +class TwoViewFrame: + default_conf = { + "default": "image", + "summary_visible": False, + } + + plot_dict = __plot_dict__ + + childs = [] + + event_to_image = [None, "image", "horizon_line", "lat_pred", "lat_gt"] + + def __init__(self, conf, data, preds, title=None, event=1, summaries=None): + self.conf = conf + self.data = data + self.preds = preds + self.names = list(preds.keys()) + self.plot = self.event_to_image[event] + self.summaries = summaries + self.fig, self.axes, self.summary_arts = self.init_frame() + if title is not None: + self.fig.canvas.manager.set_window_title(title) + + keys = None + for _, pred in preds.items(): + keys = set(pred.keys()) if keys is None else keys.intersection(pred.keys()) + + keys = keys.union(data.keys()) + + self.options = [k for k, v in self.plot_dict.items() if set(v.required_keys).issubset(keys)] + self.handle = None + self.radios = self.fig.canvas.manager.toolmanager.add_tool( + "switch plot", + RadioHideTool, + options=self.options, + callback_fn=self.draw, + active=conf.default, + keymap="R", + ) + + self.toggle_summary = self.fig.canvas.manager.toolmanager.add_tool( + "toggle summary", + ToggleTool, + toggled=self.conf.summary_visible, + callback_fn=self.set_summary_visible, + keymap="t", + ) + + if self.fig.canvas.manager.toolbar is not None: + self.fig.canvas.manager.toolbar.add_tool("switch plot", "navigation") + self.draw(conf.default) + + def init_frame(self): + """initialize frame""" + imgs = [[self.data["image"][0].permute(1, 2, 0) for _ in self.names]] + # imgs = [imgs for _ in self.names] # repeat for each model + + fig, axes = viz2d.plot_image_grid(imgs, return_fig=True, titles=None, figs=5) + [viz2d.add_text(i, n, axes=axes[0]) for i, n in enumerate(self.names)] + + fig.canvas.mpl_connect("pick_event", self.click_artist) + if self.summaries is not None: + font_size = 7 + formatter = FormatPrinter({np.float32: "%.4f", np.float64: "%.4f"}) + toggle_artists = [ + viz2d.add_text( + i, + formatter.pformat(self.summaries[n]), + axes=axes[0], + pos=(0.01, 0.01), + va="bottom", + backgroundcolor=(0, 0, 0, 0.5), + visible=self.conf.summary_visible, + fs=font_size, + ) + for i, n in enumerate(self.names) + ] + else: + toggle_artists = [] + return fig, axes, toggle_artists + + def draw(self, value): + """redraw content in frame""" + self.clear() + self.conf.default = value + self.handle = self.plot_dict[value](self.fig, self.axes, self.data, self.preds) + return self.handle + + def clear(self): + if self.handle is not None: + try: + self.handle.clear() + except AttributeError: + pass + self.handle = None + for row in self.axes: + for ax in row: + [li.remove() for li in ax.lines] + [c.remove() for c in ax.collections] + self.fig.artists.clear() + self.fig.canvas.draw_idle() + self.handle = None + + def click_artist(self, event): + art = event.artist + select = art.get_arrowstyle().arrow == "-" + art.set_arrowstyle("<|-|>" if select else "-") + if select: + art.set_zorder(1) + if hasattr(self.handle, "click_artist"): + self.handle.click_artist(event) + self.fig.canvas.draw_idle() + + def set_summary_visible(self, visible): + self.conf.summary_visible = visible + [s.set_visible(visible) for s in self.summary_arts] + self.fig.canvas.draw_idle() diff --git a/siclib/visualization/visualize_batch.py b/siclib/visualization/visualize_batch.py new file mode 100644 index 0000000000000000000000000000000000000000..47d59d2bde94d4af4821165351714f1d2ab6be21 --- /dev/null +++ b/siclib/visualization/visualize_batch.py @@ -0,0 +1,191 @@ +"""Visualization of predicted and ground truth for a single batch.""" + +from typing import Any, Dict + +import numpy as np +import torch + +from siclib.geometry.perspective_fields import get_latitude_field +from siclib.models.utils.metrics import latitude_error, up_error +from siclib.utils.conversions import rad2deg +from siclib.utils.tensor import batch_to_device +from siclib.visualization.viz2d import ( + plot_confidences, + plot_heatmaps, + plot_image_grid, + plot_latitudes, + plot_vector_fields, +) + + +def make_up_figure( + pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 +) -> Dict[str, Any]: + """Get predicted and ground truth up fields and errors. + + Args: + pred (Dict[str, torch.Tensor]): Predicted up field. + data (Dict[str, torch.Tensor]): Ground truth up field. + n_pairs (int): Number of pairs to visualize. + + Returns: + Dict[str, Any]: Dictionary with figure. + """ + pred = batch_to_device(pred, "cpu", detach=True) + data = batch_to_device(data, "cpu", detach=True) + + n_pairs = min(n_pairs, len(data["image"])) + + if "up_field" not in pred.keys(): + return {} + + errors = up_error(pred["up_field"], data["up_field"]) + + up_fields = [] + for i in range(n_pairs): + row = [data["up_field"][i], pred["up_field"][i], errors[i]] + titles = ["Up GT", "Up Pred", "Up Error"] + + if "up_confidence" in pred.keys(): + row += [pred["up_confidence"][i]] + titles += ["Up Confidence"] + + row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row] + up_fields.append(row) + + # create figure + N, M = len(up_fields), len(up_fields[0]) + 1 + imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] + fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True) + ax = np.array(ax) + + for i in range(n_pairs): + plot_vector_fields(up_fields[i][:2], axes=ax[i, [1, 2]]) + plot_heatmaps([up_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]]) + + if "up_confidence" in pred.keys(): + plot_confidences([up_fields[i][3]], axes=ax[i, [4]]) + + return {"up": fig} + + +def make_latitude_figure( + pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 +) -> Dict[str, Any]: + """Get predicted and ground truth latitude fields and errors. + + Args: + pred (Dict[str, torch.Tensor]): Predicted latitude field. + data (Dict[str, torch.Tensor]): Ground truth latitude field. + n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. + + Returns: + Dict[str, Any]: Dictionary with figure. + """ + pred = batch_to_device(pred, "cpu", detach=True) + data = batch_to_device(data, "cpu", detach=True) + + n_pairs = min(n_pairs, len(data["image"])) + latitude_fields = [] + + if "latitude_field" not in pred.keys(): + return {} + + errors = latitude_error(pred["latitude_field"], data["latitude_field"]) + for i in range(n_pairs): + row = [ + rad2deg(data["latitude_field"][i][0]), + rad2deg(pred["latitude_field"][i][0]), + errors[i], + ] + titles = ["Latitude GT", "Latitude Pred", "Latitude Error"] + + if "latitude_confidence" in pred.keys(): + row += [pred["latitude_confidence"][i]] + titles += ["Latitude Confidence"] + + row = [r.float().numpy() if isinstance(r, torch.Tensor) else r for r in row] + latitude_fields.append(row) + + # create figure + N, M = len(latitude_fields), len(latitude_fields[0]) + 1 + imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] + fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True) + ax = np.array(ax) + + for i in range(n_pairs): + plot_latitudes(latitude_fields[i][:2], is_radians=False, axes=ax[i, [1, 2]]) + plot_heatmaps([latitude_fields[i][2]], cmap="turbo", colorbar=True, axes=ax[i, [3]]) + + if "latitude_confidence" in pred.keys(): + plot_confidences([latitude_fields[i][3]], axes=ax[i, [4]]) + + return {"latitude": fig} + + +def make_camera_figure( + pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 +) -> Dict[str, Any]: + """Get predicted and ground truth camera parameters. + + Args: + pred (Dict[str, torch.Tensor]): Predicted camera parameters. + data (Dict[str, torch.Tensor]): Ground truth camera parameters. + n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. + + Returns: + Dict[str, Any]: Dictionary with figure. + """ + pred = batch_to_device(pred, "cpu", detach=True) + data = batch_to_device(data, "cpu", detach=True) + + n_pairs = min(n_pairs, len(data["image"])) + + if "camera" not in pred.keys(): + return {} + + latitudes = [] + for i in range(n_pairs): + titles = ["Cameras GT"] + row = [get_latitude_field(data["camera"][i], data["gravity"][i])] + + if "camera" in pred.keys() and "gravity" in pred.keys(): + row += [get_latitude_field(pred["camera"][i], pred["gravity"][i])] + titles += ["Cameras Pred"] + + row = [rad2deg(r).squeeze(-1).float().numpy()[0] for r in row] + latitudes.append(row) + + # create figure + N, M = len(latitudes), len(latitudes[0]) + 1 + imgs = [[data["image"][i].permute(1, 2, 0).cpu().clip(0, 1)] * M for i in range(n_pairs)] + fig, ax = plot_image_grid(imgs, titles=[["Image"] + titles] * N, return_fig=True, set_lim=True) + ax = np.array(ax) + + for i in range(n_pairs): + plot_latitudes(latitudes[i], is_radians=False, axes=ax[i, 1:]) + + return {"camera": fig} + + +def make_perspective_figures( + pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor], n_pairs: int = 2 +) -> Dict[str, Any]: + """Get predicted and ground truth perspective fields. + + Args: + pred (Dict[str, torch.Tensor]): Predicted perspective fields. + data (Dict[str, torch.Tensor]): Ground truth perspective fields. + n_pairs (int, optional): Number of pairs to visualize. Defaults to 2. + + Returns: + Dict[str, Any]: Dictionary with figure. + """ + n_pairs = min(n_pairs, len(data["image"])) + figures = make_up_figure(pred, data, n_pairs) + figures |= make_latitude_figure(pred, data, n_pairs) + figures |= make_camera_figure(pred, data, n_pairs) + + {f.tight_layout() for f in figures.values()} + + return figures diff --git a/siclib/visualization/viz2d.py b/siclib/visualization/viz2d.py new file mode 100644 index 0000000000000000000000000000000000000000..5a7d1773a7416e59d91c6be9bc0adbf58e29093b --- /dev/null +++ b/siclib/visualization/viz2d.py @@ -0,0 +1,520 @@ +""" +2D visualization primitives based on Matplotlib. +1) Plot images with `plot_images`. +2) Call TODO: add functions +3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`. +""" + +import matplotlib.patheffects as path_effects +import matplotlib.pyplot as plt +import numpy as np +import torch + +from siclib.geometry.perspective_fields import get_perspective_field +from siclib.utils.conversions import rad2deg + +# flake8: noqa +# mypy: ignore-errors + + +def cm_ranking(sc, ths=None): + if ths is None: + ths = [512, 1024, 2048, 4096] + + ls = sc.shape[0] + colors = ["red", "yellow", "lime", "cyan", "blue"] + out = ["gray"] * ls + for i in range(ls): + for c, th in zip(colors[: len(ths) + 1], ths + [ls]): + if i < th: + out[i] = c + break + sid = np.argsort(sc, axis=0).flip(0) + return np.array(out)[sid] + + +def cm_RdBl(x): + """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0]]) + return np.clip(c, 0, 1) + + +def cm_RdGn(x): + """Custom colormap: red (0) -> yellow (0.5) -> green (1).""" + x = np.clip(x, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]]) + return np.clip(c, 0, 1) + + +def cm_BlRdGn(x_): + """Custom colormap: blue (-1) -> red (0.0) -> green (1).""" + x = np.clip(x_, 0, 1)[..., None] * 2 + c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]]) + + xn = -np.clip(x_, -1, 0)[..., None] * 2 + cn = xn * np.array([[0, 1.0, 0, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]]) + return np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1) + + +def plot_images(imgs, titles=None, cmaps="gray", dpi=200, pad=0.5, adaptive=True): + """Plot a list of images. + + Args: + imgs (List[np.ndarray]): List of images to plot. + titles (List[str], optional): Titles. Defaults to None. + cmaps (str, optional): Colormaps. Defaults to "gray". + dpi (int, optional): Dots per inch. Defaults to 200. + pad (float, optional): Padding. Defaults to 0.5. + adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. + + Returns: + plt.Figure: Figure of the images. + """ + n = len(imgs) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + ratios = [i.shape[1] / i.shape[0] for i in imgs] if adaptive else [4 / 3] * n + figsize = [sum(ratios) * 4.5, 4.5] + fig, axs = plt.subplots(1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}) + if n == 1: + axs = [axs] + for i, (img, ax) in enumerate(zip(imgs, axs)): + ax.imshow(img, cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if titles: + ax.set_title(titles[i]) + fig.tight_layout(pad=pad) + + return fig + + +def plot_image_grid( + imgs, + titles=None, + cmaps="gray", + dpi=100, + pad=0.5, + fig=None, + adaptive=True, + figs=3.0, + return_fig=False, + set_lim=False, +) -> plt.Figure: + """Plot a grid of images. + + Args: + imgs (List[np.ndarray]): List of images to plot. + titles (List[str], optional): Titles. Defaults to None. + cmaps (str, optional): Colormaps. Defaults to "gray". + dpi (int, optional): Dots per inch. Defaults to 100. + pad (float, optional): Padding. Defaults to 0.5. + fig (_type_, optional): Figure to plot on. Defaults to None. + adaptive (bool, optional): Whether to adapt the aspect ratio. Defaults to True. + figs (float, optional): Figure size. Defaults to 3.0. + return_fig (bool, optional): Whether to return the figure. Defaults to False. + set_lim (bool, optional): Whether to set the limits. Defaults to False. + + Returns: + plt.Figure: Figure and axes or just axes. + """ + nr, n = len(imgs), len(imgs[0]) + if not isinstance(cmaps, (list, tuple)): + cmaps = [cmaps] * n + + if adaptive: + ratios = [i.shape[1] / i.shape[0] for i in imgs[0]] # W / H + else: + ratios = [4 / 3] * n + + figsize = [sum(ratios) * figs, nr * figs] + if fig is None: + fig, axs = plt.subplots( + nr, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios} + ) + else: + axs = fig.subplots(nr, n, gridspec_kw={"width_ratios": ratios}) + fig.figure.set_size_inches(figsize) + + if nr == 1 and n == 1: + axs = [[axs]] + elif n == 1: + axs = axs[:, None] + elif nr == 1: + axs = [axs] + + for j in range(nr): + for i in range(n): + ax = axs[j][i] + ax.imshow(imgs[j][i], cmap=plt.get_cmap(cmaps[i])) + ax.set_axis_off() + if set_lim: + ax.set_xlim([0, imgs[j][i].shape[1]]) + ax.set_ylim([imgs[j][i].shape[0], 0]) + if titles: + ax.set_title(titles[j][i]) + if isinstance(fig, plt.Figure): + fig.tight_layout(pad=pad) + return (fig, axs) if return_fig else axs + + +def add_text( + idx, + text, + pos=(0.01, 0.99), + fs=15, + color="w", + lcolor="k", + lwidth=4, + ha="left", + va="top", + axes=None, + **kwargs, +): + """Add text to a plot. + + Args: + idx (int): Index of the axes. + text (str): Text to add. + pos (tuple, optional): Text position. Defaults to (0.01, 0.99). + fs (int, optional): Font size. Defaults to 15. + color (str, optional): Text color. Defaults to "w". + lcolor (str, optional): Line color. Defaults to "k". + lwidth (int, optional): Line width. Defaults to 4. + ha (str, optional): Horizontal alignment. Defaults to "left". + va (str, optional): Vertical alignment. Defaults to "top". + axes (List[plt.Axes], optional): Axes to put text on. Defaults to None. + + Returns: + plt.Text: Text object. + """ + if axes is None: + axes = plt.gcf().axes + + ax = axes[idx] + + t = ax.text( + *pos, + text, + fontsize=fs, + ha=ha, + va=va, + color=color, + transform=ax.transAxes, + zorder=5, + **kwargs, + ) + if lcolor is not None: + t.set_path_effects( + [ + path_effects.Stroke(linewidth=lwidth, foreground=lcolor), + path_effects.Normal(), + ] + ) + return t + + +def plot_heatmaps( + heatmaps, + vmin=-1e-6, # include negative zero + vmax=None, + cmap="Spectral", + a=0.5, + axes=None, + contours_every=None, + contour_style="solid", + colorbar=False, +): + """Plot heatmaps with optional contours. + + To plot latitude field, set vmin=-90, vmax=90 and contours_every=15. + + Args: + heatmaps (List[np.ndarray | torch.Tensor]): List of 2D heatmaps. + vmin (float, optional): Min Value. Defaults to -1e-6. + vmax (float, optional): Max Value. Defaults to None. + cmap (str, optional): Colormap. Defaults to "Spectral". + a (float, optional): Alpha value. Defaults to 0.5. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + contours_every (int, optional): If not none, will draw contours. Defaults to None. + contour_style (str, optional): Style of the contours. Defaults to "solid". + colorbar (bool, optional): Whether to show colorbar. Defaults to False. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + artists = [] + + for i in range(len(axes)): + a_ = a if isinstance(a, float) else a[i] + + if isinstance(heatmaps[i], torch.Tensor): + heatmaps[i] = heatmaps[i].detach().cpu().numpy() + + alpha = a_ + # Plot the heatmap + art = axes[i].imshow( + heatmaps[i], + alpha=alpha, + vmin=vmin, + vmax=vmax, + cmap=cmap, + ) + if colorbar: + cmax = vmax or np.percentile(heatmaps[i], 99) + art.set_clim(vmin, cmax) + cbar = plt.colorbar(art, ax=axes[i]) + artists.append(cbar) + + artists.append(art) + + if contours_every is not None: + # Add contour lines to the heatmap + contour_data = np.arange(vmin, vmax + contours_every, contours_every) + + # Get the colormap colors for contour lines + contour_colors = [ + plt.colormaps.get_cmap(cmap)(plt.Normalize(vmin=vmin, vmax=vmax)(level)) + for level in contour_data + ] + contours = axes[i].contour( + heatmaps[i], + levels=contour_data, + linewidths=2, + colors=contour_colors, + linestyles=contour_style, + ) + + contours.set_clim(vmin, vmax) + + fmt = { + level: f"{label}°" + for level, label in zip(contour_data, contour_data.astype(int).astype(str)) + } + t = axes[i].clabel(contours, inline=True, fmt=fmt, fontsize=16, colors="white") + + for label in t: + label.set_path_effects( + [ + path_effects.Stroke(linewidth=1, foreground="k"), + path_effects.Normal(), + ] + ) + artists.append(contours) + + return artists + + +def plot_horizon_lines( + cameras, gravities, line_colors="orange", lw=2, styles="solid", alpha=1.0, ax=None +): + """Plot horizon lines on the perspective field. + + Args: + cameras (List[Camera]): List of cameras. + gravities (List[Gravity]): Gravities. + line_colors (str, optional): Line Colors. Defaults to "orange". + lw (int, optional): Line width. Defaults to 2. + styles (str, optional): Line styles. Defaults to "solid". + alpha (float, optional): Alphas. Defaults to 1.0. + ax (List[plt.Axes], optional): Axes to draw horizon line on. Defaults to None. + """ + if not isinstance(line_colors, list): + line_colors = [line_colors] * len(cameras) + + if not isinstance(styles, list): + styles = [styles] * len(cameras) + + fig = plt.gcf() + ax = fig.gca() if ax is None else ax + + if isinstance(ax, plt.Axes): + ax = [ax] * len(cameras) + + assert len(ax) == len(cameras), f"{len(ax)}, {len(cameras)}" + + for i in range(len(cameras)): + _, lat = get_perspective_field(cameras[i], gravities[i]) + # horizon line is zero level of the latitude field + lat = lat[0, 0].cpu().numpy() + contours = ax[i].contour(lat, levels=[0], linewidths=lw, colors=line_colors[i]) + for contour_line in contours.collections: + contour_line.set_linestyle(styles[i]) + + +def plot_vector_fields( + vector_fields, + cmap="lime", + subsample=15, + scale=None, + lw=None, + alphas=0.8, + axes=None, +): + """Plot vector fields. + + Args: + vector_fields (List[torch.Tensor]): List of vector fields of shape (2, H, W). + cmap (str, optional): Color of the vectors. Defaults to "lime". + subsample (int, optional): Subsample the vector field. Defaults to 15. + scale (float, optional): Scale of the vectors. Defaults to None. + lw (float, optional): Line width of the vectors. Defaults to None. + alphas (float | np.ndarray, optional): Alpha per vector or global. Defaults to 0.8. + axes (List[plt.Axes], optional): List of axes to draw on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + vector_fields = [v.cpu().numpy() if isinstance(v, torch.Tensor) else v for v in vector_fields] + + artists = [] + + H, W = vector_fields[0].shape[-2:] + if scale is None: + scale = subsample / min(H, W) + + if lw is None: + lw = 0.1 / subsample + + if alphas is None: + alphas = np.ones_like(vector_fields[0][0]) + alphas = np.stack([alphas] * len(vector_fields), 0) + elif isinstance(alphas, float): + alphas = np.ones_like(vector_fields[0][0]) * alphas + alphas = np.stack([alphas] * len(vector_fields), 0) + else: + alphas = np.array(alphas) + + subsample = min(W, H) // subsample + offset_x = ((W % subsample) + subsample) // 2 + + samples_x = np.arange(offset_x, W, subsample) + samples_y = np.arange(int(subsample * 0.9), H, subsample) + + x_grid, y_grid = np.meshgrid(samples_x, samples_y) + + for i in range(len(axes)): + # vector field of shape (2, H, W) with vectors of norm == 1 + vector_field = vector_fields[i] + + a = alphas[i][samples_y][:, samples_x] + x, y = vector_field[:, samples_y][:, :, samples_x] + + c = cmap + if not isinstance(cmap, str): + c = cmap[i][samples_y][:, samples_x].reshape(-1, 3) + + s = scale * min(H, W) + arrows = axes[i].quiver( + x_grid, + y_grid, + x, + y, + scale=s, + scale_units="width" if H > W else "height", + units="width" if H > W else "height", + alpha=a, + color=c, + angles="xy", + antialiased=True, + width=lw, + headaxislength=3.5, + zorder=5, + ) + + artists.append(arrows) + + return artists + + +def plot_latitudes( + latitude, + is_radians=True, + vmin=-90, + vmax=90, + cmap="seismic", + contours_every=15, + alpha=0.4, + axes=None, + **kwargs, +): + """Plot latitudes. + + Args: + latitude (List[torch.Tensor]): List of latitudes. + is_radians (bool, optional): Whether the latitudes are in radians. Defaults to True. + vmin (int, optional): Min value to clip to. Defaults to -90. + vmax (int, optional): Max value to clip to. Defaults to 90. + cmap (str, optional): Colormap. Defaults to "seismic". + contours_every (int, optional): Contours every. Defaults to 15. + alpha (float, optional): Alpha value. Defaults to 0.4. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + assert len(axes) == len(latitude), f"{len(axes)}, {len(latitude)}" + lat = [rad2deg(lat) for lat in latitude] if is_radians else latitude + return plot_heatmaps( + lat, + vmin=vmin, + vmax=vmax, + cmap=cmap, + a=alpha, + axes=axes, + contours_every=contours_every, + **kwargs, + ) + + +def plot_confidences( + confidence, + as_log=True, + vmin=-4, + vmax=0, + cmap="turbo", + alpha=0.4, + axes=None, + **kwargs, +): + """Plot confidences. + + Args: + confidence (List[torch.Tensor]): Confidence maps. + as_log (bool, optional): Whether to plot in log scale. Defaults to True. + vmin (int, optional): Min value to clip to. Defaults to -4. + vmax (int, optional): Max value to clip to. Defaults to 0. + cmap (str, optional): Colormap. Defaults to "turbo". + alpha (float, optional): Alpha value. Defaults to 0.4. + axes (List[plt.Axes], optional): Axes to plot on. Defaults to None. + + Returns: + List[plt.Artist]: List of artists. + """ + if axes is None: + axes = plt.gcf().axes + + confidence = [c.cpu() if isinstance(c, torch.Tensor) else torch.tensor(c) for c in confidence] + + assert len(axes) == len(confidence), f"{len(axes)}, {len(confidence)}" + + if as_log: + confidence = [torch.log10(c.clip(1e-5)).clip(vmin, vmax) for c in confidence] + + # normalize to [0, 1] + confidence = [(c - c.min()) / (c.max() - c.min()) for c in confidence] + return plot_heatmaps(confidence, vmin=0, vmax=1, cmap=cmap, a=alpha, axes=axes, **kwargs) + + +def save_plot(path, **kw): + """Save the current figure without any white margin.""" + plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)