Spaces:
Sleeping
Sleeping
pablovela5620
commited on
Commit
•
899c526
1
Parent(s):
e32c92e
initial commit with working dpvo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +166 -1
- config/default.yaml +19 -0
- config/fast.yaml +20 -0
- mini_dpvo/__init__.py +0 -0
- mini_dpvo/altcorr/__init__.py +1 -0
- mini_dpvo/altcorr/correlation.cpp +63 -0
- mini_dpvo/altcorr/correlation.py +74 -0
- mini_dpvo/altcorr/correlation_kernel.cu +333 -0
- mini_dpvo/api/__init__.py +0 -0
- mini_dpvo/api/inference.py +190 -0
- mini_dpvo/ba.py +182 -0
- mini_dpvo/blocks.py +118 -0
- mini_dpvo/config.py +27 -0
- mini_dpvo/data_readers/__init__.py +1 -0
- mini_dpvo/data_readers/augmentation.py +66 -0
- mini_dpvo/data_readers/base.py +176 -0
- mini_dpvo/data_readers/factory.py +26 -0
- mini_dpvo/data_readers/frame_utils.py +164 -0
- mini_dpvo/data_readers/rgbd_utils.py +188 -0
- mini_dpvo/data_readers/tartan.py +110 -0
- mini_dpvo/data_readers/tartan_test.txt +32 -0
- mini_dpvo/dpvo.py +410 -0
- mini_dpvo/extractor.py +264 -0
- mini_dpvo/fastba/__init__.py +1 -0
- mini_dpvo/fastba/ba.cpp +157 -0
- mini_dpvo/fastba/ba.py +8 -0
- mini_dpvo/fastba/ba_cuda.cu +575 -0
- mini_dpvo/lietorch/__init__.py +2 -0
- mini_dpvo/lietorch/broadcasting.py +31 -0
- mini_dpvo/lietorch/gradcheck.py +592 -0
- mini_dpvo/lietorch/group_ops.py +102 -0
- mini_dpvo/lietorch/groups.py +322 -0
- mini_dpvo/lietorch/include/common.h +12 -0
- mini_dpvo/lietorch/include/dispatch.h +48 -0
- mini_dpvo/lietorch/include/lietorch_cpu.h +51 -0
- mini_dpvo/lietorch/include/lietorch_gpu.h +51 -0
- mini_dpvo/lietorch/include/rxso3.h +324 -0
- mini_dpvo/lietorch/include/se3.h +229 -0
- mini_dpvo/lietorch/include/sim3.h +217 -0
- mini_dpvo/lietorch/include/so3.h +229 -0
- mini_dpvo/lietorch/run_tests.py +302 -0
- mini_dpvo/lietorch/src/lietorch.cpp +317 -0
- mini_dpvo/lietorch/src/lietorch_cpu.cpp +657 -0
- mini_dpvo/lietorch/src/lietorch_gpu.cu +601 -0
- mini_dpvo/logger.py +58 -0
- mini_dpvo/net.py +270 -0
- mini_dpvo/plot_utils.py +52 -0
- mini_dpvo/projective_ops.py +121 -0
- mini_dpvo/stream.py +92 -0
- mini_dpvo/utils.py +92 -0
.gitignore
CHANGED
@@ -1,4 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
# pixi environments
|
2 |
.pixi
|
3 |
*.egg-info
|
4 |
-
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
110 |
+
.pdm.toml
|
111 |
+
.pdm-python
|
112 |
+
.pdm-build/
|
113 |
+
|
114 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
115 |
+
__pypackages__/
|
116 |
+
|
117 |
+
# Celery stuff
|
118 |
+
celerybeat-schedule
|
119 |
+
celerybeat.pid
|
120 |
+
|
121 |
+
# SageMath parsed files
|
122 |
+
*.sage.py
|
123 |
+
|
124 |
+
# Environments
|
125 |
+
.env
|
126 |
+
.venv
|
127 |
+
env/
|
128 |
+
venv/
|
129 |
+
ENV/
|
130 |
+
env.bak/
|
131 |
+
venv.bak/
|
132 |
+
|
133 |
+
# Spyder project settings
|
134 |
+
.spyderproject
|
135 |
+
.spyproject
|
136 |
+
|
137 |
+
# Rope project settings
|
138 |
+
.ropeproject
|
139 |
+
|
140 |
+
# mkdocs documentation
|
141 |
+
/site
|
142 |
+
|
143 |
+
# mypy
|
144 |
+
.mypy_cache/
|
145 |
+
.dmypy.json
|
146 |
+
dmypy.json
|
147 |
+
|
148 |
+
# Pyre type checker
|
149 |
+
.pyre/
|
150 |
+
|
151 |
+
# pytype static type analyzer
|
152 |
+
.pytype/
|
153 |
+
|
154 |
+
# Cython debug symbols
|
155 |
+
cython_debug/
|
156 |
+
|
157 |
+
# PyCharm
|
158 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
159 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
160 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
+
#.idea/
|
163 |
+
|
164 |
# pixi environments
|
165 |
.pixi
|
166 |
*.egg-info
|
167 |
+
thirdparty/*
|
168 |
+
data/*
|
169 |
+
checkpoints/*
|
config/default.yaml
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### DPVO Config File ###
|
2 |
+
|
3 |
+
# VO config (increase for better accuracy)
|
4 |
+
PATCHES_PER_FRAME: 96
|
5 |
+
REMOVAL_WINDOW: 22
|
6 |
+
OPTIMIZATION_WINDOW: 10
|
7 |
+
PATCH_LIFETIME: 13
|
8 |
+
|
9 |
+
# threshold for keyframe removal
|
10 |
+
KEYFRAME_THRESH: 15.0
|
11 |
+
|
12 |
+
# camera motion model
|
13 |
+
MOTION_MODEL: 'DAMPED_LINEAR'
|
14 |
+
MOTION_DAMPING: 0.5
|
15 |
+
|
16 |
+
# maybe use mixed precision for inference
|
17 |
+
MIXED_PRECISION: True
|
18 |
+
|
19 |
+
GRADIENT_BIAS: False
|
config/fast.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
### DPVO Config File ###
|
3 |
+
|
4 |
+
# VO config (increase for better accuracy)
|
5 |
+
PATCHES_PER_FRAME: 48
|
6 |
+
REMOVAL_WINDOW: 16
|
7 |
+
OPTIMIZATION_WINDOW: 7
|
8 |
+
PATCH_LIFETIME: 11
|
9 |
+
|
10 |
+
# threshold for keyframe removal
|
11 |
+
KEYFRAME_THRESH: 15.0
|
12 |
+
|
13 |
+
# camera motion model
|
14 |
+
MOTION_MODEL: 'DAMPED_LINEAR'
|
15 |
+
MOTION_DAMPING: 0.5
|
16 |
+
|
17 |
+
# maybe use mixed precision for inference
|
18 |
+
MIXED_PRECISION: True
|
19 |
+
|
20 |
+
GRADIENT_BIAS: False
|
mini_dpvo/__init__.py
ADDED
File without changes
|
mini_dpvo/altcorr/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .correlation import corr, patchify
|
mini_dpvo/altcorr/correlation.cpp
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <vector>
|
3 |
+
|
4 |
+
// CUDA forward declarations
|
5 |
+
std::vector<torch::Tensor> corr_cuda_forward(
|
6 |
+
torch::Tensor fmap1,
|
7 |
+
torch::Tensor fmap2,
|
8 |
+
torch::Tensor coords,
|
9 |
+
torch::Tensor ii,
|
10 |
+
torch::Tensor jj,
|
11 |
+
int radius);
|
12 |
+
|
13 |
+
std::vector<torch::Tensor> corr_cuda_backward(
|
14 |
+
torch::Tensor fmap1,
|
15 |
+
torch::Tensor fmap2,
|
16 |
+
torch::Tensor coords,
|
17 |
+
torch::Tensor ii,
|
18 |
+
torch::Tensor jj,
|
19 |
+
torch::Tensor corr_grad,
|
20 |
+
int radius);
|
21 |
+
|
22 |
+
std::vector<torch::Tensor> patchify_cuda_forward(
|
23 |
+
torch::Tensor net, torch::Tensor coords, int radius);
|
24 |
+
|
25 |
+
std::vector<torch::Tensor> patchify_cuda_backward(
|
26 |
+
torch::Tensor net, torch::Tensor coords, torch::Tensor gradient, int radius);
|
27 |
+
|
28 |
+
std::vector<torch::Tensor> corr_forward(
|
29 |
+
torch::Tensor fmap1,
|
30 |
+
torch::Tensor fmap2,
|
31 |
+
torch::Tensor coords,
|
32 |
+
torch::Tensor ii,
|
33 |
+
torch::Tensor jj, int radius) {
|
34 |
+
return corr_cuda_forward(fmap1, fmap2, coords, ii, jj, radius);
|
35 |
+
}
|
36 |
+
|
37 |
+
std::vector<torch::Tensor> corr_backward(
|
38 |
+
torch::Tensor fmap1,
|
39 |
+
torch::Tensor fmap2,
|
40 |
+
torch::Tensor coords,
|
41 |
+
torch::Tensor ii,
|
42 |
+
torch::Tensor jj,
|
43 |
+
torch::Tensor corr_grad, int radius) {
|
44 |
+
return corr_cuda_backward(fmap1, fmap2, coords, ii, jj, corr_grad, radius);
|
45 |
+
}
|
46 |
+
|
47 |
+
std::vector<torch::Tensor> patchify_forward(
|
48 |
+
torch::Tensor net, torch::Tensor coords, int radius) {
|
49 |
+
return patchify_cuda_forward(net, coords, radius);
|
50 |
+
}
|
51 |
+
|
52 |
+
std::vector<torch::Tensor> patchify_backward(
|
53 |
+
torch::Tensor net, torch::Tensor coords, torch::Tensor gradient, int radius) {
|
54 |
+
return patchify_cuda_backward(net, coords, gradient, radius);
|
55 |
+
}
|
56 |
+
|
57 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
58 |
+
m.def("forward", &corr_forward, "CORR forward");
|
59 |
+
m.def("backward", &corr_backward, "CORR backward");
|
60 |
+
|
61 |
+
m.def("patchify_forward", &patchify_forward, "PATCHIFY forward");
|
62 |
+
m.def("patchify_backward", &patchify_backward, "PATCHIFY backward");
|
63 |
+
}
|
mini_dpvo/altcorr/correlation.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cuda_corr
|
3 |
+
|
4 |
+
class CorrLayer(torch.autograd.Function):
|
5 |
+
@staticmethod
|
6 |
+
def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout):
|
7 |
+
""" forward correlation """
|
8 |
+
ctx.save_for_backward(fmap1, fmap2, coords, ii, jj)
|
9 |
+
ctx.radius = radius
|
10 |
+
ctx.dropout = dropout
|
11 |
+
corr, = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius)
|
12 |
+
|
13 |
+
return corr
|
14 |
+
|
15 |
+
@staticmethod
|
16 |
+
def backward(ctx, grad):
|
17 |
+
""" backward correlation """
|
18 |
+
fmap1, fmap2, coords, ii, jj = ctx.saved_tensors
|
19 |
+
|
20 |
+
if ctx.dropout < 1:
|
21 |
+
perm = torch.rand(len(ii), device="cuda") < ctx.dropout
|
22 |
+
coords = coords[:,perm]
|
23 |
+
grad = grad[:,perm]
|
24 |
+
ii = ii[perm]
|
25 |
+
jj = jj[perm]
|
26 |
+
|
27 |
+
fmap1_grad, fmap2_grad = \
|
28 |
+
cuda_corr.backward(fmap1, fmap2, coords, ii, jj, grad, ctx.radius)
|
29 |
+
|
30 |
+
return fmap1_grad, fmap2_grad, None, None, None, None, None
|
31 |
+
|
32 |
+
|
33 |
+
class PatchLayer(torch.autograd.Function):
|
34 |
+
@staticmethod
|
35 |
+
def forward(ctx, net, coords, radius):
|
36 |
+
""" forward patchify """
|
37 |
+
ctx.radius = radius
|
38 |
+
ctx.save_for_backward(net, coords)
|
39 |
+
|
40 |
+
patches, = cuda_corr.patchify_forward(net, coords, radius)
|
41 |
+
return patches
|
42 |
+
|
43 |
+
@staticmethod
|
44 |
+
def backward(ctx, grad):
|
45 |
+
""" backward patchify """
|
46 |
+
net, coords = ctx.saved_tensors
|
47 |
+
grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)
|
48 |
+
|
49 |
+
return grad, None, None
|
50 |
+
|
51 |
+
def patchify(net, coords, radius, mode='bilinear'):
|
52 |
+
""" extract patches """
|
53 |
+
|
54 |
+
patches = PatchLayer.apply(net, coords, radius)
|
55 |
+
|
56 |
+
if mode == 'bilinear':
|
57 |
+
offset = (coords - coords.floor()).to(net.device)
|
58 |
+
dx, dy = offset[:,:,None,None,None].unbind(dim=-1)
|
59 |
+
|
60 |
+
d = 2 * radius + 1
|
61 |
+
x00 = (1-dy) * (1-dx) * patches[...,:d,:d]
|
62 |
+
x01 = (1-dy) * ( dx) * patches[...,:d,1:]
|
63 |
+
x10 = ( dy) * (1-dx) * patches[...,1:,:d]
|
64 |
+
x11 = ( dy) * ( dx) * patches[...,1:,1:]
|
65 |
+
|
66 |
+
return x00 + x01 + x10 + x11
|
67 |
+
|
68 |
+
return patches
|
69 |
+
|
70 |
+
|
71 |
+
def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1):
|
72 |
+
return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout)
|
73 |
+
|
74 |
+
|
mini_dpvo/altcorr/correlation_kernel.cu
ADDED
@@ -0,0 +1,333 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <THC/THCAtomics.cuh>
|
3 |
+
#include <vector>
|
4 |
+
#include <iostream>
|
5 |
+
|
6 |
+
using namespace torch::indexing;
|
7 |
+
|
8 |
+
#define THREADS 256
|
9 |
+
#define BLOCKS(n) (n + THREADS - 1) / THREADS
|
10 |
+
|
11 |
+
__forceinline__ __device__
|
12 |
+
bool within_bounds(int h, int w, int H, int W) {
|
13 |
+
return h >= 0 && h < H && w >= 0 && w < W;
|
14 |
+
}
|
15 |
+
|
16 |
+
template <typename scalar_t>
|
17 |
+
__global__ void patchify_forward_kernel(int R,
|
18 |
+
const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> net,
|
19 |
+
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords,
|
20 |
+
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patches)
|
21 |
+
{
|
22 |
+
// diameter
|
23 |
+
const int D = 2*R + 2;
|
24 |
+
|
25 |
+
const int B = coords.size(0);
|
26 |
+
const int M = coords.size(1);
|
27 |
+
const int C = net.size(1);
|
28 |
+
const int H = net.size(2);
|
29 |
+
const int W = net.size(3);
|
30 |
+
|
31 |
+
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
32 |
+
if (n < B * M * D * D) {
|
33 |
+
const int ii = n % D; n /= D;
|
34 |
+
const int jj = n % D; n /= D;
|
35 |
+
const int m = n % M; n /= M;
|
36 |
+
|
37 |
+
const float x = coords[n][m][0];
|
38 |
+
const float y = coords[n][m][1];
|
39 |
+
const int i = static_cast<int>(floor(y)) + (ii - R);
|
40 |
+
const int j = static_cast<int>(floor(x)) + (jj - R);
|
41 |
+
|
42 |
+
if (within_bounds(i, j, H, W)) {
|
43 |
+
for (int k=0; k<C; k++)
|
44 |
+
patches[n][m][k][ii][jj] = net[n][k][i][j];
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
template <typename scalar_t>
|
50 |
+
__global__ void patchify_backward_kernel(int R,
|
51 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patch_gradient,
|
52 |
+
const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords,
|
53 |
+
torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> gradient)
|
54 |
+
{
|
55 |
+
// diameter
|
56 |
+
const int D = 2*R + 2;
|
57 |
+
|
58 |
+
const int B = coords.size(0);
|
59 |
+
const int M = coords.size(1);
|
60 |
+
const int C = gradient.size(1);
|
61 |
+
const int H = gradient.size(2);
|
62 |
+
const int W = gradient.size(3);
|
63 |
+
|
64 |
+
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
65 |
+
if (n < B * M * D * D) {
|
66 |
+
const int ii = n % D; n /= D;
|
67 |
+
const int jj = n % D; n /= D;
|
68 |
+
const int m = n % M; n /= M;
|
69 |
+
|
70 |
+
const float x = coords[n][m][0];
|
71 |
+
const float y = coords[n][m][1];
|
72 |
+
const int i = static_cast<int>(floor(y)) + (ii - R);
|
73 |
+
const int j = static_cast<int>(floor(x)) + (jj - R);
|
74 |
+
|
75 |
+
if (within_bounds(i, j, H, W)) {
|
76 |
+
for (int k=0; k<C; k++)
|
77 |
+
atomicAdd(&gradient[n][k][i][j], patch_gradient[n][m][k][ii][jj]);
|
78 |
+
}
|
79 |
+
}
|
80 |
+
}
|
81 |
+
|
82 |
+
template <typename scalar_t>
|
83 |
+
__global__ void corr_forward_kernel(int R,
|
84 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1,
|
85 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2,
|
86 |
+
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords,
|
87 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us,
|
88 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs,
|
89 |
+
torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> corr)
|
90 |
+
{
|
91 |
+
// diameter
|
92 |
+
const int D = 2*R + 2;
|
93 |
+
|
94 |
+
const int B = coords.size(0);
|
95 |
+
const int M = coords.size(1);
|
96 |
+
const int H = coords.size(3);
|
97 |
+
const int W = coords.size(4);
|
98 |
+
|
99 |
+
const int C = fmap1.size(2);
|
100 |
+
const int H2 = fmap2.size(3);
|
101 |
+
const int W2 = fmap2.size(4);
|
102 |
+
|
103 |
+
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
104 |
+
|
105 |
+
if (n < B * M * H * W * D * D) {
|
106 |
+
const int jj = n % D; n /= D;
|
107 |
+
const int ii = n % D; n /= D;
|
108 |
+
const int j0 = n % W; n /= W;
|
109 |
+
const int i0 = n % H; n /= H;
|
110 |
+
const int m = n % M; n /= M;
|
111 |
+
|
112 |
+
const int ix = us[m];
|
113 |
+
const int jx = vs[m];
|
114 |
+
|
115 |
+
const float x = coords[n][m][0][i0][j0];
|
116 |
+
const float y = coords[n][m][1][i0][j0];
|
117 |
+
|
118 |
+
const int i1 = static_cast<int>(floor(y)) + (ii - R);
|
119 |
+
const int j1 = static_cast<int>(floor(x)) + (jj - R);
|
120 |
+
|
121 |
+
scalar_t s = 0;
|
122 |
+
if (within_bounds(i1, j1, H2, W2)) {
|
123 |
+
|
124 |
+
#pragma unroll 8
|
125 |
+
for (int i=0; i<C; i+=8) {
|
126 |
+
scalar_t f1[8]; for (int j=0; j<8; j++) f1[j] = fmap1[n][ix][i+j][i0][j0];
|
127 |
+
scalar_t f2[8]; for (int j=0; j<8; j++) f2[j] = fmap2[n][jx][i+j][i1][j1];
|
128 |
+
|
129 |
+
#pragma unroll
|
130 |
+
for (int j=0; j<8; j++) s += f1[j] * f2[j];
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
corr[n][m][ii][jj][i0][j0] = s;
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
|
139 |
+
template <typename scalar_t>
|
140 |
+
__global__ void corr_backward_kernel(int R,
|
141 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1,
|
142 |
+
const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2,
|
143 |
+
const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords,
|
144 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us,
|
145 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs,
|
146 |
+
const torch::PackedTensorAccessor32<float,6,torch::RestrictPtrTraits> corr_grad,
|
147 |
+
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1_grad,
|
148 |
+
torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2_grad)
|
149 |
+
{
|
150 |
+
// diameter
|
151 |
+
const int D = 2*R + 2;
|
152 |
+
|
153 |
+
const int B = coords.size(0);
|
154 |
+
const int M = coords.size(1);
|
155 |
+
const int H = coords.size(3);
|
156 |
+
const int W = coords.size(4);
|
157 |
+
|
158 |
+
const int C = fmap1.size(2);
|
159 |
+
const int H2 = fmap2.size(3);
|
160 |
+
const int W2 = fmap2.size(4);
|
161 |
+
|
162 |
+
int n = blockIdx.x * blockDim.x + threadIdx.x;
|
163 |
+
|
164 |
+
if (n < B * M * H * W * D * D) {
|
165 |
+
const int jj = n % D; n /= D;
|
166 |
+
const int ii = n % D; n /= D;
|
167 |
+
const int j0 = n % W; n /= W;
|
168 |
+
const int i0 = n % H; n /= H;
|
169 |
+
const int m = n % M; n /= M;
|
170 |
+
|
171 |
+
const int ix = us[m];
|
172 |
+
const int jx = vs[m];
|
173 |
+
|
174 |
+
const float x = coords[n][m][0][i0][j0];
|
175 |
+
const float y = coords[n][m][1][i0][j0];
|
176 |
+
|
177 |
+
const int i1 = static_cast<int>(floor(y)) + (ii - R);
|
178 |
+
const int j1 = static_cast<int>(floor(x)) + (jj - R);
|
179 |
+
|
180 |
+
const scalar_t g = (scalar_t) corr_grad[n][m][ii][jj][i0][j0];
|
181 |
+
|
182 |
+
if (within_bounds(i1, j1, H2, W2)) {
|
183 |
+
#pragma unroll 32
|
184 |
+
for (int i=0; i<C; i++) {
|
185 |
+
atomicAdd(&fmap1_grad[n][ix][i][i0][j0], g * fmap2[n][jx][i][i1][j1]);
|
186 |
+
atomicAdd(&fmap2_grad[n][jx][i][i1][j1], g * fmap1[n][ix][i][i0][j0]);
|
187 |
+
}
|
188 |
+
}
|
189 |
+
}
|
190 |
+
}
|
191 |
+
|
192 |
+
|
193 |
+
std::vector<torch::Tensor> corr_cuda_forward(
|
194 |
+
torch::Tensor fmap1,
|
195 |
+
torch::Tensor fmap2,
|
196 |
+
torch::Tensor coords,
|
197 |
+
torch::Tensor ii,
|
198 |
+
torch::Tensor jj,
|
199 |
+
int radius)
|
200 |
+
{
|
201 |
+
const int B = coords.size(0);
|
202 |
+
const int M = coords.size(1);
|
203 |
+
|
204 |
+
const int H = coords.size(3);
|
205 |
+
const int W = coords.size(4);
|
206 |
+
const int D = 2 * radius + 2;
|
207 |
+
|
208 |
+
auto opts = fmap1.options();
|
209 |
+
auto corr = torch::empty({B, M, D, D, H, W}, opts);
|
210 |
+
|
211 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_forward_kernel", ([&] {
|
212 |
+
corr_forward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius,
|
213 |
+
fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
214 |
+
fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
215 |
+
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
216 |
+
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
217 |
+
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
218 |
+
corr.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>());
|
219 |
+
}));
|
220 |
+
|
221 |
+
torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None});
|
222 |
+
torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None});
|
223 |
+
torch::Tensor dx = x - x.floor(); dx = dx.to(fmap1.dtype());
|
224 |
+
torch::Tensor dy = y - y.floor(); dy = dy.to(fmap2.dtype());
|
225 |
+
|
226 |
+
torch::Tensor out;
|
227 |
+
out = (1 - dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)});
|
228 |
+
out += (dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)});
|
229 |
+
out += (1 - dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)});
|
230 |
+
out += (dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)});
|
231 |
+
|
232 |
+
return { out.permute({0,1,3,2,4,5}) };
|
233 |
+
}
|
234 |
+
|
235 |
+
|
236 |
+
std::vector<torch::Tensor> corr_cuda_backward(
|
237 |
+
torch::Tensor fmap1,
|
238 |
+
torch::Tensor fmap2,
|
239 |
+
torch::Tensor coords,
|
240 |
+
torch::Tensor ii,
|
241 |
+
torch::Tensor jj,
|
242 |
+
torch::Tensor grad,
|
243 |
+
int radius)
|
244 |
+
{
|
245 |
+
const int B = coords.size(0);
|
246 |
+
const int M = coords.size(1);
|
247 |
+
|
248 |
+
const int H = coords.size(3);
|
249 |
+
const int W = coords.size(4);
|
250 |
+
const int D = 2 * radius + 2;
|
251 |
+
|
252 |
+
grad = grad.permute({0,1,3,2,4,5}).contiguous();
|
253 |
+
torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None});
|
254 |
+
torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None});
|
255 |
+
torch::Tensor dx = x - x.floor();
|
256 |
+
torch::Tensor dy = y - y.floor();
|
257 |
+
|
258 |
+
auto opts = torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
|
259 |
+
torch::Tensor g1 = torch::zeros({B, M, D, D, H, W}, grad.options());
|
260 |
+
torch::Tensor g2 = torch::zeros({B, M, D, D, H, W}, grad.options());
|
261 |
+
torch::Tensor g3 = torch::zeros({B, M, D, D, H, W}, grad.options());
|
262 |
+
torch::Tensor g4 = torch::zeros({B, M, D, D, H, W}, grad.options());
|
263 |
+
|
264 |
+
g1.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}, (1 - dx) * (1 - dy) * grad);
|
265 |
+
g2.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}, (dx) * (1 - dy) * grad);
|
266 |
+
g3.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}, (1 - dx) * (dy) * grad);
|
267 |
+
g4.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}, (dx) * (dy) * grad);
|
268 |
+
|
269 |
+
torch::Tensor corr_grad = g1 + g2 + g3 + g4;
|
270 |
+
auto fmap1_grad = torch::zeros_like(fmap1);
|
271 |
+
auto fmap2_grad = torch::zeros_like(fmap2);
|
272 |
+
|
273 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_backward_kernel", ([&] {
|
274 |
+
corr_backward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius,
|
275 |
+
fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
276 |
+
fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
277 |
+
coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
|
278 |
+
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
279 |
+
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
280 |
+
corr_grad.packed_accessor32<float,6,torch::RestrictPtrTraits>(),
|
281 |
+
fmap1_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
282 |
+
fmap2_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());
|
283 |
+
}));
|
284 |
+
|
285 |
+
return {fmap1_grad, fmap2_grad};
|
286 |
+
}
|
287 |
+
|
288 |
+
std::vector<torch::Tensor> patchify_cuda_forward(
|
289 |
+
torch::Tensor net, torch::Tensor coords, int radius)
|
290 |
+
{
|
291 |
+
const int B = coords.size(0);
|
292 |
+
const int M = coords.size(1);
|
293 |
+
const int C = net.size(1);
|
294 |
+
const int D = 2 * radius + 2;
|
295 |
+
|
296 |
+
auto opts = net.options();
|
297 |
+
auto patches = torch::zeros({B, M, C, D, D}, opts);
|
298 |
+
|
299 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_forward_kernel", ([&] {
|
300 |
+
patchify_forward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius,
|
301 |
+
net.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
|
302 |
+
coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
|
303 |
+
patches.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());
|
304 |
+
}));
|
305 |
+
|
306 |
+
return { patches };
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
std::vector<torch::Tensor> patchify_cuda_backward(
|
311 |
+
torch::Tensor net,
|
312 |
+
torch::Tensor coords,
|
313 |
+
torch::Tensor gradient,
|
314 |
+
int radius)
|
315 |
+
{
|
316 |
+
const int B = coords.size(0);
|
317 |
+
const int M = coords.size(1);
|
318 |
+
const int C = net.size(1);
|
319 |
+
const int H = net.size(2);
|
320 |
+
const int W = net.size(3);
|
321 |
+
const int D = 2 * radius + 2;
|
322 |
+
|
323 |
+
torch::Tensor net_gradient = torch::zeros_like(net);
|
324 |
+
|
325 |
+
AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_backward_kernel", ([&] {
|
326 |
+
patchify_backward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius,
|
327 |
+
gradient.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
|
328 |
+
coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
|
329 |
+
net_gradient.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>());
|
330 |
+
}));
|
331 |
+
|
332 |
+
return { net_gradient };
|
333 |
+
}
|
mini_dpvo/api/__init__.py
ADDED
File without changes
|
mini_dpvo/api/inference.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from pathlib import Path
|
5 |
+
from multiprocessing import Process, Queue
|
6 |
+
from yacs.config import CfgNode
|
7 |
+
|
8 |
+
from mini_dpvo.utils import Timer
|
9 |
+
from mini_dpvo.dpvo import DPVO
|
10 |
+
from mini_dpvo.stream import image_stream, video_stream
|
11 |
+
|
12 |
+
import rerun as rr
|
13 |
+
from jaxtyping import UInt8, Float64, Float32
|
14 |
+
from scipy.spatial.transform import Rotation
|
15 |
+
from dataclasses import dataclass
|
16 |
+
|
17 |
+
from timeit import default_timer as timer
|
18 |
+
|
19 |
+
|
20 |
+
@dataclass
|
21 |
+
class DPVOPrediction:
|
22 |
+
final_poses: Float32[torch.Tensor, "num_keyframes 7"] # noqa: F722
|
23 |
+
tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
|
24 |
+
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] # noqa: F722
|
25 |
+
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] # noqa: F722
|
26 |
+
|
27 |
+
|
28 |
+
def log_trajectory(
|
29 |
+
parent_log_path: Path,
|
30 |
+
poses: Float32[torch.Tensor, "buffer_size 7"],
|
31 |
+
points: Float32[torch.Tensor, "buffer_size*num_patches 3"],
|
32 |
+
colors: UInt8[torch.Tensor, "buffer_size num_patches 3"],
|
33 |
+
intri_np: Float64[np.ndarray, "4"],
|
34 |
+
bgr_hw3: UInt8[np.ndarray, "h w 3"],
|
35 |
+
):
|
36 |
+
cam_log_path = f"{parent_log_path}/camera"
|
37 |
+
rr.log(f"{cam_log_path}/pinhole/image", rr.Image(bgr_hw3[..., ::-1]))
|
38 |
+
rr.log(
|
39 |
+
f"{cam_log_path}/pinhole",
|
40 |
+
rr.Pinhole(
|
41 |
+
height=bgr_hw3.shape[0],
|
42 |
+
width=bgr_hw3.shape[1],
|
43 |
+
focal_length=[intri_np[0], intri_np[1]],
|
44 |
+
principal_point=[intri_np[2], intri_np[3]],
|
45 |
+
),
|
46 |
+
)
|
47 |
+
|
48 |
+
poses_mask = ~(poses[:, :6] == 0).all(dim=1)
|
49 |
+
points_mask = ~(points == 0).all(dim=1)
|
50 |
+
|
51 |
+
nonzero_poses = poses[poses_mask]
|
52 |
+
nonzero_points = points[points_mask]
|
53 |
+
|
54 |
+
last_index = nonzero_poses.shape[0] - 1
|
55 |
+
# get last non-zero pose, and the index of the last non-zero pose
|
56 |
+
quat_pose = nonzero_poses[last_index].numpy(force=True)
|
57 |
+
trans_quat = quat_pose[:3]
|
58 |
+
rotation_quat = Rotation.from_quat(quat_pose[3:])
|
59 |
+
|
60 |
+
mat3x3 = rotation_quat.as_matrix()
|
61 |
+
rr.log(
|
62 |
+
f"{cam_log_path}",
|
63 |
+
rr.Transform3D(translation=trans_quat, mat3x3=mat3x3, from_parent=True),
|
64 |
+
)
|
65 |
+
|
66 |
+
# outlier removal
|
67 |
+
trajectory_center = np.median(nonzero_poses[:, :3].numpy(force=True), axis=0)
|
68 |
+
radii = lambda a: np.linalg.norm(a - trajectory_center, axis=1)
|
69 |
+
points_np = nonzero_points.view(-1, 3).numpy(force=True)
|
70 |
+
colors_np = colors.view(-1, 3)[points_mask].numpy(force=True)
|
71 |
+
inlier_mask = (
|
72 |
+
radii(points_np) < radii(nonzero_poses[:, :3].numpy(force=True)).max() * 5
|
73 |
+
)
|
74 |
+
points_filtered = points_np[inlier_mask]
|
75 |
+
colors_filtered = colors_np[inlier_mask]
|
76 |
+
|
77 |
+
# log all points and colors at the same time
|
78 |
+
rr.log(
|
79 |
+
f"{parent_log_path}/pointcloud",
|
80 |
+
rr.Points3D(
|
81 |
+
positions=points_filtered,
|
82 |
+
colors=colors_filtered,
|
83 |
+
),
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
def log_final(
|
88 |
+
parent_log_path: Path,
|
89 |
+
final_poses: Float32[torch.Tensor, "num_keyframes 7"],
|
90 |
+
tstamps: Float64[torch.Tensor, "num_keyframes"], # noqa: F821
|
91 |
+
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"],
|
92 |
+
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"],
|
93 |
+
):
|
94 |
+
for idx, (pose_quat, tstamp) in enumerate(zip(final_poses, tstamps)):
|
95 |
+
cam_log_path = f"{parent_log_path}/camera_{idx}"
|
96 |
+
trans_quat = pose_quat[:3]
|
97 |
+
R_33 = Rotation.from_quat(pose_quat[3:]).as_matrix()
|
98 |
+
rr.log(
|
99 |
+
f"{cam_log_path}",
|
100 |
+
rr.Transform3D(translation=trans_quat, mat3x3=R_33, from_parent=False),
|
101 |
+
)
|
102 |
+
|
103 |
+
|
104 |
+
def create_reader(
|
105 |
+
imagedir: str, calib: str, stride: int, skip: int, queue: Queue
|
106 |
+
) -> Process:
|
107 |
+
if os.path.isdir(imagedir):
|
108 |
+
reader = Process(
|
109 |
+
target=image_stream, args=(queue, imagedir, calib, stride, skip)
|
110 |
+
)
|
111 |
+
else:
|
112 |
+
reader = Process(
|
113 |
+
target=video_stream, args=(queue, imagedir, calib, stride, skip)
|
114 |
+
)
|
115 |
+
|
116 |
+
return reader
|
117 |
+
|
118 |
+
|
119 |
+
@torch.no_grad()
|
120 |
+
def run(
|
121 |
+
cfg: CfgNode,
|
122 |
+
network_path: str,
|
123 |
+
imagedir: str,
|
124 |
+
calib: str,
|
125 |
+
stride: int = 1,
|
126 |
+
skip: int = 0,
|
127 |
+
vis_during: bool = True,
|
128 |
+
timeit: bool = False,
|
129 |
+
) -> tuple[DPVOPrediction, float]:
|
130 |
+
slam = None
|
131 |
+
queue = Queue(maxsize=8)
|
132 |
+
reader: Process = create_reader(imagedir, calib, stride, skip, queue)
|
133 |
+
reader.start()
|
134 |
+
|
135 |
+
if vis_during:
|
136 |
+
parent_log_path = Path("world")
|
137 |
+
rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
|
138 |
+
|
139 |
+
start = timer()
|
140 |
+
|
141 |
+
while True:
|
142 |
+
t: int
|
143 |
+
bgr_hw3: UInt8[np.ndarray, "h w 3"]
|
144 |
+
intri_np: Float64[np.ndarray, "4"]
|
145 |
+
(t, bgr_hw3, intri_np) = queue.get()
|
146 |
+
# queue will have a (-1, image, intrinsics) tuple when the reader is done
|
147 |
+
if t < 0:
|
148 |
+
break
|
149 |
+
|
150 |
+
if vis_during:
|
151 |
+
rr.set_time_sequence(timeline="timestep", sequence=t)
|
152 |
+
|
153 |
+
bgr_3hw: UInt8[torch.Tensor, "h w 3"] = (
|
154 |
+
torch.from_numpy(bgr_hw3).permute(2, 0, 1).cuda()
|
155 |
+
)
|
156 |
+
intri_torch: Float64[torch.Tensor, "4"] = torch.from_numpy(intri_np).cuda()
|
157 |
+
|
158 |
+
if slam is None:
|
159 |
+
slam = DPVO(cfg, network_path, ht=bgr_3hw.shape[1], wd=bgr_3hw.shape[2])
|
160 |
+
|
161 |
+
with Timer("SLAM", enabled=timeit):
|
162 |
+
slam(t, bgr_3hw, intri_torch)
|
163 |
+
|
164 |
+
if slam.is_initialized and vis_during:
|
165 |
+
poses: Float32[torch.Tensor, "buffer_size 7"] = slam.poses_
|
166 |
+
points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_
|
167 |
+
colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
|
168 |
+
log_trajectory(parent_log_path, poses, points, colors, intri_np, bgr_hw3)
|
169 |
+
|
170 |
+
for _ in range(12):
|
171 |
+
slam.update()
|
172 |
+
|
173 |
+
total_time: float = timer() - start
|
174 |
+
print(f"Total time: {total_time:.2f}s")
|
175 |
+
|
176 |
+
reader.join()
|
177 |
+
|
178 |
+
final_poses: Float32[torch.Tensor, "num_keyframes 7"]
|
179 |
+
tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
|
180 |
+
|
181 |
+
final_poses, tstamps = slam.terminate()
|
182 |
+
final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_
|
183 |
+
final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
|
184 |
+
dpvo_pred = DPVOPrediction(
|
185 |
+
final_poses=final_poses,
|
186 |
+
tstamps=tstamps,
|
187 |
+
final_points=final_points,
|
188 |
+
final_colors=final_colors,
|
189 |
+
)
|
190 |
+
return dpvo_pred, total_time
|
mini_dpvo/ba.py
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch_scatter import scatter_sum
|
3 |
+
|
4 |
+
from . import fastba
|
5 |
+
from . import lietorch
|
6 |
+
from .lietorch import SE3
|
7 |
+
|
8 |
+
from .utils import Timer
|
9 |
+
|
10 |
+
from . import projective_ops as pops
|
11 |
+
|
12 |
+
class CholeskySolver(torch.autograd.Function):
|
13 |
+
@staticmethod
|
14 |
+
def forward(ctx, H, b):
|
15 |
+
# don't crash training if cholesky decomp fails
|
16 |
+
U, info = torch.linalg.cholesky_ex(H)
|
17 |
+
|
18 |
+
if torch.any(info):
|
19 |
+
ctx.failed = True
|
20 |
+
return torch.zeros_like(b)
|
21 |
+
|
22 |
+
xs = torch.cholesky_solve(b, U)
|
23 |
+
ctx.save_for_backward(U, xs)
|
24 |
+
ctx.failed = False
|
25 |
+
|
26 |
+
return xs
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def backward(ctx, grad_x):
|
30 |
+
if ctx.failed:
|
31 |
+
return None, None
|
32 |
+
|
33 |
+
U, xs = ctx.saved_tensors
|
34 |
+
dz = torch.cholesky_solve(grad_x, U)
|
35 |
+
dH = -torch.matmul(xs, dz.transpose(-1,-2))
|
36 |
+
|
37 |
+
return dH, dz
|
38 |
+
|
39 |
+
# utility functions for scattering ops
|
40 |
+
def safe_scatter_add_mat(A, ii, jj, n, m):
|
41 |
+
v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
|
42 |
+
return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)
|
43 |
+
|
44 |
+
def safe_scatter_add_vec(b, ii, n):
|
45 |
+
v = (ii >= 0) & (ii < n)
|
46 |
+
return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n)
|
47 |
+
|
48 |
+
# apply retraction operator to inv-depth maps
|
49 |
+
def disp_retr(disps, dz, ii):
|
50 |
+
ii = ii.to(device=dz.device)
|
51 |
+
return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1])
|
52 |
+
|
53 |
+
# apply retraction operator to poses
|
54 |
+
def pose_retr(poses, dx, ii):
|
55 |
+
ii = ii.to(device=dx.device)
|
56 |
+
return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1]))
|
57 |
+
|
58 |
+
def block_matmul(A, B):
|
59 |
+
""" block matrix multiply """
|
60 |
+
b, n1, m1, p1, q1 = A.shape
|
61 |
+
b, n2, m2, p2, q2 = B.shape
|
62 |
+
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
|
63 |
+
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
|
64 |
+
return torch.matmul(A, B).reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
|
65 |
+
|
66 |
+
def block_solve(A, B, ep=1.0, lm=1e-4):
|
67 |
+
""" block matrix solve """
|
68 |
+
b, n1, m1, p1, q1 = A.shape
|
69 |
+
b, n2, m2, p2, q2 = B.shape
|
70 |
+
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
|
71 |
+
B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
|
72 |
+
|
73 |
+
A = A + (ep + lm * A) * torch.eye(n1*p1, device=A.device)
|
74 |
+
|
75 |
+
X = CholeskySolver.apply(A, B)
|
76 |
+
return X.reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
|
77 |
+
|
78 |
+
|
79 |
+
def block_show(A):
|
80 |
+
import matplotlib.pyplot as plt
|
81 |
+
b, n1, m1, p1, q1 = A.shape
|
82 |
+
A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
|
83 |
+
plt.imshow(A[0].detach().cpu().numpy())
|
84 |
+
plt.show()
|
85 |
+
|
86 |
+
def BA(poses, patches, intrinsics, targets, weights, lmbda, ii, jj, kk, bounds, ep=100.0, PRINT=False, fixedp=1, structure_only=False):
|
87 |
+
""" bundle adjustment """
|
88 |
+
|
89 |
+
b = 1
|
90 |
+
n = max(ii.max().item(), jj.max().item()) + 1
|
91 |
+
|
92 |
+
coords, v, (Ji, Jj, Jz) = \
|
93 |
+
pops.transform(poses, patches, intrinsics, ii, jj, kk, jacobian=True)
|
94 |
+
|
95 |
+
p = coords.shape[3]
|
96 |
+
r = targets - coords[...,p//2,p//2,:]
|
97 |
+
|
98 |
+
v *= (r.norm(dim=-1) < 250).float()
|
99 |
+
|
100 |
+
in_bounds = \
|
101 |
+
(coords[...,p//2,p//2,0] > bounds[0]) & \
|
102 |
+
(coords[...,p//2,p//2,1] > bounds[1]) & \
|
103 |
+
(coords[...,p//2,p//2,0] < bounds[2]) & \
|
104 |
+
(coords[...,p//2,p//2,1] < bounds[3])
|
105 |
+
|
106 |
+
v *= in_bounds.float()
|
107 |
+
|
108 |
+
if PRINT:
|
109 |
+
print((r * v[...,None]).norm(dim=-1).mean().item())
|
110 |
+
|
111 |
+
r = (v[...,None] * r).unsqueeze(dim=-1)
|
112 |
+
weights = (v[...,None] * weights).unsqueeze(dim=-1)
|
113 |
+
|
114 |
+
wJiT = (weights * Ji).transpose(2,3)
|
115 |
+
wJjT = (weights * Jj).transpose(2,3)
|
116 |
+
wJzT = (weights * Jz).transpose(2,3)
|
117 |
+
|
118 |
+
Bii = torch.matmul(wJiT, Ji)
|
119 |
+
Bij = torch.matmul(wJiT, Jj)
|
120 |
+
Bji = torch.matmul(wJjT, Ji)
|
121 |
+
Bjj = torch.matmul(wJjT, Jj)
|
122 |
+
|
123 |
+
Eik = torch.matmul(wJiT, Jz)
|
124 |
+
Ejk = torch.matmul(wJjT, Jz)
|
125 |
+
|
126 |
+
vi = torch.matmul(wJiT, r)
|
127 |
+
vj = torch.matmul(wJjT, r)
|
128 |
+
|
129 |
+
# fix first pose
|
130 |
+
ii = ii.clone()
|
131 |
+
jj = jj.clone()
|
132 |
+
|
133 |
+
n = n - fixedp
|
134 |
+
ii = ii - fixedp
|
135 |
+
jj = jj - fixedp
|
136 |
+
|
137 |
+
kx, kk = torch.unique(kk, return_inverse=True, sorted=True)
|
138 |
+
m = len(kx)
|
139 |
+
|
140 |
+
B = safe_scatter_add_mat(Bii, ii, ii, n, n).view(b, n, n, 6, 6) + \
|
141 |
+
safe_scatter_add_mat(Bij, ii, jj, n, n).view(b, n, n, 6, 6) + \
|
142 |
+
safe_scatter_add_mat(Bji, jj, ii, n, n).view(b, n, n, 6, 6) + \
|
143 |
+
safe_scatter_add_mat(Bjj, jj, jj, n, n).view(b, n, n, 6, 6)
|
144 |
+
|
145 |
+
E = safe_scatter_add_mat(Eik, ii, kk, n, m).view(b, n, m, 6, 1) + \
|
146 |
+
safe_scatter_add_mat(Ejk, jj, kk, n, m).view(b, n, m, 6, 1)
|
147 |
+
|
148 |
+
C = safe_scatter_add_vec(torch.matmul(wJzT, Jz), kk, m)
|
149 |
+
|
150 |
+
v = safe_scatter_add_vec(vi, ii, n).view(b, n, 1, 6, 1) + \
|
151 |
+
safe_scatter_add_vec(vj, jj, n).view(b, n, 1, 6, 1)
|
152 |
+
|
153 |
+
w = safe_scatter_add_vec(torch.matmul(wJzT, r), kk, m)
|
154 |
+
|
155 |
+
if isinstance(lmbda, torch.Tensor):
|
156 |
+
lmbda = lmbda.reshape(*C.shape)
|
157 |
+
|
158 |
+
Q = 1.0 / (C + lmbda)
|
159 |
+
|
160 |
+
### solve w/ schur complement ###
|
161 |
+
EQ = E * Q[:,None]
|
162 |
+
|
163 |
+
if structure_only or n == 0:
|
164 |
+
dZ = (Q * w).view(b, -1, 1, 1)
|
165 |
+
|
166 |
+
else:
|
167 |
+
S = B - block_matmul(EQ, E.permute(0,2,1,4,3))
|
168 |
+
y = v - block_matmul(EQ, w.unsqueeze(dim=2))
|
169 |
+
dX = block_solve(S, y, ep=ep, lm=1e-4)
|
170 |
+
|
171 |
+
dZ = Q * (w - block_matmul(E.permute(0,2,1,4,3), dX).squeeze(dim=-1))
|
172 |
+
dX = dX.view(b, -1, 6)
|
173 |
+
dZ = dZ.view(b, -1, 1, 1)
|
174 |
+
|
175 |
+
x, y, disps = patches.unbind(dim=2)
|
176 |
+
disps = disp_retr(disps, dZ, kx).clamp(min=1e-3, max=10.0)
|
177 |
+
patches = torch.stack([x, y, disps], dim=2)
|
178 |
+
|
179 |
+
if not structure_only and n > 0:
|
180 |
+
poses = pose_retr(poses, dX, fixedp + torch.arange(n))
|
181 |
+
|
182 |
+
return poses, patches
|
mini_dpvo/blocks.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
import torch_scatter
|
6 |
+
|
7 |
+
class LayerNorm1D(nn.Module):
|
8 |
+
def __init__(self, dim):
|
9 |
+
super(LayerNorm1D, self).__init__()
|
10 |
+
self.norm = nn.LayerNorm(dim, eps=1e-4)
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return self.norm(x.transpose(1,2)).transpose(1,2)
|
14 |
+
|
15 |
+
class GatedResidual(nn.Module):
|
16 |
+
def __init__(self, dim):
|
17 |
+
super().__init__()
|
18 |
+
|
19 |
+
self.gate = nn.Sequential(
|
20 |
+
nn.Linear(dim, dim),
|
21 |
+
nn.Sigmoid())
|
22 |
+
|
23 |
+
self.res = nn.Sequential(
|
24 |
+
nn.Linear(dim, dim),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Linear(dim, dim))
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return x + self.gate(x) * self.res(x)
|
30 |
+
|
31 |
+
class SoftAgg(nn.Module):
|
32 |
+
def __init__(self, dim=512, expand=True):
|
33 |
+
super(SoftAgg, self).__init__()
|
34 |
+
self.dim = dim
|
35 |
+
self.expand = expand
|
36 |
+
self.f = nn.Linear(self.dim, self.dim)
|
37 |
+
self.g = nn.Linear(self.dim, self.dim)
|
38 |
+
self.h = nn.Linear(self.dim, self.dim)
|
39 |
+
|
40 |
+
def forward(self, x, ix):
|
41 |
+
_, jx = torch.unique(ix, return_inverse=True)
|
42 |
+
w = torch_scatter.scatter_softmax(self.g(x), jx, dim=1)
|
43 |
+
y = torch_scatter.scatter_sum(self.f(x) * w, jx, dim=1)
|
44 |
+
|
45 |
+
if self.expand:
|
46 |
+
return self.h(y)[:,jx]
|
47 |
+
|
48 |
+
return self.h(y)
|
49 |
+
|
50 |
+
class SoftAggBasic(nn.Module):
|
51 |
+
def __init__(self, dim=512, expand=True):
|
52 |
+
super(SoftAggBasic, self).__init__()
|
53 |
+
self.dim = dim
|
54 |
+
self.expand = expand
|
55 |
+
self.f = nn.Linear(self.dim, self.dim)
|
56 |
+
self.g = nn.Linear(self.dim, 1)
|
57 |
+
self.h = nn.Linear(self.dim, self.dim)
|
58 |
+
|
59 |
+
def forward(self, x, ix):
|
60 |
+
_, jx = torch.unique(ix, return_inverse=True)
|
61 |
+
w = torch_scatter.scatter_softmax(self.g(x), jx, dim=1)
|
62 |
+
y = torch_scatter.scatter_sum(self.f(x) * w, jx, dim=1)
|
63 |
+
|
64 |
+
if self.expand:
|
65 |
+
return self.h(y)[:,jx]
|
66 |
+
|
67 |
+
return self.h(y)
|
68 |
+
|
69 |
+
|
70 |
+
### Gradient Clipping and Zeroing Operations ###
|
71 |
+
|
72 |
+
GRAD_CLIP = 0.1
|
73 |
+
|
74 |
+
class GradClip(torch.autograd.Function):
|
75 |
+
@staticmethod
|
76 |
+
def forward(ctx, x):
|
77 |
+
return x
|
78 |
+
|
79 |
+
@staticmethod
|
80 |
+
def backward(ctx, grad_x):
|
81 |
+
grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x)
|
82 |
+
return grad_x.clamp(min=-0.01, max=0.01)
|
83 |
+
|
84 |
+
class GradientClip(nn.Module):
|
85 |
+
def __init__(self):
|
86 |
+
super(GradientClip, self).__init__()
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
return GradClip.apply(x)
|
90 |
+
|
91 |
+
class GradZero(torch.autograd.Function):
|
92 |
+
@staticmethod
|
93 |
+
def forward(ctx, x):
|
94 |
+
return x
|
95 |
+
|
96 |
+
@staticmethod
|
97 |
+
def backward(ctx, grad_x):
|
98 |
+
grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x)
|
99 |
+
grad_x = torch.where(torch.abs(grad_x) > GRAD_CLIP, torch.zeros_like(grad_x), grad_x)
|
100 |
+
return grad_x
|
101 |
+
|
102 |
+
class GradientZero(nn.Module):
|
103 |
+
def __init__(self):
|
104 |
+
super(GradientZero, self).__init__()
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
return GradZero.apply(x)
|
108 |
+
|
109 |
+
|
110 |
+
class GradMag(torch.autograd.Function):
|
111 |
+
@staticmethod
|
112 |
+
def forward(ctx, x):
|
113 |
+
return x
|
114 |
+
|
115 |
+
@staticmethod
|
116 |
+
def backward(ctx, grad_x):
|
117 |
+
print(grad_x.abs().mean())
|
118 |
+
return grad_x
|
mini_dpvo/config.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
_C = CN()
|
4 |
+
|
5 |
+
# max number of keyframes
|
6 |
+
_C.BUFFER_SIZE = 2048
|
7 |
+
|
8 |
+
# bias patch selection towards high gradient regions?
|
9 |
+
_C.GRADIENT_BIAS = True
|
10 |
+
|
11 |
+
# VO config (increase for better accuracy)
|
12 |
+
_C.PATCHES_PER_FRAME = 80
|
13 |
+
_C.REMOVAL_WINDOW = 20
|
14 |
+
_C.OPTIMIZATION_WINDOW = 12
|
15 |
+
_C.PATCH_LIFETIME = 12
|
16 |
+
|
17 |
+
# threshold for keyframe removal
|
18 |
+
_C.KEYFRAME_INDEX = 4
|
19 |
+
_C.KEYFRAME_THRESH = 12.5
|
20 |
+
|
21 |
+
# camera motion model
|
22 |
+
_C.MOTION_MODEL = 'DAMPED_LINEAR'
|
23 |
+
_C.MOTION_DAMPING = 0.5
|
24 |
+
|
25 |
+
_C.MIXED_PRECISION = True
|
26 |
+
|
27 |
+
cfg = _C
|
mini_dpvo/data_readers/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
mini_dpvo/data_readers/augmentation.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as transforms
|
3 |
+
import numpy as np
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class RGBDAugmentor:
|
8 |
+
""" perform augmentation on RGB-D video """
|
9 |
+
|
10 |
+
def __init__(self, crop_size):
|
11 |
+
self.crop_size = crop_size
|
12 |
+
self.augcolor = transforms.Compose([
|
13 |
+
transforms.ToPILImage(),
|
14 |
+
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2/3.14),
|
15 |
+
transforms.RandomGrayscale(p=0.1),
|
16 |
+
transforms.RandomInvert(p=0.1),
|
17 |
+
transforms.ToTensor()])
|
18 |
+
|
19 |
+
self.max_scale = 0.5
|
20 |
+
|
21 |
+
def spatial_transform(self, images, depths, poses, intrinsics):
|
22 |
+
""" cropping and resizing """
|
23 |
+
ht, wd = images.shape[2:]
|
24 |
+
|
25 |
+
max_scale = self.max_scale
|
26 |
+
min_scale = np.log2(np.maximum(
|
27 |
+
(self.crop_size[0] + 1) / float(ht),
|
28 |
+
(self.crop_size[1] + 1) / float(wd)))
|
29 |
+
|
30 |
+
scale = 1
|
31 |
+
if np.random.rand() < 0.8:
|
32 |
+
scale = 2 ** np.random.uniform(0.0, max_scale)
|
33 |
+
|
34 |
+
intrinsics = scale * intrinsics
|
35 |
+
|
36 |
+
ht1 = int(scale * ht)
|
37 |
+
wd1 = int(scale * wd)
|
38 |
+
|
39 |
+
depths = depths.unsqueeze(dim=1)
|
40 |
+
|
41 |
+
images = F.interpolate(images, (ht1, wd1), mode='bicubic', align_corners=False)
|
42 |
+
depths = F.interpolate(depths, (ht1, wd1), recompute_scale_factor=False)
|
43 |
+
|
44 |
+
# always perform center crop (TODO: try non-center crops)
|
45 |
+
y0 = (images.shape[2] - self.crop_size[0]) // 2
|
46 |
+
x0 = (images.shape[3] - self.crop_size[1]) // 2
|
47 |
+
|
48 |
+
intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0])
|
49 |
+
images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
50 |
+
depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
|
51 |
+
|
52 |
+
depths = depths.squeeze(dim=1)
|
53 |
+
return images, poses, depths, intrinsics
|
54 |
+
|
55 |
+
def color_transform(self, images):
|
56 |
+
""" color jittering """
|
57 |
+
num, ch, ht, wd = images.shape
|
58 |
+
images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num)
|
59 |
+
images = 255 * self.augcolor(images[[2,1,0]] / 255.0)
|
60 |
+
return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous()
|
61 |
+
|
62 |
+
def __call__(self, images, poses, depths, intrinsics):
|
63 |
+
if np.random.rand() < 0.5:
|
64 |
+
images = self.color_transform(images)
|
65 |
+
|
66 |
+
return self.spatial_transform(images, depths, poses, intrinsics)
|
mini_dpvo/data_readers/base.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.utils.data as data
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
import csv
|
7 |
+
import os
|
8 |
+
import cv2
|
9 |
+
import math
|
10 |
+
import random
|
11 |
+
import json
|
12 |
+
import pickle
|
13 |
+
import os.path as osp
|
14 |
+
|
15 |
+
from .augmentation import RGBDAugmentor
|
16 |
+
from .rgbd_utils import *
|
17 |
+
|
18 |
+
class RGBDDataset(data.Dataset):
|
19 |
+
def __init__(self, name, datapath, n_frames=4, crop_size=[480,640], fmin=10.0, fmax=75.0, aug=True, sample=True):
|
20 |
+
""" Base class for RGBD dataset """
|
21 |
+
self.aug = None
|
22 |
+
self.root = datapath
|
23 |
+
self.name = name
|
24 |
+
|
25 |
+
self.aug = aug
|
26 |
+
self.sample = sample
|
27 |
+
|
28 |
+
self.n_frames = n_frames
|
29 |
+
self.fmin = fmin # exclude very easy examples
|
30 |
+
self.fmax = fmax # exclude very hard examples
|
31 |
+
|
32 |
+
if self.aug:
|
33 |
+
self.aug = RGBDAugmentor(crop_size=crop_size)
|
34 |
+
|
35 |
+
# building dataset is expensive, cache so only needs to be performed once
|
36 |
+
cur_path = osp.dirname(osp.abspath(__file__))
|
37 |
+
if not os.path.isdir(osp.join(cur_path, 'cache')):
|
38 |
+
os.mkdir(osp.join(cur_path, 'cache'))
|
39 |
+
|
40 |
+
self.scene_info = \
|
41 |
+
pickle.load(open('datasets/TartanAir.pickle', 'rb'))[0]
|
42 |
+
|
43 |
+
self._build_dataset_index()
|
44 |
+
|
45 |
+
def _build_dataset_index(self):
|
46 |
+
self.dataset_index = []
|
47 |
+
for scene in self.scene_info:
|
48 |
+
if not self.__class__.is_test_scene(scene):
|
49 |
+
graph = self.scene_info[scene]['graph']
|
50 |
+
for i in graph:
|
51 |
+
if i < len(graph) - 65:
|
52 |
+
self.dataset_index.append((scene, i))
|
53 |
+
else:
|
54 |
+
print("Reserving {} for validation".format(scene))
|
55 |
+
|
56 |
+
@staticmethod
|
57 |
+
def image_read(image_file):
|
58 |
+
return cv2.imread(image_file)
|
59 |
+
|
60 |
+
@staticmethod
|
61 |
+
def depth_read(depth_file):
|
62 |
+
return np.load(depth_file)
|
63 |
+
|
64 |
+
def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256):
|
65 |
+
""" compute optical flow distance between all pairs of frames """
|
66 |
+
def read_disp(fn):
|
67 |
+
depth = self.__class__.depth_read(fn)[f//2::f, f//2::f]
|
68 |
+
depth[depth < 0.01] = np.mean(depth)
|
69 |
+
return 1.0 / depth
|
70 |
+
|
71 |
+
poses = np.array(poses)
|
72 |
+
intrinsics = np.array(intrinsics) / f
|
73 |
+
|
74 |
+
disps = np.stack(list(map(read_disp, depths)), 0)
|
75 |
+
d = f * compute_distance_matrix_flow(poses, disps, intrinsics)
|
76 |
+
|
77 |
+
graph = {}
|
78 |
+
for i in range(d.shape[0]):
|
79 |
+
j, = np.where(d[i] < max_flow)
|
80 |
+
graph[i] = (j, d[i,j])
|
81 |
+
|
82 |
+
return graph
|
83 |
+
|
84 |
+
def __getitem__(self, index):
|
85 |
+
""" return training video """
|
86 |
+
|
87 |
+
index = index % len(self.dataset_index)
|
88 |
+
scene_id, ix = self.dataset_index[index]
|
89 |
+
|
90 |
+
frame_graph = self.scene_info[scene_id]['graph']
|
91 |
+
images_list = self.scene_info[scene_id]['images']
|
92 |
+
depths_list = self.scene_info[scene_id]['depths']
|
93 |
+
poses_list = self.scene_info[scene_id]['poses']
|
94 |
+
intrinsics_list = self.scene_info[scene_id]['intrinsics']
|
95 |
+
|
96 |
+
# stride = np.random.choice([1,2,3])
|
97 |
+
|
98 |
+
d = np.random.uniform(self.fmin, self.fmax)
|
99 |
+
s = 1
|
100 |
+
|
101 |
+
inds = [ ix ]
|
102 |
+
|
103 |
+
while len(inds) < self.n_frames:
|
104 |
+
# get other frames within flow threshold
|
105 |
+
|
106 |
+
if self.sample:
|
107 |
+
k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax)
|
108 |
+
frames = frame_graph[ix][0][k]
|
109 |
+
|
110 |
+
# prefer frames forward in time
|
111 |
+
if np.count_nonzero(frames[frames > ix]):
|
112 |
+
ix = np.random.choice(frames[frames > ix])
|
113 |
+
|
114 |
+
elif ix + 1 < len(images_list):
|
115 |
+
ix = ix + 1
|
116 |
+
|
117 |
+
elif np.count_nonzero(frames):
|
118 |
+
ix = np.random.choice(frames)
|
119 |
+
|
120 |
+
else:
|
121 |
+
i = frame_graph[ix][0].copy()
|
122 |
+
g = frame_graph[ix][1].copy()
|
123 |
+
|
124 |
+
g[g > d] = -1
|
125 |
+
if s > 0:
|
126 |
+
g[i <= ix] = -1
|
127 |
+
else:
|
128 |
+
g[i >= ix] = -1
|
129 |
+
|
130 |
+
if len(g) > 0 and np.max(g) > 0:
|
131 |
+
ix = i[np.argmax(g)]
|
132 |
+
else:
|
133 |
+
if ix + s >= len(images_list) or ix + s < 0:
|
134 |
+
s *= -1
|
135 |
+
|
136 |
+
ix = ix + s
|
137 |
+
|
138 |
+
inds += [ ix ]
|
139 |
+
|
140 |
+
|
141 |
+
images, depths, poses, intrinsics = [], [], [], []
|
142 |
+
for i in inds:
|
143 |
+
images.append(self.__class__.image_read(images_list[i]))
|
144 |
+
depths.append(self.__class__.depth_read(depths_list[i]))
|
145 |
+
poses.append(poses_list[i])
|
146 |
+
intrinsics.append(intrinsics_list[i])
|
147 |
+
|
148 |
+
images = np.stack(images).astype(np.float32)
|
149 |
+
depths = np.stack(depths).astype(np.float32)
|
150 |
+
poses = np.stack(poses).astype(np.float32)
|
151 |
+
intrinsics = np.stack(intrinsics).astype(np.float32)
|
152 |
+
|
153 |
+
images = torch.from_numpy(images).float()
|
154 |
+
images = images.permute(0, 3, 1, 2)
|
155 |
+
|
156 |
+
disps = torch.from_numpy(1.0 / depths)
|
157 |
+
poses = torch.from_numpy(poses)
|
158 |
+
intrinsics = torch.from_numpy(intrinsics)
|
159 |
+
|
160 |
+
if self.aug:
|
161 |
+
images, poses, disps, intrinsics = \
|
162 |
+
self.aug(images, poses, disps, intrinsics)
|
163 |
+
|
164 |
+
# normalize depth
|
165 |
+
s = .7 * torch.quantile(disps, .98)
|
166 |
+
disps = disps / s
|
167 |
+
poses[...,:3] *= s
|
168 |
+
|
169 |
+
return images, poses, disps, intrinsics
|
170 |
+
|
171 |
+
def __len__(self):
|
172 |
+
return len(self.dataset_index)
|
173 |
+
|
174 |
+
def __imul__(self, x):
|
175 |
+
self.dataset_index *= x
|
176 |
+
return self
|
mini_dpvo/data_readers/factory.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import pickle
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
|
6 |
+
# RGBD-Dataset
|
7 |
+
from .tartan import TartanAir
|
8 |
+
|
9 |
+
def dataset_factory(dataset_list, **kwargs):
|
10 |
+
""" create a combined dataset """
|
11 |
+
|
12 |
+
from torch.utils.data import ConcatDataset
|
13 |
+
|
14 |
+
dataset_map = {
|
15 |
+
'tartan': (TartanAir, ),
|
16 |
+
}
|
17 |
+
|
18 |
+
db_list = []
|
19 |
+
for key in dataset_list:
|
20 |
+
# cache datasets for faster future loading
|
21 |
+
db = dataset_map[key][0](**kwargs)
|
22 |
+
|
23 |
+
print("Dataset {} has {} images".format(key, len(db)))
|
24 |
+
db_list.append(db)
|
25 |
+
|
26 |
+
return ConcatDataset(db_list)
|
mini_dpvo/data_readers/frame_utils.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
from os.path import *
|
4 |
+
import re
|
5 |
+
import cv2
|
6 |
+
cv2.setNumThreads(0)
|
7 |
+
|
8 |
+
|
9 |
+
TAG_CHAR = np.array([202021.25], np.float32)
|
10 |
+
|
11 |
+
def readFlowKITTI(filename):
|
12 |
+
flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
|
13 |
+
flow = flow[:,:,::-1].astype(np.float32)
|
14 |
+
flow, valid = flow[:, :, :2], flow[:, :, 2]
|
15 |
+
flow = (flow - 2**15) / 64.0
|
16 |
+
return flow, valid
|
17 |
+
|
18 |
+
def readFlow(fn):
|
19 |
+
""" Read .flo file in Middlebury format"""
|
20 |
+
# Code adapted from:
|
21 |
+
# http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
|
22 |
+
|
23 |
+
# WARNING: this will work on little-endian architectures (eg Intel x86) only!
|
24 |
+
# print 'fn = %s'%(fn)
|
25 |
+
with open(fn, 'rb') as f:
|
26 |
+
magic = np.fromfile(f, np.float32, count=1)
|
27 |
+
if 202021.25 != magic:
|
28 |
+
print('Magic number incorrect. Invalid .flo file')
|
29 |
+
return None
|
30 |
+
else:
|
31 |
+
w = np.fromfile(f, np.int32, count=1)
|
32 |
+
h = np.fromfile(f, np.int32, count=1)
|
33 |
+
# print 'Reading %d x %d flo file\n' % (w, h)
|
34 |
+
data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
|
35 |
+
# Reshape data into 3D array (columns, rows, bands)
|
36 |
+
# The reshape here is for visualization, the original code is (w,h,2)
|
37 |
+
return np.resize(data, (int(h), int(w), 2))
|
38 |
+
|
39 |
+
def readPFM(file):
|
40 |
+
file = open(file, 'rb')
|
41 |
+
|
42 |
+
color = None
|
43 |
+
width = None
|
44 |
+
height = None
|
45 |
+
scale = None
|
46 |
+
endian = None
|
47 |
+
|
48 |
+
header = file.readline().rstrip()
|
49 |
+
if header == b'PF':
|
50 |
+
color = True
|
51 |
+
elif header == b'Pf':
|
52 |
+
color = False
|
53 |
+
else:
|
54 |
+
raise Exception('Not a PFM file.')
|
55 |
+
|
56 |
+
try:
|
57 |
+
dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
|
58 |
+
except:
|
59 |
+
dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
|
60 |
+
|
61 |
+
if dim_match:
|
62 |
+
width, height = map(int, dim_match.groups())
|
63 |
+
else:
|
64 |
+
raise Exception('Malformed PFM header.')
|
65 |
+
|
66 |
+
scale = float(file.readline().rstrip())
|
67 |
+
if scale < 0: # little-endian
|
68 |
+
endian = '<'
|
69 |
+
scale = -scale
|
70 |
+
else:
|
71 |
+
endian = '>' # big-endian
|
72 |
+
|
73 |
+
data = np.fromfile(file, endian + 'f')
|
74 |
+
shape = (height, width, 3) if color else (height, width)
|
75 |
+
|
76 |
+
data = np.reshape(data, shape)
|
77 |
+
data = np.flipud(data)
|
78 |
+
return data
|
79 |
+
|
80 |
+
|
81 |
+
def writeFlow(filename,uv,v=None):
|
82 |
+
""" Write optical flow to file.
|
83 |
+
|
84 |
+
If v is None, uv is assumed to contain both u and v channels,
|
85 |
+
stacked in depth.
|
86 |
+
Original code by Deqing Sun, adapted from Daniel Scharstein.
|
87 |
+
"""
|
88 |
+
nBands = 2
|
89 |
+
|
90 |
+
if v is None:
|
91 |
+
assert(uv.ndim == 3)
|
92 |
+
assert(uv.shape[2] == 2)
|
93 |
+
u = uv[:,:,0]
|
94 |
+
v = uv[:,:,1]
|
95 |
+
else:
|
96 |
+
u = uv
|
97 |
+
|
98 |
+
assert(u.shape == v.shape)
|
99 |
+
height,width = u.shape
|
100 |
+
f = open(filename,'wb')
|
101 |
+
# write the header
|
102 |
+
f.write(TAG_CHAR)
|
103 |
+
np.array(width).astype(np.int32).tofile(f)
|
104 |
+
np.array(height).astype(np.int32).tofile(f)
|
105 |
+
# arrange into matrix form
|
106 |
+
tmp = np.zeros((height, width*nBands))
|
107 |
+
tmp[:,np.arange(width)*2] = u
|
108 |
+
tmp[:,np.arange(width)*2 + 1] = v
|
109 |
+
tmp.astype(np.float32).tofile(f)
|
110 |
+
f.close()
|
111 |
+
|
112 |
+
|
113 |
+
def readDPT(filename):
|
114 |
+
""" Read depth data from file, return as numpy array. """
|
115 |
+
f = open(filename,'rb')
|
116 |
+
check = np.fromfile(f,dtype=np.float32,count=1)[0]
|
117 |
+
TAG_FLOAT = 202021.25
|
118 |
+
TAG_CHAR = 'PIEH'
|
119 |
+
assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
|
120 |
+
width = np.fromfile(f,dtype=np.int32,count=1)[0]
|
121 |
+
height = np.fromfile(f,dtype=np.int32,count=1)[0]
|
122 |
+
size = width*height
|
123 |
+
assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
|
124 |
+
depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
|
125 |
+
return depth
|
126 |
+
|
127 |
+
def cam_read(filename):
|
128 |
+
""" Read camera data, return (M,N) tuple.
|
129 |
+
M is the intrinsic matrix, N is the extrinsic matrix, so that
|
130 |
+
x = M*N*X,
|
131 |
+
with x being a point in homogeneous image pixel coordinates, X being a
|
132 |
+
point in homogeneous world coordinates."""
|
133 |
+
f = open(filename,'rb')
|
134 |
+
check = np.fromfile(f,dtype=np.float32,count=1)[0]
|
135 |
+
M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
|
136 |
+
N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
|
137 |
+
|
138 |
+
E = np.eye(4)
|
139 |
+
E[0:3,:] = N
|
140 |
+
|
141 |
+
fx, fy, cx, cy = M[0,0], M[1,1], M[0,2], M[1,2]
|
142 |
+
kvec = np.array([fx, fy, cx, cy])
|
143 |
+
|
144 |
+
q = Rotation.from_matrix(E[:3,:3]).as_quat()
|
145 |
+
pvec = np.concatenate([E[:3,3], q], 0)
|
146 |
+
|
147 |
+
return pvec, kvec
|
148 |
+
|
149 |
+
|
150 |
+
def read_gen(file_name, pil=False):
|
151 |
+
ext = splitext(file_name)[-1]
|
152 |
+
if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
|
153 |
+
return Image.open(file_name)
|
154 |
+
elif ext == '.bin' or ext == '.raw':
|
155 |
+
return np.load(file_name)
|
156 |
+
elif ext == '.flo':
|
157 |
+
return readFlow(file_name).astype(np.float32)
|
158 |
+
elif ext == '.pfm':
|
159 |
+
return readPFM(file_name).astype(np.float32)
|
160 |
+
elif ext == '.dpt':
|
161 |
+
return readDPT(file_name).astype(np.float32)
|
162 |
+
elif ext == '.cam':
|
163 |
+
return cam_read(file_name)
|
164 |
+
return []
|
mini_dpvo/data_readers/rgbd_utils.py
ADDED
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import os.path as osp
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from ..lietorch import SE3
|
6 |
+
|
7 |
+
from scipy.spatial.transform import Rotation
|
8 |
+
|
9 |
+
def parse_list(filepath, skiprows=0):
|
10 |
+
""" read list data """
|
11 |
+
data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
|
12 |
+
return data
|
13 |
+
|
14 |
+
def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=1.0):
|
15 |
+
""" pair images, depths, and poses """
|
16 |
+
associations = []
|
17 |
+
for i, t in enumerate(tstamp_image):
|
18 |
+
if tstamp_pose is None:
|
19 |
+
j = np.argmin(np.abs(tstamp_depth - t))
|
20 |
+
if (np.abs(tstamp_depth[j] - t) < max_dt):
|
21 |
+
associations.append((i, j))
|
22 |
+
|
23 |
+
else:
|
24 |
+
j = np.argmin(np.abs(tstamp_depth - t))
|
25 |
+
k = np.argmin(np.abs(tstamp_pose - t))
|
26 |
+
|
27 |
+
if (np.abs(tstamp_depth[j] - t) < max_dt) and \
|
28 |
+
(np.abs(tstamp_pose[k] - t) < max_dt):
|
29 |
+
associations.append((i, j, k))
|
30 |
+
|
31 |
+
return associations
|
32 |
+
|
33 |
+
def loadtum(datapath, frame_rate=-1):
|
34 |
+
""" read video data in tum-rgbd format """
|
35 |
+
if osp.isfile(osp.join(datapath, 'groundtruth.txt')):
|
36 |
+
pose_list = osp.join(datapath, 'groundtruth.txt')
|
37 |
+
|
38 |
+
elif osp.isfile(osp.join(datapath, 'pose.txt')):
|
39 |
+
pose_list = osp.join(datapath, 'pose.txt')
|
40 |
+
|
41 |
+
else:
|
42 |
+
return None, None, None, None
|
43 |
+
|
44 |
+
image_list = osp.join(datapath, 'rgb.txt')
|
45 |
+
depth_list = osp.join(datapath, 'depth.txt')
|
46 |
+
|
47 |
+
calib_path = osp.join(datapath, 'calibration.txt')
|
48 |
+
intrinsic = None
|
49 |
+
if osp.isfile(calib_path):
|
50 |
+
intrinsic = np.loadtxt(calib_path, delimiter=' ')
|
51 |
+
intrinsic = intrinsic.astype(np.float64)
|
52 |
+
|
53 |
+
image_data = parse_list(image_list)
|
54 |
+
depth_data = parse_list(depth_list)
|
55 |
+
pose_data = parse_list(pose_list, skiprows=1)
|
56 |
+
pose_vecs = pose_data[:,1:].astype(np.float64)
|
57 |
+
|
58 |
+
tstamp_image = image_data[:,0].astype(np.float64)
|
59 |
+
tstamp_depth = depth_data[:,0].astype(np.float64)
|
60 |
+
tstamp_pose = pose_data[:,0].astype(np.float64)
|
61 |
+
associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose)
|
62 |
+
|
63 |
+
# print(len(tstamp_image))
|
64 |
+
# print(len(associations))
|
65 |
+
|
66 |
+
indicies = range(len(associations))[::5]
|
67 |
+
|
68 |
+
# indicies = [ 0 ]
|
69 |
+
# for i in range(1, len(associations)):
|
70 |
+
# t0 = tstamp_image[associations[indicies[-1]][0]]
|
71 |
+
# t1 = tstamp_image[associations[i][0]]
|
72 |
+
# if t1 - t0 > 1.0 / frame_rate:
|
73 |
+
# indicies += [ i ]
|
74 |
+
|
75 |
+
images, poses, depths, intrinsics, tstamps = [], [], [], [], []
|
76 |
+
for ix in indicies:
|
77 |
+
(i, j, k) = associations[ix]
|
78 |
+
images += [ osp.join(datapath, image_data[i,1]) ]
|
79 |
+
depths += [ osp.join(datapath, depth_data[j,1]) ]
|
80 |
+
poses += [ pose_vecs[k] ]
|
81 |
+
tstamps += [ tstamp_image[i] ]
|
82 |
+
|
83 |
+
if intrinsic is not None:
|
84 |
+
intrinsics += [ intrinsic ]
|
85 |
+
|
86 |
+
return images, depths, poses, intrinsics, tstamps
|
87 |
+
|
88 |
+
|
89 |
+
def all_pairs_distance_matrix(poses, beta=2.5):
|
90 |
+
""" compute distance matrix between all pairs of poses """
|
91 |
+
poses = np.array(poses, dtype=np.float32)
|
92 |
+
poses[:,:3] *= beta # scale to balence rot + trans
|
93 |
+
poses = SE3(torch.from_numpy(poses))
|
94 |
+
|
95 |
+
r = (poses[:,None].inv() * poses[None,:]).log()
|
96 |
+
return r.norm(dim=-1).cpu().numpy()
|
97 |
+
|
98 |
+
def pose_matrix_to_quaternion(pose):
|
99 |
+
""" convert 4x4 pose matrix to (t, q) """
|
100 |
+
q = Rotation.from_matrix(pose[:3, :3]).as_quat()
|
101 |
+
return np.concatenate([pose[:3, 3], q], axis=0)
|
102 |
+
|
103 |
+
def compute_distance_matrix_flow(poses, disps, intrinsics):
|
104 |
+
""" compute flow magnitude between all pairs of frames """
|
105 |
+
if not isinstance(poses, SE3):
|
106 |
+
poses = torch.from_numpy(poses).float().cuda()[None]
|
107 |
+
poses = SE3(poses).inv()
|
108 |
+
|
109 |
+
disps = torch.from_numpy(disps).float().cuda()[None]
|
110 |
+
intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
111 |
+
|
112 |
+
N = poses.shape[1]
|
113 |
+
|
114 |
+
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
115 |
+
ii = ii.reshape(-1).cuda()
|
116 |
+
jj = jj.reshape(-1).cuda()
|
117 |
+
|
118 |
+
MAX_FLOW = 100.0
|
119 |
+
matrix = np.zeros((N, N), dtype=np.float32)
|
120 |
+
|
121 |
+
s = 2048
|
122 |
+
for i in range(0, ii.shape[0], s):
|
123 |
+
flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
124 |
+
flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
|
125 |
+
|
126 |
+
flow = torch.stack([flow1, flow2], dim=2)
|
127 |
+
val = torch.stack([val1, val2], dim=2)
|
128 |
+
|
129 |
+
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
130 |
+
mag = mag.view(mag.shape[1], -1)
|
131 |
+
val = val.view(val.shape[1], -1)
|
132 |
+
|
133 |
+
mag = (mag * val).mean(-1) / val.mean(-1)
|
134 |
+
mag[val.mean(-1) < 0.7] = np.inf
|
135 |
+
|
136 |
+
i1 = ii[i:i+s].cpu().numpy()
|
137 |
+
j1 = jj[i:i+s].cpu().numpy()
|
138 |
+
matrix[i1, j1] = mag.cpu().numpy()
|
139 |
+
|
140 |
+
return matrix
|
141 |
+
|
142 |
+
|
143 |
+
def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
|
144 |
+
""" compute flow magnitude between all pairs of frames """
|
145 |
+
# if not isinstance(poses, SE3):
|
146 |
+
# poses = torch.from_numpy(poses).float().cuda()[None]
|
147 |
+
# poses = SE3(poses).inv()
|
148 |
+
|
149 |
+
# disps = torch.from_numpy(disps).float().cuda()[None]
|
150 |
+
# intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
|
151 |
+
|
152 |
+
N = poses.shape[1]
|
153 |
+
|
154 |
+
ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
|
155 |
+
ii = ii.reshape(-1)
|
156 |
+
jj = jj.reshape(-1)
|
157 |
+
|
158 |
+
MAX_FLOW = 128.0
|
159 |
+
matrix = np.zeros((N, N), dtype=np.float32)
|
160 |
+
|
161 |
+
s = 2048
|
162 |
+
for i in range(0, ii.shape[0], s):
|
163 |
+
flow1a, val1a = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
|
164 |
+
flow1b, val1b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
165 |
+
flow2a, val2a = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
|
166 |
+
flow2b, val2b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
|
167 |
+
|
168 |
+
flow1 = flow1a + beta * flow1b
|
169 |
+
val1 = val1a * val2b
|
170 |
+
|
171 |
+
flow2 = flow2a + beta * flow2b
|
172 |
+
val2 = val2a * val2b
|
173 |
+
|
174 |
+
flow = torch.stack([flow1, flow2], dim=2)
|
175 |
+
val = torch.stack([val1, val2], dim=2)
|
176 |
+
|
177 |
+
mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
|
178 |
+
mag = mag.view(mag.shape[1], -1)
|
179 |
+
val = val.view(val.shape[1], -1)
|
180 |
+
|
181 |
+
mag = (mag * val).mean(-1) / val.mean(-1)
|
182 |
+
mag[val.mean(-1) < 0.8] = np.inf
|
183 |
+
|
184 |
+
i1 = ii[i:i+s].cpu().numpy()
|
185 |
+
j1 = jj[i:i+s].cpu().numpy()
|
186 |
+
matrix[i1, j1] = mag.cpu().numpy()
|
187 |
+
|
188 |
+
return matrix
|
mini_dpvo/data_readers/tartan.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import glob
|
5 |
+
import cv2
|
6 |
+
import os
|
7 |
+
import os.path as osp
|
8 |
+
|
9 |
+
from ..lietorch import SE3
|
10 |
+
from .base import RGBDDataset
|
11 |
+
|
12 |
+
# cur_path = osp.dirname(osp.abspath(__file__))
|
13 |
+
# test_split = osp.join(cur_path, 'tartan_test.txt')
|
14 |
+
# test_split = open(test_split).read().split()
|
15 |
+
|
16 |
+
|
17 |
+
test_split = [
|
18 |
+
"abandonedfactory/abandonedfactory/Easy/P011",
|
19 |
+
"abandonedfactory/abandonedfactory/Hard/P011",
|
20 |
+
"abandonedfactory_night/abandonedfactory_night/Easy/P013",
|
21 |
+
"abandonedfactory_night/abandonedfactory_night/Hard/P014",
|
22 |
+
"amusement/amusement/Easy/P008",
|
23 |
+
"amusement/amusement/Hard/P007",
|
24 |
+
"carwelding/carwelding/Easy/P007",
|
25 |
+
"endofworld/endofworld/Easy/P009",
|
26 |
+
"gascola/gascola/Easy/P008",
|
27 |
+
"gascola/gascola/Hard/P009",
|
28 |
+
"hospital/hospital/Easy/P036",
|
29 |
+
"hospital/hospital/Hard/P049",
|
30 |
+
"japanesealley/japanesealley/Easy/P007",
|
31 |
+
"japanesealley/japanesealley/Hard/P005",
|
32 |
+
"neighborhood/neighborhood/Easy/P021",
|
33 |
+
"neighborhood/neighborhood/Hard/P017",
|
34 |
+
"ocean/ocean/Easy/P013",
|
35 |
+
"ocean/ocean/Hard/P009",
|
36 |
+
"office2/office2/Easy/P011",
|
37 |
+
"office2/office2/Hard/P010",
|
38 |
+
"office/office/Hard/P007",
|
39 |
+
"oldtown/oldtown/Easy/P007",
|
40 |
+
"oldtown/oldtown/Hard/P008",
|
41 |
+
"seasidetown/seasidetown/Easy/P009",
|
42 |
+
"seasonsforest/seasonsforest/Easy/P011",
|
43 |
+
"seasonsforest/seasonsforest/Hard/P006",
|
44 |
+
"seasonsforest_winter/seasonsforest_winter/Easy/P009",
|
45 |
+
"seasonsforest_winter/seasonsforest_winter/Hard/P018",
|
46 |
+
"soulcity/soulcity/Easy/P012",
|
47 |
+
"soulcity/soulcity/Hard/P009",
|
48 |
+
"westerndesert/westerndesert/Easy/P013",
|
49 |
+
"westerndesert/westerndesert/Hard/P007",
|
50 |
+
]
|
51 |
+
|
52 |
+
|
53 |
+
class TartanAir(RGBDDataset):
|
54 |
+
|
55 |
+
# scale depths to balance rot & trans
|
56 |
+
DEPTH_SCALE = 5.0
|
57 |
+
|
58 |
+
def __init__(self, mode='training', **kwargs):
|
59 |
+
self.mode = mode
|
60 |
+
self.n_frames = 2
|
61 |
+
super(TartanAir, self).__init__(name='TartanAir', **kwargs)
|
62 |
+
|
63 |
+
@staticmethod
|
64 |
+
def is_test_scene(scene):
|
65 |
+
# print(scene, any(x in scene for x in test_split))
|
66 |
+
return any(x in scene for x in test_split)
|
67 |
+
|
68 |
+
def _build_dataset(self):
|
69 |
+
from tqdm import tqdm
|
70 |
+
print("Building TartanAir dataset")
|
71 |
+
|
72 |
+
scene_info = {}
|
73 |
+
scenes = glob.glob(osp.join(self.root, '*/*/*/*'))
|
74 |
+
for scene in tqdm(sorted(scenes)):
|
75 |
+
images = sorted(glob.glob(osp.join(scene, 'image_left/*.png')))
|
76 |
+
depths = sorted(glob.glob(osp.join(scene, 'depth_left/*.npy')))
|
77 |
+
|
78 |
+
if len(images) != len(depths):
|
79 |
+
continue
|
80 |
+
|
81 |
+
poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ')
|
82 |
+
poses = poses[:, [1, 2, 0, 4, 5, 3, 6]]
|
83 |
+
poses[:,:3] /= TartanAir.DEPTH_SCALE
|
84 |
+
intrinsics = [TartanAir.calib_read()] * len(images)
|
85 |
+
|
86 |
+
# graph of co-visible frames based on flow
|
87 |
+
graph = self.build_frame_graph(poses, depths, intrinsics)
|
88 |
+
|
89 |
+
scene = '/'.join(scene.split('/'))
|
90 |
+
scene_info[scene] = {'images': images, 'depths': depths,
|
91 |
+
'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
|
92 |
+
|
93 |
+
return scene_info
|
94 |
+
|
95 |
+
@staticmethod
|
96 |
+
def calib_read():
|
97 |
+
return np.array([320.0, 320.0, 320.0, 240.0])
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def image_read(image_file):
|
101 |
+
return cv2.imread(image_file)
|
102 |
+
|
103 |
+
@staticmethod
|
104 |
+
def depth_read(depth_file):
|
105 |
+
depth = np.load(depth_file) / TartanAir.DEPTH_SCALE
|
106 |
+
depth[depth==np.nan] = 1.0
|
107 |
+
depth[depth==np.inf] = 1.0
|
108 |
+
return depth
|
109 |
+
|
110 |
+
|
mini_dpvo/data_readers/tartan_test.txt
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
abandonedfactory/abandonedfactory/Easy/P011
|
2 |
+
abandonedfactory/abandonedfactory/Hard/P011
|
3 |
+
abandonedfactory_night/abandonedfactory_night/Easy/P013
|
4 |
+
abandonedfactory_night/abandonedfactory_night/Hard/P014
|
5 |
+
amusement/amusement/Easy/P008
|
6 |
+
amusement/amusement/Hard/P007
|
7 |
+
carwelding/carwelding/Easy/P007
|
8 |
+
endofworld/endofworld/Easy/P009
|
9 |
+
gascola/gascola/Easy/P008
|
10 |
+
gascola/gascola/Hard/P009
|
11 |
+
hospital/hospital/Easy/P036
|
12 |
+
hospital/hospital/Hard/P049
|
13 |
+
japanesealley/japanesealley/Easy/P007
|
14 |
+
japanesealley/japanesealley/Hard/P005
|
15 |
+
neighborhood/neighborhood/Easy/P021
|
16 |
+
neighborhood/neighborhood/Hard/P017
|
17 |
+
ocean/ocean/Easy/P013
|
18 |
+
ocean/ocean/Hard/P009
|
19 |
+
office2/office2/Easy/P011
|
20 |
+
office2/office2/Hard/P010
|
21 |
+
office/office/Hard/P007
|
22 |
+
oldtown/oldtown/Easy/P007
|
23 |
+
oldtown/oldtown/Hard/P008
|
24 |
+
seasidetown/seasidetown/Easy/P009
|
25 |
+
seasonsforest/seasonsforest/Easy/P011
|
26 |
+
seasonsforest/seasonsforest/Hard/P006
|
27 |
+
seasonsforest_winter/seasonsforest_winter/Easy/P009
|
28 |
+
seasonsforest_winter/seasonsforest_winter/Hard/P018
|
29 |
+
soulcity/soulcity/Easy/P012
|
30 |
+
soulcity/soulcity/Hard/P009
|
31 |
+
westerndesert/westerndesert/Easy/P013
|
32 |
+
westerndesert/westerndesert/Hard/P007
|
mini_dpvo/dpvo.py
ADDED
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from . import fastba
|
6 |
+
from . import altcorr
|
7 |
+
from . import lietorch
|
8 |
+
from .lietorch import SE3
|
9 |
+
|
10 |
+
from .net import VONet
|
11 |
+
|
12 |
+
from .utils import Timer, flatmeshgrid
|
13 |
+
from . import projective_ops as pops
|
14 |
+
|
15 |
+
autocast = torch.cuda.amp.autocast
|
16 |
+
Id = SE3.Identity(1, device="cuda")
|
17 |
+
|
18 |
+
|
19 |
+
class DPVO:
|
20 |
+
def __init__(self, cfg, network, ht=480, wd=640):
|
21 |
+
self.cfg = cfg
|
22 |
+
self.load_weights(network)
|
23 |
+
self.is_initialized = False
|
24 |
+
self.enable_timing = False
|
25 |
+
|
26 |
+
self.n = 0 # number of frames
|
27 |
+
self.m = 0 # number of patches
|
28 |
+
self.M = self.cfg.PATCHES_PER_FRAME
|
29 |
+
self.N = self.cfg.BUFFER_SIZE
|
30 |
+
|
31 |
+
self.ht = ht # image height
|
32 |
+
self.wd = wd # image width
|
33 |
+
|
34 |
+
DIM = self.DIM
|
35 |
+
RES = self.RES
|
36 |
+
|
37 |
+
### state attributes ###
|
38 |
+
self.tlist = []
|
39 |
+
self.counter = 0
|
40 |
+
|
41 |
+
# dummy image for visualization
|
42 |
+
self.image_ = torch.zeros(self.ht, self.wd, 3, dtype=torch.uint8, device="cpu")
|
43 |
+
|
44 |
+
self.tstamps_ = torch.zeros(self.N, dtype=torch.float64, device="cuda")
|
45 |
+
self.poses_ = torch.zeros(self.N, 7, dtype=torch.float32, device="cuda")
|
46 |
+
self.patches_ = torch.zeros(
|
47 |
+
self.N, self.M, 3, self.P, self.P, dtype=torch.float, device="cuda"
|
48 |
+
)
|
49 |
+
self.intrinsics_ = torch.zeros(self.N, 4, dtype=torch.float, device="cuda")
|
50 |
+
|
51 |
+
self.points_ = torch.zeros(self.N * self.M, 3, dtype=torch.float, device="cuda")
|
52 |
+
self.colors_ = torch.zeros(self.N, self.M, 3, dtype=torch.uint8, device="cuda")
|
53 |
+
|
54 |
+
self.index_ = torch.zeros(self.N, self.M, dtype=torch.long, device="cuda")
|
55 |
+
self.index_map_ = torch.zeros(self.N, dtype=torch.long, device="cuda")
|
56 |
+
|
57 |
+
### network attributes ###
|
58 |
+
self.mem = 32
|
59 |
+
|
60 |
+
if self.cfg.MIXED_PRECISION:
|
61 |
+
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.half}
|
62 |
+
else:
|
63 |
+
self.kwargs = kwargs = {"device": "cuda", "dtype": torch.float}
|
64 |
+
|
65 |
+
self.imap_ = torch.zeros(self.mem, self.M, DIM, **kwargs)
|
66 |
+
self.gmap_ = torch.zeros(self.mem, self.M, 128, self.P, self.P, **kwargs)
|
67 |
+
|
68 |
+
ht = ht // RES
|
69 |
+
wd = wd // RES
|
70 |
+
|
71 |
+
self.fmap1_ = torch.zeros(1, self.mem, 128, ht // 1, wd // 1, **kwargs)
|
72 |
+
self.fmap2_ = torch.zeros(1, self.mem, 128, ht // 4, wd // 4, **kwargs)
|
73 |
+
|
74 |
+
# feature pyramid
|
75 |
+
self.pyramid = (self.fmap1_, self.fmap2_)
|
76 |
+
|
77 |
+
self.net = torch.zeros(1, 0, DIM, **kwargs)
|
78 |
+
self.ii = torch.as_tensor([], dtype=torch.long, device="cuda")
|
79 |
+
self.jj = torch.as_tensor([], dtype=torch.long, device="cuda")
|
80 |
+
self.kk = torch.as_tensor([], dtype=torch.long, device="cuda")
|
81 |
+
|
82 |
+
# initialize poses to identity matrix
|
83 |
+
self.poses_[:, 6] = 1.0
|
84 |
+
|
85 |
+
# store relative poses for removed frames
|
86 |
+
self.delta = {}
|
87 |
+
|
88 |
+
def load_weights(self, network):
|
89 |
+
# load network from checkpoint file
|
90 |
+
if isinstance(network, str):
|
91 |
+
from collections import OrderedDict
|
92 |
+
|
93 |
+
state_dict = torch.load(network)
|
94 |
+
new_state_dict = OrderedDict()
|
95 |
+
for k, v in state_dict.items():
|
96 |
+
if "update.lmbda" not in k:
|
97 |
+
new_state_dict[k.replace("module.", "")] = v
|
98 |
+
|
99 |
+
self.network = VONet()
|
100 |
+
self.network.load_state_dict(new_state_dict)
|
101 |
+
|
102 |
+
else:
|
103 |
+
self.network = network
|
104 |
+
|
105 |
+
# steal network attributes
|
106 |
+
self.DIM = self.network.DIM
|
107 |
+
self.RES = self.network.RES
|
108 |
+
self.P = self.network.P
|
109 |
+
|
110 |
+
self.network.cuda()
|
111 |
+
self.network.eval()
|
112 |
+
|
113 |
+
# if self.cfg.MIXED_PRECISION:
|
114 |
+
# self.network.half()
|
115 |
+
|
116 |
+
@property
|
117 |
+
def poses(self):
|
118 |
+
return self.poses_.view(1, self.N, 7)
|
119 |
+
|
120 |
+
@property
|
121 |
+
def patches(self):
|
122 |
+
return self.patches_.view(1, self.N * self.M, 3, 3, 3)
|
123 |
+
|
124 |
+
@property
|
125 |
+
def intrinsics(self):
|
126 |
+
return self.intrinsics_.view(1, self.N, 4)
|
127 |
+
|
128 |
+
@property
|
129 |
+
def ix(self):
|
130 |
+
return self.index_.view(-1)
|
131 |
+
|
132 |
+
@property
|
133 |
+
def imap(self):
|
134 |
+
return self.imap_.view(1, self.mem * self.M, self.DIM)
|
135 |
+
|
136 |
+
@property
|
137 |
+
def gmap(self):
|
138 |
+
return self.gmap_.view(1, self.mem * self.M, 128, 3, 3)
|
139 |
+
|
140 |
+
def get_pose(self, t):
|
141 |
+
if t in self.traj:
|
142 |
+
return SE3(self.traj[t])
|
143 |
+
|
144 |
+
t0, dP = self.delta[t]
|
145 |
+
return dP * self.get_pose(t0)
|
146 |
+
|
147 |
+
def terminate(self):
|
148 |
+
"""interpolate missing poses"""
|
149 |
+
print("Terminating...")
|
150 |
+
self.traj = {}
|
151 |
+
for i in range(self.n):
|
152 |
+
current_t: int = self.tstamps_[i].item()
|
153 |
+
self.traj[current_t] = self.poses_[i]
|
154 |
+
|
155 |
+
poses = [self.get_pose(t) for t in range(self.counter)]
|
156 |
+
poses = lietorch.stack(poses, dim=0)
|
157 |
+
poses = poses.inv().data.cpu().numpy()
|
158 |
+
tstamps = np.array(self.tlist, dtype=np.float64)
|
159 |
+
|
160 |
+
return poses, tstamps
|
161 |
+
|
162 |
+
def corr(self, coords, indicies=None):
|
163 |
+
"""local correlation volume"""
|
164 |
+
ii, jj = indicies if indicies is not None else (self.kk, self.jj)
|
165 |
+
ii1 = ii % (self.M * self.mem)
|
166 |
+
jj1 = jj % (self.mem)
|
167 |
+
corr1 = altcorr.corr(self.gmap, self.pyramid[0], coords / 1, ii1, jj1, 3)
|
168 |
+
corr2 = altcorr.corr(self.gmap, self.pyramid[1], coords / 4, ii1, jj1, 3)
|
169 |
+
return torch.stack([corr1, corr2], -1).view(1, len(ii), -1)
|
170 |
+
|
171 |
+
def reproject(self, indicies=None):
|
172 |
+
"""reproject patch k from i -> j"""
|
173 |
+
(ii, jj, kk) = indicies if indicies is not None else (self.ii, self.jj, self.kk)
|
174 |
+
coords = pops.transform(
|
175 |
+
SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk
|
176 |
+
)
|
177 |
+
return coords.permute(0, 1, 4, 2, 3).contiguous()
|
178 |
+
|
179 |
+
def append_factors(self, ii, jj):
|
180 |
+
self.jj = torch.cat([self.jj, jj])
|
181 |
+
self.kk = torch.cat([self.kk, ii])
|
182 |
+
self.ii = torch.cat([self.ii, self.ix[ii]])
|
183 |
+
|
184 |
+
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
|
185 |
+
self.net = torch.cat([self.net, net], dim=1)
|
186 |
+
|
187 |
+
def remove_factors(self, m):
|
188 |
+
self.ii = self.ii[~m]
|
189 |
+
self.jj = self.jj[~m]
|
190 |
+
self.kk = self.kk[~m]
|
191 |
+
self.net = self.net[:, ~m]
|
192 |
+
|
193 |
+
def motion_probe(self):
|
194 |
+
"""kinda hacky way to ensure enough motion for initialization"""
|
195 |
+
kk = torch.arange(self.m - self.M, self.m, device="cuda")
|
196 |
+
jj = self.n * torch.ones_like(kk)
|
197 |
+
ii = self.ix[kk]
|
198 |
+
|
199 |
+
net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
|
200 |
+
coords = self.reproject(indicies=(ii, jj, kk))
|
201 |
+
|
202 |
+
with autocast(enabled=self.cfg.MIXED_PRECISION):
|
203 |
+
corr = self.corr(coords, indicies=(kk, jj))
|
204 |
+
ctx = self.imap[:, kk % (self.M * self.mem)]
|
205 |
+
net, (delta, weight, _) = self.network.update(
|
206 |
+
net, ctx, corr, None, ii, jj, kk
|
207 |
+
)
|
208 |
+
|
209 |
+
return torch.quantile(delta.norm(dim=-1).float(), 0.5)
|
210 |
+
|
211 |
+
def motionmag(self, i, j):
|
212 |
+
k = (self.ii == i) & (self.jj == j)
|
213 |
+
ii = self.ii[k]
|
214 |
+
jj = self.jj[k]
|
215 |
+
kk = self.kk[k]
|
216 |
+
|
217 |
+
flow = pops.flow_mag(
|
218 |
+
SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk, beta=0.5
|
219 |
+
)
|
220 |
+
return flow.mean().item()
|
221 |
+
|
222 |
+
def keyframe(self):
|
223 |
+
i = self.n - self.cfg.KEYFRAME_INDEX - 1
|
224 |
+
j = self.n - self.cfg.KEYFRAME_INDEX + 1
|
225 |
+
m = self.motionmag(i, j) + self.motionmag(j, i)
|
226 |
+
|
227 |
+
if m / 2 < self.cfg.KEYFRAME_THRESH:
|
228 |
+
k = self.n - self.cfg.KEYFRAME_INDEX
|
229 |
+
t0 = self.tstamps_[k - 1].item()
|
230 |
+
t1 = self.tstamps_[k].item()
|
231 |
+
|
232 |
+
dP = SE3(self.poses_[k]) * SE3(self.poses_[k - 1]).inv()
|
233 |
+
self.delta[t1] = (t0, dP)
|
234 |
+
|
235 |
+
to_remove = (self.ii == k) | (self.jj == k)
|
236 |
+
self.remove_factors(to_remove)
|
237 |
+
|
238 |
+
self.kk[self.ii > k] -= self.M
|
239 |
+
self.ii[self.ii > k] -= 1
|
240 |
+
self.jj[self.jj > k] -= 1
|
241 |
+
|
242 |
+
for i in range(k, self.n - 1):
|
243 |
+
self.tstamps_[i] = self.tstamps_[i + 1]
|
244 |
+
self.colors_[i] = self.colors_[i + 1]
|
245 |
+
self.poses_[i] = self.poses_[i + 1]
|
246 |
+
self.patches_[i] = self.patches_[i + 1]
|
247 |
+
self.intrinsics_[i] = self.intrinsics_[i + 1]
|
248 |
+
|
249 |
+
self.imap_[i % self.mem] = self.imap_[(i + 1) % self.mem]
|
250 |
+
self.gmap_[i % self.mem] = self.gmap_[(i + 1) % self.mem]
|
251 |
+
self.fmap1_[0, i % self.mem] = self.fmap1_[0, (i + 1) % self.mem]
|
252 |
+
self.fmap2_[0, i % self.mem] = self.fmap2_[0, (i + 1) % self.mem]
|
253 |
+
|
254 |
+
self.n -= 1
|
255 |
+
self.m -= self.M
|
256 |
+
|
257 |
+
to_remove = self.ix[self.kk] < self.n - self.cfg.REMOVAL_WINDOW
|
258 |
+
self.remove_factors(to_remove)
|
259 |
+
|
260 |
+
def update(self):
|
261 |
+
with Timer("other", enabled=self.enable_timing):
|
262 |
+
coords = self.reproject()
|
263 |
+
|
264 |
+
with autocast(enabled=True):
|
265 |
+
corr = self.corr(coords)
|
266 |
+
ctx = self.imap[:, self.kk % (self.M * self.mem)]
|
267 |
+
self.net, (delta, weight, _) = self.network.update(
|
268 |
+
self.net, ctx, corr, None, self.ii, self.jj, self.kk
|
269 |
+
)
|
270 |
+
|
271 |
+
lmbda = torch.as_tensor([1e-4], device="cuda")
|
272 |
+
weight = weight.float()
|
273 |
+
target = coords[..., self.P // 2, self.P // 2] + delta.float()
|
274 |
+
|
275 |
+
with Timer("BA", enabled=self.enable_timing):
|
276 |
+
t0 = self.n - self.cfg.OPTIMIZATION_WINDOW if self.is_initialized else 1
|
277 |
+
t0 = max(t0, 1)
|
278 |
+
|
279 |
+
try:
|
280 |
+
fastba.BA(
|
281 |
+
self.poses,
|
282 |
+
self.patches,
|
283 |
+
self.intrinsics,
|
284 |
+
target,
|
285 |
+
weight,
|
286 |
+
lmbda,
|
287 |
+
self.ii,
|
288 |
+
self.jj,
|
289 |
+
self.kk,
|
290 |
+
t0,
|
291 |
+
self.n,
|
292 |
+
2,
|
293 |
+
)
|
294 |
+
except:
|
295 |
+
print("Warning BA failed...")
|
296 |
+
|
297 |
+
points = pops.point_cloud(
|
298 |
+
SE3(self.poses),
|
299 |
+
self.patches[:, : self.m],
|
300 |
+
self.intrinsics,
|
301 |
+
self.ix[: self.m],
|
302 |
+
)
|
303 |
+
points = (points[..., 1, 1, :3] / points[..., 1, 1, 3:]).reshape(-1, 3)
|
304 |
+
self.points_[: len(points)] = points[:]
|
305 |
+
|
306 |
+
def __edges_all(self):
|
307 |
+
return flatmeshgrid(
|
308 |
+
torch.arange(0, self.m, device="cuda"),
|
309 |
+
torch.arange(0, self.n, device="cuda"),
|
310 |
+
indexing="ij",
|
311 |
+
)
|
312 |
+
|
313 |
+
def __edges_forw(self):
|
314 |
+
r = self.cfg.PATCH_LIFETIME
|
315 |
+
t0 = self.M * max((self.n - r), 0)
|
316 |
+
t1 = self.M * max((self.n - 1), 0)
|
317 |
+
return flatmeshgrid(
|
318 |
+
torch.arange(t0, t1, device="cuda"),
|
319 |
+
torch.arange(self.n - 1, self.n, device="cuda"),
|
320 |
+
indexing="ij",
|
321 |
+
)
|
322 |
+
|
323 |
+
def __edges_back(self):
|
324 |
+
r = self.cfg.PATCH_LIFETIME
|
325 |
+
t0 = self.M * max((self.n - 1), 0)
|
326 |
+
t1 = self.M * max((self.n - 0), 0)
|
327 |
+
return flatmeshgrid(
|
328 |
+
torch.arange(t0, t1, device="cuda"),
|
329 |
+
torch.arange(max(self.n - r, 0), self.n, device="cuda"),
|
330 |
+
indexing="ij",
|
331 |
+
)
|
332 |
+
|
333 |
+
def __call__(self, tstamp: int, image, intrinsics) -> None:
|
334 |
+
"""track new frame"""
|
335 |
+
|
336 |
+
if (self.n + 1) >= self.N:
|
337 |
+
raise Exception(
|
338 |
+
f'The buffer size is too small. You can increase it using "--buffer {self.N*2}"'
|
339 |
+
)
|
340 |
+
|
341 |
+
image = 2 * (image[None, None] / 255.0) - 0.5
|
342 |
+
|
343 |
+
with autocast(enabled=self.cfg.MIXED_PRECISION):
|
344 |
+
fmap, gmap, imap, patches, _, clr = self.network.patchify(
|
345 |
+
image,
|
346 |
+
patches_per_image=self.cfg.PATCHES_PER_FRAME,
|
347 |
+
gradient_bias=self.cfg.GRADIENT_BIAS,
|
348 |
+
return_color=True,
|
349 |
+
)
|
350 |
+
|
351 |
+
### update state attributes ###
|
352 |
+
self.tlist.append(tstamp)
|
353 |
+
self.tstamps_[self.n] = self.counter
|
354 |
+
self.intrinsics_[self.n] = intrinsics / self.RES
|
355 |
+
|
356 |
+
# color info for visualization
|
357 |
+
clr = (clr[0, :, [2, 1, 0]] + 0.5) * (255.0 / 2)
|
358 |
+
self.colors_[self.n] = clr.to(torch.uint8)
|
359 |
+
|
360 |
+
self.index_[self.n + 1] = self.n + 1
|
361 |
+
self.index_map_[self.n + 1] = self.m + self.M
|
362 |
+
|
363 |
+
if self.n > 1:
|
364 |
+
if self.cfg.MOTION_MODEL == "DAMPED_LINEAR":
|
365 |
+
P1 = SE3(self.poses_[self.n - 1])
|
366 |
+
P2 = SE3(self.poses_[self.n - 2])
|
367 |
+
|
368 |
+
xi = self.cfg.MOTION_DAMPING * (P1 * P2.inv()).log()
|
369 |
+
tvec_qvec = (SE3.exp(xi) * P1).data
|
370 |
+
self.poses_[self.n] = tvec_qvec
|
371 |
+
else:
|
372 |
+
tvec_qvec = self.poses[self.n - 1]
|
373 |
+
self.poses_[self.n] = tvec_qvec
|
374 |
+
|
375 |
+
# TODO better depth initialization
|
376 |
+
patches[:, :, 2] = torch.rand_like(patches[:, :, 2, 0, 0, None, None])
|
377 |
+
if self.is_initialized:
|
378 |
+
s = torch.median(self.patches_[self.n - 3 : self.n, :, 2])
|
379 |
+
patches[:, :, 2] = s
|
380 |
+
|
381 |
+
self.patches_[self.n] = patches
|
382 |
+
|
383 |
+
### update network attributes ###
|
384 |
+
self.imap_[self.n % self.mem] = imap.squeeze()
|
385 |
+
self.gmap_[self.n % self.mem] = gmap.squeeze()
|
386 |
+
self.fmap1_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 1, 1)
|
387 |
+
self.fmap2_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 4, 4)
|
388 |
+
|
389 |
+
self.counter += 1
|
390 |
+
if self.n > 0 and not self.is_initialized:
|
391 |
+
if self.motion_probe() < 2.0:
|
392 |
+
self.delta[self.counter - 1] = (self.counter - 2, Id[0])
|
393 |
+
return
|
394 |
+
|
395 |
+
self.n += 1
|
396 |
+
self.m += self.M
|
397 |
+
|
398 |
+
# relative pose
|
399 |
+
self.append_factors(*self.__edges_forw())
|
400 |
+
self.append_factors(*self.__edges_back())
|
401 |
+
|
402 |
+
if self.n == 8 and not self.is_initialized:
|
403 |
+
self.is_initialized = True
|
404 |
+
|
405 |
+
for itr in range(12):
|
406 |
+
self.update()
|
407 |
+
|
408 |
+
elif self.is_initialized:
|
409 |
+
self.update()
|
410 |
+
self.keyframe()
|
mini_dpvo/extractor.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
class ResidualBlock(nn.Module):
|
7 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
8 |
+
super(ResidualBlock, self).__init__()
|
9 |
+
|
10 |
+
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
|
11 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
|
12 |
+
self.relu = nn.ReLU(inplace=True)
|
13 |
+
|
14 |
+
num_groups = planes // 8
|
15 |
+
|
16 |
+
if norm_fn == 'group':
|
17 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
18 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
19 |
+
if not stride == 1:
|
20 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
21 |
+
|
22 |
+
elif norm_fn == 'batch':
|
23 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
24 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
25 |
+
if not stride == 1:
|
26 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
27 |
+
|
28 |
+
elif norm_fn == 'instance':
|
29 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
30 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
31 |
+
if not stride == 1:
|
32 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
33 |
+
|
34 |
+
elif norm_fn == 'none':
|
35 |
+
self.norm1 = nn.Sequential()
|
36 |
+
self.norm2 = nn.Sequential()
|
37 |
+
if not stride == 1:
|
38 |
+
self.norm3 = nn.Sequential()
|
39 |
+
|
40 |
+
if stride == 1:
|
41 |
+
self.downsample = None
|
42 |
+
|
43 |
+
else:
|
44 |
+
self.downsample = nn.Sequential(
|
45 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
y = x
|
49 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
50 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
51 |
+
|
52 |
+
if self.downsample is not None:
|
53 |
+
x = self.downsample(x)
|
54 |
+
|
55 |
+
return self.relu(x+y)
|
56 |
+
|
57 |
+
|
58 |
+
class BottleneckBlock(nn.Module):
|
59 |
+
def __init__(self, in_planes, planes, norm_fn='group', stride=1):
|
60 |
+
super(BottleneckBlock, self).__init__()
|
61 |
+
|
62 |
+
self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
|
63 |
+
self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
|
64 |
+
self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
|
65 |
+
self.relu = nn.ReLU(inplace=True)
|
66 |
+
|
67 |
+
num_groups = planes // 8
|
68 |
+
|
69 |
+
if norm_fn == 'group':
|
70 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
71 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
|
72 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
73 |
+
if not stride == 1:
|
74 |
+
self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
75 |
+
|
76 |
+
elif norm_fn == 'batch':
|
77 |
+
self.norm1 = nn.BatchNorm2d(planes//4)
|
78 |
+
self.norm2 = nn.BatchNorm2d(planes//4)
|
79 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
80 |
+
if not stride == 1:
|
81 |
+
self.norm4 = nn.BatchNorm2d(planes)
|
82 |
+
|
83 |
+
elif norm_fn == 'instance':
|
84 |
+
self.norm1 = nn.InstanceNorm2d(planes//4)
|
85 |
+
self.norm2 = nn.InstanceNorm2d(planes//4)
|
86 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
87 |
+
if not stride == 1:
|
88 |
+
self.norm4 = nn.InstanceNorm2d(planes)
|
89 |
+
|
90 |
+
elif norm_fn == 'none':
|
91 |
+
self.norm1 = nn.Sequential()
|
92 |
+
self.norm2 = nn.Sequential()
|
93 |
+
self.norm3 = nn.Sequential()
|
94 |
+
if not stride == 1:
|
95 |
+
self.norm4 = nn.Sequential()
|
96 |
+
|
97 |
+
if stride == 1:
|
98 |
+
self.downsample = None
|
99 |
+
|
100 |
+
else:
|
101 |
+
self.downsample = nn.Sequential(
|
102 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
y = x
|
106 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
107 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
108 |
+
y = self.relu(self.norm3(self.conv3(y)))
|
109 |
+
|
110 |
+
if self.downsample is not None:
|
111 |
+
x = self.downsample(x)
|
112 |
+
|
113 |
+
return self.relu(x+y)
|
114 |
+
|
115 |
+
DIM=32
|
116 |
+
|
117 |
+
class BasicEncoder(nn.Module):
|
118 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
|
119 |
+
super(BasicEncoder, self).__init__()
|
120 |
+
self.norm_fn = norm_fn
|
121 |
+
self.multidim = multidim
|
122 |
+
|
123 |
+
if self.norm_fn == 'group':
|
124 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
|
125 |
+
|
126 |
+
elif self.norm_fn == 'batch':
|
127 |
+
self.norm1 = nn.BatchNorm2d(DIM)
|
128 |
+
|
129 |
+
elif self.norm_fn == 'instance':
|
130 |
+
self.norm1 = nn.InstanceNorm2d(DIM)
|
131 |
+
|
132 |
+
elif self.norm_fn == 'none':
|
133 |
+
self.norm1 = nn.Sequential()
|
134 |
+
|
135 |
+
self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
|
136 |
+
self.relu1 = nn.ReLU(inplace=True)
|
137 |
+
|
138 |
+
self.in_planes = DIM
|
139 |
+
self.layer1 = self._make_layer(DIM, stride=1)
|
140 |
+
self.layer2 = self._make_layer(2*DIM, stride=2)
|
141 |
+
self.layer3 = self._make_layer(4*DIM, stride=2)
|
142 |
+
|
143 |
+
# output convolution
|
144 |
+
self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
|
145 |
+
|
146 |
+
if self.multidim:
|
147 |
+
self.layer4 = self._make_layer(256, stride=2)
|
148 |
+
self.layer5 = self._make_layer(512, stride=2)
|
149 |
+
|
150 |
+
self.in_planes = 256
|
151 |
+
self.layer6 = self._make_layer(256, stride=1)
|
152 |
+
|
153 |
+
self.in_planes = 128
|
154 |
+
self.layer7 = self._make_layer(128, stride=1)
|
155 |
+
|
156 |
+
self.up1 = nn.Conv2d(512, 256, 1)
|
157 |
+
self.up2 = nn.Conv2d(256, 128, 1)
|
158 |
+
self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1)
|
159 |
+
|
160 |
+
if dropout > 0:
|
161 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
162 |
+
else:
|
163 |
+
self.dropout = None
|
164 |
+
|
165 |
+
for m in self.modules():
|
166 |
+
if isinstance(m, nn.Conv2d):
|
167 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
168 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
169 |
+
if m.weight is not None:
|
170 |
+
nn.init.constant_(m.weight, 1)
|
171 |
+
if m.bias is not None:
|
172 |
+
nn.init.constant_(m.bias, 0)
|
173 |
+
|
174 |
+
def _make_layer(self, dim, stride=1):
|
175 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
176 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
177 |
+
layers = (layer1, layer2)
|
178 |
+
|
179 |
+
self.in_planes = dim
|
180 |
+
return nn.Sequential(*layers)
|
181 |
+
|
182 |
+
def forward(self, x):
|
183 |
+
b, n, c1, h1, w1 = x.shape
|
184 |
+
x = x.view(b*n, c1, h1, w1)
|
185 |
+
|
186 |
+
x = self.conv1(x)
|
187 |
+
x = self.norm1(x)
|
188 |
+
x = self.relu1(x)
|
189 |
+
|
190 |
+
x = self.layer1(x)
|
191 |
+
x = self.layer2(x)
|
192 |
+
x = self.layer3(x)
|
193 |
+
|
194 |
+
x = self.conv2(x)
|
195 |
+
|
196 |
+
_, c2, h2, w2 = x.shape
|
197 |
+
return x.view(b, n, c2, h2, w2)
|
198 |
+
|
199 |
+
|
200 |
+
class BasicEncoder4(nn.Module):
|
201 |
+
def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
|
202 |
+
super(BasicEncoder4, self).__init__()
|
203 |
+
self.norm_fn = norm_fn
|
204 |
+
self.multidim = multidim
|
205 |
+
|
206 |
+
if self.norm_fn == 'group':
|
207 |
+
self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
|
208 |
+
|
209 |
+
elif self.norm_fn == 'batch':
|
210 |
+
self.norm1 = nn.BatchNorm2d(DIM)
|
211 |
+
|
212 |
+
elif self.norm_fn == 'instance':
|
213 |
+
self.norm1 = nn.InstanceNorm2d(DIM)
|
214 |
+
|
215 |
+
elif self.norm_fn == 'none':
|
216 |
+
self.norm1 = nn.Sequential()
|
217 |
+
|
218 |
+
self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
|
219 |
+
self.relu1 = nn.ReLU(inplace=True)
|
220 |
+
|
221 |
+
self.in_planes = DIM
|
222 |
+
self.layer1 = self._make_layer(DIM, stride=1)
|
223 |
+
self.layer2 = self._make_layer(2*DIM, stride=2)
|
224 |
+
|
225 |
+
# output convolution
|
226 |
+
self.conv2 = nn.Conv2d(2*DIM, output_dim, kernel_size=1)
|
227 |
+
|
228 |
+
if dropout > 0:
|
229 |
+
self.dropout = nn.Dropout2d(p=dropout)
|
230 |
+
else:
|
231 |
+
self.dropout = None
|
232 |
+
|
233 |
+
for m in self.modules():
|
234 |
+
if isinstance(m, nn.Conv2d):
|
235 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
236 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
|
237 |
+
if m.weight is not None:
|
238 |
+
nn.init.constant_(m.weight, 1)
|
239 |
+
if m.bias is not None:
|
240 |
+
nn.init.constant_(m.bias, 0)
|
241 |
+
|
242 |
+
def _make_layer(self, dim, stride=1):
|
243 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
244 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
245 |
+
layers = (layer1, layer2)
|
246 |
+
|
247 |
+
self.in_planes = dim
|
248 |
+
return nn.Sequential(*layers)
|
249 |
+
|
250 |
+
def forward(self, x):
|
251 |
+
b, n, c1, h1, w1 = x.shape
|
252 |
+
x = x.view(b*n, c1, h1, w1)
|
253 |
+
|
254 |
+
x = self.conv1(x)
|
255 |
+
x = self.norm1(x)
|
256 |
+
x = self.relu1(x)
|
257 |
+
|
258 |
+
x = self.layer1(x)
|
259 |
+
x = self.layer2(x)
|
260 |
+
|
261 |
+
x = self.conv2(x)
|
262 |
+
|
263 |
+
_, c2, h2, w2 = x.shape
|
264 |
+
return x.view(b, n, c2, h2, w2)
|
mini_dpvo/fastba/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .ba import BA, neighbors, reproject
|
mini_dpvo/fastba/ba.cpp
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <vector>
|
3 |
+
#include <unordered_map>
|
4 |
+
#include <algorithm>
|
5 |
+
#include <iostream>
|
6 |
+
|
7 |
+
|
8 |
+
std::vector<torch::Tensor> cuda_ba(
|
9 |
+
torch::Tensor poses,
|
10 |
+
torch::Tensor patches,
|
11 |
+
torch::Tensor intrinsics,
|
12 |
+
torch::Tensor target,
|
13 |
+
torch::Tensor weight,
|
14 |
+
torch::Tensor lmbda,
|
15 |
+
torch::Tensor ii,
|
16 |
+
torch::Tensor jj,
|
17 |
+
torch::Tensor kk,
|
18 |
+
int t0, int t1, int iterations);
|
19 |
+
|
20 |
+
|
21 |
+
torch::Tensor cuda_reproject(
|
22 |
+
torch::Tensor poses,
|
23 |
+
torch::Tensor patches,
|
24 |
+
torch::Tensor intrinsics,
|
25 |
+
torch::Tensor ii,
|
26 |
+
torch::Tensor jj,
|
27 |
+
torch::Tensor kk);
|
28 |
+
|
29 |
+
std::vector<torch::Tensor> ba(
|
30 |
+
torch::Tensor poses,
|
31 |
+
torch::Tensor patches,
|
32 |
+
torch::Tensor intrinsics,
|
33 |
+
torch::Tensor target,
|
34 |
+
torch::Tensor weight,
|
35 |
+
torch::Tensor lmbda,
|
36 |
+
torch::Tensor ii,
|
37 |
+
torch::Tensor jj,
|
38 |
+
torch::Tensor kk,
|
39 |
+
int t0, int t1, int iterations) {
|
40 |
+
return cuda_ba(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations);
|
41 |
+
}
|
42 |
+
|
43 |
+
|
44 |
+
torch::Tensor reproject(
|
45 |
+
torch::Tensor poses,
|
46 |
+
torch::Tensor patches,
|
47 |
+
torch::Tensor intrinsics,
|
48 |
+
torch::Tensor ii,
|
49 |
+
torch::Tensor jj,
|
50 |
+
torch::Tensor kk) {
|
51 |
+
return cuda_reproject(poses, patches, intrinsics, ii, jj, kk);
|
52 |
+
}
|
53 |
+
|
54 |
+
// std::vector<torch::Tensor> neighbors(torch::Tensor ii, torch::Tensor jj)
|
55 |
+
// {
|
56 |
+
// ii = ii.to(torch::kCPU);
|
57 |
+
// jj = jj.to(torch::kCPU);
|
58 |
+
// auto ii_data = ii.accessor<long,1>();
|
59 |
+
// auto jj_data = jj.accessor<long,1>();
|
60 |
+
|
61 |
+
// std::unordered_map<long, std::vector<long>> graph;
|
62 |
+
// std::unordered_map<long, std::vector<long>> index;
|
63 |
+
// for (int i=0; i < ii.size(0); i++) {
|
64 |
+
// const long ix = ii_data[i];
|
65 |
+
// const long jx = jj_data[i];
|
66 |
+
// if (graph.find(ix) == graph.end()) {
|
67 |
+
// graph[ix] = std::vector<long>();
|
68 |
+
// index[ix] = std::vector<long>();
|
69 |
+
// }
|
70 |
+
// graph[ix].push_back(jx);
|
71 |
+
// index[ix].push_back( i);
|
72 |
+
// }
|
73 |
+
|
74 |
+
// auto opts = torch::TensorOptions().dtype(torch::kInt64);
|
75 |
+
// torch::Tensor ix = torch::empty({ii.size(0)}, opts);
|
76 |
+
// torch::Tensor jx = torch::empty({jj.size(0)}, opts);
|
77 |
+
|
78 |
+
// auto ix_data = ix.accessor<long,1>();
|
79 |
+
// auto jx_data = jx.accessor<long,1>();
|
80 |
+
|
81 |
+
// for (std::pair<long, std::vector<long>> element : graph) {
|
82 |
+
// std::vector<long>& v = graph[element.first];
|
83 |
+
// std::vector<long>& idx = index[element.first];
|
84 |
+
|
85 |
+
// std::stable_sort(idx.begin(), idx.end(),
|
86 |
+
// [&v](size_t i, size_t j) {return v[i] < v[j];});
|
87 |
+
|
88 |
+
// ix_data[idx.front()] = -1;
|
89 |
+
// jx_data[idx.back()] = -1;
|
90 |
+
|
91 |
+
// for (int i=0; i < idx.size(); i++) {
|
92 |
+
// ix_data[idx[i]] = (i > 0) ? idx[i-1] : -1;
|
93 |
+
// jx_data[idx[i]] = (i < idx.size() - 1) ? idx[i+1] : -1;
|
94 |
+
// }
|
95 |
+
// }
|
96 |
+
|
97 |
+
// ix = ix.to(torch::kCUDA);
|
98 |
+
// jx = jx.to(torch::kCUDA);
|
99 |
+
|
100 |
+
// return {ix, jx};
|
101 |
+
// }
|
102 |
+
|
103 |
+
|
104 |
+
std::vector<torch::Tensor> neighbors(torch::Tensor ii, torch::Tensor jj)
|
105 |
+
{
|
106 |
+
|
107 |
+
auto tup = torch::_unique(ii, true, true);
|
108 |
+
torch::Tensor uniq = std::get<0>(tup).to(torch::kCPU);
|
109 |
+
torch::Tensor perm = std::get<1>(tup).to(torch::kCPU);
|
110 |
+
|
111 |
+
jj = jj.to(torch::kCPU);
|
112 |
+
auto jj_accessor = jj.accessor<long,1>();
|
113 |
+
|
114 |
+
auto perm_accessor = perm.accessor<long,1>();
|
115 |
+
std::vector<std::vector<long>> index(uniq.size(0));
|
116 |
+
for (int i=0; i < ii.size(0); i++) {
|
117 |
+
index[perm_accessor[i]].push_back(i);
|
118 |
+
}
|
119 |
+
|
120 |
+
auto opts = torch::TensorOptions().dtype(torch::kInt64);
|
121 |
+
torch::Tensor ix = torch::empty({ii.size(0)}, opts);
|
122 |
+
torch::Tensor jx = torch::empty({ii.size(0)}, opts);
|
123 |
+
|
124 |
+
auto ix_accessor = ix.accessor<long,1>();
|
125 |
+
auto jx_accessor = jx.accessor<long,1>();
|
126 |
+
|
127 |
+
for (int i=0; i<uniq.size(0); i++) {
|
128 |
+
std::vector<long>& idx = index[i];
|
129 |
+
std::stable_sort(idx.begin(), idx.end(),
|
130 |
+
[&jj_accessor](size_t i, size_t j) {return jj_accessor[i] < jj_accessor[j];});
|
131 |
+
|
132 |
+
for (int i=0; i < idx.size(); i++) {
|
133 |
+
ix_accessor[idx[i]] = (i > 0) ? idx[i-1] : -1;
|
134 |
+
jx_accessor[idx[i]] = (i < idx.size() - 1) ? idx[i+1] : -1;
|
135 |
+
}
|
136 |
+
}
|
137 |
+
|
138 |
+
// for (int i=0; i<ii.size(0); i++) {
|
139 |
+
// std::cout << jj_accessor[i] << " ";
|
140 |
+
// if (ix_accessor[i] >= 0) std::cout << jj_accessor[ix_accessor[i]] << " ";
|
141 |
+
// if (jx_accessor[i] >= 0) std::cout << jj_accessor[jx_accessor[i]] << " ";
|
142 |
+
// std::cout << std::endl;
|
143 |
+
// }
|
144 |
+
|
145 |
+
ix = ix.to(torch::kCUDA);
|
146 |
+
jx = jx.to(torch::kCUDA);
|
147 |
+
|
148 |
+
return {ix, jx};
|
149 |
+
}
|
150 |
+
|
151 |
+
|
152 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
153 |
+
m.def("forward", &ba, "BA forward operator");
|
154 |
+
m.def("neighbors", &neighbors, "temporal neighboor indicies");
|
155 |
+
m.def("reproject", &reproject, "temporal neighboor indicies");
|
156 |
+
|
157 |
+
}
|
mini_dpvo/fastba/ba.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cuda_ba
|
3 |
+
|
4 |
+
neighbors = cuda_ba.neighbors
|
5 |
+
reproject = cuda_ba.reproject
|
6 |
+
|
7 |
+
def BA(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations=2):
|
8 |
+
return cuda_ba.forward(poses.data, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations)
|
mini_dpvo/fastba/ba_cuda.cu
ADDED
@@ -0,0 +1,575 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <vector>
|
3 |
+
#include <iostream>
|
4 |
+
|
5 |
+
#include <ATen/ATen.h>
|
6 |
+
#include <ATen/NativeFunctions.h>
|
7 |
+
#include <ATen/Parallel.h>
|
8 |
+
|
9 |
+
|
10 |
+
#define GPU_1D_KERNEL_LOOP(i, n) \
|
11 |
+
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i<n; i += blockDim.x * gridDim.x)
|
12 |
+
|
13 |
+
|
14 |
+
#define NUM_THREADS 256
|
15 |
+
#define NUM_BLOCKS(batch_size) ((batch_size + NUM_THREADS - 1) / NUM_THREADS)
|
16 |
+
|
17 |
+
|
18 |
+
__device__ void
|
19 |
+
actSO3(const float *q, const float *X, float *Y) {
|
20 |
+
float uv[3];
|
21 |
+
uv[0] = 2.0 * (q[1]*X[2] - q[2]*X[1]);
|
22 |
+
uv[1] = 2.0 * (q[2]*X[0] - q[0]*X[2]);
|
23 |
+
uv[2] = 2.0 * (q[0]*X[1] - q[1]*X[0]);
|
24 |
+
|
25 |
+
Y[0] = X[0] + q[3]*uv[0] + (q[1]*uv[2] - q[2]*uv[1]);
|
26 |
+
Y[1] = X[1] + q[3]*uv[1] + (q[2]*uv[0] - q[0]*uv[2]);
|
27 |
+
Y[2] = X[2] + q[3]*uv[2] + (q[0]*uv[1] - q[1]*uv[0]);
|
28 |
+
}
|
29 |
+
|
30 |
+
__device__ void
|
31 |
+
actSE3(const float *t, const float *q, const float *X, float *Y) {
|
32 |
+
actSO3(q, X, Y);
|
33 |
+
Y[3] = X[3];
|
34 |
+
Y[0] += X[3] * t[0];
|
35 |
+
Y[1] += X[3] * t[1];
|
36 |
+
Y[2] += X[3] * t[2];
|
37 |
+
}
|
38 |
+
|
39 |
+
__device__ void
|
40 |
+
adjSE3(const float *t, const float *q, const float *X, float *Y) {
|
41 |
+
float qinv[4] = {-q[0], -q[1], -q[2], q[3]};
|
42 |
+
actSO3(qinv, &X[0], &Y[0]);
|
43 |
+
actSO3(qinv, &X[3], &Y[3]);
|
44 |
+
|
45 |
+
float u[3], v[3];
|
46 |
+
u[0] = t[2]*X[1] - t[1]*X[2];
|
47 |
+
u[1] = t[0]*X[2] - t[2]*X[0];
|
48 |
+
u[2] = t[1]*X[0] - t[0]*X[1];
|
49 |
+
|
50 |
+
actSO3(qinv, u, v);
|
51 |
+
Y[3] += v[0];
|
52 |
+
Y[4] += v[1];
|
53 |
+
Y[5] += v[2];
|
54 |
+
}
|
55 |
+
|
56 |
+
__device__ void
|
57 |
+
relSE3(const float *ti, const float *qi, const float *tj, const float *qj, float *tij, float *qij) {
|
58 |
+
qij[0] = -qj[3] * qi[0] + qj[0] * qi[3] - qj[1] * qi[2] + qj[2] * qi[1],
|
59 |
+
qij[1] = -qj[3] * qi[1] + qj[1] * qi[3] - qj[2] * qi[0] + qj[0] * qi[2],
|
60 |
+
qij[2] = -qj[3] * qi[2] + qj[2] * qi[3] - qj[0] * qi[1] + qj[1] * qi[0],
|
61 |
+
qij[3] = qj[3] * qi[3] + qj[0] * qi[0] + qj[1] * qi[1] + qj[2] * qi[2],
|
62 |
+
|
63 |
+
actSO3(qij, ti, tij);
|
64 |
+
tij[0] = tj[0] - tij[0];
|
65 |
+
tij[1] = tj[1] - tij[1];
|
66 |
+
tij[2] = tj[2] - tij[2];
|
67 |
+
}
|
68 |
+
|
69 |
+
|
70 |
+
__device__ void
|
71 |
+
expSO3(const float *phi, float* q) {
|
72 |
+
// SO3 exponential map
|
73 |
+
float theta_sq = phi[0]*phi[0] + phi[1]*phi[1] + phi[2]*phi[2];
|
74 |
+
float theta_p4 = theta_sq * theta_sq;
|
75 |
+
|
76 |
+
float theta = sqrtf(theta_sq);
|
77 |
+
float imag, real;
|
78 |
+
|
79 |
+
if (theta_sq < 1e-8) {
|
80 |
+
imag = 0.5 - (1.0/48.0)*theta_sq + (1.0/3840.0)*theta_p4;
|
81 |
+
real = 1.0 - (1.0/ 8.0)*theta_sq + (1.0/ 384.0)*theta_p4;
|
82 |
+
} else {
|
83 |
+
imag = sinf(0.5 * theta) / theta;
|
84 |
+
real = cosf(0.5 * theta);
|
85 |
+
}
|
86 |
+
|
87 |
+
q[0] = imag * phi[0];
|
88 |
+
q[1] = imag * phi[1];
|
89 |
+
q[2] = imag * phi[2];
|
90 |
+
q[3] = real;
|
91 |
+
|
92 |
+
}
|
93 |
+
|
94 |
+
__device__ void
|
95 |
+
crossInplace(const float* a, float *b) {
|
96 |
+
float x[3] = {
|
97 |
+
a[1]*b[2] - a[2]*b[1],
|
98 |
+
a[2]*b[0] - a[0]*b[2],
|
99 |
+
a[0]*b[1] - a[1]*b[0],
|
100 |
+
};
|
101 |
+
|
102 |
+
b[0] = x[0];
|
103 |
+
b[1] = x[1];
|
104 |
+
b[2] = x[2];
|
105 |
+
}
|
106 |
+
|
107 |
+
__device__ void
|
108 |
+
expSE3(const float *xi, float* t, float* q) {
|
109 |
+
// SE3 exponential map
|
110 |
+
|
111 |
+
expSO3(xi + 3, q);
|
112 |
+
float tau[3] = {xi[0], xi[1], xi[2]};
|
113 |
+
float phi[3] = {xi[3], xi[4], xi[5]};
|
114 |
+
|
115 |
+
float theta_sq = phi[0]*phi[0] + phi[1]*phi[1] + phi[2]*phi[2];
|
116 |
+
float theta = sqrtf(theta_sq);
|
117 |
+
|
118 |
+
t[0] = tau[0];
|
119 |
+
t[1] = tau[1];
|
120 |
+
t[2] = tau[2];
|
121 |
+
|
122 |
+
if (theta > 1e-4) {
|
123 |
+
float a = (1 - cosf(theta)) / theta_sq;
|
124 |
+
crossInplace(phi, tau);
|
125 |
+
t[0] += a * tau[0];
|
126 |
+
t[1] += a * tau[1];
|
127 |
+
t[2] += a * tau[2];
|
128 |
+
|
129 |
+
float b = (theta - sinf(theta)) / (theta * theta_sq);
|
130 |
+
crossInplace(phi, tau);
|
131 |
+
t[0] += b * tau[0];
|
132 |
+
t[1] += b * tau[1];
|
133 |
+
t[2] += b * tau[2];
|
134 |
+
}
|
135 |
+
}
|
136 |
+
|
137 |
+
|
138 |
+
__device__ void
|
139 |
+
retrSE3(const float *xi, const float* t, const float* q, float* t1, float* q1) {
|
140 |
+
// retraction on SE3 manifold
|
141 |
+
|
142 |
+
float dt[3] = {0, 0, 0};
|
143 |
+
float dq[4] = {0, 0, 0, 1};
|
144 |
+
|
145 |
+
expSE3(xi, dt, dq);
|
146 |
+
|
147 |
+
q1[0] = dq[3] * q[0] + dq[0] * q[3] + dq[1] * q[2] - dq[2] * q[1];
|
148 |
+
q1[1] = dq[3] * q[1] + dq[1] * q[3] + dq[2] * q[0] - dq[0] * q[2];
|
149 |
+
q1[2] = dq[3] * q[2] + dq[2] * q[3] + dq[0] * q[1] - dq[1] * q[0];
|
150 |
+
q1[3] = dq[3] * q[3] - dq[0] * q[0] - dq[1] * q[1] - dq[2] * q[2];
|
151 |
+
|
152 |
+
actSO3(dq, t, t1);
|
153 |
+
t1[0] += dt[0];
|
154 |
+
t1[1] += dt[1];
|
155 |
+
t1[2] += dt[2];
|
156 |
+
}
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
__global__ void pose_retr_kernel(const int t0, const int t1,
|
161 |
+
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
|
162 |
+
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> update)
|
163 |
+
{
|
164 |
+
GPU_1D_KERNEL_LOOP(i, t1 - t0) {
|
165 |
+
const float t = t0 + i;
|
166 |
+
float t1[3], t0[3] = { poses[t][0], poses[t][1], poses[t][2] };
|
167 |
+
float q1[4], q0[4] = { poses[t][3], poses[t][4], poses[t][5], poses[t][6] };
|
168 |
+
|
169 |
+
float xi[6] = {
|
170 |
+
update[i][0],
|
171 |
+
update[i][1],
|
172 |
+
update[i][2],
|
173 |
+
update[i][3],
|
174 |
+
update[i][4],
|
175 |
+
update[i][5],
|
176 |
+
};
|
177 |
+
|
178 |
+
retrSE3(xi, t0, q0, t1, q1);
|
179 |
+
|
180 |
+
poses[t][0] = t1[0];
|
181 |
+
poses[t][1] = t1[1];
|
182 |
+
poses[t][2] = t1[2];
|
183 |
+
poses[t][3] = q1[0];
|
184 |
+
poses[t][4] = q1[1];
|
185 |
+
poses[t][5] = q1[2];
|
186 |
+
poses[t][6] = q1[3];
|
187 |
+
}
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
__global__ void patch_retr_kernel(
|
192 |
+
torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> index,
|
193 |
+
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
|
194 |
+
torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> update)
|
195 |
+
{
|
196 |
+
GPU_1D_KERNEL_LOOP(n, index.size(0)) {
|
197 |
+
const int p = patches.size(2);
|
198 |
+
const int ix = index[n];
|
199 |
+
|
200 |
+
float d = patches[ix][2][0][0];
|
201 |
+
d = d + update[n];
|
202 |
+
d = (d > 20) ? 1.0 : d;
|
203 |
+
d = max(d, 1e-4);
|
204 |
+
|
205 |
+
for (int i=0; i<p; i++) {
|
206 |
+
for (int j=0; j<p; j++) {
|
207 |
+
patches[ix][2][i][j] = d;
|
208 |
+
}
|
209 |
+
}
|
210 |
+
}
|
211 |
+
}
|
212 |
+
|
213 |
+
|
214 |
+
__global__ void reprojection_residuals_and_hessian(
|
215 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
|
216 |
+
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
|
217 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
|
218 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> target,
|
219 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> weight,
|
220 |
+
const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> lmbda,
|
221 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ii,
|
222 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> jj,
|
223 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> kk,
|
224 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ku,
|
225 |
+
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> B,
|
226 |
+
torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> E,
|
227 |
+
torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> C,
|
228 |
+
torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> v,
|
229 |
+
torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> u, const int t0)
|
230 |
+
{
|
231 |
+
|
232 |
+
__shared__ float fx, fy, cx, cy;
|
233 |
+
if (threadIdx.x == 0) {
|
234 |
+
fx = intrinsics[0][0];
|
235 |
+
fy = intrinsics[0][1];
|
236 |
+
cx = intrinsics[0][2];
|
237 |
+
cy = intrinsics[0][3];
|
238 |
+
}
|
239 |
+
|
240 |
+
__syncthreads();
|
241 |
+
|
242 |
+
GPU_1D_KERNEL_LOOP(n, ii.size(0)) {
|
243 |
+
int k = ku[n];
|
244 |
+
int ix = ii[n];
|
245 |
+
int jx = jj[n];
|
246 |
+
int kx = kk[n];
|
247 |
+
|
248 |
+
float ti[3] = { poses[ix][0], poses[ix][1], poses[ix][2] };
|
249 |
+
float tj[3] = { poses[jx][0], poses[jx][1], poses[jx][2] };
|
250 |
+
float qi[4] = { poses[ix][3], poses[ix][4], poses[ix][5], poses[ix][6] };
|
251 |
+
float qj[4] = { poses[jx][3], poses[jx][4], poses[jx][5], poses[jx][6] };
|
252 |
+
|
253 |
+
float Xi[4], Xj[4];
|
254 |
+
Xi[0] = (patches[kx][0][1][1] - cx) / fx;
|
255 |
+
Xi[1] = (patches[kx][1][1][1] - cy) / fy;
|
256 |
+
Xi[2] = 1.0;
|
257 |
+
Xi[3] = patches[kx][2][1][1];
|
258 |
+
|
259 |
+
float tij[3], qij[4];
|
260 |
+
relSE3(ti, qi, tj, qj, tij, qij);
|
261 |
+
actSE3(tij, qij, Xi, Xj);
|
262 |
+
|
263 |
+
const float X = Xj[0];
|
264 |
+
const float Y = Xj[1];
|
265 |
+
const float Z = Xj[2];
|
266 |
+
const float W = Xj[3];
|
267 |
+
|
268 |
+
const float d = (Z >= 0.2) ? 1.0 / Z : 0.0;
|
269 |
+
const float d2 = d * d;
|
270 |
+
|
271 |
+
const float x1 = fx * (X / Z) + cx;
|
272 |
+
const float y1 = fy * (Y / Z) + cy;
|
273 |
+
|
274 |
+
const float rx = target[n][0] - x1;
|
275 |
+
const float ry = target[n][1] - y1;
|
276 |
+
|
277 |
+
const bool in_bounds = (sqrt(rx*rx + ry*ry) < 128) && (Z > 0.2) &&
|
278 |
+
(x1 > -64) && (y1 > -64) && (x1 < 2*cx + 64) && (y1 < 2*cy + 64);
|
279 |
+
|
280 |
+
const float mask = in_bounds ? 1.0 : 0.0;
|
281 |
+
|
282 |
+
ix = ix - t0;
|
283 |
+
jx = jx - t0;
|
284 |
+
|
285 |
+
{
|
286 |
+
const float r = target[n][0] - x1;
|
287 |
+
const float w = mask * weight[n][0];
|
288 |
+
|
289 |
+
float Jz = fx * (tij[0] * d - tij[2] * (X * d2));
|
290 |
+
float Ji[6], Jj[6] = {fx*W*d, 0, fx*-X*W*d2, fx*-X*Y*d2, fx*(1+X*X*d2), fx*-Y*d};
|
291 |
+
|
292 |
+
adjSE3(tij, qij, Jj, Ji);
|
293 |
+
|
294 |
+
for (int i=0; i<6; i++) {
|
295 |
+
for (int j=0; j<6; j++) {
|
296 |
+
if (ix >= 0)
|
297 |
+
atomicAdd(&B[6*ix+i][6*ix+j], w * Ji[i] * Ji[j]);
|
298 |
+
if (jx >= 0)
|
299 |
+
atomicAdd(&B[6*jx+i][6*jx+j], w * Jj[i] * Jj[j]);
|
300 |
+
if (ix >= 0 && jx >= 0) {
|
301 |
+
atomicAdd(&B[6*ix+i][6*jx+j], -w * Ji[i] * Jj[j]);
|
302 |
+
atomicAdd(&B[6*jx+i][6*ix+j], -w * Jj[i] * Ji[j]);
|
303 |
+
}
|
304 |
+
}
|
305 |
+
}
|
306 |
+
|
307 |
+
for (int i=0; i<6; i++) {
|
308 |
+
if (ix >= 0)
|
309 |
+
atomicAdd(&E[6*ix+i][k], -w * Jz * Ji[i]);
|
310 |
+
if (jx >= 0)
|
311 |
+
atomicAdd(&E[6*jx+i][k], w * Jz * Jj[i]);
|
312 |
+
}
|
313 |
+
|
314 |
+
for (int i=0; i<6; i++) {
|
315 |
+
if (ix >= 0)
|
316 |
+
atomicAdd(&v[6*ix+i], -w * r * Ji[i]);
|
317 |
+
if (jx >= 0)
|
318 |
+
atomicAdd(&v[6*jx+i], w * r * Jj[i]);
|
319 |
+
}
|
320 |
+
|
321 |
+
atomicAdd(&C[k], w * Jz * Jz);
|
322 |
+
atomicAdd(&u[k], w * r * Jz);
|
323 |
+
}
|
324 |
+
|
325 |
+
{
|
326 |
+
const float r = target[n][1] - y1;
|
327 |
+
const float w = mask * weight[n][1];
|
328 |
+
|
329 |
+
float Jz = fy * (tij[1] * d - tij[2] * (Y * d2));
|
330 |
+
float Ji[6], Jj[6] = {0, fy*W*d, fy*-Y*W*d2, fy*(-1-Y*Y*d2), fy*(X*Y*d2), fy*X*d};
|
331 |
+
|
332 |
+
adjSE3(tij, qij, Jj, Ji);
|
333 |
+
|
334 |
+
for (int i=0; i<6; i++) {
|
335 |
+
for (int j=0; j<6; j++) {
|
336 |
+
if (ix >= 0)
|
337 |
+
atomicAdd(&B[6*ix+i][6*ix+j], w * Ji[i] * Ji[j]);
|
338 |
+
if (jx >= 0)
|
339 |
+
atomicAdd(&B[6*jx+i][6*jx+j], w * Jj[i] * Jj[j]);
|
340 |
+
if (ix >= 0 && jx >= 0) {
|
341 |
+
atomicAdd(&B[6*ix+i][6*jx+j], -w * Ji[i] * Jj[j]);
|
342 |
+
atomicAdd(&B[6*jx+i][6*ix+j], -w * Jj[i] * Ji[j]);
|
343 |
+
}
|
344 |
+
}
|
345 |
+
}
|
346 |
+
|
347 |
+
for (int i=0; i<6; i++) {
|
348 |
+
if (ix >= 0)
|
349 |
+
atomicAdd(&E[6*ix+i][k], -w * Jz * Ji[i]);
|
350 |
+
if (jx >= 0)
|
351 |
+
atomicAdd(&E[6*jx+i][k], w * Jz * Jj[i]);
|
352 |
+
}
|
353 |
+
|
354 |
+
for (int i=0; i<6; i++) {
|
355 |
+
if (ix >= 0)
|
356 |
+
atomicAdd(&v[6*ix+i], -w * r * Ji[i]);
|
357 |
+
if (jx >= 0)
|
358 |
+
atomicAdd(&v[6*jx+i], w * r * Jj[i]);
|
359 |
+
}
|
360 |
+
|
361 |
+
atomicAdd(&C[k], w * Jz * Jz);
|
362 |
+
atomicAdd(&u[k], w * r * Jz);
|
363 |
+
}
|
364 |
+
}
|
365 |
+
}
|
366 |
+
|
367 |
+
|
368 |
+
__global__ void reproject(
|
369 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
|
370 |
+
const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
|
371 |
+
const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
|
372 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ii,
|
373 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> jj,
|
374 |
+
const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> kk,
|
375 |
+
torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords) {
|
376 |
+
|
377 |
+
__shared__ float fx, fy, cx, cy;
|
378 |
+
if (threadIdx.x == 0) {
|
379 |
+
fx = intrinsics[0][0];
|
380 |
+
fy = intrinsics[0][1];
|
381 |
+
cx = intrinsics[0][2];
|
382 |
+
cy = intrinsics[0][3];
|
383 |
+
}
|
384 |
+
|
385 |
+
__syncthreads();
|
386 |
+
|
387 |
+
GPU_1D_KERNEL_LOOP(n, ii.size(0)) {
|
388 |
+
int ix = ii[n];
|
389 |
+
int jx = jj[n];
|
390 |
+
int kx = kk[n];
|
391 |
+
|
392 |
+
float ti[3] = { poses[ix][0], poses[ix][1], poses[ix][2] };
|
393 |
+
float tj[3] = { poses[jx][0], poses[jx][1], poses[jx][2] };
|
394 |
+
float qi[4] = { poses[ix][3], poses[ix][4], poses[ix][5], poses[ix][6] };
|
395 |
+
float qj[4] = { poses[jx][3], poses[jx][4], poses[jx][5], poses[jx][6] };
|
396 |
+
|
397 |
+
float tij[3], qij[4];
|
398 |
+
relSE3(ti, qi, tj, qj, tij, qij);
|
399 |
+
|
400 |
+
float Xi[4], Xj[4];
|
401 |
+
for (int i=0; i<patches.size(2); i++) {
|
402 |
+
for (int j=0; j<patches.size(3); j++) {
|
403 |
+
|
404 |
+
Xi[0] = (patches[kx][0][i][j] - cx) / fx;
|
405 |
+
Xi[1] = (patches[kx][1][i][j] - cy) / fy;
|
406 |
+
Xi[2] = 1.0;
|
407 |
+
Xi[3] = patches[kx][2][i][j];
|
408 |
+
|
409 |
+
actSE3(tij, qij, Xi, Xj);
|
410 |
+
|
411 |
+
coords[n][0][i][j] = fx * (Xj[0] / Xj[2]) + cx;
|
412 |
+
coords[n][1][i][j] = fy * (Xj[1] / Xj[2]) + cy;
|
413 |
+
// coords[n][2][i][j] = 1.0 / Xj[2];
|
414 |
+
|
415 |
+
}
|
416 |
+
}
|
417 |
+
}
|
418 |
+
}
|
419 |
+
|
420 |
+
|
421 |
+
|
422 |
+
std::vector<torch::Tensor> cuda_ba(
|
423 |
+
torch::Tensor poses,
|
424 |
+
torch::Tensor patches,
|
425 |
+
torch::Tensor intrinsics,
|
426 |
+
torch::Tensor target,
|
427 |
+
torch::Tensor weight,
|
428 |
+
torch::Tensor lmbda,
|
429 |
+
torch::Tensor ii,
|
430 |
+
torch::Tensor jj,
|
431 |
+
torch::Tensor kk,
|
432 |
+
const int t0, const int t1, const int iterations)
|
433 |
+
{
|
434 |
+
|
435 |
+
auto ktuple = torch::_unique(kk, true, true);
|
436 |
+
torch::Tensor kx = std::get<0>(ktuple);
|
437 |
+
torch::Tensor ku = std::get<1>(ktuple);
|
438 |
+
|
439 |
+
const int N = t1 - t0; // number of poses
|
440 |
+
const int M = kx.size(0); // number of patches
|
441 |
+
const int P = patches.size(3); // patch size
|
442 |
+
|
443 |
+
auto opts = torch::TensorOptions()
|
444 |
+
.dtype(torch::kFloat32).device(torch::kCUDA);
|
445 |
+
|
446 |
+
poses = poses.view({-1, 7});
|
447 |
+
patches = patches.view({-1,3,P,P});
|
448 |
+
intrinsics = intrinsics.view({-1, 4});
|
449 |
+
|
450 |
+
target = target.view({-1, 2});
|
451 |
+
weight = weight.view({-1, 2});
|
452 |
+
|
453 |
+
const int num = ii.size(0);
|
454 |
+
torch::Tensor B = torch::empty({6*N, 6*N}, opts);
|
455 |
+
torch::Tensor E = torch::empty({6*N, 1*M}, opts);
|
456 |
+
torch::Tensor C = torch::empty({M}, opts);
|
457 |
+
|
458 |
+
torch::Tensor v = torch::empty({6*N}, opts);
|
459 |
+
torch::Tensor u = torch::empty({1*M}, opts);
|
460 |
+
|
461 |
+
for (int itr=0; itr < iterations; itr++) {
|
462 |
+
|
463 |
+
B.zero_();
|
464 |
+
E.zero_();
|
465 |
+
C.zero_();
|
466 |
+
v.zero_();
|
467 |
+
u.zero_();
|
468 |
+
|
469 |
+
v = v.view({6*N});
|
470 |
+
u = u.view({1*M});
|
471 |
+
|
472 |
+
reprojection_residuals_and_hessian<<<NUM_BLOCKS(ii.size(0)), NUM_THREADS>>>(
|
473 |
+
poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
474 |
+
patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
475 |
+
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
476 |
+
target.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
477 |
+
weight.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
478 |
+
lmbda.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
|
479 |
+
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
480 |
+
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
481 |
+
kk.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
482 |
+
ku.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
483 |
+
B.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
484 |
+
E.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
485 |
+
C.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
|
486 |
+
v.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
|
487 |
+
u.packed_accessor32<float,1,torch::RestrictPtrTraits>(), t0);
|
488 |
+
|
489 |
+
v = v.view({6*N, 1});
|
490 |
+
u = u.view({1*M, 1});
|
491 |
+
|
492 |
+
torch::Tensor Q = 1.0 / (C + lmbda).view({1, M});
|
493 |
+
|
494 |
+
if (t1 - t0 == 0) {
|
495 |
+
|
496 |
+
torch::Tensor Qt = torch::transpose(Q, 0, 1);
|
497 |
+
torch::Tensor dZ = Qt * u;
|
498 |
+
|
499 |
+
dZ = dZ.view({M});
|
500 |
+
|
501 |
+
patch_retr_kernel<<<NUM_BLOCKS(M), NUM_THREADS>>>(
|
502 |
+
kx.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
503 |
+
patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
504 |
+
dZ.packed_accessor32<float,1,torch::RestrictPtrTraits>());
|
505 |
+
|
506 |
+
}
|
507 |
+
|
508 |
+
else {
|
509 |
+
|
510 |
+
torch::Tensor EQ = E * Q;
|
511 |
+
torch::Tensor Et = torch::transpose(E, 0, 1);
|
512 |
+
torch::Tensor Qt = torch::transpose(Q, 0, 1);
|
513 |
+
|
514 |
+
torch::Tensor S = B - torch::matmul(EQ, Et);
|
515 |
+
torch::Tensor y = v - torch::matmul(EQ, u);
|
516 |
+
|
517 |
+
torch::Tensor I = torch::eye(6*N, opts);
|
518 |
+
S += I * (1e-4 * S + 1.0);
|
519 |
+
|
520 |
+
|
521 |
+
torch::Tensor U = torch::linalg::cholesky(S);
|
522 |
+
torch::Tensor dX = torch::cholesky_solve(y, U);
|
523 |
+
torch::Tensor dZ = Qt * (u - torch::matmul(Et, dX));
|
524 |
+
|
525 |
+
dX = dX.view({N, 6});
|
526 |
+
dZ = dZ.view({M});
|
527 |
+
|
528 |
+
pose_retr_kernel<<<NUM_BLOCKS(N), NUM_THREADS>>>(t0, t1,
|
529 |
+
poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
530 |
+
dX.packed_accessor32<float,2,torch::RestrictPtrTraits>());
|
531 |
+
|
532 |
+
patch_retr_kernel<<<NUM_BLOCKS(M), NUM_THREADS>>>(
|
533 |
+
kx.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
534 |
+
patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
535 |
+
dZ.packed_accessor32<float,1,torch::RestrictPtrTraits>());
|
536 |
+
}
|
537 |
+
}
|
538 |
+
|
539 |
+
return {};
|
540 |
+
}
|
541 |
+
|
542 |
+
|
543 |
+
torch::Tensor cuda_reproject(
|
544 |
+
torch::Tensor poses,
|
545 |
+
torch::Tensor patches,
|
546 |
+
torch::Tensor intrinsics,
|
547 |
+
torch::Tensor ii,
|
548 |
+
torch::Tensor jj,
|
549 |
+
torch::Tensor kk)
|
550 |
+
{
|
551 |
+
|
552 |
+
const int N = ii.size(0);
|
553 |
+
const int P = patches.size(3); // patch size
|
554 |
+
|
555 |
+
poses = poses.view({-1, 7});
|
556 |
+
patches = patches.view({-1,3,P,P});
|
557 |
+
intrinsics = intrinsics.view({-1, 4});
|
558 |
+
|
559 |
+
auto opts = torch::TensorOptions()
|
560 |
+
.dtype(torch::kFloat32).device(torch::kCUDA);
|
561 |
+
|
562 |
+
torch::Tensor coords = torch::empty({N, 2, P, P}, opts);
|
563 |
+
|
564 |
+
reproject<<<NUM_BLOCKS(N), NUM_THREADS>>>(
|
565 |
+
poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
566 |
+
patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
|
567 |
+
intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
|
568 |
+
ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
569 |
+
jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
570 |
+
kk.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
|
571 |
+
coords.packed_accessor32<float,4,torch::RestrictPtrTraits>());
|
572 |
+
|
573 |
+
return coords.view({1, N, 2, P, P});
|
574 |
+
|
575 |
+
}
|
mini_dpvo/lietorch/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
__all__ = ['groups']
|
2 |
+
from .groups import LieGroupParameter, SO3, RxSO3, SE3, Sim3, cat, stack
|
mini_dpvo/lietorch/broadcasting.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def check_broadcastable(x, y):
|
5 |
+
assert len(x.shape) == len(y.shape)
|
6 |
+
for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
|
7 |
+
assert n==m or n==1 or m==1
|
8 |
+
|
9 |
+
def broadcast_inputs(x, y):
|
10 |
+
""" Automatic broadcasting of missing dimensions """
|
11 |
+
if y is None:
|
12 |
+
xs, xd = x.shape[:-1], x.shape[-1]
|
13 |
+
return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
|
14 |
+
|
15 |
+
check_broadcastable(x, y)
|
16 |
+
|
17 |
+
xs, xd = x.shape[:-1], x.shape[-1]
|
18 |
+
ys, yd = y.shape[:-1], y.shape[-1]
|
19 |
+
out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
|
20 |
+
|
21 |
+
if x.shape[:-1] == y.shape[-1]:
|
22 |
+
x1 = x.view(-1, xd)
|
23 |
+
y1 = y.view(-1, yd)
|
24 |
+
|
25 |
+
else:
|
26 |
+
x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
|
27 |
+
y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
|
28 |
+
x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
|
29 |
+
y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
|
30 |
+
|
31 |
+
return (x1, y1), tuple(out_shape)
|
mini_dpvo/lietorch/gradcheck.py
ADDED
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
TORCH_MAJOR = int(torch.__version__.split('.')[0])
|
4 |
+
TORCH_MINOR = int(torch.__version__.split('.')[1])
|
5 |
+
|
6 |
+
from torch.types import _TensorOrTensors
|
7 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
8 |
+
from torch._six import container_abcs, istuple
|
9 |
+
else:
|
10 |
+
import collections.abc as container_abcs
|
11 |
+
|
12 |
+
import torch.testing
|
13 |
+
from torch.overrides import is_tensor_like
|
14 |
+
from itertools import product
|
15 |
+
import warnings
|
16 |
+
from typing import Callable, Union, Optional, Iterable, List
|
17 |
+
|
18 |
+
def zero_gradients(x):
|
19 |
+
if isinstance(x, torch.Tensor):
|
20 |
+
if x.grad is not None:
|
21 |
+
x.grad.detach_()
|
22 |
+
x.grad.zero_()
|
23 |
+
elif isinstance(x, container_abcs.Iterable):
|
24 |
+
for elem in x:
|
25 |
+
zero_gradients(elem)
|
26 |
+
|
27 |
+
|
28 |
+
def make_jacobian(input, num_out):
|
29 |
+
if is_tensor_like(input):
|
30 |
+
if not input.is_floating_point() and not input.is_complex():
|
31 |
+
return None
|
32 |
+
if not input.requires_grad:
|
33 |
+
return None
|
34 |
+
return input.new_zeros((input.nelement(), num_out), dtype=input.dtype, layout=torch.strided)
|
35 |
+
elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
|
36 |
+
jacobians = list(filter(
|
37 |
+
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
|
38 |
+
if not jacobians:
|
39 |
+
return None
|
40 |
+
return type(input)(jacobians) # type: ignore
|
41 |
+
else:
|
42 |
+
return None
|
43 |
+
|
44 |
+
|
45 |
+
def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
|
46 |
+
if is_tensor_like(x):
|
47 |
+
# mypy doesn't narrow type of `x` to torch.Tensor
|
48 |
+
if x.requires_grad or not only_requiring_grad: # type: ignore
|
49 |
+
yield x # type: ignore
|
50 |
+
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
|
51 |
+
for elem in x:
|
52 |
+
for result in iter_tensors(elem, only_requiring_grad):
|
53 |
+
yield result
|
54 |
+
|
55 |
+
def get_numerical_jacobian(fn, input, target=None, eps=1e-3, grad_out=1.0):
|
56 |
+
"""
|
57 |
+
input: input to `fn`
|
58 |
+
target: the Tensors wrt whom Jacobians are calculated (default=`input`)
|
59 |
+
grad_out: grad output value used to calculate gradients.
|
60 |
+
|
61 |
+
Note that `target` may not even be part of `input` to `fn`, so please be
|
62 |
+
**very careful** in this to not clone `target`.
|
63 |
+
"""
|
64 |
+
if target is None:
|
65 |
+
target = input
|
66 |
+
output_size = fn(input).numel()
|
67 |
+
jacobian = make_jacobian(target, output_size)
|
68 |
+
|
69 |
+
# It's much easier to iterate over flattened lists of tensors.
|
70 |
+
# These are reference to the same objects in jacobian, so any changes
|
71 |
+
# will be reflected in it as well.
|
72 |
+
x_tensors = iter_tensors(target, True)
|
73 |
+
j_tensors = iter_tensors(jacobian)
|
74 |
+
|
75 |
+
def update_jacobians(x, idx, d, d_idx, is_mkldnn=False):
|
76 |
+
|
77 |
+
# compute_jacobian only works for pure real
|
78 |
+
# or pure imaginary delta
|
79 |
+
def compute_gradient(delta):
|
80 |
+
# we currently assume that the norm of delta equals eps
|
81 |
+
assert(delta == eps or delta == (eps * 1j))
|
82 |
+
|
83 |
+
def fn_out():
|
84 |
+
if not is_mkldnn:
|
85 |
+
# x is a view into input and so this works
|
86 |
+
return fn(input).clone()
|
87 |
+
else:
|
88 |
+
# convert the dense tensor back to have mkldnn layout
|
89 |
+
return fn([x.to_mkldnn()])
|
90 |
+
|
91 |
+
orig = x[idx].item()
|
92 |
+
x[idx] = orig - delta
|
93 |
+
outa = fn_out()
|
94 |
+
x[idx] = orig + delta
|
95 |
+
outb = fn_out()
|
96 |
+
x[idx] = orig
|
97 |
+
r = (outb - outa) / (2 * eps)
|
98 |
+
return r.detach().reshape(-1)
|
99 |
+
|
100 |
+
# for details on the algorithm used here, refer:
|
101 |
+
# Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
|
102 |
+
# s = fn(z) where z = x for real valued input
|
103 |
+
# and z = x + yj for complex valued input
|
104 |
+
ds_dx = compute_gradient(eps)
|
105 |
+
if x.is_complex(): # C -> C, C -> R
|
106 |
+
ds_dy = compute_gradient(eps * 1j)
|
107 |
+
# conjugate wirtinger derivative
|
108 |
+
conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
|
109 |
+
# wirtinger derivative
|
110 |
+
w_d = 0.5 * (ds_dx - ds_dy * 1j)
|
111 |
+
d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
|
112 |
+
elif ds_dx.is_complex(): # R -> C
|
113 |
+
# w_d = conj_w_d = 0.5 * ds_dx
|
114 |
+
# dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()]
|
115 |
+
# = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()]
|
116 |
+
# = 0.5 * 2 * real(grad_out.conj() * ds_dx)
|
117 |
+
# = real(grad_out.conj() * ds_dx)
|
118 |
+
d[d_idx] = torch.real(grad_out.conjugate() * ds_dx)
|
119 |
+
else: # R -> R
|
120 |
+
d[d_idx] = ds_dx * grad_out
|
121 |
+
|
122 |
+
# TODO: compare structure
|
123 |
+
for x_tensor, d_tensor in zip(x_tensors, j_tensors):
|
124 |
+
if x_tensor.is_sparse:
|
125 |
+
def get_stride(size):
|
126 |
+
dim = len(size)
|
127 |
+
tmp = 1
|
128 |
+
stride = [0] * dim
|
129 |
+
for i in reversed(range(dim)):
|
130 |
+
stride[i] = tmp
|
131 |
+
tmp *= size[i]
|
132 |
+
return stride
|
133 |
+
|
134 |
+
x_nnz = x_tensor._nnz()
|
135 |
+
x_size = list(x_tensor.size())
|
136 |
+
x_indices = x_tensor._indices().t()
|
137 |
+
x_values = x_tensor._values()
|
138 |
+
x_stride = get_stride(x_size)
|
139 |
+
|
140 |
+
# Use .data here to get around the version check
|
141 |
+
x_values = x_values.data
|
142 |
+
|
143 |
+
for i in range(x_nnz):
|
144 |
+
x_value = x_values[i]
|
145 |
+
for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
|
146 |
+
indices = x_indices[i].tolist() + list(x_idx)
|
147 |
+
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
|
148 |
+
update_jacobians(x_value, x_idx, d_tensor, d_idx)
|
149 |
+
elif x_tensor.layout == torch._mkldnn: # type: ignore
|
150 |
+
# Use .data here to get around the version check
|
151 |
+
x_tensor = x_tensor.data
|
152 |
+
if len(input) != 1:
|
153 |
+
raise ValueError('gradcheck currently only supports functions with 1 input, but got: ',
|
154 |
+
len(input))
|
155 |
+
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
|
156 |
+
# this is really inefficient, but without indexing implemented, there's
|
157 |
+
# not really a better way than converting back and forth
|
158 |
+
x_tensor_dense = x_tensor.to_dense()
|
159 |
+
update_jacobians(x_tensor_dense, x_idx, d_tensor, d_idx, is_mkldnn=True)
|
160 |
+
else:
|
161 |
+
# Use .data here to get around the version check
|
162 |
+
x_tensor = x_tensor.data
|
163 |
+
for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
|
164 |
+
update_jacobians(x_tensor, x_idx, d_tensor, d_idx)
|
165 |
+
|
166 |
+
return jacobian
|
167 |
+
|
168 |
+
|
169 |
+
def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
|
170 |
+
# it is easier to call to_dense() on the sparse output than
|
171 |
+
# to modify analytical jacobian
|
172 |
+
if output.is_sparse:
|
173 |
+
raise ValueError('Sparse output is not supported at gradcheck yet. '
|
174 |
+
'Please call to_dense() on the output of fn for gradcheck.')
|
175 |
+
if output.layout == torch._mkldnn: # type: ignore
|
176 |
+
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
|
177 |
+
'Please call to_dense() on the output of fn for gradcheck.')
|
178 |
+
diff_input_list = list(iter_tensors(input, True))
|
179 |
+
jacobian = make_jacobian(input, output.numel())
|
180 |
+
jacobian_reentrant = make_jacobian(input, output.numel())
|
181 |
+
grad_output = torch.zeros_like(output, memory_format=torch.legacy_contiguous_format)
|
182 |
+
flat_grad_output = grad_output.view(-1)
|
183 |
+
reentrant = True
|
184 |
+
correct_grad_sizes = True
|
185 |
+
correct_grad_types = True
|
186 |
+
|
187 |
+
for i in range(flat_grad_output.numel()):
|
188 |
+
flat_grad_output.zero_()
|
189 |
+
flat_grad_output[i] = grad_out
|
190 |
+
for jacobian_c in (jacobian, jacobian_reentrant):
|
191 |
+
grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
|
192 |
+
retain_graph=True, allow_unused=True)
|
193 |
+
for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list):
|
194 |
+
if d_x is not None and d_x.size() != x.size():
|
195 |
+
correct_grad_sizes = False
|
196 |
+
elif d_x is not None and d_x.dtype != x.dtype:
|
197 |
+
correct_grad_types = False
|
198 |
+
elif jacobian_x.numel() != 0:
|
199 |
+
if d_x is None:
|
200 |
+
jacobian_x[:, i].zero_()
|
201 |
+
else:
|
202 |
+
d_x_dense = d_x.to_dense() if not d_x.layout == torch.strided else d_x
|
203 |
+
assert jacobian_x[:, i].numel() == d_x_dense.numel()
|
204 |
+
jacobian_x[:, i] = d_x_dense.contiguous().view(-1)
|
205 |
+
|
206 |
+
for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
|
207 |
+
if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() > nondet_tol:
|
208 |
+
reentrant = False
|
209 |
+
|
210 |
+
return jacobian, reentrant, correct_grad_sizes, correct_grad_types
|
211 |
+
|
212 |
+
|
213 |
+
def _as_tuple(x):
|
214 |
+
if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
|
215 |
+
b_tuple = istuple(x)
|
216 |
+
else:
|
217 |
+
b_tuple = isinstance(x, tuple)
|
218 |
+
|
219 |
+
if b_tuple:
|
220 |
+
return x
|
221 |
+
elif isinstance(x, list):
|
222 |
+
return tuple(x)
|
223 |
+
else:
|
224 |
+
return x,
|
225 |
+
|
226 |
+
|
227 |
+
|
228 |
+
def _differentiable_outputs(x):
|
229 |
+
return tuple(o for o in _as_tuple(x) if o.requires_grad)
|
230 |
+
|
231 |
+
|
232 |
+
# Note [VarArg of Tensors]
|
233 |
+
# ~~~~~~~~~~~~~~~~~~~~~~~~
|
234 |
+
# 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
|
235 |
+
# If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
|
236 |
+
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
|
237 |
+
# For now, we permit any input.
|
238 |
+
# the '...' first argument of Callable can be replaced with VarArg(Tensor).
|
239 |
+
# For now, we permit any input.
|
240 |
+
|
241 |
+
def gradcheck(
|
242 |
+
func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
|
243 |
+
inputs: _TensorOrTensors,
|
244 |
+
eps: float = 1e-6,
|
245 |
+
atol: float = 1e-5,
|
246 |
+
rtol: float = 1e-3,
|
247 |
+
raise_exception: bool = True,
|
248 |
+
check_sparse_nnz: bool = False,
|
249 |
+
nondet_tol: float = 0.0,
|
250 |
+
check_undefined_grad: bool = True,
|
251 |
+
check_grad_dtypes: bool = False
|
252 |
+
) -> bool:
|
253 |
+
r"""Check gradients computed via small finite differences against analytical
|
254 |
+
gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
|
255 |
+
and with ``requires_grad=True``.
|
256 |
+
|
257 |
+
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
|
258 |
+
|
259 |
+
For complex functions, no notion of Jacobian exists. Gradcheck verifies if the numerical and
|
260 |
+
analytical values of Wirtinger and Conjugate Wirtinger derivative are consistent. The gradient
|
261 |
+
computation is done under the assumption that the overall function has a real valued output.
|
262 |
+
For functions with complex output, gradcheck compares the numerical and analytical gradients
|
263 |
+
for two values of :attr:`grad_output`: 1 and 1j. For more details, check out
|
264 |
+
:ref:`complex_autograd-doc`.
|
265 |
+
|
266 |
+
.. note::
|
267 |
+
The default values are designed for :attr:`input` of double precision.
|
268 |
+
This check will likely fail if :attr:`input` is of less precision, e.g.,
|
269 |
+
``FloatTensor``.
|
270 |
+
|
271 |
+
.. warning::
|
272 |
+
If any checked tensor in :attr:`input` has overlapping memory, i.e.,
|
273 |
+
different indices pointing to the same memory address (e.g., from
|
274 |
+
:func:`torch.expand`), this check will likely fail because the numerical
|
275 |
+
gradients computed by point perturbation at such indices will change
|
276 |
+
values at all other indices that share the same memory address.
|
277 |
+
|
278 |
+
Args:
|
279 |
+
func (function): a Python function that takes Tensor inputs and returns
|
280 |
+
a Tensor or a tuple of Tensors
|
281 |
+
inputs (tuple of Tensor or Tensor): inputs to the function
|
282 |
+
eps (float, optional): perturbation for finite differences
|
283 |
+
atol (float, optional): absolute tolerance
|
284 |
+
rtol (float, optional): relative tolerance
|
285 |
+
raise_exception (bool, optional): indicating whether to raise an exception if
|
286 |
+
the check fails. The exception gives more information about the
|
287 |
+
exact nature of the failure. This is helpful when debugging gradchecks.
|
288 |
+
check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
|
289 |
+
and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
|
290 |
+
nondet_tol (float, optional): tolerance for non-determinism. When running
|
291 |
+
identical inputs through the differentiation, the results must either match
|
292 |
+
exactly (default, 0.0) or be within this tolerance.
|
293 |
+
check_undefined_grad (bool, options): if True, check if undefined output grads
|
294 |
+
are supported and treated as zeros, for ``Tensor`` outputs.
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
True if all differences satisfy allclose condition
|
298 |
+
"""
|
299 |
+
def fail_test(msg):
|
300 |
+
if raise_exception:
|
301 |
+
raise RuntimeError(msg)
|
302 |
+
return False
|
303 |
+
|
304 |
+
tupled_inputs = _as_tuple(inputs)
|
305 |
+
if not check_sparse_nnz and any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)):
|
306 |
+
return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
|
307 |
+
|
308 |
+
# Make sure that gradients are saved for at least one input
|
309 |
+
any_input_requiring_grad = False
|
310 |
+
for idx, inp in enumerate(tupled_inputs):
|
311 |
+
if is_tensor_like(inp) and inp.requires_grad:
|
312 |
+
if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
|
313 |
+
warnings.warn(
|
314 |
+
f'Input #{idx} requires gradient and '
|
315 |
+
'is not a double precision floating point or complex. '
|
316 |
+
'This check will likely fail if all the inputs are '
|
317 |
+
'not of double precision floating point or complex. ')
|
318 |
+
content = inp._values() if inp.is_sparse else inp
|
319 |
+
# TODO: To cover more problematic cases, replace stride = 0 check with
|
320 |
+
# "any overlap in memory" once we have a proper function to check it.
|
321 |
+
if content.layout is not torch._mkldnn: # type: ignore
|
322 |
+
if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
|
323 |
+
raise RuntimeError(
|
324 |
+
'The {}th input has a dimension with stride 0. gradcheck only '
|
325 |
+
'supports inputs that are non-overlapping to be able to '
|
326 |
+
'compute the numerical gradients correctly. You should call '
|
327 |
+
'.contiguous on the input before passing it to gradcheck.')
|
328 |
+
any_input_requiring_grad = True
|
329 |
+
inp.retain_grad()
|
330 |
+
if not any_input_requiring_grad:
|
331 |
+
raise ValueError(
|
332 |
+
'gradcheck expects at least one input tensor to require gradient, '
|
333 |
+
'but none of the them have requires_grad=True.')
|
334 |
+
|
335 |
+
func_out = func(*tupled_inputs)
|
336 |
+
output = _differentiable_outputs(func_out)
|
337 |
+
|
338 |
+
if not output:
|
339 |
+
for i, o in enumerate(func_out):
|
340 |
+
def fn(input):
|
341 |
+
return _as_tuple(func(*input))[i]
|
342 |
+
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
|
343 |
+
for n in numerical:
|
344 |
+
if torch.ne(n, 0).sum() > 0:
|
345 |
+
return fail_test('Numerical gradient for function expected to be zero')
|
346 |
+
return True
|
347 |
+
|
348 |
+
for i, o in enumerate(output):
|
349 |
+
if not o.requires_grad:
|
350 |
+
continue
|
351 |
+
|
352 |
+
def fn(input):
|
353 |
+
return _as_tuple(func(*input))[i]
|
354 |
+
|
355 |
+
analytical, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian(tupled_inputs,
|
356 |
+
o,
|
357 |
+
nondet_tol=nondet_tol)
|
358 |
+
numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
|
359 |
+
|
360 |
+
return analytical, numerical
|
361 |
+
|
362 |
+
out_is_complex = o.is_complex()
|
363 |
+
|
364 |
+
if out_is_complex:
|
365 |
+
# analytical vjp with grad_out = 1.0j
|
366 |
+
analytical_with_imag_grad_out, reentrant_with_imag_grad_out, \
|
367 |
+
correct_grad_sizes_with_imag_grad_out, correct_grad_types_with_imag_grad_out \
|
368 |
+
= get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol, grad_out=1j)
|
369 |
+
numerical_with_imag_grad_out = get_numerical_jacobian(fn, tupled_inputs, eps=eps, grad_out=1j)
|
370 |
+
|
371 |
+
if not correct_grad_types and check_grad_dtypes:
|
372 |
+
return fail_test('Gradient has dtype mismatch')
|
373 |
+
|
374 |
+
if out_is_complex and not correct_grad_types_with_imag_grad_out and check_grad_dtypes:
|
375 |
+
return fail_test('Gradient (calculated using complex valued grad output) has dtype mismatch')
|
376 |
+
|
377 |
+
if not correct_grad_sizes:
|
378 |
+
return fail_test('Analytical gradient has incorrect size')
|
379 |
+
|
380 |
+
if out_is_complex and not correct_grad_sizes_with_imag_grad_out:
|
381 |
+
return fail_test('Analytical gradient (calculated using complex valued grad output) has incorrect size')
|
382 |
+
|
383 |
+
def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
|
384 |
+
if not torch.allclose(a, n, rtol, atol):
|
385 |
+
return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
|
386 |
+
'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
|
387 |
+
|
388 |
+
inp_tensors = iter_tensors(tupled_inputs, True)
|
389 |
+
|
390 |
+
for j, (a, n, inp) in enumerate(zip(analytical, numerical, inp_tensors)):
|
391 |
+
if a.numel() != 0 or n.numel() != 0:
|
392 |
+
if o.is_complex():
|
393 |
+
# C -> C, R -> C
|
394 |
+
a_with_imag_grad_out = analytical_with_imag_grad_out[j]
|
395 |
+
n_with_imag_grad_out = numerical_with_imag_grad_out[j]
|
396 |
+
checkIfNumericalAnalyticAreClose(a_with_imag_grad_out, n_with_imag_grad_out, j,
|
397 |
+
"Gradients failed to compare equal for grad output = 1j. ")
|
398 |
+
if inp.is_complex():
|
399 |
+
# C -> R, C -> C
|
400 |
+
checkIfNumericalAnalyticAreClose(a, n, j,
|
401 |
+
"Gradients failed to compare equal for grad output = 1. ")
|
402 |
+
else:
|
403 |
+
# R -> R, R -> C
|
404 |
+
checkIfNumericalAnalyticAreClose(a, n, j)
|
405 |
+
|
406 |
+
|
407 |
+
def not_reentrant_error(error_str=''):
|
408 |
+
error_msg = "Backward" + error_str + " is not reentrant, i.e., running backward with same \
|
409 |
+
input and grad_output multiple times gives different values, \
|
410 |
+
although analytical gradient matches numerical gradient. \
|
411 |
+
The tolerance for nondeterminism was {}.".format(nondet_tol)
|
412 |
+
return fail_test(error_msg)
|
413 |
+
|
414 |
+
if not reentrant:
|
415 |
+
return not_reentrant_error()
|
416 |
+
|
417 |
+
if out_is_complex and not reentrant_with_imag_grad_out:
|
418 |
+
return not_reentrant_error(' (calculated using complex valued grad output)')
|
419 |
+
|
420 |
+
# check if the backward multiplies by grad_output
|
421 |
+
output = _differentiable_outputs(func(*tupled_inputs))
|
422 |
+
if any([o.requires_grad for o in output]):
|
423 |
+
diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True))
|
424 |
+
if not diff_input_list:
|
425 |
+
raise RuntimeError("no Tensors requiring grad found in input")
|
426 |
+
grads_input = torch.autograd.grad(output, diff_input_list,
|
427 |
+
[torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output],
|
428 |
+
allow_unused=True)
|
429 |
+
for gi, di in zip(grads_input, diff_input_list):
|
430 |
+
if gi is None:
|
431 |
+
continue
|
432 |
+
if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
|
433 |
+
if gi.layout != di.layout:
|
434 |
+
return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')')
|
435 |
+
if gi.layout == torch.sparse_coo:
|
436 |
+
if gi.sparse_dim() != di.sparse_dim():
|
437 |
+
return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
|
438 |
+
if gi.dense_dim() != di.dense_dim():
|
439 |
+
return fail_test('grad is sparse tensor, but has incorrect dense_dim')
|
440 |
+
gi = gi.to_dense()
|
441 |
+
di = di.to_dense()
|
442 |
+
if not gi.eq(0).all():
|
443 |
+
return fail_test('backward not multiplied by grad_output')
|
444 |
+
if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
|
445 |
+
return fail_test("grad is incorrect type")
|
446 |
+
if gi.size() != di.size():
|
447 |
+
return fail_test('grad is incorrect size')
|
448 |
+
|
449 |
+
if check_undefined_grad:
|
450 |
+
def warn_bc_breaking():
|
451 |
+
warnings.warn((
|
452 |
+
'Backwards compatibility: New undefined gradient support checking '
|
453 |
+
'feature is enabled by default, but it may break existing callers '
|
454 |
+
'of this function. If this is true for you, you can call this '
|
455 |
+
'function with "check_undefined_grad=False" to disable the feature'))
|
456 |
+
|
457 |
+
def check_undefined_grad_support(output_to_check):
|
458 |
+
grads_output = [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output_to_check]
|
459 |
+
try:
|
460 |
+
grads_input = torch.autograd.grad(output_to_check,
|
461 |
+
diff_input_list,
|
462 |
+
grads_output,
|
463 |
+
allow_unused=True)
|
464 |
+
except RuntimeError:
|
465 |
+
warn_bc_breaking()
|
466 |
+
return fail_test((
|
467 |
+
'Expected backward function to handle undefined output grads. '
|
468 |
+
'Please look at "Notes about undefined output gradients" in '
|
469 |
+
'"tools/autograd/derivatives.yaml"'))
|
470 |
+
|
471 |
+
for gi, i in zip(grads_input, diff_input_list):
|
472 |
+
if (gi is not None) and (not gi.eq(0).all()):
|
473 |
+
warn_bc_breaking()
|
474 |
+
return fail_test((
|
475 |
+
'Expected all input grads to be undefined or zero when all output grads are undefined '
|
476 |
+
'or zero. Please look at "Notes about undefined output gradients" in '
|
477 |
+
'"tools/autograd/derivatives.yaml"'))
|
478 |
+
return True
|
479 |
+
|
480 |
+
# All backward functions must work properly if all output grads are undefined
|
481 |
+
outputs_to_check = [[
|
482 |
+
torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs))
|
483 |
+
# This check filters out Tensor-likes that aren't instances of Tensor.
|
484 |
+
if isinstance(o, torch.Tensor)
|
485 |
+
]]
|
486 |
+
|
487 |
+
# If there are multiple output grads, we should be able to undef one at a time without error
|
488 |
+
if len(outputs_to_check[0]) > 1:
|
489 |
+
for undef_grad_idx in range(len(output)):
|
490 |
+
output_to_check = _differentiable_outputs(func(*tupled_inputs))
|
491 |
+
outputs_to_check.append([
|
492 |
+
torch._C._functions.UndefinedGrad()(o) if idx == undef_grad_idx else o
|
493 |
+
for idx, o in enumerate(output_to_check)])
|
494 |
+
|
495 |
+
for output_to_check in outputs_to_check:
|
496 |
+
if not check_undefined_grad_support(output_to_check):
|
497 |
+
return False
|
498 |
+
|
499 |
+
return True
|
500 |
+
|
501 |
+
|
502 |
+
def gradgradcheck(
|
503 |
+
func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
|
504 |
+
inputs: _TensorOrTensors,
|
505 |
+
grad_outputs: Optional[_TensorOrTensors] = None,
|
506 |
+
eps: float = 1e-6,
|
507 |
+
atol: float = 1e-5,
|
508 |
+
rtol: float = 1e-3,
|
509 |
+
gen_non_contig_grad_outputs: bool = False,
|
510 |
+
raise_exception: bool = True,
|
511 |
+
nondet_tol: float = 0.0,
|
512 |
+
check_undefined_grad: bool = True,
|
513 |
+
check_grad_dtypes: bool = False
|
514 |
+
) -> bool:
|
515 |
+
r"""Check gradients of gradients computed via small finite differences
|
516 |
+
against analytical gradients w.r.t. tensors in :attr:`inputs` and
|
517 |
+
:attr:`grad_outputs` that are of floating point or complex type and with
|
518 |
+
``requires_grad=True``.
|
519 |
+
|
520 |
+
This function checks that backpropagating through the gradients computed
|
521 |
+
to the given :attr:`grad_outputs` are correct.
|
522 |
+
|
523 |
+
The check between numerical and analytical gradients uses :func:`~torch.allclose`.
|
524 |
+
|
525 |
+
.. note::
|
526 |
+
The default values are designed for :attr:`input` and
|
527 |
+
:attr:`grad_outputs` of double precision. This check will likely fail if
|
528 |
+
they are of less precision, e.g., ``FloatTensor``.
|
529 |
+
|
530 |
+
.. warning::
|
531 |
+
If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
|
532 |
+
overlapping memory, i.e., different indices pointing to the same memory
|
533 |
+
address (e.g., from :func:`torch.expand`), this check will likely fail
|
534 |
+
because the numerical gradients computed by point perturbation at such
|
535 |
+
indices will change values at all other indices that share the same
|
536 |
+
memory address.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
func (function): a Python function that takes Tensor inputs and returns
|
540 |
+
a Tensor or a tuple of Tensors
|
541 |
+
inputs (tuple of Tensor or Tensor): inputs to the function
|
542 |
+
grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
|
543 |
+
respect to the function's outputs.
|
544 |
+
eps (float, optional): perturbation for finite differences
|
545 |
+
atol (float, optional): absolute tolerance
|
546 |
+
rtol (float, optional): relative tolerance
|
547 |
+
gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
|
548 |
+
``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
|
549 |
+
randomly generated gradient outputs are made to be noncontiguous
|
550 |
+
raise_exception (bool, optional): indicating whether to raise an exception if
|
551 |
+
the check fails. The exception gives more information about the
|
552 |
+
exact nature of the failure. This is helpful when debugging gradchecks.
|
553 |
+
nondet_tol (float, optional): tolerance for non-determinism. When running
|
554 |
+
identical inputs through the differentiation, the results must either match
|
555 |
+
exactly (default, 0.0) or be within this tolerance. Note that a small amount
|
556 |
+
of nondeterminism in the gradient will lead to larger inaccuracies in
|
557 |
+
the second derivative.
|
558 |
+
check_undefined_grad (bool, options): if True, check if undefined output grads
|
559 |
+
are supported and treated as zeros
|
560 |
+
|
561 |
+
Returns:
|
562 |
+
True if all differences satisfy allclose condition
|
563 |
+
"""
|
564 |
+
tupled_inputs = _as_tuple(inputs)
|
565 |
+
|
566 |
+
if grad_outputs is None:
|
567 |
+
# If grad_outputs is not specified, create random Tensors of the same
|
568 |
+
# shape, type, and device as the outputs
|
569 |
+
def randn_like(x):
|
570 |
+
y = torch.testing.randn_like(
|
571 |
+
x if (x.is_floating_point() or x.is_complex()) else x.double(), memory_format=torch.legacy_contiguous_format)
|
572 |
+
if gen_non_contig_grad_outputs:
|
573 |
+
y = torch.testing.make_non_contiguous(y)
|
574 |
+
return y.requires_grad_()
|
575 |
+
outputs = _as_tuple(func(*tupled_inputs))
|
576 |
+
tupled_grad_outputs = tuple(randn_like(x) for x in outputs)
|
577 |
+
else:
|
578 |
+
tupled_grad_outputs = _as_tuple(grad_outputs)
|
579 |
+
|
580 |
+
num_outputs = len(tupled_grad_outputs)
|
581 |
+
|
582 |
+
def new_func(*args):
|
583 |
+
input_args = args[:-num_outputs]
|
584 |
+
grad_outputs = args[-num_outputs:]
|
585 |
+
outputs = _differentiable_outputs(func(*input_args))
|
586 |
+
input_args = tuple(x for x in input_args if isinstance(x, torch.Tensor) and x.requires_grad)
|
587 |
+
grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)
|
588 |
+
return grad_inputs
|
589 |
+
|
590 |
+
return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception,
|
591 |
+
nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad,
|
592 |
+
check_grad_dtypes=check_grad_dtypes)
|
mini_dpvo/lietorch/group_ops.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lietorch_backends
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class GroupOp(torch.autograd.Function):
|
8 |
+
""" group operation base class """
|
9 |
+
|
10 |
+
@classmethod
|
11 |
+
def forward(cls, ctx, group_id, *inputs):
|
12 |
+
ctx.group_id = group_id
|
13 |
+
ctx.save_for_backward(*inputs)
|
14 |
+
out = cls.forward_op(ctx.group_id, *inputs)
|
15 |
+
return out
|
16 |
+
|
17 |
+
@classmethod
|
18 |
+
def backward(cls, ctx, grad):
|
19 |
+
error_str = "Backward operation not implemented for {}".format(cls)
|
20 |
+
assert cls.backward_op is not None, error_str
|
21 |
+
|
22 |
+
inputs = ctx.saved_tensors
|
23 |
+
grad = grad.contiguous()
|
24 |
+
grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
|
25 |
+
return (None, ) + tuple(grad_inputs)
|
26 |
+
|
27 |
+
|
28 |
+
class Exp(GroupOp):
|
29 |
+
""" exponential map """
|
30 |
+
forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
|
31 |
+
|
32 |
+
class Log(GroupOp):
|
33 |
+
""" logarithm map """
|
34 |
+
forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward
|
35 |
+
|
36 |
+
class Inv(GroupOp):
|
37 |
+
""" group inverse """
|
38 |
+
forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward
|
39 |
+
|
40 |
+
class Mul(GroupOp):
|
41 |
+
""" group multiplication """
|
42 |
+
forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward
|
43 |
+
|
44 |
+
class Adj(GroupOp):
|
45 |
+
""" adjoint operator """
|
46 |
+
forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward
|
47 |
+
|
48 |
+
class AdjT(GroupOp):
|
49 |
+
""" adjoint operator """
|
50 |
+
forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward
|
51 |
+
|
52 |
+
class Act3(GroupOp):
|
53 |
+
""" action on point """
|
54 |
+
forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward
|
55 |
+
|
56 |
+
class Act4(GroupOp):
|
57 |
+
""" action on point """
|
58 |
+
forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward
|
59 |
+
|
60 |
+
class Jinv(GroupOp):
|
61 |
+
""" adjoint operator """
|
62 |
+
forward_op, backward_op = lietorch_backends.Jinv, None
|
63 |
+
|
64 |
+
class ToMatrix(GroupOp):
|
65 |
+
""" convert to matrix representation """
|
66 |
+
forward_op, backward_op = lietorch_backends.as_matrix, None
|
67 |
+
|
68 |
+
|
69 |
+
|
70 |
+
|
71 |
+
### conversion operations to/from Euclidean embeddings ###
|
72 |
+
|
73 |
+
class FromVec(torch.autograd.Function):
|
74 |
+
""" convert vector into group object """
|
75 |
+
|
76 |
+
@classmethod
|
77 |
+
def forward(cls, ctx, group_id, *inputs):
|
78 |
+
ctx.group_id = group_id
|
79 |
+
ctx.save_for_backward(*inputs)
|
80 |
+
return inputs[0]
|
81 |
+
|
82 |
+
@classmethod
|
83 |
+
def backward(cls, ctx, grad):
|
84 |
+
inputs = ctx.saved_tensors
|
85 |
+
J = lietorch_backends.projector(ctx.group_id, *inputs)
|
86 |
+
return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
|
87 |
+
|
88 |
+
class ToVec(torch.autograd.Function):
|
89 |
+
""" convert group object to vector """
|
90 |
+
|
91 |
+
@classmethod
|
92 |
+
def forward(cls, ctx, group_id, *inputs):
|
93 |
+
ctx.group_id = group_id
|
94 |
+
ctx.save_for_backward(*inputs)
|
95 |
+
return inputs[0]
|
96 |
+
|
97 |
+
@classmethod
|
98 |
+
def backward(cls, ctx, grad):
|
99 |
+
inputs = ctx.saved_tensors
|
100 |
+
J = lietorch_backends.projector(ctx.group_id, *inputs)
|
101 |
+
return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)
|
102 |
+
|
mini_dpvo/lietorch/groups.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
# group operations implemented in cuda
|
5 |
+
from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ToVec, FromVec
|
6 |
+
from .broadcasting import broadcast_inputs
|
7 |
+
|
8 |
+
|
9 |
+
class LieGroupParameter(torch.Tensor):
|
10 |
+
""" Wrapper class for LieGroup """
|
11 |
+
|
12 |
+
from torch._C import _disabled_torch_function_impl
|
13 |
+
__torch_function__ = _disabled_torch_function_impl
|
14 |
+
|
15 |
+
def __new__(cls, group, requires_grad=True):
|
16 |
+
data = torch.zeros(group.tangent_shape,
|
17 |
+
device=group.data.device,
|
18 |
+
dtype=group.data.dtype,
|
19 |
+
requires_grad=True)
|
20 |
+
|
21 |
+
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
22 |
+
|
23 |
+
def __init__(self, group):
|
24 |
+
self.group = group
|
25 |
+
|
26 |
+
def retr(self):
|
27 |
+
return self.group.retr(self)
|
28 |
+
|
29 |
+
def log(self):
|
30 |
+
return self.retr().log()
|
31 |
+
|
32 |
+
def inv(self):
|
33 |
+
return self.retr().inv()
|
34 |
+
|
35 |
+
def adj(self, a):
|
36 |
+
return self.retr().adj(a)
|
37 |
+
|
38 |
+
def __mul__(self, other):
|
39 |
+
if isinstance(other, LieGroupParameter):
|
40 |
+
return self.retr() * other.retr()
|
41 |
+
else:
|
42 |
+
return self.retr() * other
|
43 |
+
|
44 |
+
def add_(self, update, alpha):
|
45 |
+
self.group = self.group.exp(alpha*update) * self.group
|
46 |
+
|
47 |
+
def __getitem__(self, index):
|
48 |
+
return self.retr().__getitem__(index)
|
49 |
+
|
50 |
+
|
51 |
+
class LieGroup:
|
52 |
+
""" Base class for Lie Group """
|
53 |
+
|
54 |
+
def __init__(self, data):
|
55 |
+
self.data = data
|
56 |
+
|
57 |
+
def __repr__(self):
|
58 |
+
return "{}: size={}, device={}, dtype={}".format(
|
59 |
+
self.group_name, self.shape, self.device, self.dtype)
|
60 |
+
|
61 |
+
@property
|
62 |
+
def shape(self):
|
63 |
+
return self.data.shape[:-1]
|
64 |
+
|
65 |
+
@property
|
66 |
+
def device(self):
|
67 |
+
return self.data.device
|
68 |
+
|
69 |
+
@property
|
70 |
+
def dtype(self):
|
71 |
+
return self.data.dtype
|
72 |
+
|
73 |
+
def vec(self):
|
74 |
+
return self.apply_op(ToVec, self.data)
|
75 |
+
|
76 |
+
@property
|
77 |
+
def tangent_shape(self):
|
78 |
+
return self.data.shape[:-1] + (self.manifold_dim,)
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def Identity(cls, *batch_shape, **kwargs):
|
82 |
+
""" Construct identity element with batch shape """
|
83 |
+
|
84 |
+
if isinstance(batch_shape[0], tuple):
|
85 |
+
batch_shape = batch_shape[0]
|
86 |
+
|
87 |
+
elif isinstance(batch_shape[0], list):
|
88 |
+
batch_shape = tuple(batch_shape[0])
|
89 |
+
|
90 |
+
numel = np.prod(batch_shape)
|
91 |
+
data = cls.id_elem.reshape(1,-1)
|
92 |
+
|
93 |
+
if 'device' in kwargs:
|
94 |
+
data = data.to(kwargs['device'])
|
95 |
+
|
96 |
+
if 'dtype' in kwargs:
|
97 |
+
data = data.type(kwargs['dtype'])
|
98 |
+
|
99 |
+
data = data.repeat(numel, 1)
|
100 |
+
return cls(data).view(batch_shape)
|
101 |
+
|
102 |
+
@classmethod
|
103 |
+
def IdentityLike(cls, G):
|
104 |
+
return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
|
105 |
+
|
106 |
+
@classmethod
|
107 |
+
def InitFromVec(cls, data):
|
108 |
+
return cls(cls.apply_op(FromVec, data))
|
109 |
+
|
110 |
+
@classmethod
|
111 |
+
def Random(cls, *batch_shape, sigma=1.0, **kwargs):
|
112 |
+
""" Construct random element with batch_shape by random sampling in tangent space"""
|
113 |
+
|
114 |
+
if isinstance(batch_shape[0], tuple):
|
115 |
+
batch_shape = batch_shape[0]
|
116 |
+
|
117 |
+
elif isinstance(batch_shape[0], list):
|
118 |
+
batch_shape = tuple(batch_shape[0])
|
119 |
+
|
120 |
+
tangent_shape = batch_shape + (cls.manifold_dim,)
|
121 |
+
xi = torch.randn(tangent_shape, **kwargs)
|
122 |
+
return cls.exp(sigma * xi)
|
123 |
+
|
124 |
+
@classmethod
|
125 |
+
def apply_op(cls, op, x, y=None):
|
126 |
+
""" Apply group operator """
|
127 |
+
inputs, out_shape = broadcast_inputs(x, y)
|
128 |
+
|
129 |
+
data = op.apply(cls.group_id, *inputs)
|
130 |
+
return data.view(out_shape + (-1,))
|
131 |
+
|
132 |
+
@classmethod
|
133 |
+
def exp(cls, x):
|
134 |
+
""" exponential map: x -> X """
|
135 |
+
return cls(cls.apply_op(Exp, x))
|
136 |
+
|
137 |
+
def quaternion(self):
|
138 |
+
""" extract quaternion """
|
139 |
+
return self.apply_op(Quat, self.data)
|
140 |
+
|
141 |
+
def log(self):
|
142 |
+
""" logarithm map """
|
143 |
+
return self.apply_op(Log, self.data)
|
144 |
+
|
145 |
+
def inv(self):
|
146 |
+
""" group inverse """
|
147 |
+
return self.__class__(self.apply_op(Inv, self.data))
|
148 |
+
|
149 |
+
def mul(self, other):
|
150 |
+
""" group multiplication """
|
151 |
+
return self.__class__(self.apply_op(Mul, self.data, other.data))
|
152 |
+
|
153 |
+
def retr(self, a):
|
154 |
+
""" retraction: Exp(a) * X """
|
155 |
+
dX = self.__class__.apply_op(Exp, a)
|
156 |
+
return self.__class__(self.apply_op(Mul, dX, self.data))
|
157 |
+
|
158 |
+
def adj(self, a):
|
159 |
+
""" adjoint operator: b = A(X) * a """
|
160 |
+
return self.apply_op(Adj, self.data, a)
|
161 |
+
|
162 |
+
def adjT(self, a):
|
163 |
+
""" transposed adjoint operator: b = a * A(X) """
|
164 |
+
return self.apply_op(AdjT, self.data, a)
|
165 |
+
|
166 |
+
def Jinv(self, a):
|
167 |
+
return self.apply_op(Jinv, self.data, a)
|
168 |
+
|
169 |
+
def act(self, p):
|
170 |
+
""" action on a point cloud """
|
171 |
+
|
172 |
+
# action on point
|
173 |
+
if p.shape[-1] == 3:
|
174 |
+
return self.apply_op(Act3, self.data, p)
|
175 |
+
|
176 |
+
# action on homogeneous point
|
177 |
+
elif p.shape[-1] == 4:
|
178 |
+
return self.apply_op(Act4, self.data, p)
|
179 |
+
|
180 |
+
def matrix(self):
|
181 |
+
""" convert element to 4x4 matrix """
|
182 |
+
I = torch.eye(4, dtype=self.dtype, device=self.device)
|
183 |
+
I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
|
184 |
+
return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2)
|
185 |
+
|
186 |
+
def translation(self):
|
187 |
+
""" extract translation component """
|
188 |
+
p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
|
189 |
+
p = p.view([1] * (len(self.data.shape) - 1) + [4,])
|
190 |
+
return self.apply_op(Act4, self.data, p)
|
191 |
+
|
192 |
+
def detach(self):
|
193 |
+
return self.__class__(self.data.detach())
|
194 |
+
|
195 |
+
def view(self, dims):
|
196 |
+
data_reshaped = self.data.view(dims + (self.embedded_dim,))
|
197 |
+
return self.__class__(data_reshaped)
|
198 |
+
|
199 |
+
def __mul__(self, other):
|
200 |
+
# group multiplication
|
201 |
+
|
202 |
+
if isinstance(other, LieGroup):
|
203 |
+
return self.mul(other)
|
204 |
+
|
205 |
+
# action on point
|
206 |
+
elif isinstance(other, torch.Tensor):
|
207 |
+
return self.act(other)
|
208 |
+
|
209 |
+
def __getitem__(self, index):
|
210 |
+
return self.__class__(self.data[index])
|
211 |
+
|
212 |
+
def __setitem__(self, index, item):
|
213 |
+
self.data[index] = item.data
|
214 |
+
|
215 |
+
def to(self, *args, **kwargs):
|
216 |
+
return self.__class__(self.data.to(*args, **kwargs))
|
217 |
+
|
218 |
+
def cpu(self):
|
219 |
+
return self.__class__(self.data.cpu())
|
220 |
+
|
221 |
+
def cuda(self):
|
222 |
+
return self.__class__(self.data.cuda())
|
223 |
+
|
224 |
+
def float(self, device):
|
225 |
+
return self.__class__(self.data.float())
|
226 |
+
|
227 |
+
def double(self, device):
|
228 |
+
return self.__class__(self.data.double())
|
229 |
+
|
230 |
+
def unbind(self, dim=0):
|
231 |
+
return [self.__class__(x) for x in self.data.unbind(dim=dim)]
|
232 |
+
|
233 |
+
|
234 |
+
class SO3(LieGroup):
|
235 |
+
group_name = 'SO3'
|
236 |
+
group_id = 1
|
237 |
+
manifold_dim = 3
|
238 |
+
embedded_dim = 4
|
239 |
+
|
240 |
+
# unit quaternion
|
241 |
+
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])
|
242 |
+
|
243 |
+
def __init__(self, data):
|
244 |
+
if isinstance(data, SE3):
|
245 |
+
data = data.data[..., 3:7]
|
246 |
+
|
247 |
+
super(SO3, self).__init__(data)
|
248 |
+
|
249 |
+
|
250 |
+
class RxSO3(LieGroup):
|
251 |
+
group_name = 'RxSO3'
|
252 |
+
group_id = 2
|
253 |
+
manifold_dim = 4
|
254 |
+
embedded_dim = 5
|
255 |
+
|
256 |
+
# unit quaternion
|
257 |
+
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])
|
258 |
+
|
259 |
+
def __init__(self, data):
|
260 |
+
if isinstance(data, Sim3):
|
261 |
+
data = data.data[..., 3:8]
|
262 |
+
|
263 |
+
super(RxSO3, self).__init__(data)
|
264 |
+
|
265 |
+
|
266 |
+
class SE3(LieGroup):
|
267 |
+
group_name = 'SE3'
|
268 |
+
group_id = 3
|
269 |
+
manifold_dim = 6
|
270 |
+
embedded_dim = 7
|
271 |
+
|
272 |
+
# translation, unit quaternion
|
273 |
+
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
|
274 |
+
|
275 |
+
def __init__(self, data):
|
276 |
+
if isinstance(data, SO3):
|
277 |
+
translation = torch.zeros_like(data.data[...,:3])
|
278 |
+
data = torch.cat([translation, data.data], -1)
|
279 |
+
|
280 |
+
super(SE3, self).__init__(data)
|
281 |
+
|
282 |
+
def scale(self, s):
|
283 |
+
t, q = self.data.split([3,4], -1)
|
284 |
+
t = t * s.unsqueeze(-1)
|
285 |
+
return SE3(torch.cat([t, q], dim=-1))
|
286 |
+
|
287 |
+
|
288 |
+
class Sim3(LieGroup):
|
289 |
+
group_name = 'Sim3'
|
290 |
+
group_id = 4
|
291 |
+
manifold_dim = 7
|
292 |
+
embedded_dim = 8
|
293 |
+
|
294 |
+
# translation, unit quaternion, scale
|
295 |
+
id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
|
296 |
+
|
297 |
+
def __init__(self, data):
|
298 |
+
|
299 |
+
if isinstance(data, SO3):
|
300 |
+
scale = torch.ones_like(SO3.data[...,:1])
|
301 |
+
translation = torch.zeros_like(SO3.data[...,:3])
|
302 |
+
data = torch.cat([translation, SO3.data, scale], -1)
|
303 |
+
|
304 |
+
elif isinstance(data, SE3):
|
305 |
+
scale = torch.ones_like(data.data[...,:1])
|
306 |
+
data = torch.cat([data.data, scale], -1)
|
307 |
+
|
308 |
+
elif isinstance(data, Sim3):
|
309 |
+
data = data.data
|
310 |
+
|
311 |
+
super(Sim3, self).__init__(data)
|
312 |
+
|
313 |
+
|
314 |
+
def cat(group_list, dim):
|
315 |
+
""" Concatenate groups along dimension """
|
316 |
+
data = torch.cat([X.data for X in group_list], dim=dim)
|
317 |
+
return group_list[0].__class__(data)
|
318 |
+
|
319 |
+
def stack(group_list, dim):
|
320 |
+
""" Concatenate groups along dimension """
|
321 |
+
data = torch.stack([X.data for X in group_list], dim=dim)
|
322 |
+
return group_list[0].__class__(data)
|
mini_dpvo/lietorch/include/common.h
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef COMMON_H
|
2 |
+
#define COMMON_H
|
3 |
+
|
4 |
+
#define EIGEN_DEFAULT_DENSE_INDEX_TYPE int
|
5 |
+
#define EIGEN_RUNTIME_NO_MALLOC
|
6 |
+
|
7 |
+
#define EPS 1e-6
|
8 |
+
#define PI 3.14159265358979323846
|
9 |
+
|
10 |
+
|
11 |
+
#endif
|
12 |
+
|
mini_dpvo/lietorch/include/dispatch.h
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#ifndef DISPATCH_H
|
2 |
+
#define DISPATCH_H
|
3 |
+
|
4 |
+
#include <torch/extension.h>
|
5 |
+
|
6 |
+
#include "so3.h"
|
7 |
+
#include "rxso3.h"
|
8 |
+
#include "se3.h"
|
9 |
+
#include "sim3.h"
|
10 |
+
|
11 |
+
|
12 |
+
#define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...) \
|
13 |
+
case enum_type: { \
|
14 |
+
using scalar_t = type; \
|
15 |
+
switch (group_index) { \
|
16 |
+
case 1: { \
|
17 |
+
using group_t = SO3<type>; \
|
18 |
+
return __VA_ARGS__(); \
|
19 |
+
} \
|
20 |
+
case 2: { \
|
21 |
+
using group_t = RxSO3<type>; \
|
22 |
+
return __VA_ARGS__(); \
|
23 |
+
} \
|
24 |
+
case 3: { \
|
25 |
+
using group_t = SE3<type>; \
|
26 |
+
return __VA_ARGS__(); \
|
27 |
+
} \
|
28 |
+
case 4: { \
|
29 |
+
using group_t = Sim3<type>; \
|
30 |
+
return __VA_ARGS__(); \
|
31 |
+
} \
|
32 |
+
} \
|
33 |
+
} \
|
34 |
+
|
35 |
+
#define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...) \
|
36 |
+
[&] { \
|
37 |
+
const auto& the_type = TYPE; \
|
38 |
+
/* don't use TYPE again in case it is an expensive or side-effect op */ \
|
39 |
+
at::ScalarType _st = ::detail::scalar_type(the_type); \
|
40 |
+
switch (_st) { \
|
41 |
+
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
|
42 |
+
PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
|
43 |
+
default: break; \
|
44 |
+
} \
|
45 |
+
}()
|
46 |
+
|
47 |
+
#endif
|
48 |
+
|
mini_dpvo/lietorch/include/lietorch_cpu.h
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef LIETORCH_CPU_H_
|
3 |
+
#define LIETORCH_CPU_H_
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
#include <torch/extension.h>
|
7 |
+
|
8 |
+
|
9 |
+
// unary operations
|
10 |
+
torch::Tensor exp_forward_cpu(int, torch::Tensor);
|
11 |
+
std::vector<torch::Tensor> exp_backward_cpu(int, torch::Tensor, torch::Tensor);
|
12 |
+
|
13 |
+
torch::Tensor log_forward_cpu(int, torch::Tensor);
|
14 |
+
std::vector<torch::Tensor> log_backward_cpu(int, torch::Tensor, torch::Tensor);
|
15 |
+
|
16 |
+
torch::Tensor inv_forward_cpu(int, torch::Tensor);
|
17 |
+
std::vector<torch::Tensor> inv_backward_cpu(int, torch::Tensor, torch::Tensor);
|
18 |
+
|
19 |
+
// binary operations
|
20 |
+
torch::Tensor mul_forward_cpu(int, torch::Tensor, torch::Tensor);
|
21 |
+
std::vector<torch::Tensor> mul_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
22 |
+
|
23 |
+
torch::Tensor adj_forward_cpu(int, torch::Tensor, torch::Tensor);
|
24 |
+
std::vector<torch::Tensor> adj_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
25 |
+
|
26 |
+
torch::Tensor adjT_forward_cpu(int, torch::Tensor, torch::Tensor);
|
27 |
+
std::vector<torch::Tensor> adjT_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
28 |
+
|
29 |
+
torch::Tensor act_forward_cpu(int, torch::Tensor, torch::Tensor);
|
30 |
+
std::vector<torch::Tensor> act_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
31 |
+
|
32 |
+
torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor);
|
33 |
+
std::vector<torch::Tensor> act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
34 |
+
|
35 |
+
|
36 |
+
// conversion operations
|
37 |
+
// std::vector<torch::Tensor> to_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
|
38 |
+
// std::vector<torch::Tensor> from_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
|
39 |
+
|
40 |
+
// utility operations
|
41 |
+
torch::Tensor orthogonal_projector_cpu(int, torch::Tensor);
|
42 |
+
|
43 |
+
torch::Tensor as_matrix_forward_cpu(int, torch::Tensor);
|
44 |
+
|
45 |
+
torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor);
|
46 |
+
|
47 |
+
|
48 |
+
#endif
|
49 |
+
|
50 |
+
|
51 |
+
|
mini_dpvo/lietorch/include/lietorch_gpu.h
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef LIETORCH_GPU_H_
|
3 |
+
#define LIETORCH_GPU_H_
|
4 |
+
|
5 |
+
#include <vector>
|
6 |
+
#include <torch/extension.h>
|
7 |
+
#include <cuda.h>
|
8 |
+
#include <cuda_runtime.h>
|
9 |
+
|
10 |
+
|
11 |
+
// unary operations
|
12 |
+
torch::Tensor exp_forward_gpu(int, torch::Tensor);
|
13 |
+
std::vector<torch::Tensor> exp_backward_gpu(int, torch::Tensor, torch::Tensor);
|
14 |
+
|
15 |
+
torch::Tensor log_forward_gpu(int, torch::Tensor);
|
16 |
+
std::vector<torch::Tensor> log_backward_gpu(int, torch::Tensor, torch::Tensor);
|
17 |
+
|
18 |
+
torch::Tensor inv_forward_gpu(int, torch::Tensor);
|
19 |
+
std::vector<torch::Tensor> inv_backward_gpu(int, torch::Tensor, torch::Tensor);
|
20 |
+
|
21 |
+
// binary operations
|
22 |
+
torch::Tensor mul_forward_gpu(int, torch::Tensor, torch::Tensor);
|
23 |
+
std::vector<torch::Tensor> mul_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
24 |
+
|
25 |
+
torch::Tensor adj_forward_gpu(int, torch::Tensor, torch::Tensor);
|
26 |
+
std::vector<torch::Tensor> adj_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
27 |
+
|
28 |
+
torch::Tensor adjT_forward_gpu(int, torch::Tensor, torch::Tensor);
|
29 |
+
std::vector<torch::Tensor> adjT_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
30 |
+
|
31 |
+
torch::Tensor act_forward_gpu(int, torch::Tensor, torch::Tensor);
|
32 |
+
std::vector<torch::Tensor> act_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
33 |
+
|
34 |
+
torch::Tensor act4_forward_gpu(int, torch::Tensor, torch::Tensor);
|
35 |
+
std::vector<torch::Tensor> act4_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
|
36 |
+
|
37 |
+
// conversion operations
|
38 |
+
// std::vector<torch::Tensor> to_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
|
39 |
+
// std::vector<torch::Tensor> from_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
|
40 |
+
|
41 |
+
// utility operators
|
42 |
+
torch::Tensor orthogonal_projector_gpu(int, torch::Tensor);
|
43 |
+
|
44 |
+
torch::Tensor as_matrix_forward_gpu(int, torch::Tensor);
|
45 |
+
|
46 |
+
torch::Tensor jleft_forward_gpu(int, torch::Tensor, torch::Tensor);
|
47 |
+
|
48 |
+
#endif
|
49 |
+
|
50 |
+
|
51 |
+
|
mini_dpvo/lietorch/include/rxso3.h
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef RxSO3_HEADER
|
3 |
+
#define RxSO3_HEADER
|
4 |
+
|
5 |
+
#include <stdio.h>
|
6 |
+
#include <Eigen/Dense>
|
7 |
+
#include <Eigen/Geometry>
|
8 |
+
|
9 |
+
#include "common.h"
|
10 |
+
|
11 |
+
template <typename Scalar>
|
12 |
+
class RxSO3 {
|
13 |
+
public:
|
14 |
+
const static int constexpr K = 4; // manifold dimension
|
15 |
+
const static int constexpr N = 5; // embedding dimension
|
16 |
+
|
17 |
+
using Vector3 = Eigen::Matrix<Scalar,3,1>;
|
18 |
+
using Vector4 = Eigen::Matrix<Scalar,4,1>;
|
19 |
+
using Matrix3 = Eigen::Matrix<Scalar,3,3>;
|
20 |
+
|
21 |
+
using Tangent = Eigen::Matrix<Scalar,K,1>;
|
22 |
+
using Data = Eigen::Matrix<Scalar,N,1>;
|
23 |
+
|
24 |
+
using Point = Eigen::Matrix<Scalar,3,1>;
|
25 |
+
using Point4 = Eigen::Matrix<Scalar,4,1>;
|
26 |
+
|
27 |
+
using Quaternion = Eigen::Quaternion<Scalar>;
|
28 |
+
using Transformation = Eigen::Matrix<Scalar,3,3>;
|
29 |
+
using Adjoint = Eigen::Matrix<Scalar,4,4>;
|
30 |
+
|
31 |
+
|
32 |
+
EIGEN_DEVICE_FUNC RxSO3(Quaternion const& q, Scalar const s)
|
33 |
+
: unit_quaternion(q), scale(s) {
|
34 |
+
unit_quaternion.normalize();
|
35 |
+
};
|
36 |
+
|
37 |
+
EIGEN_DEVICE_FUNC RxSO3(const Scalar *data) : unit_quaternion(data), scale(data[4]) {
|
38 |
+
unit_quaternion.normalize();
|
39 |
+
};
|
40 |
+
|
41 |
+
EIGEN_DEVICE_FUNC RxSO3() {
|
42 |
+
unit_quaternion = Quaternion::Identity();
|
43 |
+
scale = Scalar(1.0);
|
44 |
+
}
|
45 |
+
|
46 |
+
EIGEN_DEVICE_FUNC RxSO3<Scalar> inv() {
|
47 |
+
return RxSO3<Scalar>(unit_quaternion.conjugate(), 1.0/scale);
|
48 |
+
}
|
49 |
+
|
50 |
+
EIGEN_DEVICE_FUNC Data data() const {
|
51 |
+
Data data_vec; data_vec << unit_quaternion.coeffs(), scale;
|
52 |
+
return data_vec;
|
53 |
+
}
|
54 |
+
|
55 |
+
EIGEN_DEVICE_FUNC RxSO3<Scalar> operator*(RxSO3<Scalar> const& other) {
|
56 |
+
return RxSO3<Scalar>(unit_quaternion * other.unit_quaternion, scale * other.scale);
|
57 |
+
}
|
58 |
+
|
59 |
+
EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
|
60 |
+
const Quaternion& q = unit_quaternion;
|
61 |
+
Point uv = q.vec().cross(p); uv += uv;
|
62 |
+
return scale * (p + q.w()*uv + q.vec().cross(uv));
|
63 |
+
}
|
64 |
+
|
65 |
+
EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
|
66 |
+
Point4 p1; p1 << this->operator*(p.template segment<3>(0)), p(3);
|
67 |
+
return p1;
|
68 |
+
}
|
69 |
+
|
70 |
+
EIGEN_DEVICE_FUNC Adjoint Adj() const {
|
71 |
+
Adjoint Ad = Adjoint::Identity();
|
72 |
+
Ad.template block<3,3>(0,0) = unit_quaternion.toRotationMatrix();
|
73 |
+
return Ad;
|
74 |
+
}
|
75 |
+
|
76 |
+
EIGEN_DEVICE_FUNC Transformation Matrix() const {
|
77 |
+
return scale * unit_quaternion.toRotationMatrix();
|
78 |
+
}
|
79 |
+
|
80 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> Matrix4x4() const {
|
81 |
+
Eigen::Matrix<Scalar,4,4> T;
|
82 |
+
T = Eigen::Matrix<Scalar,4,4>::Identity();
|
83 |
+
T.template block<3,3>(0,0) = Matrix();
|
84 |
+
return T;
|
85 |
+
}
|
86 |
+
|
87 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,5,5> orthogonal_projector() const {
|
88 |
+
// jacobian action on a point
|
89 |
+
Eigen::Matrix<Scalar,5,5> J = Eigen::Matrix<Scalar,5,5>::Zero();
|
90 |
+
|
91 |
+
J.template block<3,3>(0,0) = 0.5 * (
|
92 |
+
unit_quaternion.w() * Matrix3::Identity() +
|
93 |
+
SO3<Scalar>::hat(-unit_quaternion.vec())
|
94 |
+
);
|
95 |
+
|
96 |
+
J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
|
97 |
+
|
98 |
+
// scale
|
99 |
+
J(4,3) = scale;
|
100 |
+
|
101 |
+
return J;
|
102 |
+
}
|
103 |
+
|
104 |
+
EIGEN_DEVICE_FUNC Transformation Rotation() const {
|
105 |
+
return unit_quaternion.toRotationMatrix();
|
106 |
+
}
|
107 |
+
|
108 |
+
EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
|
109 |
+
return Adj() * a;
|
110 |
+
}
|
111 |
+
|
112 |
+
EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
|
113 |
+
return Adj().transpose() * a;
|
114 |
+
}
|
115 |
+
|
116 |
+
EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& phi_sigma) {
|
117 |
+
Vector3 const phi = phi_sigma.template segment<3>(0);
|
118 |
+
return SO3<Scalar>::hat(phi) + phi(3) * Transformation::Identity();
|
119 |
+
}
|
120 |
+
|
121 |
+
EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& phi_sigma) {
|
122 |
+
Vector3 const phi = phi_sigma.template segment<3>(0);
|
123 |
+
Matrix3 const Phi = SO3<Scalar>::hat(phi);
|
124 |
+
|
125 |
+
Adjoint ad = Adjoint::Zero();
|
126 |
+
ad.template block<3,3>(0,0) = Phi;
|
127 |
+
|
128 |
+
return ad;
|
129 |
+
}
|
130 |
+
|
131 |
+
EIGEN_DEVICE_FUNC Tangent Log() const {
|
132 |
+
using std::abs;
|
133 |
+
using std::atan;
|
134 |
+
using std::sqrt;
|
135 |
+
|
136 |
+
Scalar squared_n = unit_quaternion.vec().squaredNorm();
|
137 |
+
Scalar w = unit_quaternion.w();
|
138 |
+
Scalar two_atan_nbyw_by_n;
|
139 |
+
|
140 |
+
/// Atan-based log thanks to
|
141 |
+
///
|
142 |
+
/// C. Hertzberg et al.:
|
143 |
+
/// "Integrating Generic Sensor Fusion Algorithms with Sound State
|
144 |
+
/// Representation through Encapsulation of Manifolds"
|
145 |
+
/// Information Fusion, 2011
|
146 |
+
|
147 |
+
if (squared_n < EPS * EPS) {
|
148 |
+
two_atan_nbyw_by_n = Scalar(2) / w - Scalar(2.0/3.0) * (squared_n) / (w * w * w);
|
149 |
+
} else {
|
150 |
+
Scalar n = sqrt(squared_n);
|
151 |
+
if (abs(w) < EPS) {
|
152 |
+
if (w > Scalar(0)) {
|
153 |
+
two_atan_nbyw_by_n = PI / n;
|
154 |
+
} else {
|
155 |
+
two_atan_nbyw_by_n = -PI / n;
|
156 |
+
}
|
157 |
+
} else {
|
158 |
+
two_atan_nbyw_by_n = Scalar(2) * atan(n / w) / n;
|
159 |
+
}
|
160 |
+
}
|
161 |
+
|
162 |
+
Tangent phi_sigma;
|
163 |
+
phi_sigma << two_atan_nbyw_by_n * unit_quaternion.vec(), log(scale);
|
164 |
+
|
165 |
+
return phi_sigma;
|
166 |
+
}
|
167 |
+
|
168 |
+
EIGEN_DEVICE_FUNC static RxSO3<Scalar> Exp(Tangent const& phi_sigma) {
|
169 |
+
Vector3 phi = phi_sigma.template segment<3>(0);
|
170 |
+
Scalar scale = exp(phi_sigma(3));
|
171 |
+
|
172 |
+
Scalar theta2 = phi.squaredNorm();
|
173 |
+
Scalar theta = sqrt(theta2);
|
174 |
+
Scalar imag_factor;
|
175 |
+
Scalar real_factor;
|
176 |
+
|
177 |
+
if (theta < EPS) {
|
178 |
+
Scalar theta4 = theta2 * theta2;
|
179 |
+
imag_factor = Scalar(0.5) - Scalar(1.0/48.0) * theta2 + Scalar(1.0/3840.0) * theta4;
|
180 |
+
real_factor = Scalar(1) - Scalar(1.0/8.0) * theta2 + Scalar(1.0/384.0) * theta4;
|
181 |
+
} else {
|
182 |
+
imag_factor = sin(.5 * theta) / theta;
|
183 |
+
real_factor = cos(.5 * theta);
|
184 |
+
}
|
185 |
+
|
186 |
+
Quaternion q(real_factor, imag_factor*phi.x(), imag_factor*phi.y(), imag_factor*phi.z());
|
187 |
+
return RxSO3<Scalar>(q, scale);
|
188 |
+
}
|
189 |
+
|
190 |
+
EIGEN_DEVICE_FUNC static Matrix3 calcW(Tangent const& phi_sigma) {
|
191 |
+
// left jacobian
|
192 |
+
using std::abs;
|
193 |
+
Matrix3 const I = Matrix3::Identity();
|
194 |
+
Scalar const one(1);
|
195 |
+
Scalar const half(0.5);
|
196 |
+
|
197 |
+
Vector3 const phi = phi_sigma.template segment<3>(0);
|
198 |
+
Scalar const sigma = phi_sigma(3);
|
199 |
+
Scalar const theta = phi.norm();
|
200 |
+
|
201 |
+
Matrix3 const Phi = SO3<Scalar>::hat(phi);
|
202 |
+
Matrix3 const Phi2 = Phi * Phi;
|
203 |
+
Scalar const scale = exp(sigma);
|
204 |
+
|
205 |
+
Scalar A, B, C;
|
206 |
+
if (abs(sigma) < EPS) {
|
207 |
+
C = one;
|
208 |
+
if (abs(theta) < EPS) {
|
209 |
+
A = half;
|
210 |
+
B = Scalar(1. / 6.);
|
211 |
+
} else {
|
212 |
+
Scalar theta_sq = theta * theta;
|
213 |
+
A = (one - cos(theta)) / theta_sq;
|
214 |
+
B = (theta - sin(theta)) / (theta_sq * theta);
|
215 |
+
}
|
216 |
+
} else {
|
217 |
+
C = (scale - one) / sigma;
|
218 |
+
if (abs(theta) < EPS) {
|
219 |
+
Scalar sigma_sq = sigma * sigma;
|
220 |
+
A = ((sigma - one) * scale + one) / sigma_sq;
|
221 |
+
B = (scale * half * sigma_sq + scale - one - sigma * scale) /
|
222 |
+
(sigma_sq * sigma);
|
223 |
+
} else {
|
224 |
+
Scalar theta_sq = theta * theta;
|
225 |
+
Scalar a = scale * sin(theta);
|
226 |
+
Scalar b = scale * cos(theta);
|
227 |
+
Scalar c = theta_sq + sigma * sigma;
|
228 |
+
A = (a * sigma + (one - b) * theta) / (theta * c);
|
229 |
+
B = (C - ((b - one) * sigma + a * theta) / (c)) * one / (theta_sq);
|
230 |
+
}
|
231 |
+
}
|
232 |
+
return A * Phi + B * Phi2 + C * I;
|
233 |
+
}
|
234 |
+
|
235 |
+
EIGEN_DEVICE_FUNC static Matrix3 calcWInv(Tangent const& phi_sigma) {
|
236 |
+
// left jacobian inverse
|
237 |
+
Matrix3 const I = Matrix3::Identity();
|
238 |
+
Scalar const half(0.5);
|
239 |
+
Scalar const one(1);
|
240 |
+
Scalar const two(2);
|
241 |
+
|
242 |
+
Vector3 const phi = phi_sigma.template segment<3>(0);
|
243 |
+
Scalar const sigma = phi_sigma(3);
|
244 |
+
Scalar const theta = phi.norm();
|
245 |
+
Scalar const scale = exp(sigma);
|
246 |
+
|
247 |
+
Matrix3 const Phi = SO3<Scalar>::hat(phi);
|
248 |
+
Matrix3 const Phi2 = Phi * Phi;
|
249 |
+
Scalar const scale_sq = scale * scale;
|
250 |
+
Scalar const theta_sq = theta * theta;
|
251 |
+
Scalar const sin_theta = sin(theta);
|
252 |
+
Scalar const cos_theta = cos(theta);
|
253 |
+
|
254 |
+
Scalar a, b, c;
|
255 |
+
if (abs(sigma * sigma) < EPS) {
|
256 |
+
c = one - half * sigma;
|
257 |
+
a = -half;
|
258 |
+
if (abs(theta_sq) < EPS) {
|
259 |
+
b = Scalar(1. / 12.);
|
260 |
+
} else {
|
261 |
+
b = (theta * sin_theta + two * cos_theta - two) /
|
262 |
+
(two * theta_sq * (cos_theta - one));
|
263 |
+
}
|
264 |
+
} else {
|
265 |
+
Scalar const scale_cu = scale_sq * scale;
|
266 |
+
c = sigma / (scale - one);
|
267 |
+
if (abs(theta_sq) < EPS) {
|
268 |
+
a = (-sigma * scale + scale - one) / ((scale - one) * (scale - one));
|
269 |
+
b = (scale_sq * sigma - two * scale_sq + scale * sigma + two * scale) /
|
270 |
+
(two * scale_cu - Scalar(6) * scale_sq + Scalar(6) * scale - two);
|
271 |
+
} else {
|
272 |
+
Scalar const s_sin_theta = scale * sin_theta;
|
273 |
+
Scalar const s_cos_theta = scale * cos_theta;
|
274 |
+
a = (theta * s_cos_theta - theta - sigma * s_sin_theta) /
|
275 |
+
(theta * (scale_sq - two * s_cos_theta + one));
|
276 |
+
b = -scale *
|
277 |
+
(theta * s_sin_theta - theta * sin_theta + sigma * s_cos_theta -
|
278 |
+
scale * sigma + sigma * cos_theta - sigma) /
|
279 |
+
(theta_sq * (scale_cu - two * scale * s_cos_theta - scale_sq +
|
280 |
+
two * s_cos_theta + scale - one));
|
281 |
+
}
|
282 |
+
}
|
283 |
+
return a * Phi + b * Phi2 + c * I;
|
284 |
+
}
|
285 |
+
|
286 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& phi_sigma) {
|
287 |
+
// left jacobian
|
288 |
+
Adjoint J = Adjoint::Identity();
|
289 |
+
Vector3 phi = phi_sigma.template segment<3>(0);
|
290 |
+
J.template block<3,3>(0,0) = SO3<Scalar>::left_jacobian(phi);
|
291 |
+
return J;
|
292 |
+
}
|
293 |
+
|
294 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi_sigma) {
|
295 |
+
// left jacobian inverse
|
296 |
+
Adjoint Jinv = Adjoint::Identity();
|
297 |
+
Vector3 phi = phi_sigma.template segment<3>(0);
|
298 |
+
Jinv.template block<3,3>(0,0) = SO3<Scalar>::left_jacobian_inverse(phi);
|
299 |
+
return Jinv;
|
300 |
+
}
|
301 |
+
|
302 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,4> act_jacobian(Point const& p) {
|
303 |
+
// jacobian action on a point
|
304 |
+
Eigen::Matrix<Scalar,3,4> Ja;
|
305 |
+
Ja << SO3<Scalar>::hat(-p), p;
|
306 |
+
return Ja;
|
307 |
+
}
|
308 |
+
|
309 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,4> act4_jacobian(Point4 const& p) {
|
310 |
+
// jacobian action on a point
|
311 |
+
Eigen::Matrix<Scalar,4,4> J = Eigen::Matrix<Scalar,4,4>::Zero();
|
312 |
+
J.template block<3,3>(0,0) = SO3<Scalar>::hat(-p.template segment<3>(0));
|
313 |
+
J.template block<3,1>(0,3) = p.template segment<3>(0);
|
314 |
+
return J;
|
315 |
+
}
|
316 |
+
|
317 |
+
private:
|
318 |
+
Quaternion unit_quaternion;
|
319 |
+
Scalar scale;
|
320 |
+
};
|
321 |
+
|
322 |
+
#endif
|
323 |
+
|
324 |
+
|
mini_dpvo/lietorch/include/se3.h
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef SE3_HEADER
|
3 |
+
#define SE3_HEADER
|
4 |
+
|
5 |
+
#include <stdio.h>
|
6 |
+
#include <Eigen/Dense>
|
7 |
+
#include <Eigen/Geometry>
|
8 |
+
|
9 |
+
#include "common.h"
|
10 |
+
#include "so3.h"
|
11 |
+
|
12 |
+
|
13 |
+
template <typename Scalar>
|
14 |
+
class SE3 {
|
15 |
+
public:
|
16 |
+
const static int constexpr K = 6; // manifold dimension
|
17 |
+
const static int constexpr N = 7; // embedding dimension
|
18 |
+
|
19 |
+
using Vector3 = Eigen::Matrix<Scalar,3,1>;
|
20 |
+
using Vector4 = Eigen::Matrix<Scalar,4,1>;
|
21 |
+
using Matrix3 = Eigen::Matrix<Scalar,3,3>;
|
22 |
+
|
23 |
+
using Tangent = Eigen::Matrix<Scalar,K,1>;
|
24 |
+
using Point = Eigen::Matrix<Scalar,3,1>;
|
25 |
+
using Point4 = Eigen::Matrix<Scalar,4,1>;
|
26 |
+
using Data = Eigen::Matrix<Scalar,N,1>;
|
27 |
+
using Transformation = Eigen::Matrix<Scalar,4,4>;
|
28 |
+
using Adjoint = Eigen::Matrix<Scalar,K,K>;
|
29 |
+
|
30 |
+
EIGEN_DEVICE_FUNC SE3() { translation = Vector3::Zero(); }
|
31 |
+
|
32 |
+
EIGEN_DEVICE_FUNC SE3(SO3<Scalar> const& so3, Vector3 const& t) : so3(so3), translation(t) {};
|
33 |
+
|
34 |
+
EIGEN_DEVICE_FUNC SE3(const Scalar *data) : translation(data), so3(data+3) {};
|
35 |
+
|
36 |
+
EIGEN_DEVICE_FUNC SE3<Scalar> inv() {
|
37 |
+
return SE3(so3.inv(), -(so3.inv()*translation));
|
38 |
+
}
|
39 |
+
|
40 |
+
EIGEN_DEVICE_FUNC Data data() const {
|
41 |
+
Data data_vec; data_vec << translation, so3.data();
|
42 |
+
return data_vec;
|
43 |
+
}
|
44 |
+
|
45 |
+
EIGEN_DEVICE_FUNC SE3<Scalar> operator*(SE3<Scalar> const& other) {
|
46 |
+
return SE3(so3 * other.so3, translation + so3 * other.translation);
|
47 |
+
}
|
48 |
+
|
49 |
+
EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
|
50 |
+
return so3 * p + translation;
|
51 |
+
}
|
52 |
+
|
53 |
+
EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
|
54 |
+
Point4 p1; p1 << so3 * p.template segment<3>(0) + translation * p(3), p(3);
|
55 |
+
return p1;
|
56 |
+
}
|
57 |
+
|
58 |
+
EIGEN_DEVICE_FUNC Adjoint Adj() const {
|
59 |
+
Matrix3 R = so3.Matrix();
|
60 |
+
Matrix3 tx = SO3<Scalar>::hat(translation);
|
61 |
+
Matrix3 Zer = Matrix3::Zero();
|
62 |
+
|
63 |
+
Adjoint Ad;
|
64 |
+
Ad << R, tx*R, Zer, R;
|
65 |
+
|
66 |
+
return Ad;
|
67 |
+
}
|
68 |
+
|
69 |
+
EIGEN_DEVICE_FUNC Transformation Matrix() const {
|
70 |
+
Transformation T = Transformation::Identity();
|
71 |
+
T.template block<3,3>(0,0) = so3.Matrix();
|
72 |
+
T.template block<3,1>(0,3) = translation;
|
73 |
+
return T;
|
74 |
+
}
|
75 |
+
|
76 |
+
EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
|
77 |
+
return Matrix();
|
78 |
+
}
|
79 |
+
|
80 |
+
EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
|
81 |
+
return Adj() * a;
|
82 |
+
}
|
83 |
+
|
84 |
+
EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
|
85 |
+
return Adj().transpose() * a;
|
86 |
+
}
|
87 |
+
|
88 |
+
|
89 |
+
EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi) {
|
90 |
+
Vector3 tau = tau_phi.template segment<3>(0);
|
91 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
92 |
+
|
93 |
+
Transformation TauPhi = Transformation::Zero();
|
94 |
+
TauPhi.template block<3,3>(0,0) = SO3<Scalar>::hat(phi);
|
95 |
+
TauPhi.template block<3,1>(0,3) = tau;
|
96 |
+
|
97 |
+
return TauPhi;
|
98 |
+
}
|
99 |
+
|
100 |
+
EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi) {
|
101 |
+
Vector3 tau = tau_phi.template segment<3>(0);
|
102 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
103 |
+
|
104 |
+
Matrix3 Tau = SO3<Scalar>::hat(tau);
|
105 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
106 |
+
Matrix3 Zer = Matrix3::Zero();
|
107 |
+
|
108 |
+
Adjoint ad;
|
109 |
+
ad << Phi, Tau, Zer, Phi;
|
110 |
+
|
111 |
+
return ad;
|
112 |
+
}
|
113 |
+
|
114 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,7,7> orthogonal_projector() const {
|
115 |
+
// jacobian action on a point
|
116 |
+
Eigen::Matrix<Scalar,7,7> J = Eigen::Matrix<Scalar,7,7>::Zero();
|
117 |
+
J.template block<3,3>(0,0) = Matrix3::Identity();
|
118 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
|
119 |
+
J.template block<4,4>(3,3) = so3.orthogonal_projector();
|
120 |
+
|
121 |
+
return J;
|
122 |
+
}
|
123 |
+
|
124 |
+
EIGEN_DEVICE_FUNC Tangent Log() const {
|
125 |
+
Vector3 phi = so3.Log();
|
126 |
+
Matrix3 Vinv = SO3<Scalar>::left_jacobian_inverse(phi);
|
127 |
+
|
128 |
+
Tangent tau_phi;
|
129 |
+
tau_phi << Vinv * translation, phi;
|
130 |
+
|
131 |
+
return tau_phi;
|
132 |
+
}
|
133 |
+
|
134 |
+
EIGEN_DEVICE_FUNC static SE3<Scalar> Exp(Tangent const& tau_phi) {
|
135 |
+
Vector3 tau = tau_phi.template segment<3>(0);
|
136 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
137 |
+
|
138 |
+
SO3<Scalar> so3 = SO3<Scalar>::Exp(phi);
|
139 |
+
Vector3 t = SO3<Scalar>::left_jacobian(phi) * tau;
|
140 |
+
|
141 |
+
return SE3<Scalar>(so3, t);
|
142 |
+
}
|
143 |
+
|
144 |
+
EIGEN_DEVICE_FUNC static Matrix3 calcQ(Tangent const& tau_phi) {
|
145 |
+
// Q matrix
|
146 |
+
Vector3 tau = tau_phi.template segment<3>(0);
|
147 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
148 |
+
Matrix3 Tau = SO3<Scalar>::hat(tau);
|
149 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
150 |
+
|
151 |
+
Scalar theta = phi.norm();
|
152 |
+
Scalar theta_pow2 = theta * theta;
|
153 |
+
Scalar theta_pow4 = theta_pow2 * theta_pow2;
|
154 |
+
|
155 |
+
Scalar coef1 = (theta < EPS) ?
|
156 |
+
Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta_pow2 :
|
157 |
+
(theta - sin(theta)) / (theta_pow2 * theta);
|
158 |
+
|
159 |
+
Scalar coef2 = (theta < EPS) ?
|
160 |
+
Scalar(1.0/24.0) - Scalar(1.0/720.0) * theta_pow2 :
|
161 |
+
(theta_pow2 + 2*cos(theta) - 2) / (2 * theta_pow4);
|
162 |
+
|
163 |
+
Scalar coef3 = (theta < EPS) ?
|
164 |
+
Scalar(1.0/120.0) - Scalar(1.0/2520.0) * theta_pow2 :
|
165 |
+
(2*theta - 3*sin(theta) + theta*cos(theta)) / (2 * theta_pow4 * theta);
|
166 |
+
|
167 |
+
Matrix3 Q = Scalar(0.5) * Tau +
|
168 |
+
coef1 * (Phi*Tau + Tau*Phi + Phi*Tau*Phi) +
|
169 |
+
coef2 * (Phi*Phi*Tau + Tau*Phi*Phi - 3*Phi*Tau*Phi) +
|
170 |
+
coef3 * (Phi*Tau*Phi*Phi + Phi*Phi*Tau*Phi);
|
171 |
+
|
172 |
+
return Q;
|
173 |
+
}
|
174 |
+
|
175 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi) {
|
176 |
+
// left jacobian
|
177 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
178 |
+
Matrix3 J = SO3<Scalar>::left_jacobian(phi);
|
179 |
+
Matrix3 Q = SE3<Scalar>::calcQ(tau_phi);
|
180 |
+
Matrix3 Zer = Matrix3::Zero();
|
181 |
+
|
182 |
+
Adjoint J6x6;
|
183 |
+
J6x6 << J, Q, Zer, J;
|
184 |
+
|
185 |
+
return J6x6;
|
186 |
+
}
|
187 |
+
|
188 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi) {
|
189 |
+
// left jacobian inverse
|
190 |
+
Vector3 tau = tau_phi.template segment<3>(0);
|
191 |
+
Vector3 phi = tau_phi.template segment<3>(3);
|
192 |
+
Matrix3 Jinv = SO3<Scalar>::left_jacobian_inverse(phi);
|
193 |
+
Matrix3 Q = SE3<Scalar>::calcQ(tau_phi);
|
194 |
+
Matrix3 Zer = Matrix3::Zero();
|
195 |
+
|
196 |
+
Adjoint J6x6;
|
197 |
+
J6x6 << Jinv, -Jinv * Q * Jinv, Zer, Jinv;
|
198 |
+
|
199 |
+
return J6x6;
|
200 |
+
|
201 |
+
}
|
202 |
+
|
203 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,6> act_jacobian(Point const& p) {
|
204 |
+
// jacobian action on a point
|
205 |
+
Eigen::Matrix<Scalar,3,6> J;
|
206 |
+
J.template block<3,3>(0,0) = Matrix3::Identity();
|
207 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p);
|
208 |
+
return J;
|
209 |
+
}
|
210 |
+
|
211 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,6> act4_jacobian(Point4 const& p) {
|
212 |
+
// jacobian action on a point
|
213 |
+
Eigen::Matrix<Scalar,4,6> J = Eigen::Matrix<Scalar,4,6>::Zero();
|
214 |
+
J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
|
215 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p.template segment<3>(0));
|
216 |
+
return J;
|
217 |
+
}
|
218 |
+
|
219 |
+
|
220 |
+
|
221 |
+
|
222 |
+
private:
|
223 |
+
SO3<Scalar> so3;
|
224 |
+
Vector3 translation;
|
225 |
+
|
226 |
+
};
|
227 |
+
|
228 |
+
#endif
|
229 |
+
|
mini_dpvo/lietorch/include/sim3.h
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef Sim3_HEADER
|
3 |
+
#define Sim3_HEADER
|
4 |
+
|
5 |
+
#include <stdio.h>
|
6 |
+
#include <iostream>
|
7 |
+
|
8 |
+
#include <Eigen/Dense>
|
9 |
+
#include <Eigen/Geometry>
|
10 |
+
|
11 |
+
#include "common.h"
|
12 |
+
#include "so3.h"
|
13 |
+
#include "rxso3.h"
|
14 |
+
|
15 |
+
|
16 |
+
template <typename Scalar>
|
17 |
+
class Sim3 {
|
18 |
+
public:
|
19 |
+
const static int constexpr K = 7; // manifold dimension
|
20 |
+
const static int constexpr N = 8; // embedding dimension
|
21 |
+
|
22 |
+
using Vector3 = Eigen::Matrix<Scalar,3,1>;
|
23 |
+
using Vector4 = Eigen::Matrix<Scalar,4,1>;
|
24 |
+
using Matrix3 = Eigen::Matrix<Scalar,3,3>;
|
25 |
+
|
26 |
+
using Tangent = Eigen::Matrix<Scalar,K,1>;
|
27 |
+
using Point = Eigen::Matrix<Scalar,3,1>;
|
28 |
+
using Point4 = Eigen::Matrix<Scalar,4,1>;
|
29 |
+
using Data = Eigen::Matrix<Scalar,N,1>;
|
30 |
+
using Transformation = Eigen::Matrix<Scalar,4,4>;
|
31 |
+
using Adjoint = Eigen::Matrix<Scalar,K,K>;
|
32 |
+
|
33 |
+
EIGEN_DEVICE_FUNC Sim3() {
|
34 |
+
translation = Vector3::Zero();
|
35 |
+
}
|
36 |
+
|
37 |
+
EIGEN_DEVICE_FUNC Sim3(RxSO3<Scalar> const& rxso3, Vector3 const& t)
|
38 |
+
: rxso3(rxso3), translation(t) {};
|
39 |
+
|
40 |
+
EIGEN_DEVICE_FUNC Sim3(const Scalar *data)
|
41 |
+
: translation(data), rxso3(data+3) {};
|
42 |
+
|
43 |
+
EIGEN_DEVICE_FUNC Sim3<Scalar> inv() {
|
44 |
+
return Sim3<Scalar>(rxso3.inv(), -(rxso3.inv() * translation));
|
45 |
+
}
|
46 |
+
|
47 |
+
EIGEN_DEVICE_FUNC Data data() const {
|
48 |
+
Data data_vec; data_vec << translation, rxso3.data();
|
49 |
+
return data_vec;
|
50 |
+
}
|
51 |
+
|
52 |
+
EIGEN_DEVICE_FUNC Sim3<Scalar> operator*(Sim3<Scalar> const& other) {
|
53 |
+
return Sim3(rxso3 * other.rxso3, translation + rxso3 * other.translation);
|
54 |
+
}
|
55 |
+
|
56 |
+
EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
|
57 |
+
return (rxso3 * p) + translation;
|
58 |
+
}
|
59 |
+
|
60 |
+
EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
|
61 |
+
Point4 p1; p1 << rxso3 * p.template segment<3>(0) + p(3) * translation , p(3);
|
62 |
+
return p1;
|
63 |
+
}
|
64 |
+
|
65 |
+
EIGEN_DEVICE_FUNC Transformation Matrix() const {
|
66 |
+
Transformation T = Transformation::Identity();
|
67 |
+
T.template block<3,3>(0,0) = rxso3.Matrix();
|
68 |
+
T.template block<3,1>(0,3) = translation;
|
69 |
+
return T;
|
70 |
+
}
|
71 |
+
|
72 |
+
EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
|
73 |
+
Transformation T = Transformation::Identity();
|
74 |
+
T.template block<3,3>(0,0) = rxso3.Matrix();
|
75 |
+
T.template block<3,1>(0,3) = translation;
|
76 |
+
return T;
|
77 |
+
}
|
78 |
+
|
79 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,8,8> orthogonal_projector() const {
|
80 |
+
// jacobian action on a point
|
81 |
+
Eigen::Matrix<Scalar,8,8> J = Eigen::Matrix<Scalar,8,8>::Zero();
|
82 |
+
J.template block<3,3>(0,0) = Matrix3::Identity();
|
83 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
|
84 |
+
J.template block<3,1>(0,6) = translation;
|
85 |
+
J.template block<5,5>(3,3) = rxso3.orthogonal_projector();
|
86 |
+
return J;
|
87 |
+
}
|
88 |
+
|
89 |
+
EIGEN_DEVICE_FUNC Adjoint Adj() const {
|
90 |
+
Adjoint Ad = Adjoint::Identity();
|
91 |
+
Matrix3 sR = rxso3.Matrix();
|
92 |
+
Matrix3 tx = SO3<Scalar>::hat(translation);
|
93 |
+
Matrix3 R = rxso3.Rotation();
|
94 |
+
|
95 |
+
Ad.template block<3,3>(0,0) = sR;
|
96 |
+
Ad.template block<3,3>(0,3) = tx * R;
|
97 |
+
Ad.template block<3,1>(0,6) = -translation;
|
98 |
+
Ad.template block<3,3>(3,3) = R;
|
99 |
+
|
100 |
+
return Ad;
|
101 |
+
}
|
102 |
+
|
103 |
+
EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
|
104 |
+
return Adj() * a;
|
105 |
+
}
|
106 |
+
|
107 |
+
EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
|
108 |
+
return Adj().transpose() * a;
|
109 |
+
}
|
110 |
+
|
111 |
+
EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi_sigma) {
|
112 |
+
Vector3 tau = tau_phi_sigma.template segment<3>(0);
|
113 |
+
Vector3 phi = tau_phi_sigma.template segment<3>(3);
|
114 |
+
Scalar sigma = tau_phi_sigma(6);
|
115 |
+
|
116 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
117 |
+
Matrix3 I = Matrix3::Identity();
|
118 |
+
|
119 |
+
Transformation Omega = Transformation::Zero();
|
120 |
+
Omega.template block<3,3>(0,0) = Phi + sigma * I;
|
121 |
+
Omega.template block<3,1>(0,3) = tau;
|
122 |
+
|
123 |
+
return Omega;
|
124 |
+
}
|
125 |
+
|
126 |
+
EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi_sigma) {
|
127 |
+
Adjoint ad = Adjoint::Zero();
|
128 |
+
Vector3 tau = tau_phi_sigma.template segment<3>(0);
|
129 |
+
Vector3 phi = tau_phi_sigma.template segment<3>(3);
|
130 |
+
Scalar sigma = tau_phi_sigma(6);
|
131 |
+
|
132 |
+
Matrix3 Tau = SO3<Scalar>::hat(tau);
|
133 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
134 |
+
Matrix3 I = Matrix3::Identity();
|
135 |
+
|
136 |
+
ad.template block<3,3>(0,0) = Phi + sigma * I;
|
137 |
+
ad.template block<3,3>(0,3) = Tau;
|
138 |
+
ad.template block<3,1>(0,6) = -tau;
|
139 |
+
ad.template block<3,3>(3,3) = Phi;
|
140 |
+
|
141 |
+
return ad;
|
142 |
+
}
|
143 |
+
|
144 |
+
|
145 |
+
EIGEN_DEVICE_FUNC Tangent Log() const {
|
146 |
+
// logarithm map
|
147 |
+
Vector4 phi_sigma = rxso3.Log();
|
148 |
+
Matrix3 W = RxSO3<Scalar>::calcW(phi_sigma);
|
149 |
+
|
150 |
+
Tangent tau_phi_sigma;
|
151 |
+
tau_phi_sigma << W.inverse() * translation, phi_sigma;
|
152 |
+
|
153 |
+
return tau_phi_sigma;
|
154 |
+
}
|
155 |
+
|
156 |
+
EIGEN_DEVICE_FUNC static Sim3<Scalar> Exp(Tangent const& tau_phi_sigma) {
|
157 |
+
// exponential map
|
158 |
+
Vector3 tau = tau_phi_sigma.template segment<3>(0);
|
159 |
+
Vector4 phi_sigma = tau_phi_sigma.template segment<4>(3);
|
160 |
+
|
161 |
+
RxSO3<Scalar> rxso3 = RxSO3<Scalar>::Exp(phi_sigma);
|
162 |
+
Matrix3 W = RxSO3<Scalar>::calcW(phi_sigma);
|
163 |
+
|
164 |
+
return Sim3<Scalar>(rxso3, W*tau);
|
165 |
+
}
|
166 |
+
|
167 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi_sigma) {
|
168 |
+
// left jacobian
|
169 |
+
Adjoint const Xi = adj(tau_phi_sigma);
|
170 |
+
Adjoint const Xi2 = Xi * Xi;
|
171 |
+
Adjoint const Xi4 = Xi2 * Xi2;
|
172 |
+
|
173 |
+
return Adjoint::Identity()
|
174 |
+
+ Scalar(1.0/2.0)*Xi
|
175 |
+
+ Scalar(1.0/6.0)*Xi2
|
176 |
+
+ Scalar(1.0/24.0)*Xi*Xi2
|
177 |
+
+ Scalar(1.0/120.0)*Xi4;
|
178 |
+
+ Scalar(1.0/720.0)*Xi*Xi4;
|
179 |
+
}
|
180 |
+
|
181 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi_sigma) {
|
182 |
+
// left jacobian inverse
|
183 |
+
Adjoint const Xi = adj(tau_phi_sigma);
|
184 |
+
Adjoint const Xi2 = Xi * Xi;
|
185 |
+
Adjoint const Xi4 = Xi2 * Xi2;
|
186 |
+
|
187 |
+
return Adjoint::Identity()
|
188 |
+
- Scalar(1.0/2.0)*Xi
|
189 |
+
+ Scalar(1.0/12.0)*Xi2
|
190 |
+
- Scalar(1.0/720.0)*Xi4;
|
191 |
+
}
|
192 |
+
|
193 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,7> act_jacobian(Point const& p) {
|
194 |
+
// jacobian action on a point
|
195 |
+
Eigen::Matrix<Scalar,3,7> J;
|
196 |
+
J.template block<3,3>(0,0) = Matrix3::Identity();
|
197 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p);
|
198 |
+
J.template block<3,1>(0,6) = p;
|
199 |
+
return J;
|
200 |
+
}
|
201 |
+
|
202 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,7> act4_jacobian(Point4 const& p) {
|
203 |
+
// jacobian action on a point
|
204 |
+
Eigen::Matrix<Scalar,4,7> J = Eigen::Matrix<Scalar,4,7>::Zero();
|
205 |
+
J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
|
206 |
+
J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p.template segment<3>(0));
|
207 |
+
J.template block<3,1>(0,6) = p.template segment<3>(0);
|
208 |
+
return J;
|
209 |
+
}
|
210 |
+
|
211 |
+
private:
|
212 |
+
Vector3 translation;
|
213 |
+
RxSO3<Scalar> rxso3;
|
214 |
+
};
|
215 |
+
|
216 |
+
#endif
|
217 |
+
|
mini_dpvo/lietorch/include/so3.h
ADDED
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#ifndef SO3_HEADER
|
3 |
+
#define SO3_HEADER
|
4 |
+
|
5 |
+
#include <cuda.h>
|
6 |
+
#include <stdio.h>
|
7 |
+
#include <Eigen/Dense>
|
8 |
+
#include <Eigen/Geometry>
|
9 |
+
|
10 |
+
#include "common.h"
|
11 |
+
|
12 |
+
template <typename Scalar>
|
13 |
+
class SO3 {
|
14 |
+
public:
|
15 |
+
const static int constexpr K = 3; // manifold dimension
|
16 |
+
const static int constexpr N = 4; // embedding dimension
|
17 |
+
|
18 |
+
using Vector3 = Eigen::Matrix<Scalar,3,1>;
|
19 |
+
using Vector4 = Eigen::Matrix<Scalar,4,1>;
|
20 |
+
using Matrix3 = Eigen::Matrix<Scalar,3,3>;
|
21 |
+
|
22 |
+
using Tangent = Eigen::Matrix<Scalar,K,1>;
|
23 |
+
using Data = Eigen::Matrix<Scalar,N,1>;
|
24 |
+
|
25 |
+
using Point = Eigen::Matrix<Scalar,3,1>;
|
26 |
+
using Point4 = Eigen::Matrix<Scalar,4,1>;
|
27 |
+
using Transformation = Eigen::Matrix<Scalar,3,3>;
|
28 |
+
using Adjoint = Eigen::Matrix<Scalar,K,K>;
|
29 |
+
using Quaternion = Eigen::Quaternion<Scalar>;
|
30 |
+
|
31 |
+
EIGEN_DEVICE_FUNC SO3(Quaternion const& q) : unit_quaternion(q) {
|
32 |
+
unit_quaternion.normalize();
|
33 |
+
};
|
34 |
+
|
35 |
+
EIGEN_DEVICE_FUNC SO3(const Scalar *data) : unit_quaternion(data) {
|
36 |
+
unit_quaternion.normalize();
|
37 |
+
};
|
38 |
+
|
39 |
+
EIGEN_DEVICE_FUNC SO3() {
|
40 |
+
unit_quaternion = Quaternion::Identity();
|
41 |
+
}
|
42 |
+
|
43 |
+
EIGEN_DEVICE_FUNC SO3<Scalar> inv() {
|
44 |
+
return SO3<Scalar>(unit_quaternion.conjugate());
|
45 |
+
}
|
46 |
+
|
47 |
+
EIGEN_DEVICE_FUNC Data data() const {
|
48 |
+
return unit_quaternion.coeffs();
|
49 |
+
}
|
50 |
+
|
51 |
+
EIGEN_DEVICE_FUNC SO3<Scalar> operator*(SO3<Scalar> const& other) {
|
52 |
+
return SO3(unit_quaternion * other.unit_quaternion);
|
53 |
+
}
|
54 |
+
|
55 |
+
EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
|
56 |
+
const Quaternion& q = unit_quaternion;
|
57 |
+
Point uv = q.vec().cross(p);
|
58 |
+
uv += uv;
|
59 |
+
return p + q.w()*uv + q.vec().cross(uv);
|
60 |
+
}
|
61 |
+
|
62 |
+
EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
|
63 |
+
Point4 p1; p1 << this->operator*(p.template segment<3>(0)), p(3);
|
64 |
+
return p1;
|
65 |
+
}
|
66 |
+
|
67 |
+
EIGEN_DEVICE_FUNC Adjoint Adj() const {
|
68 |
+
return unit_quaternion.toRotationMatrix();
|
69 |
+
}
|
70 |
+
|
71 |
+
EIGEN_DEVICE_FUNC Transformation Matrix() const {
|
72 |
+
return unit_quaternion.toRotationMatrix();
|
73 |
+
}
|
74 |
+
|
75 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> Matrix4x4() const {
|
76 |
+
Eigen::Matrix<Scalar,4,4> T = Eigen::Matrix<Scalar,4,4>::Identity();
|
77 |
+
T.template block<3,3>(0,0) = Matrix();
|
78 |
+
return T;
|
79 |
+
}
|
80 |
+
|
81 |
+
EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> orthogonal_projector() const {
|
82 |
+
// jacobian action on a point
|
83 |
+
Eigen::Matrix<Scalar,4,4> J = Eigen::Matrix<Scalar,4,4>::Zero();
|
84 |
+
J.template block<3,3>(0,0) = 0.5 * (
|
85 |
+
unit_quaternion.w() * Matrix3::Identity() +
|
86 |
+
SO3<Scalar>::hat(-unit_quaternion.vec())
|
87 |
+
);
|
88 |
+
|
89 |
+
J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
|
90 |
+
return J;
|
91 |
+
}
|
92 |
+
|
93 |
+
EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
|
94 |
+
return Adj() * a;
|
95 |
+
}
|
96 |
+
|
97 |
+
EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
|
98 |
+
return Adj().transpose() * a;
|
99 |
+
}
|
100 |
+
|
101 |
+
EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& phi) {
|
102 |
+
Transformation Phi;
|
103 |
+
Phi <<
|
104 |
+
0.0, -phi(2), phi(1),
|
105 |
+
phi(2), 0.0, -phi(0),
|
106 |
+
-phi(1), phi(0), 0.0;
|
107 |
+
|
108 |
+
return Phi;
|
109 |
+
}
|
110 |
+
|
111 |
+
EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& phi) {
|
112 |
+
return SO3<Scalar>::hat(phi);
|
113 |
+
}
|
114 |
+
|
115 |
+
EIGEN_DEVICE_FUNC Tangent Log() const {
|
116 |
+
using std::abs;
|
117 |
+
using std::atan;
|
118 |
+
using std::sqrt;
|
119 |
+
Scalar squared_n = unit_quaternion.vec().squaredNorm();
|
120 |
+
Scalar w = unit_quaternion.w();
|
121 |
+
|
122 |
+
Scalar two_atan_nbyw_by_n;
|
123 |
+
|
124 |
+
/// Atan-based log thanks to
|
125 |
+
///
|
126 |
+
/// C. Hertzberg et al.:
|
127 |
+
/// "Integrating Generic Sensor Fusion Algorithms with Sound State
|
128 |
+
/// Representation through Encapsulation of Manifolds"
|
129 |
+
/// Information Fusion, 2011
|
130 |
+
|
131 |
+
if (squared_n < EPS * EPS) {
|
132 |
+
// If quaternion is normalized and n=0, then w should be 1;
|
133 |
+
// w=0 should never happen here!
|
134 |
+
Scalar squared_w = w * w;
|
135 |
+
two_atan_nbyw_by_n =
|
136 |
+
Scalar(2) / w - Scalar(2.0/3.0) * (squared_n) / (w * squared_w);
|
137 |
+
} else {
|
138 |
+
Scalar n = sqrt(squared_n);
|
139 |
+
if (abs(w) < EPS) {
|
140 |
+
if (w > Scalar(0)) {
|
141 |
+
two_atan_nbyw_by_n = Scalar(PI) / n;
|
142 |
+
} else {
|
143 |
+
two_atan_nbyw_by_n = -Scalar(PI) / n;
|
144 |
+
}
|
145 |
+
} else {
|
146 |
+
two_atan_nbyw_by_n = Scalar(2) * atan(n / w) / n;
|
147 |
+
}
|
148 |
+
}
|
149 |
+
|
150 |
+
return two_atan_nbyw_by_n * unit_quaternion.vec();
|
151 |
+
}
|
152 |
+
|
153 |
+
EIGEN_DEVICE_FUNC static SO3<Scalar> Exp(Tangent const& phi) {
|
154 |
+
Scalar theta2 = phi.squaredNorm();
|
155 |
+
Scalar theta = sqrt(theta2);
|
156 |
+
Scalar imag_factor;
|
157 |
+
Scalar real_factor;
|
158 |
+
|
159 |
+
if (theta < EPS) {
|
160 |
+
Scalar theta4 = theta2 * theta2;
|
161 |
+
imag_factor = Scalar(0.5) - Scalar(1.0/48.0) * theta2 + Scalar(1.0/3840.0) * theta4;
|
162 |
+
real_factor = Scalar(1) - Scalar(1.0/8.0) * theta2 + Scalar(1.0/384.0) * theta4;
|
163 |
+
} else {
|
164 |
+
imag_factor = sin(.5 * theta) / theta;
|
165 |
+
real_factor = cos(.5 * theta);
|
166 |
+
}
|
167 |
+
|
168 |
+
Quaternion q(real_factor, imag_factor*phi.x(), imag_factor*phi.y(), imag_factor*phi.z());
|
169 |
+
return SO3<Scalar>(q);
|
170 |
+
}
|
171 |
+
|
172 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& phi) {
|
173 |
+
// left jacobian
|
174 |
+
Matrix3 I = Matrix3::Identity();
|
175 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
176 |
+
Matrix3 Phi2 = Phi * Phi;
|
177 |
+
|
178 |
+
Scalar theta2 = phi.squaredNorm();
|
179 |
+
Scalar theta = sqrt(theta2);
|
180 |
+
|
181 |
+
Scalar coef1 = (theta < EPS) ?
|
182 |
+
Scalar(1.0/2.0) - Scalar(1.0/24.0) * theta2 :
|
183 |
+
(1.0 - cos(theta)) / theta2;
|
184 |
+
|
185 |
+
Scalar coef2 = (theta < EPS) ?
|
186 |
+
Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta2 :
|
187 |
+
(theta - sin(theta)) / (theta2 * theta);
|
188 |
+
|
189 |
+
return I + coef1 * Phi + coef2 * Phi2;
|
190 |
+
}
|
191 |
+
|
192 |
+
EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi) {
|
193 |
+
// left jacobian inverse
|
194 |
+
Matrix3 I = Matrix3::Identity();
|
195 |
+
Matrix3 Phi = SO3<Scalar>::hat(phi);
|
196 |
+
Matrix3 Phi2 = Phi * Phi;
|
197 |
+
|
198 |
+
Scalar theta2 = phi.squaredNorm();
|
199 |
+
Scalar theta = sqrt(theta2);
|
200 |
+
Scalar half_theta = Scalar(.5) * theta ;
|
201 |
+
|
202 |
+
Scalar coef2 = (theta < EPS) ? Scalar(1.0/12.0) :
|
203 |
+
(Scalar(1) -
|
204 |
+
theta * cos(half_theta) / (Scalar(2) * sin(half_theta))) /
|
205 |
+
(theta * theta);
|
206 |
+
|
207 |
+
return I + Scalar(-0.5) * Phi + coef2 * Phi2;
|
208 |
+
}
|
209 |
+
|
210 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,3> act_jacobian(Point const& p) {
|
211 |
+
// jacobian action on a point
|
212 |
+
return SO3<Scalar>::hat(-p);
|
213 |
+
}
|
214 |
+
|
215 |
+
EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,3> act4_jacobian(Point4 const& p) {
|
216 |
+
// jacobian action on a point
|
217 |
+
Eigen::Matrix<Scalar,4,3> J = Eigen::Matrix<Scalar,4,3>::Zero();
|
218 |
+
J.template block<3,3>(0,0) = SO3<Scalar>::hat(-p.template segment<3>(0));
|
219 |
+
return J;
|
220 |
+
}
|
221 |
+
|
222 |
+
private:
|
223 |
+
Quaternion unit_quaternion;
|
224 |
+
|
225 |
+
};
|
226 |
+
|
227 |
+
#endif
|
228 |
+
|
229 |
+
|
mini_dpvo/lietorch/run_tests.py
ADDED
@@ -0,0 +1,302 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lietorch
|
3 |
+
|
4 |
+
from lietorch import SO3, RxSO3, SE3, Sim3
|
5 |
+
from gradcheck import gradcheck, get_analytical_jacobian
|
6 |
+
|
7 |
+
|
8 |
+
### forward tests ###
|
9 |
+
|
10 |
+
def make_homogeneous(p):
|
11 |
+
return torch.cat([p, torch.ones_like(p[...,:1])], dim=-1)
|
12 |
+
|
13 |
+
def matv(A, b):
|
14 |
+
return torch.matmul(A, b[...,None])[..., 0]
|
15 |
+
|
16 |
+
def test_exp_log(Group, device='cuda'):
|
17 |
+
""" check Log(Exp(x)) == x """
|
18 |
+
a = .2*torch.randn(2,3,4,5,6,7,Group.manifold_dim, device=device).double()
|
19 |
+
b = Group.exp(a).log()
|
20 |
+
assert torch.allclose(a,b,atol=1e-8), "should be identity"
|
21 |
+
print("\t-", Group, "Passed exp-log test")
|
22 |
+
|
23 |
+
def test_inv(Group, device='cuda'):
|
24 |
+
""" check X * X^{-1} == 0 """
|
25 |
+
X = Group.exp(.1*torch.randn(2,3,4,5,Group.manifold_dim, device=device).double())
|
26 |
+
a = (X * X.inv()).log()
|
27 |
+
assert torch.allclose(a, torch.zeros_like(a), atol=1e-8), "should be 0"
|
28 |
+
print("\t-", Group, "Passed inv test")
|
29 |
+
|
30 |
+
def test_adj(Group, device='cuda'):
|
31 |
+
""" check X * Exp(a) == Exp(Adj(X,a)) * X 0 """
|
32 |
+
X = Group.exp(torch.randn(2,3,4,5, Group.manifold_dim, device=device).double())
|
33 |
+
a = torch.randn(2,3,4,5, Group.manifold_dim, device=device).double()
|
34 |
+
|
35 |
+
b = X.adj(a)
|
36 |
+
Y1 = X * Group.exp(a)
|
37 |
+
Y2 = Group.exp(b) * X
|
38 |
+
|
39 |
+
c = (Y1 * Y2.inv()).log()
|
40 |
+
assert torch.allclose(c, torch.zeros_like(c), atol=1e-8), "should be 0"
|
41 |
+
print("\t-", Group, "Passed adj test")
|
42 |
+
|
43 |
+
|
44 |
+
def test_act(Group, device='cuda'):
|
45 |
+
X = Group.exp(torch.randn(1, Group.manifold_dim, device=device).double())
|
46 |
+
p = torch.randn(1,3,device=device).double()
|
47 |
+
|
48 |
+
p1 = X.act(p)
|
49 |
+
p2 = matv(X.matrix(), make_homogeneous(p))
|
50 |
+
|
51 |
+
assert torch.allclose(p1, p2[...,:3], atol=1e-8), "should be 0"
|
52 |
+
print("\t-", Group, "Passed act test")
|
53 |
+
|
54 |
+
|
55 |
+
### backward tests ###
|
56 |
+
def test_exp_log_grad(Group, device='cuda', tol=1e-8):
|
57 |
+
|
58 |
+
D = Group.manifold_dim
|
59 |
+
|
60 |
+
def fn(a):
|
61 |
+
return Group.exp(a).log()
|
62 |
+
|
63 |
+
a = torch.zeros(1, Group.manifold_dim, requires_grad=True, device=device).double()
|
64 |
+
analytical, reentrant, correct_grad_sizes, correct_grad_types = \
|
65 |
+
get_analytical_jacobian((a,), fn(a))
|
66 |
+
|
67 |
+
assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)
|
68 |
+
|
69 |
+
a = .2 * torch.randn(1, Group.manifold_dim, requires_grad=True, device=device).double()
|
70 |
+
analytical, reentrant, correct_grad_sizes, correct_grad_types = \
|
71 |
+
get_analytical_jacobian((a,), fn(a))
|
72 |
+
|
73 |
+
assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)
|
74 |
+
|
75 |
+
print("\t-", Group, "Passed eye-grad test")
|
76 |
+
|
77 |
+
|
78 |
+
def test_inv_log_grad(Group, device='cuda', tol=1e-8):
|
79 |
+
|
80 |
+
D = Group.manifold_dim
|
81 |
+
X = Group.exp(.2*torch.randn(1,D,device=device).double())
|
82 |
+
|
83 |
+
def fn(a):
|
84 |
+
return (Group.exp(a) * X).inv().log()
|
85 |
+
|
86 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
87 |
+
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
|
88 |
+
|
89 |
+
# assert torch.allclose(analytical[0], numerical[0], atol=tol)
|
90 |
+
if not torch.allclose(analytical[0], numerical[0], atol=tol):
|
91 |
+
print(analytical[0])
|
92 |
+
print(numerical[0])
|
93 |
+
|
94 |
+
print("\t-", Group, "Passed inv-grad test")
|
95 |
+
|
96 |
+
|
97 |
+
def test_adj_grad(Group, device='cuda'):
|
98 |
+
D = Group.manifold_dim
|
99 |
+
X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double())
|
100 |
+
|
101 |
+
def fn(a, b):
|
102 |
+
return (Group.exp(a) * X).adj(b)
|
103 |
+
|
104 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
105 |
+
b = torch.randn(1, D, requires_grad=True, device=device).double()
|
106 |
+
|
107 |
+
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
|
108 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
|
109 |
+
assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
|
110 |
+
|
111 |
+
print("\t-", Group, "Passed adj-grad test")
|
112 |
+
|
113 |
+
|
114 |
+
def test_adjT_grad(Group, device='cuda'):
|
115 |
+
D = Group.manifold_dim
|
116 |
+
X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double())
|
117 |
+
|
118 |
+
def fn(a, b):
|
119 |
+
return (Group.exp(a) * X).adjT(b)
|
120 |
+
|
121 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
122 |
+
b = torch.randn(1, D, requires_grad=True, device=device).double()
|
123 |
+
|
124 |
+
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
|
125 |
+
|
126 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
|
127 |
+
assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
|
128 |
+
|
129 |
+
print("\t-", Group, "Passed adjT-grad test")
|
130 |
+
|
131 |
+
|
132 |
+
def test_act_grad(Group, device='cuda'):
|
133 |
+
D = Group.manifold_dim
|
134 |
+
X = Group.exp(5*torch.randn(1,D, device=device).double())
|
135 |
+
|
136 |
+
def fn(a, b):
|
137 |
+
return (X*Group.exp(a)).act(b)
|
138 |
+
|
139 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
140 |
+
b = torch.randn(1, 3, requires_grad=True, device=device).double()
|
141 |
+
|
142 |
+
analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
|
143 |
+
|
144 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
|
145 |
+
assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
|
146 |
+
|
147 |
+
print("\t-", Group, "Passed act-grad test")
|
148 |
+
|
149 |
+
|
150 |
+
def test_matrix_grad(Group, device='cuda'):
|
151 |
+
D = Group.manifold_dim
|
152 |
+
X = Group.exp(torch.randn(1, D, device=device).double())
|
153 |
+
|
154 |
+
def fn(a):
|
155 |
+
return (Group.exp(a) * X).matrix()
|
156 |
+
|
157 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
158 |
+
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
|
159 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-6)
|
160 |
+
|
161 |
+
print("\t-", Group, "Passed matrix-grad test")
|
162 |
+
|
163 |
+
|
164 |
+
def extract_translation_grad(Group, device='cuda'):
|
165 |
+
""" prototype function """
|
166 |
+
|
167 |
+
D = Group.manifold_dim
|
168 |
+
X = Group.exp(5*torch.randn(1,D, device=device).double())
|
169 |
+
|
170 |
+
def fn(a):
|
171 |
+
return (Group.exp(a)*X).translation()
|
172 |
+
|
173 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
174 |
+
|
175 |
+
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
|
176 |
+
|
177 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
|
178 |
+
print("\t-", Group, "Passed translation grad test")
|
179 |
+
|
180 |
+
|
181 |
+
def test_vec_grad(Group, device='cuda', tol=1e-6):
|
182 |
+
|
183 |
+
D = Group.manifold_dim
|
184 |
+
X = Group.exp(5*torch.randn(1,D, device=device).double())
|
185 |
+
|
186 |
+
def fn(a):
|
187 |
+
return (Group.exp(a)*X).vec()
|
188 |
+
|
189 |
+
a = torch.zeros(1, D, requires_grad=True, device=device).double()
|
190 |
+
|
191 |
+
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
|
192 |
+
|
193 |
+
assert torch.allclose(analytical[0], numerical[0], atol=tol)
|
194 |
+
print("\t-", Group, "Passed tovec grad test")
|
195 |
+
|
196 |
+
|
197 |
+
def test_fromvec_grad(Group, device='cuda', tol=1e-6):
|
198 |
+
|
199 |
+
def fn(a):
|
200 |
+
if Group == SO3:
|
201 |
+
a = a / a.norm(dim=-1, keepdim=True)
|
202 |
+
|
203 |
+
elif Group == RxSO3:
|
204 |
+
q, s = a.split([4, 1], dim=-1)
|
205 |
+
q = q / q.norm(dim=-1, keepdim=True)
|
206 |
+
a = torch.cat([q, s.exp()], dim=-1)
|
207 |
+
|
208 |
+
elif Group == SE3:
|
209 |
+
t, q = a.split([3, 4], dim=-1)
|
210 |
+
q = q / q.norm(dim=-1, keepdim=True)
|
211 |
+
a = torch.cat([t, q], dim=-1)
|
212 |
+
|
213 |
+
elif Group == Sim3:
|
214 |
+
t, q, s = a.split([3, 4, 1], dim=-1)
|
215 |
+
q = q / q.norm(dim=-1, keepdim=True)
|
216 |
+
a = torch.cat([t, q, s.exp()], dim=-1)
|
217 |
+
|
218 |
+
return Group.InitFromVec(a).vec()
|
219 |
+
|
220 |
+
D = Group.embedded_dim
|
221 |
+
a = torch.randn(1, 2, D, requires_grad=True, device=device).double()
|
222 |
+
|
223 |
+
analytical, numerical = gradcheck(fn, [a], eps=1e-4)
|
224 |
+
|
225 |
+
assert torch.allclose(analytical[0], numerical[0], atol=tol)
|
226 |
+
print("\t-", Group, "Passed fromvec grad test")
|
227 |
+
|
228 |
+
|
229 |
+
|
230 |
+
def scale(device='cuda'):
|
231 |
+
|
232 |
+
def fn(a, s):
|
233 |
+
X = SE3.exp(a)
|
234 |
+
X.scale(s)
|
235 |
+
return X.log()
|
236 |
+
|
237 |
+
s = torch.rand(1, requires_grad=True, device=device).double()
|
238 |
+
a = torch.randn(1, 6, requires_grad=True, device=device).double()
|
239 |
+
|
240 |
+
analytical, numerical = gradcheck(fn, [a, s], eps=1e-3)
|
241 |
+
print(analytical[1])
|
242 |
+
print(numerical[1])
|
243 |
+
|
244 |
+
|
245 |
+
assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
|
246 |
+
assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
|
247 |
+
|
248 |
+
print("\t-", "Passed se3-to-sim3 test")
|
249 |
+
|
250 |
+
|
251 |
+
if __name__ == '__main__':
|
252 |
+
|
253 |
+
|
254 |
+
print("Testing lietorch forward pass (CPU) ...")
|
255 |
+
for Group in [SO3, RxSO3, SE3, Sim3]:
|
256 |
+
test_exp_log(Group, device='cpu')
|
257 |
+
test_inv(Group, device='cpu')
|
258 |
+
test_adj(Group, device='cpu')
|
259 |
+
test_act(Group, device='cpu')
|
260 |
+
|
261 |
+
print("Testing lietorch backward pass (CPU)...")
|
262 |
+
for Group in [SO3, RxSO3, SE3, Sim3]:
|
263 |
+
if Group == Sim3:
|
264 |
+
tol = 1e-3
|
265 |
+
else:
|
266 |
+
tol = 1e-8
|
267 |
+
|
268 |
+
test_exp_log_grad(Group, device='cpu', tol=tol)
|
269 |
+
test_inv_log_grad(Group, device='cpu', tol=tol)
|
270 |
+
test_adj_grad(Group, device='cpu')
|
271 |
+
test_adjT_grad(Group, device='cpu')
|
272 |
+
test_act_grad(Group, device='cpu')
|
273 |
+
test_matrix_grad(Group, device='cpu')
|
274 |
+
extract_translation_grad(Group, device='cpu')
|
275 |
+
test_vec_grad(Group, device='cpu')
|
276 |
+
test_fromvec_grad(Group, device='cpu')
|
277 |
+
|
278 |
+
print("Testing lietorch forward pass (GPU) ...")
|
279 |
+
for Group in [SO3, RxSO3, SE3, Sim3]:
|
280 |
+
test_exp_log(Group, device='cuda')
|
281 |
+
test_inv(Group, device='cuda')
|
282 |
+
test_adj(Group, device='cuda')
|
283 |
+
test_act(Group, device='cuda')
|
284 |
+
|
285 |
+
print("Testing lietorch backward pass (GPU)...")
|
286 |
+
for Group in [SO3, RxSO3, SE3, Sim3]:
|
287 |
+
if Group == Sim3:
|
288 |
+
tol = 1e-3
|
289 |
+
else:
|
290 |
+
tol = 1e-8
|
291 |
+
|
292 |
+
test_exp_log_grad(Group, device='cuda', tol=tol)
|
293 |
+
test_inv_log_grad(Group, device='cuda', tol=tol)
|
294 |
+
test_adj_grad(Group, device='cuda')
|
295 |
+
test_adjT_grad(Group, device='cuda')
|
296 |
+
test_act_grad(Group, device='cuda')
|
297 |
+
test_matrix_grad(Group, device='cuda')
|
298 |
+
extract_translation_grad(Group, device='cuda')
|
299 |
+
test_vec_grad(Group, device='cuda')
|
300 |
+
test_fromvec_grad(Group, device='cuda')
|
301 |
+
|
302 |
+
|
mini_dpvo/lietorch/src/lietorch.cpp
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#include <torch/extension.h>
|
2 |
+
#include <vector>
|
3 |
+
#include "lietorch_gpu.h"
|
4 |
+
#include "lietorch_cpu.h"
|
5 |
+
|
6 |
+
|
7 |
+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
8 |
+
|
9 |
+
|
10 |
+
/* Interface for cuda and c++ group operations
|
11 |
+
|
12 |
+
enum group_t { SO3=1, SE3=2, Sim3=3 };
|
13 |
+
X, Y, Z: (uppercase) Lie Group Elements
|
14 |
+
a, b, c: (lowercase) Lie Algebra Elements
|
15 |
+
*/
|
16 |
+
|
17 |
+
// Unary operations
|
18 |
+
torch::Tensor expm(int group_index, torch::Tensor a) {
|
19 |
+
CHECK_CONTIGUOUS(a);
|
20 |
+
if (a.device().type() == torch::DeviceType::CPU) {
|
21 |
+
return exp_forward_cpu(group_index, a);
|
22 |
+
|
23 |
+
} else if (a.device().type() == torch::DeviceType::CUDA) {
|
24 |
+
return exp_forward_gpu(group_index, a);
|
25 |
+
}
|
26 |
+
|
27 |
+
return a;
|
28 |
+
}
|
29 |
+
|
30 |
+
std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, torch::Tensor a) {
|
31 |
+
CHECK_CONTIGUOUS(a);
|
32 |
+
CHECK_CONTIGUOUS(grad);
|
33 |
+
if (a.device().type() == torch::DeviceType::CPU) {
|
34 |
+
return exp_backward_cpu(group_index, grad, a);
|
35 |
+
|
36 |
+
} else if (a.device().type() == torch::DeviceType::CUDA) {
|
37 |
+
return exp_backward_gpu(group_index, grad, a);
|
38 |
+
}
|
39 |
+
|
40 |
+
return {};
|
41 |
+
}
|
42 |
+
|
43 |
+
torch::Tensor logm(int group_index, torch::Tensor X) {
|
44 |
+
CHECK_CONTIGUOUS(X);
|
45 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
46 |
+
return log_forward_cpu(group_index, X);
|
47 |
+
|
48 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
49 |
+
return log_forward_gpu(group_index, X);
|
50 |
+
}
|
51 |
+
|
52 |
+
return X;
|
53 |
+
}
|
54 |
+
|
55 |
+
std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
|
56 |
+
CHECK_CONTIGUOUS(X);
|
57 |
+
CHECK_CONTIGUOUS(grad);
|
58 |
+
|
59 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
60 |
+
return log_backward_cpu(group_index, grad, X);
|
61 |
+
|
62 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
63 |
+
return log_backward_gpu(group_index, grad, X);
|
64 |
+
}
|
65 |
+
|
66 |
+
return {};
|
67 |
+
}
|
68 |
+
|
69 |
+
torch::Tensor inv(int group_index, torch::Tensor X) {
|
70 |
+
CHECK_CONTIGUOUS(X);
|
71 |
+
|
72 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
73 |
+
return inv_forward_cpu(group_index, X);
|
74 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
75 |
+
return inv_forward_gpu(group_index, X);
|
76 |
+
}
|
77 |
+
|
78 |
+
return X;
|
79 |
+
}
|
80 |
+
|
81 |
+
std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
|
82 |
+
CHECK_CONTIGUOUS(X);
|
83 |
+
CHECK_CONTIGUOUS(grad);
|
84 |
+
|
85 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
86 |
+
return inv_backward_cpu(group_index, grad, X);
|
87 |
+
|
88 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
89 |
+
return inv_backward_gpu(group_index, grad, X);
|
90 |
+
}
|
91 |
+
|
92 |
+
return {};
|
93 |
+
}
|
94 |
+
|
95 |
+
// Binary operations
|
96 |
+
|
97 |
+
torch::Tensor mul(int group_index, torch::Tensor X, torch::Tensor Y) {
|
98 |
+
CHECK_CONTIGUOUS(X);
|
99 |
+
CHECK_CONTIGUOUS(Y);
|
100 |
+
|
101 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
102 |
+
return mul_forward_cpu(group_index, X, Y);
|
103 |
+
|
104 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
105 |
+
return mul_forward_gpu(group_index, X, Y);
|
106 |
+
}
|
107 |
+
|
108 |
+
return X;
|
109 |
+
}
|
110 |
+
|
111 |
+
std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
|
112 |
+
CHECK_CONTIGUOUS(X);
|
113 |
+
CHECK_CONTIGUOUS(Y);
|
114 |
+
CHECK_CONTIGUOUS(grad);
|
115 |
+
|
116 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
117 |
+
return mul_backward_cpu(group_index, grad, X, Y);
|
118 |
+
|
119 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
120 |
+
return mul_backward_gpu(group_index, grad, X, Y);
|
121 |
+
}
|
122 |
+
|
123 |
+
return {};
|
124 |
+
}
|
125 |
+
|
126 |
+
torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) {
|
127 |
+
CHECK_CONTIGUOUS(X);
|
128 |
+
CHECK_CONTIGUOUS(a);
|
129 |
+
|
130 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
131 |
+
return adj_forward_cpu(group_index, X, a);
|
132 |
+
|
133 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
134 |
+
return adj_forward_gpu(group_index, X, a);
|
135 |
+
}
|
136 |
+
|
137 |
+
return X;
|
138 |
+
}
|
139 |
+
|
140 |
+
std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
141 |
+
CHECK_CONTIGUOUS(X);
|
142 |
+
CHECK_CONTIGUOUS(a);
|
143 |
+
CHECK_CONTIGUOUS(grad);
|
144 |
+
|
145 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
146 |
+
return adj_backward_cpu(group_index, grad, X, a);
|
147 |
+
|
148 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
149 |
+
return adj_backward_gpu(group_index, grad, X, a);
|
150 |
+
}
|
151 |
+
|
152 |
+
return {};
|
153 |
+
}
|
154 |
+
|
155 |
+
torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) {
|
156 |
+
CHECK_CONTIGUOUS(X);
|
157 |
+
CHECK_CONTIGUOUS(a);
|
158 |
+
|
159 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
160 |
+
return adjT_forward_cpu(group_index, X, a);
|
161 |
+
|
162 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
163 |
+
return adjT_forward_gpu(group_index, X, a);
|
164 |
+
}
|
165 |
+
|
166 |
+
return X;
|
167 |
+
}
|
168 |
+
|
169 |
+
std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
170 |
+
CHECK_CONTIGUOUS(X);
|
171 |
+
CHECK_CONTIGUOUS(a);
|
172 |
+
CHECK_CONTIGUOUS(grad);
|
173 |
+
|
174 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
175 |
+
return adjT_backward_cpu(group_index, grad, X, a);
|
176 |
+
|
177 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
178 |
+
return adjT_backward_gpu(group_index, grad, X, a);
|
179 |
+
}
|
180 |
+
|
181 |
+
return {};
|
182 |
+
}
|
183 |
+
|
184 |
+
|
185 |
+
torch::Tensor act(int group_index, torch::Tensor X, torch::Tensor p) {
|
186 |
+
CHECK_CONTIGUOUS(X);
|
187 |
+
CHECK_CONTIGUOUS(p);
|
188 |
+
|
189 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
190 |
+
return act_forward_cpu(group_index, X, p);
|
191 |
+
|
192 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
193 |
+
return act_forward_gpu(group_index, X, p);
|
194 |
+
}
|
195 |
+
|
196 |
+
return X;
|
197 |
+
}
|
198 |
+
|
199 |
+
std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
200 |
+
CHECK_CONTIGUOUS(X);
|
201 |
+
CHECK_CONTIGUOUS(p);
|
202 |
+
CHECK_CONTIGUOUS(grad);
|
203 |
+
|
204 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
205 |
+
return act_backward_cpu(group_index, grad, X, p);
|
206 |
+
|
207 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
208 |
+
return act_backward_gpu(group_index, grad, X, p);
|
209 |
+
}
|
210 |
+
|
211 |
+
return {};
|
212 |
+
}
|
213 |
+
|
214 |
+
torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) {
|
215 |
+
CHECK_CONTIGUOUS(X);
|
216 |
+
CHECK_CONTIGUOUS(p);
|
217 |
+
|
218 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
219 |
+
return act4_forward_cpu(group_index, X, p);
|
220 |
+
|
221 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
222 |
+
return act4_forward_gpu(group_index, X, p);
|
223 |
+
}
|
224 |
+
|
225 |
+
return X;
|
226 |
+
}
|
227 |
+
|
228 |
+
std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
229 |
+
CHECK_CONTIGUOUS(X);
|
230 |
+
CHECK_CONTIGUOUS(p);
|
231 |
+
CHECK_CONTIGUOUS(grad);
|
232 |
+
|
233 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
234 |
+
return act4_backward_cpu(group_index, grad, X, p);
|
235 |
+
|
236 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
237 |
+
return act4_backward_gpu(group_index, grad, X, p);
|
238 |
+
}
|
239 |
+
|
240 |
+
return {};
|
241 |
+
}
|
242 |
+
|
243 |
+
|
244 |
+
torch::Tensor projector(int group_index, torch::Tensor X) {
|
245 |
+
CHECK_CONTIGUOUS(X);
|
246 |
+
|
247 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
248 |
+
return orthogonal_projector_cpu(group_index, X);
|
249 |
+
|
250 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
251 |
+
return orthogonal_projector_gpu(group_index, X);
|
252 |
+
}
|
253 |
+
|
254 |
+
return X;
|
255 |
+
}
|
256 |
+
|
257 |
+
|
258 |
+
torch::Tensor as_matrix(int group_index, torch::Tensor X) {
|
259 |
+
CHECK_CONTIGUOUS(X);
|
260 |
+
|
261 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
262 |
+
return as_matrix_forward_cpu(group_index, X);
|
263 |
+
|
264 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
265 |
+
return as_matrix_forward_gpu(group_index, X);
|
266 |
+
}
|
267 |
+
|
268 |
+
return X;
|
269 |
+
}
|
270 |
+
|
271 |
+
torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) {
|
272 |
+
CHECK_CONTIGUOUS(X);
|
273 |
+
CHECK_CONTIGUOUS(a);
|
274 |
+
|
275 |
+
if (X.device().type() == torch::DeviceType::CPU) {
|
276 |
+
return jleft_forward_cpu(group_index, X, a);
|
277 |
+
|
278 |
+
} else if (X.device().type() == torch::DeviceType::CUDA) {
|
279 |
+
return jleft_forward_gpu(group_index, X, a);
|
280 |
+
}
|
281 |
+
|
282 |
+
return a;
|
283 |
+
}
|
284 |
+
|
285 |
+
// {exp, log, inv, mul, adj, adjT, act, act4} forward/backward bindings
|
286 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
287 |
+
m.def("expm", &expm, "exp map forward");
|
288 |
+
m.def("expm_backward", &expm_backward, "exp map backward");
|
289 |
+
|
290 |
+
m.def("logm", &logm, "log map forward");
|
291 |
+
m.def("logm_backward", &logm_backward, "log map backward");
|
292 |
+
|
293 |
+
m.def("inv", &inv, "inverse operator");
|
294 |
+
m.def("inv_backward", &inv_backward, "inverse operator backward");
|
295 |
+
|
296 |
+
m.def("mul", &mul, "group operator");
|
297 |
+
m.def("mul_backward", &mul_backward, "group operator backward");
|
298 |
+
|
299 |
+
m.def("adj", &adj, "adjoint operator");
|
300 |
+
m.def("adj_backward", &adj_backward, "adjoint operator backward");
|
301 |
+
|
302 |
+
m.def("adjT", &adjT, "transposed adjoint operator");
|
303 |
+
m.def("adjT_backward", &adjT_backward, "transposed adjoint operator backward");
|
304 |
+
|
305 |
+
m.def("act", &act, "action on point");
|
306 |
+
m.def("act_backward", &act_backward, "action on point backward");
|
307 |
+
|
308 |
+
m.def("act4", &act4, "action on homogeneous point");
|
309 |
+
m.def("act4_backward", &act4_backward, "action on homogeneous point backward");
|
310 |
+
|
311 |
+
// functions with no gradient
|
312 |
+
m.def("as_matrix", &as_matrix, "convert to matrix");
|
313 |
+
m.def("projector", &projector, "orthogonal projection matrix");
|
314 |
+
m.def("Jinv", &Jinv, "left inverse jacobian operator");
|
315 |
+
|
316 |
+
};
|
317 |
+
|
mini_dpvo/lietorch/src/lietorch_cpu.cpp
ADDED
@@ -0,0 +1,657 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include "lietorch_cpu.h"
|
3 |
+
#include <Eigen/Dense>
|
4 |
+
|
5 |
+
#include <iostream>
|
6 |
+
#include "common.h"
|
7 |
+
#include "dispatch.h"
|
8 |
+
|
9 |
+
#include "so3.h"
|
10 |
+
#include "rxso3.h"
|
11 |
+
#include "se3.h"
|
12 |
+
#include "sim3.h"
|
13 |
+
|
14 |
+
|
15 |
+
template <typename Group, typename scalar_t>
|
16 |
+
void exp_forward_kernel(const scalar_t* a_ptr, scalar_t* X_ptr, int batch_size) {
|
17 |
+
// exponential map forward kernel
|
18 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
19 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
20 |
+
|
21 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
22 |
+
for (int64_t i=start; i<end; i++) {
|
23 |
+
Tangent a(a_ptr + i*Group::K);
|
24 |
+
Eigen::Map<Data>(X_ptr + i*Group::N) = Group::Exp(a).data();
|
25 |
+
}
|
26 |
+
});
|
27 |
+
}
|
28 |
+
|
29 |
+
template <typename Group, typename scalar_t>
|
30 |
+
void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t* da, int batch_size) {
|
31 |
+
// exponential map backward kernel
|
32 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
33 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
34 |
+
|
35 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
36 |
+
for (int64_t i=start; i<end; i++) {
|
37 |
+
Tangent a(a_ptr + i*Group::K);
|
38 |
+
Grad dX(grad + i*Group::N);
|
39 |
+
Eigen::Map<Grad>(da + i*Group::K) = dX * Group::left_jacobian(a);
|
40 |
+
}
|
41 |
+
});
|
42 |
+
}
|
43 |
+
|
44 |
+
template <typename Group, typename scalar_t>
|
45 |
+
void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int batch_size) {
|
46 |
+
// logarithm map forward kernel
|
47 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
48 |
+
|
49 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
50 |
+
for (int64_t i=start; i<end; i++) {
|
51 |
+
Tangent a = Group(X_ptr + i*Group::N).Log();
|
52 |
+
Eigen::Map<Tangent>(a_ptr + i*Group::K) = a;
|
53 |
+
}
|
54 |
+
});
|
55 |
+
}
|
56 |
+
|
57 |
+
template <typename Group, typename scalar_t>
|
58 |
+
void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
|
59 |
+
// logarithm map backward kernel
|
60 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
61 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
62 |
+
|
63 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
64 |
+
for (int64_t i=start; i<end; i++) {
|
65 |
+
Tangent a = Group(X_ptr + i*Group::N).Log();
|
66 |
+
Grad da(grad + i*Group::K);
|
67 |
+
Eigen::Map<Grad>(dX + i*Group::N) = da * Group::left_jacobian_inverse(a);
|
68 |
+
}
|
69 |
+
});
|
70 |
+
}
|
71 |
+
|
72 |
+
template <typename Group, typename scalar_t>
|
73 |
+
void inv_forward_kernel(const scalar_t* X_ptr, scalar_t* Y_ptr, int batch_size) {
|
74 |
+
// group inverse forward kernel
|
75 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
76 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
77 |
+
|
78 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
79 |
+
for (int64_t i=start; i<end; i++) {
|
80 |
+
Group X(X_ptr + i*Group::N);
|
81 |
+
Eigen::Map<Data>(Y_ptr + i*Group::N) = X.inv().data();
|
82 |
+
}
|
83 |
+
});
|
84 |
+
}
|
85 |
+
|
86 |
+
template <typename Group, typename scalar_t>
|
87 |
+
void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t *dX, int batch_size) {
|
88 |
+
// group inverse backward kernel
|
89 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
90 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
91 |
+
|
92 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
93 |
+
for (int64_t i=start; i<end; i++) {
|
94 |
+
Group Y = Group(X_ptr + i*Group::N).inv();
|
95 |
+
Grad dY(grad + i*Group::N);
|
96 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -dY * Y.Adj();
|
97 |
+
}
|
98 |
+
});
|
99 |
+
}
|
100 |
+
|
101 |
+
template <typename Group, typename scalar_t>
|
102 |
+
void mul_forward_kernel(const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* Z_ptr, int batch_size) {
|
103 |
+
// group multiplication forward kernel
|
104 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
105 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
106 |
+
|
107 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
108 |
+
for (int64_t i=start; i<end; i++) {
|
109 |
+
Group Z = Group(X_ptr + i*Group::N) * Group(Y_ptr + i*Group::N);
|
110 |
+
Eigen::Map<Data>(Z_ptr + i*Group::N) = Z.data();
|
111 |
+
}
|
112 |
+
});
|
113 |
+
}
|
114 |
+
|
115 |
+
template <class Group, typename scalar_t>
|
116 |
+
void mul_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* dX, scalar_t* dY, int batch_size) {
|
117 |
+
// group multiplication backward kernel
|
118 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
119 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
120 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
121 |
+
|
122 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
123 |
+
for (int64_t i=start; i<end; i++) {
|
124 |
+
Grad dZ(grad + i*Group::N);
|
125 |
+
Group X(X_ptr + i*Group::N);
|
126 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dZ;
|
127 |
+
Eigen::Map<Grad>(dY + i*Group::N) = dZ * X.Adj();
|
128 |
+
}
|
129 |
+
});
|
130 |
+
}
|
131 |
+
|
132 |
+
template <typename Group, typename scalar_t>
|
133 |
+
void adj_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
|
134 |
+
// adjoint forward kernel
|
135 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
136 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
137 |
+
|
138 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
139 |
+
for (int64_t i=start; i<end; i++) {
|
140 |
+
Group X(X_ptr + i*Group::N);
|
141 |
+
Tangent a(a_ptr + i*Group::K);
|
142 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.Adj(a);
|
143 |
+
}
|
144 |
+
});
|
145 |
+
}
|
146 |
+
|
147 |
+
template <typename Group, typename scalar_t>
|
148 |
+
void adj_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int batch_size) {
|
149 |
+
// adjoint backward kernel
|
150 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
151 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
152 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
153 |
+
|
154 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
155 |
+
for (int64_t i=start; i<end; i++) {
|
156 |
+
Group X(X_ptr + i*Group::N);
|
157 |
+
Grad db(grad + i*Group::K);
|
158 |
+
|
159 |
+
Tangent a(a_ptr + i*Group::K);
|
160 |
+
Tangent b = X.Adj() * a;
|
161 |
+
|
162 |
+
Eigen::Map<Grad>(da + i*Group::K) = db * X.Adj();
|
163 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -db * Group::adj(b);
|
164 |
+
}
|
165 |
+
});
|
166 |
+
}
|
167 |
+
|
168 |
+
template <typename Group, typename scalar_t>
|
169 |
+
void adjT_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
|
170 |
+
// adjoint forward kernel
|
171 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
172 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
173 |
+
|
174 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
175 |
+
for (int64_t i=start; i<end; i++) {
|
176 |
+
Group X(X_ptr + i*Group::N);
|
177 |
+
Tangent a(a_ptr + i*Group::K);
|
178 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.AdjT(a);
|
179 |
+
}
|
180 |
+
});
|
181 |
+
}
|
182 |
+
|
183 |
+
template <typename Group, typename scalar_t>
|
184 |
+
void adjT_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int batch_size) {
|
185 |
+
// adjoint backward kernel
|
186 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
187 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
188 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
189 |
+
|
190 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
191 |
+
for (int64_t i=start; i<end; i++) {
|
192 |
+
Group X(X_ptr + i*Group::N);
|
193 |
+
Tangent db(grad + i*Group::K);
|
194 |
+
Grad a(a_ptr + i*Group::K);
|
195 |
+
|
196 |
+
Eigen::Map<Tangent>(da + i*Group::K) = X.Adj(db);
|
197 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -a * Group::adj(X.Adj(db));
|
198 |
+
}
|
199 |
+
});
|
200 |
+
}
|
201 |
+
|
202 |
+
|
203 |
+
template <typename Group, typename scalar_t>
|
204 |
+
void act_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) {
|
205 |
+
// action on point forward kernel
|
206 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
207 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
208 |
+
using Point = Eigen::Matrix<scalar_t,3,1>;
|
209 |
+
|
210 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
211 |
+
for (int64_t i=start; i<end; i++) {
|
212 |
+
Group X(X_ptr + i*Group::N);
|
213 |
+
Point p(p_ptr + i*3);
|
214 |
+
Eigen::Map<Point>(q_ptr + i*3) = X * p;
|
215 |
+
}
|
216 |
+
});
|
217 |
+
}
|
218 |
+
|
219 |
+
template <typename Group, typename scalar_t>
|
220 |
+
void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int batch_size) {
|
221 |
+
// adjoint backward kernel
|
222 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
223 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
224 |
+
using Point = Eigen::Matrix<scalar_t,3,1>;
|
225 |
+
using PointGrad = Eigen::Matrix<scalar_t,1,3>;
|
226 |
+
using Transformation = Eigen::Matrix<scalar_t,4,4>;
|
227 |
+
|
228 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
229 |
+
for (int64_t i=start; i<end; i++) {
|
230 |
+
Group X(X_ptr + i*Group::N);
|
231 |
+
Point p(p_ptr + i*3);
|
232 |
+
PointGrad dq(grad + i*3);
|
233 |
+
|
234 |
+
Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix().template block<3,3>(0,0);
|
235 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p);
|
236 |
+
}
|
237 |
+
});
|
238 |
+
}
|
239 |
+
|
240 |
+
|
241 |
+
// template <typename Group, typename scalar_t>
|
242 |
+
// void tovec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
|
243 |
+
// // group inverse forward kernel
|
244 |
+
// using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
245 |
+
// using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
|
246 |
+
|
247 |
+
// at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
248 |
+
// for (int64_t i=start; i<end; i++) {
|
249 |
+
// Group X(X_ptr + i*Group::N);
|
250 |
+
// Grad g(grad + i*Group::N);
|
251 |
+
// Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
|
252 |
+
// }
|
253 |
+
// });
|
254 |
+
// }
|
255 |
+
|
256 |
+
// template <typename Group, typename scalar_t>
|
257 |
+
// void fromvec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
|
258 |
+
// // group inverse forward kernel
|
259 |
+
// using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
260 |
+
// using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
|
261 |
+
|
262 |
+
// at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
263 |
+
// for (int64_t i=start; i<end; i++) {
|
264 |
+
// Group X(X_ptr + i*Group::N);
|
265 |
+
// Grad g(grad + i*Group::N);
|
266 |
+
// Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
|
267 |
+
// }
|
268 |
+
// });
|
269 |
+
// }
|
270 |
+
|
271 |
+
|
272 |
+
template <typename Group, typename scalar_t>
|
273 |
+
void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) {
|
274 |
+
// action on homogeneous point forward kernel
|
275 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
276 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
277 |
+
using Point = Eigen::Matrix<scalar_t,4,1>;
|
278 |
+
|
279 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
280 |
+
for (int64_t i=start; i<end; i++) {
|
281 |
+
Group X(X_ptr + i*Group::N);
|
282 |
+
Point p(p_ptr + i*4);
|
283 |
+
Eigen::Map<Point>(q_ptr + i*4) = X.act4(p);
|
284 |
+
}
|
285 |
+
});
|
286 |
+
}
|
287 |
+
|
288 |
+
template <typename Group, typename scalar_t>
|
289 |
+
void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int batch_size) {
|
290 |
+
// action on homogeneous point backward kernel
|
291 |
+
|
292 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
293 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
294 |
+
using Point = Eigen::Matrix<scalar_t,4,1>;
|
295 |
+
using PointGrad = Eigen::Matrix<scalar_t,1,4>;
|
296 |
+
using Transformation = Eigen::Matrix<scalar_t,4,4>;
|
297 |
+
|
298 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
299 |
+
for (int64_t i=start; i<end; i++) {
|
300 |
+
Group X(X_ptr + i*Group::N);
|
301 |
+
Point p(p_ptr + i*4);
|
302 |
+
PointGrad dq(grad + i*4);
|
303 |
+
|
304 |
+
Eigen::Map<PointGrad>(dp + i*4) = dq * X.Matrix4x4();
|
305 |
+
const Point q = X.act4(p);
|
306 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act4_jacobian(q);
|
307 |
+
}
|
308 |
+
});
|
309 |
+
}
|
310 |
+
|
311 |
+
template <typename Group, typename scalar_t>
|
312 |
+
void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int batch_size) {
|
313 |
+
// group inverse forward kernel
|
314 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
315 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
316 |
+
using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>;
|
317 |
+
|
318 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
319 |
+
for (int64_t i=start; i<end; i++) {
|
320 |
+
Group X(X_ptr + i*Group::N);
|
321 |
+
Eigen::Map<Matrix4>(T_ptr + i*16) = X.Matrix4x4();
|
322 |
+
}
|
323 |
+
});
|
324 |
+
}
|
325 |
+
|
326 |
+
template <typename Group, typename scalar_t>
|
327 |
+
void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int batch_size) {
|
328 |
+
// group inverse forward kernel
|
329 |
+
using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
|
330 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
331 |
+
for (int64_t i=start; i<end; i++) {
|
332 |
+
Group X(X_ptr + i*Group::N);
|
333 |
+
Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
|
334 |
+
}
|
335 |
+
});
|
336 |
+
}
|
337 |
+
|
338 |
+
template <typename Group, typename scalar_t>
|
339 |
+
void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
|
340 |
+
// left-jacobian inverse action
|
341 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
342 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
343 |
+
|
344 |
+
at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
|
345 |
+
for (int64_t i=start; i<end; i++) {
|
346 |
+
Group X(X_ptr + i*Group::N);
|
347 |
+
Tangent a(a_ptr + i*Group::K);
|
348 |
+
Tangent b = Group::left_jacobian_inverse(X.Log()) * a;
|
349 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = b;
|
350 |
+
}
|
351 |
+
});
|
352 |
+
}
|
353 |
+
|
354 |
+
// unary operations
|
355 |
+
|
356 |
+
torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
|
357 |
+
int batch_size = a.size(0);
|
358 |
+
torch::Tensor X;
|
359 |
+
|
360 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
|
361 |
+
X = torch::zeros({batch_size, group_t::N}, a.options());
|
362 |
+
exp_forward_kernel<group_t, scalar_t>(
|
363 |
+
a.data_ptr<scalar_t>(),
|
364 |
+
X.data_ptr<scalar_t>(),
|
365 |
+
batch_size);
|
366 |
+
}));
|
367 |
+
|
368 |
+
return X;
|
369 |
+
}
|
370 |
+
|
371 |
+
std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor a) {
|
372 |
+
int batch_size = a.size(0);
|
373 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
374 |
+
|
375 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
|
376 |
+
exp_backward_kernel<group_t, scalar_t>(
|
377 |
+
grad.data_ptr<scalar_t>(),
|
378 |
+
a.data_ptr<scalar_t>(),
|
379 |
+
da.data_ptr<scalar_t>(),
|
380 |
+
batch_size);
|
381 |
+
}));
|
382 |
+
|
383 |
+
return {da};
|
384 |
+
}
|
385 |
+
|
386 |
+
torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
|
387 |
+
int batch_size = X.size(0);
|
388 |
+
torch::Tensor a;
|
389 |
+
|
390 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
|
391 |
+
a = torch::zeros({batch_size, group_t::K}, X.options());
|
392 |
+
log_forward_kernel<group_t, scalar_t>(
|
393 |
+
X.data_ptr<scalar_t>(),
|
394 |
+
a.data_ptr<scalar_t>(),
|
395 |
+
batch_size);
|
396 |
+
}));
|
397 |
+
|
398 |
+
return a;
|
399 |
+
}
|
400 |
+
|
401 |
+
std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X) {
|
402 |
+
int batch_size = X.size(0);
|
403 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
404 |
+
|
405 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
|
406 |
+
log_backward_kernel<group_t, scalar_t>(
|
407 |
+
grad.data_ptr<scalar_t>(),
|
408 |
+
X.data_ptr<scalar_t>(),
|
409 |
+
dX.data_ptr<scalar_t>(),
|
410 |
+
batch_size);
|
411 |
+
}));
|
412 |
+
|
413 |
+
return {dX};
|
414 |
+
}
|
415 |
+
|
416 |
+
torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
|
417 |
+
int batch_size = X.size(0);
|
418 |
+
torch::Tensor Y = torch::zeros_like(X);
|
419 |
+
|
420 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
|
421 |
+
inv_forward_kernel<group_t, scalar_t>(
|
422 |
+
X.data_ptr<scalar_t>(),
|
423 |
+
Y.data_ptr<scalar_t>(),
|
424 |
+
batch_size);
|
425 |
+
}));
|
426 |
+
|
427 |
+
return Y;
|
428 |
+
}
|
429 |
+
|
430 |
+
std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X) {
|
431 |
+
int batch_size = X.size(0);
|
432 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
433 |
+
|
434 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
|
435 |
+
inv_backward_kernel<group_t, scalar_t>(
|
436 |
+
grad.data_ptr<scalar_t>(),
|
437 |
+
X.data_ptr<scalar_t>(),
|
438 |
+
dX.data_ptr<scalar_t>(),
|
439 |
+
batch_size);
|
440 |
+
}));
|
441 |
+
|
442 |
+
return {dX};
|
443 |
+
}
|
444 |
+
|
445 |
+
// binary operations
|
446 |
+
torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
|
447 |
+
int batch_size = X.size(0);
|
448 |
+
torch::Tensor Z = torch::zeros_like(X);
|
449 |
+
|
450 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
|
451 |
+
mul_forward_kernel<group_t, scalar_t>(
|
452 |
+
X.data_ptr<scalar_t>(),
|
453 |
+
Y.data_ptr<scalar_t>(),
|
454 |
+
Z.data_ptr<scalar_t>(),
|
455 |
+
batch_size);
|
456 |
+
}));
|
457 |
+
|
458 |
+
return Z;
|
459 |
+
}
|
460 |
+
|
461 |
+
std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
|
462 |
+
int batch_size = X.size(0);
|
463 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
464 |
+
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
|
465 |
+
|
466 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
|
467 |
+
mul_backward_kernel<group_t, scalar_t>(
|
468 |
+
grad.data_ptr<scalar_t>(),
|
469 |
+
X.data_ptr<scalar_t>(),
|
470 |
+
Y.data_ptr<scalar_t>(),
|
471 |
+
dX.data_ptr<scalar_t>(),
|
472 |
+
dY.data_ptr<scalar_t>(),
|
473 |
+
batch_size);
|
474 |
+
}));
|
475 |
+
|
476 |
+
return {dX, dY};
|
477 |
+
}
|
478 |
+
|
479 |
+
torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
480 |
+
int batch_size = X.size(0);
|
481 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
482 |
+
|
483 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
|
484 |
+
adj_forward_kernel<group_t, scalar_t>(
|
485 |
+
X.data_ptr<scalar_t>(),
|
486 |
+
a.data_ptr<scalar_t>(),
|
487 |
+
b.data_ptr<scalar_t>(),
|
488 |
+
batch_size);
|
489 |
+
}));
|
490 |
+
|
491 |
+
return b;
|
492 |
+
}
|
493 |
+
|
494 |
+
std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
495 |
+
int batch_size = X.size(0);
|
496 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
497 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
498 |
+
|
499 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
|
500 |
+
adj_backward_kernel<group_t, scalar_t>(
|
501 |
+
grad.data_ptr<scalar_t>(),
|
502 |
+
X.data_ptr<scalar_t>(),
|
503 |
+
a.data_ptr<scalar_t>(),
|
504 |
+
dX.data_ptr<scalar_t>(),
|
505 |
+
da.data_ptr<scalar_t>(),
|
506 |
+
batch_size);
|
507 |
+
}));
|
508 |
+
|
509 |
+
return {dX, da};
|
510 |
+
}
|
511 |
+
|
512 |
+
|
513 |
+
torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
514 |
+
int batch_size = X.size(0);
|
515 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
516 |
+
|
517 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
|
518 |
+
adjT_forward_kernel<group_t, scalar_t>(
|
519 |
+
X.data_ptr<scalar_t>(),
|
520 |
+
a.data_ptr<scalar_t>(),
|
521 |
+
b.data_ptr<scalar_t>(),
|
522 |
+
batch_size);
|
523 |
+
}));
|
524 |
+
|
525 |
+
return b;
|
526 |
+
}
|
527 |
+
|
528 |
+
std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
529 |
+
int batch_size = X.size(0);
|
530 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
531 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
532 |
+
|
533 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
|
534 |
+
adjT_backward_kernel<group_t, scalar_t>(
|
535 |
+
grad.data_ptr<scalar_t>(),
|
536 |
+
X.data_ptr<scalar_t>(),
|
537 |
+
a.data_ptr<scalar_t>(),
|
538 |
+
dX.data_ptr<scalar_t>(),
|
539 |
+
da.data_ptr<scalar_t>(),
|
540 |
+
batch_size);
|
541 |
+
}));
|
542 |
+
|
543 |
+
return {dX, da};
|
544 |
+
}
|
545 |
+
|
546 |
+
|
547 |
+
torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
|
548 |
+
int batch_size = X.size(0);
|
549 |
+
torch::Tensor q = torch::zeros(p.sizes(), p.options());
|
550 |
+
|
551 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
|
552 |
+
act_forward_kernel<group_t, scalar_t>(
|
553 |
+
X.data_ptr<scalar_t>(),
|
554 |
+
p.data_ptr<scalar_t>(),
|
555 |
+
q.data_ptr<scalar_t>(),
|
556 |
+
batch_size);
|
557 |
+
}));
|
558 |
+
|
559 |
+
return q;
|
560 |
+
}
|
561 |
+
|
562 |
+
std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
563 |
+
int batch_size = X.size(0);
|
564 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
565 |
+
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
|
566 |
+
|
567 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
|
568 |
+
act_backward_kernel<group_t, scalar_t>(
|
569 |
+
grad.data_ptr<scalar_t>(),
|
570 |
+
X.data_ptr<scalar_t>(),
|
571 |
+
p.data_ptr<scalar_t>(),
|
572 |
+
dX.data_ptr<scalar_t>(),
|
573 |
+
dp.data_ptr<scalar_t>(),
|
574 |
+
batch_size);
|
575 |
+
}));
|
576 |
+
|
577 |
+
return {dX, dp};
|
578 |
+
}
|
579 |
+
|
580 |
+
|
581 |
+
torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
|
582 |
+
int batch_size = X.size(0);
|
583 |
+
torch::Tensor q = torch::zeros(p.sizes(), p.options());
|
584 |
+
|
585 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
|
586 |
+
act4_forward_kernel<group_t, scalar_t>(
|
587 |
+
X.data_ptr<scalar_t>(),
|
588 |
+
p.data_ptr<scalar_t>(),
|
589 |
+
q.data_ptr<scalar_t>(),
|
590 |
+
batch_size);
|
591 |
+
}));
|
592 |
+
|
593 |
+
return q;
|
594 |
+
}
|
595 |
+
|
596 |
+
std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
597 |
+
int batch_size = X.size(0);
|
598 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
599 |
+
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
|
600 |
+
|
601 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
|
602 |
+
act4_backward_kernel<group_t, scalar_t>(
|
603 |
+
grad.data_ptr<scalar_t>(),
|
604 |
+
X.data_ptr<scalar_t>(),
|
605 |
+
p.data_ptr<scalar_t>(),
|
606 |
+
dX.data_ptr<scalar_t>(),
|
607 |
+
dp.data_ptr<scalar_t>(),
|
608 |
+
batch_size);
|
609 |
+
}));
|
610 |
+
|
611 |
+
return {dX, dp};
|
612 |
+
}
|
613 |
+
|
614 |
+
|
615 |
+
torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
|
616 |
+
int batch_size = X.size(0);
|
617 |
+
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
|
618 |
+
|
619 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
|
620 |
+
as_matrix_forward_kernel<group_t, scalar_t>(
|
621 |
+
X.data_ptr<scalar_t>(),
|
622 |
+
T4x4.data_ptr<scalar_t>(),
|
623 |
+
batch_size);
|
624 |
+
}));
|
625 |
+
|
626 |
+
return T4x4;
|
627 |
+
}
|
628 |
+
|
629 |
+
|
630 |
+
torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
|
631 |
+
int batch_size = X.size(0);
|
632 |
+
torch::Tensor P;
|
633 |
+
|
634 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
|
635 |
+
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
|
636 |
+
orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size);
|
637 |
+
}));
|
638 |
+
|
639 |
+
return P;
|
640 |
+
}
|
641 |
+
|
642 |
+
|
643 |
+
|
644 |
+
torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
645 |
+
int batch_size = X.size(0);
|
646 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
647 |
+
|
648 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
|
649 |
+
jleft_forward_kernel<group_t, scalar_t>(
|
650 |
+
X.data_ptr<scalar_t>(),
|
651 |
+
a.data_ptr<scalar_t>(),
|
652 |
+
b.data_ptr<scalar_t>(),
|
653 |
+
batch_size);
|
654 |
+
}));
|
655 |
+
|
656 |
+
return b;
|
657 |
+
}
|
mini_dpvo/lietorch/src/lietorch_gpu.cu
ADDED
@@ -0,0 +1,601 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
#include "lietorch_gpu.h"
|
3 |
+
#include <Eigen/Dense>
|
4 |
+
|
5 |
+
#include "common.h"
|
6 |
+
#include "dispatch.h"
|
7 |
+
|
8 |
+
#include "so3.h"
|
9 |
+
#include "rxso3.h"
|
10 |
+
#include "se3.h"
|
11 |
+
#include "sim3.h"
|
12 |
+
|
13 |
+
#define GPU_1D_KERNEL_LOOP(i, n) \
|
14 |
+
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i<n; i += blockDim.x * gridDim.x)
|
15 |
+
|
16 |
+
#define NUM_THREADS 256
|
17 |
+
#define NUM_BLOCKS(batch_size) ((batch_size + NUM_THREADS - 1) / NUM_THREADS)
|
18 |
+
|
19 |
+
|
20 |
+
template <typename Group, typename scalar_t>
|
21 |
+
__global__ void exp_forward_kernel(const scalar_t* a_ptr, scalar_t* X_ptr, int num_threads) {
|
22 |
+
// exponential map forward kernel
|
23 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
24 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
25 |
+
|
26 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
27 |
+
Tangent a(a_ptr + i*Group::K);
|
28 |
+
Eigen::Map<Data>(X_ptr + i*Group::N) = Group::Exp(a).data();
|
29 |
+
}
|
30 |
+
}
|
31 |
+
|
32 |
+
template <typename Group, typename scalar_t>
|
33 |
+
__global__ void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t* da, int num_threads) {
|
34 |
+
// exponential map backward kernel
|
35 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
36 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
37 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
38 |
+
|
39 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
40 |
+
Tangent a(a_ptr + i*Group::K);
|
41 |
+
Grad dX(grad + i*Group::N);
|
42 |
+
Eigen::Map<Grad>(da + i*Group::K) = dX * Group::left_jacobian(a);
|
43 |
+
}
|
44 |
+
}
|
45 |
+
|
46 |
+
template <typename Group, typename scalar_t>
|
47 |
+
__global__ void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int num_threads) {
|
48 |
+
// logarithm map forward kernel
|
49 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
50 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
51 |
+
|
52 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
53 |
+
Tangent a = Group(X_ptr + i*Group::N).Log();
|
54 |
+
Eigen::Map<Tangent>(a_ptr + i*Group::K) = a;
|
55 |
+
}
|
56 |
+
}
|
57 |
+
|
58 |
+
template <typename Group, typename scalar_t>
|
59 |
+
__global__ void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int num_threads) {
|
60 |
+
// logarithm map backward kernel
|
61 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
62 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
63 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
64 |
+
|
65 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
66 |
+
Tangent a = Group(X_ptr + i*Group::N).Log();
|
67 |
+
Grad da(grad + i*Group::K);
|
68 |
+
Eigen::Map<Grad>(dX + i*Group::N) = da * Group::left_jacobian_inverse(a);
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
template <typename Group, typename scalar_t>
|
73 |
+
__global__ void inv_forward_kernel(const scalar_t* X_ptr, scalar_t* Y_ptr, int num_threads) {
|
74 |
+
// group inverse forward kernel
|
75 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
76 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
77 |
+
|
78 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
79 |
+
Group X(X_ptr + i*Group::N);
|
80 |
+
Eigen::Map<Data>(Y_ptr + i*Group::N) = X.inv().data();
|
81 |
+
}
|
82 |
+
}
|
83 |
+
|
84 |
+
|
85 |
+
template <typename Group, typename scalar_t>
|
86 |
+
__global__ void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t *dX, int num_threads) {
|
87 |
+
// group inverse backward kernel
|
88 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
89 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
90 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
91 |
+
|
92 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
93 |
+
Group Y = Group(X_ptr + i*Group::N).inv();
|
94 |
+
Grad dY(grad + i*Group::N);
|
95 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -dY * Y.Adj();
|
96 |
+
}
|
97 |
+
}
|
98 |
+
|
99 |
+
|
100 |
+
template <typename Group, typename scalar_t>
|
101 |
+
__global__ void mul_forward_kernel(const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* Z_ptr, int num_threads) {
|
102 |
+
// group multiplication forward kernel
|
103 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
104 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
105 |
+
|
106 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
107 |
+
Group Z = Group(X_ptr + i*Group::N) * Group(Y_ptr + i*Group::N);
|
108 |
+
Eigen::Map<Data>(Z_ptr + i*Group::N) = Z.data();
|
109 |
+
}
|
110 |
+
}
|
111 |
+
|
112 |
+
template <class Group, typename scalar_t>
|
113 |
+
__global__ void mul_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* dX, scalar_t* dY, int num_threads) {
|
114 |
+
// group multiplication backward kernel
|
115 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
116 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
117 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
118 |
+
|
119 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
120 |
+
Grad dZ(grad + i*Group::N);
|
121 |
+
Group X(X_ptr + i*Group::N);
|
122 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dZ;
|
123 |
+
Eigen::Map<Grad>(dY + i*Group::N) = dZ * X.Adj();
|
124 |
+
}
|
125 |
+
}
|
126 |
+
|
127 |
+
template <typename Group, typename scalar_t>
|
128 |
+
__global__ void adj_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
|
129 |
+
// adjoint forward kernel
|
130 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
131 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
132 |
+
|
133 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
134 |
+
Group X(X_ptr + i*Group::N);
|
135 |
+
Tangent a(a_ptr + i*Group::K);
|
136 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.Adj(a);
|
137 |
+
}
|
138 |
+
}
|
139 |
+
|
140 |
+
template <typename Group, typename scalar_t>
|
141 |
+
__global__ void adj_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int num_threads) {
|
142 |
+
// adjoint backward kernel
|
143 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
144 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
145 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
146 |
+
|
147 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
148 |
+
Group X(X_ptr + i*Group::N);
|
149 |
+
Grad db(grad + i*Group::K);
|
150 |
+
|
151 |
+
Tangent a(a_ptr + i*Group::K);
|
152 |
+
Tangent b = X.Adj() * a;
|
153 |
+
|
154 |
+
Eigen::Map<Grad>(da + i*Group::K) = db * X.Adj();
|
155 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -db * Group::adj(b);
|
156 |
+
}
|
157 |
+
}
|
158 |
+
|
159 |
+
|
160 |
+
template <typename Group, typename scalar_t>
|
161 |
+
__global__ void adjT_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
|
162 |
+
// adjoint forward kernel
|
163 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
164 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
165 |
+
|
166 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
167 |
+
Group X(X_ptr + i*Group::N);
|
168 |
+
Tangent a(a_ptr + i*Group::K);
|
169 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.AdjT(a);
|
170 |
+
}
|
171 |
+
}
|
172 |
+
|
173 |
+
template <typename Group, typename scalar_t>
|
174 |
+
__global__ void adjT_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int num_threads) {
|
175 |
+
// adjoint backward kernel
|
176 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
177 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
178 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
179 |
+
|
180 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
181 |
+
Group X(X_ptr + i*Group::N);
|
182 |
+
Tangent db(grad + i*Group::K);
|
183 |
+
Grad a(a_ptr + i*Group::K);
|
184 |
+
|
185 |
+
Eigen::Map<Tangent>(da + i*Group::K) = X.Adj(db);
|
186 |
+
Eigen::Map<Grad>(dX + i*Group::N) = -a * Group::adj(X.Adj(db));
|
187 |
+
}
|
188 |
+
}
|
189 |
+
|
190 |
+
template <typename Group, typename scalar_t>
|
191 |
+
__global__ void act_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int num_threads) {
|
192 |
+
// action on point forward kernel
|
193 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
194 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
195 |
+
using Point = Eigen::Matrix<scalar_t,3,1>;
|
196 |
+
|
197 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
198 |
+
Group X(X_ptr + i*Group::N);
|
199 |
+
Point p(p_ptr + i*3);
|
200 |
+
Eigen::Map<Point>(q_ptr + i*3) = X * p;
|
201 |
+
}
|
202 |
+
}
|
203 |
+
|
204 |
+
template <typename Group, typename scalar_t>
|
205 |
+
__global__ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) {
|
206 |
+
// adjoint backward kernel
|
207 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
208 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
209 |
+
using Point = Eigen::Matrix<scalar_t,3,1>;
|
210 |
+
using PointGrad = Eigen::Matrix<scalar_t,1,3>;
|
211 |
+
using Transformation = Eigen::Matrix<scalar_t,4,4>;
|
212 |
+
|
213 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
214 |
+
Group X(X_ptr + i*Group::N);
|
215 |
+
Point p(p_ptr + i*3);
|
216 |
+
PointGrad dq(grad + i*3);
|
217 |
+
|
218 |
+
Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix4x4().block<3,3>(0,0);
|
219 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p);
|
220 |
+
}
|
221 |
+
}
|
222 |
+
|
223 |
+
|
224 |
+
template <typename Group, typename scalar_t>
|
225 |
+
__global__ void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int num_threads) {
|
226 |
+
// action on point forward kernel
|
227 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
228 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
229 |
+
using Point = Eigen::Matrix<scalar_t,4,1>;
|
230 |
+
|
231 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
232 |
+
Group X(X_ptr + i*Group::N);
|
233 |
+
Point p(p_ptr + i*4);
|
234 |
+
Eigen::Map<Point>(q_ptr + i*4) = X.act4(p);
|
235 |
+
}
|
236 |
+
}
|
237 |
+
|
238 |
+
template <typename Group, typename scalar_t>
|
239 |
+
__global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) {
|
240 |
+
// adjoint backward kernel
|
241 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
242 |
+
using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
|
243 |
+
using Point = Eigen::Matrix<scalar_t,4,1>;
|
244 |
+
using PointGrad = Eigen::Matrix<scalar_t,1,4>;
|
245 |
+
using Transformation = Eigen::Matrix<scalar_t,4,4>;
|
246 |
+
|
247 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
248 |
+
Group X(X_ptr + i*Group::N);
|
249 |
+
Point p(p_ptr + i*4);
|
250 |
+
PointGrad dq(grad + i*4);
|
251 |
+
|
252 |
+
Eigen::Map<PointGrad>(dp + i*4) = dq * X.Matrix4x4();
|
253 |
+
const Point q = X.act4(p);
|
254 |
+
Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act4_jacobian(q);
|
255 |
+
}
|
256 |
+
}
|
257 |
+
|
258 |
+
template <typename Group, typename scalar_t>
|
259 |
+
__global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int num_threads) {
|
260 |
+
// convert to 4x4 matrix representation
|
261 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
262 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
263 |
+
using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>;
|
264 |
+
|
265 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
266 |
+
Group X(X_ptr + i*Group::N);
|
267 |
+
Eigen::Map<Matrix4>(T_ptr + i*16) = X.Matrix4x4();
|
268 |
+
}
|
269 |
+
}
|
270 |
+
|
271 |
+
template <typename Group, typename scalar_t>
|
272 |
+
__global__ void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int num_threads) {
|
273 |
+
// orthogonal projection matrix
|
274 |
+
using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
|
275 |
+
|
276 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
277 |
+
Group X(X_ptr + i*Group::N);
|
278 |
+
Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
|
279 |
+
}
|
280 |
+
}
|
281 |
+
|
282 |
+
template <typename Group, typename scalar_t>
|
283 |
+
__global__ void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
|
284 |
+
// left jacobian inverse action
|
285 |
+
using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
|
286 |
+
using Data = Eigen::Matrix<scalar_t,Group::N,1>;
|
287 |
+
|
288 |
+
GPU_1D_KERNEL_LOOP(i, num_threads) {
|
289 |
+
Group X(X_ptr + i*Group::N);
|
290 |
+
Tangent a(a_ptr + i*Group::K);
|
291 |
+
Tangent b = Group::left_jacobian_inverse(X.Log()) * a;
|
292 |
+
Eigen::Map<Tangent>(b_ptr + i*Group::K) = b;
|
293 |
+
}
|
294 |
+
}
|
295 |
+
|
296 |
+
// unary operations
|
297 |
+
|
298 |
+
torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
|
299 |
+
int batch_size = a.size(0);
|
300 |
+
torch::Tensor X;
|
301 |
+
|
302 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
|
303 |
+
X = torch::zeros({batch_size, group_t::N}, a.options());
|
304 |
+
exp_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
305 |
+
a.data_ptr<scalar_t>(),
|
306 |
+
X.data_ptr<scalar_t>(),
|
307 |
+
batch_size);
|
308 |
+
}));
|
309 |
+
|
310 |
+
return X;
|
311 |
+
}
|
312 |
+
|
313 |
+
std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor a) {
|
314 |
+
int batch_size = a.size(0);
|
315 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
316 |
+
|
317 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
|
318 |
+
exp_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
319 |
+
grad.data_ptr<scalar_t>(),
|
320 |
+
a.data_ptr<scalar_t>(),
|
321 |
+
da.data_ptr<scalar_t>(),
|
322 |
+
batch_size);
|
323 |
+
}));
|
324 |
+
|
325 |
+
return {da};
|
326 |
+
}
|
327 |
+
|
328 |
+
torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
|
329 |
+
int batch_size = X.size(0);
|
330 |
+
torch::Tensor a;
|
331 |
+
|
332 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
|
333 |
+
a = torch::zeros({batch_size, group_t::K}, X.options());
|
334 |
+
log_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
335 |
+
X.data_ptr<scalar_t>(),
|
336 |
+
a.data_ptr<scalar_t>(),
|
337 |
+
batch_size);
|
338 |
+
}));
|
339 |
+
|
340 |
+
return a;
|
341 |
+
}
|
342 |
+
|
343 |
+
std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X) {
|
344 |
+
int batch_size = X.size(0);
|
345 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
346 |
+
|
347 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
|
348 |
+
log_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
349 |
+
grad.data_ptr<scalar_t>(),
|
350 |
+
X.data_ptr<scalar_t>(),
|
351 |
+
dX.data_ptr<scalar_t>(),
|
352 |
+
batch_size);
|
353 |
+
}));
|
354 |
+
|
355 |
+
return {dX};
|
356 |
+
}
|
357 |
+
|
358 |
+
torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
|
359 |
+
int batch_size = X.size(0);
|
360 |
+
torch::Tensor Y = torch::zeros_like(X);
|
361 |
+
|
362 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
|
363 |
+
inv_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
364 |
+
X.data_ptr<scalar_t>(),
|
365 |
+
Y.data_ptr<scalar_t>(),
|
366 |
+
batch_size);
|
367 |
+
}));
|
368 |
+
|
369 |
+
return Y;
|
370 |
+
}
|
371 |
+
|
372 |
+
std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X) {
|
373 |
+
int batch_size = X.size(0);
|
374 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
375 |
+
|
376 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
|
377 |
+
inv_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
378 |
+
grad.data_ptr<scalar_t>(),
|
379 |
+
X.data_ptr<scalar_t>(),
|
380 |
+
dX.data_ptr<scalar_t>(),
|
381 |
+
batch_size);
|
382 |
+
}));
|
383 |
+
|
384 |
+
return {dX};
|
385 |
+
}
|
386 |
+
|
387 |
+
// binary operations
|
388 |
+
torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
|
389 |
+
int batch_size = X.size(0);
|
390 |
+
torch::Tensor Z = torch::zeros_like(X);
|
391 |
+
|
392 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
|
393 |
+
mul_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
394 |
+
X.data_ptr<scalar_t>(),
|
395 |
+
Y.data_ptr<scalar_t>(),
|
396 |
+
Z.data_ptr<scalar_t>(),
|
397 |
+
batch_size);
|
398 |
+
}));
|
399 |
+
|
400 |
+
return Z;
|
401 |
+
}
|
402 |
+
|
403 |
+
std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
|
404 |
+
int batch_size = X.size(0);
|
405 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
406 |
+
torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
|
407 |
+
|
408 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
|
409 |
+
mul_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
410 |
+
grad.data_ptr<scalar_t>(),
|
411 |
+
X.data_ptr<scalar_t>(),
|
412 |
+
Y.data_ptr<scalar_t>(),
|
413 |
+
dX.data_ptr<scalar_t>(),
|
414 |
+
dY.data_ptr<scalar_t>(),
|
415 |
+
batch_size);
|
416 |
+
}));
|
417 |
+
|
418 |
+
return {dX, dY};
|
419 |
+
}
|
420 |
+
|
421 |
+
torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
422 |
+
int batch_size = X.size(0);
|
423 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
424 |
+
|
425 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
|
426 |
+
adj_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
427 |
+
X.data_ptr<scalar_t>(),
|
428 |
+
a.data_ptr<scalar_t>(),
|
429 |
+
b.data_ptr<scalar_t>(),
|
430 |
+
batch_size);
|
431 |
+
}));
|
432 |
+
|
433 |
+
return b;
|
434 |
+
}
|
435 |
+
|
436 |
+
std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
437 |
+
int batch_size = X.size(0);
|
438 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
439 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
440 |
+
|
441 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
|
442 |
+
adj_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
443 |
+
grad.data_ptr<scalar_t>(),
|
444 |
+
X.data_ptr<scalar_t>(),
|
445 |
+
a.data_ptr<scalar_t>(),
|
446 |
+
dX.data_ptr<scalar_t>(),
|
447 |
+
da.data_ptr<scalar_t>(),
|
448 |
+
batch_size);
|
449 |
+
}));
|
450 |
+
|
451 |
+
return {dX, da};
|
452 |
+
}
|
453 |
+
|
454 |
+
|
455 |
+
torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
456 |
+
int batch_size = X.size(0);
|
457 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
458 |
+
|
459 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
|
460 |
+
adjT_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
461 |
+
X.data_ptr<scalar_t>(),
|
462 |
+
a.data_ptr<scalar_t>(),
|
463 |
+
b.data_ptr<scalar_t>(),
|
464 |
+
batch_size);
|
465 |
+
}));
|
466 |
+
|
467 |
+
return b;
|
468 |
+
}
|
469 |
+
|
470 |
+
std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
|
471 |
+
int batch_size = X.size(0);
|
472 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
473 |
+
torch::Tensor da = torch::zeros(a.sizes(), grad.options());
|
474 |
+
|
475 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
|
476 |
+
adjT_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
477 |
+
grad.data_ptr<scalar_t>(),
|
478 |
+
X.data_ptr<scalar_t>(),
|
479 |
+
a.data_ptr<scalar_t>(),
|
480 |
+
dX.data_ptr<scalar_t>(),
|
481 |
+
da.data_ptr<scalar_t>(),
|
482 |
+
batch_size);
|
483 |
+
}));
|
484 |
+
|
485 |
+
return {dX, da};
|
486 |
+
}
|
487 |
+
|
488 |
+
|
489 |
+
|
490 |
+
torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
|
491 |
+
int batch_size = X.size(0);
|
492 |
+
torch::Tensor q = torch::zeros(p.sizes(), p.options());
|
493 |
+
|
494 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
|
495 |
+
act_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
496 |
+
X.data_ptr<scalar_t>(),
|
497 |
+
p.data_ptr<scalar_t>(),
|
498 |
+
q.data_ptr<scalar_t>(),
|
499 |
+
batch_size);
|
500 |
+
}));
|
501 |
+
|
502 |
+
return q;
|
503 |
+
}
|
504 |
+
|
505 |
+
std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
506 |
+
int batch_size = X.size(0);
|
507 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
508 |
+
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
|
509 |
+
|
510 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
|
511 |
+
act_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
512 |
+
grad.data_ptr<scalar_t>(),
|
513 |
+
X.data_ptr<scalar_t>(),
|
514 |
+
p.data_ptr<scalar_t>(),
|
515 |
+
dX.data_ptr<scalar_t>(),
|
516 |
+
dp.data_ptr<scalar_t>(),
|
517 |
+
batch_size);
|
518 |
+
}));
|
519 |
+
|
520 |
+
return {dX, dp};
|
521 |
+
}
|
522 |
+
|
523 |
+
torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
|
524 |
+
int batch_size = X.size(0);
|
525 |
+
torch::Tensor q = torch::zeros(p.sizes(), p.options());
|
526 |
+
|
527 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
|
528 |
+
act4_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
529 |
+
X.data_ptr<scalar_t>(),
|
530 |
+
p.data_ptr<scalar_t>(),
|
531 |
+
q.data_ptr<scalar_t>(),
|
532 |
+
batch_size);
|
533 |
+
}));
|
534 |
+
|
535 |
+
return q;
|
536 |
+
}
|
537 |
+
|
538 |
+
std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
|
539 |
+
int batch_size = X.size(0);
|
540 |
+
torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
|
541 |
+
torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
|
542 |
+
|
543 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
|
544 |
+
act4_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
545 |
+
grad.data_ptr<scalar_t>(),
|
546 |
+
X.data_ptr<scalar_t>(),
|
547 |
+
p.data_ptr<scalar_t>(),
|
548 |
+
dX.data_ptr<scalar_t>(),
|
549 |
+
dp.data_ptr<scalar_t>(),
|
550 |
+
batch_size);
|
551 |
+
}));
|
552 |
+
|
553 |
+
return {dX, dp};
|
554 |
+
}
|
555 |
+
|
556 |
+
|
557 |
+
torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
|
558 |
+
int batch_size = X.size(0);
|
559 |
+
torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
|
560 |
+
|
561 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
|
562 |
+
as_matrix_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
563 |
+
X.data_ptr<scalar_t>(),
|
564 |
+
T4x4.data_ptr<scalar_t>(),
|
565 |
+
batch_size);
|
566 |
+
}));
|
567 |
+
|
568 |
+
return T4x4;
|
569 |
+
}
|
570 |
+
|
571 |
+
|
572 |
+
torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
|
573 |
+
int batch_size = X.size(0);
|
574 |
+
torch::Tensor P;
|
575 |
+
|
576 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
|
577 |
+
P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
|
578 |
+
orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
579 |
+
X.data_ptr<scalar_t>(),
|
580 |
+
P.data_ptr<scalar_t>(),
|
581 |
+
batch_size);
|
582 |
+
}));
|
583 |
+
|
584 |
+
return P;
|
585 |
+
}
|
586 |
+
|
587 |
+
|
588 |
+
torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
|
589 |
+
int batch_size = X.size(0);
|
590 |
+
torch::Tensor b = torch::zeros(a.sizes(), a.options());
|
591 |
+
|
592 |
+
DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
|
593 |
+
jleft_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
|
594 |
+
X.data_ptr<scalar_t>(),
|
595 |
+
a.data_ptr<scalar_t>(),
|
596 |
+
b.data_ptr<scalar_t>(),
|
597 |
+
batch_size);
|
598 |
+
}));
|
599 |
+
|
600 |
+
return b;
|
601 |
+
}
|
mini_dpvo/logger.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from torch.utils.tensorboard import SummaryWriter
|
4 |
+
|
5 |
+
|
6 |
+
SUM_FREQ = 100
|
7 |
+
|
8 |
+
class Logger:
|
9 |
+
def __init__(self, name, scheduler):
|
10 |
+
self.total_steps = 0
|
11 |
+
self.running_loss = {}
|
12 |
+
self.writer = None
|
13 |
+
self.name = name
|
14 |
+
self.scheduler = scheduler
|
15 |
+
|
16 |
+
def _print_training_status(self):
|
17 |
+
if self.writer is None:
|
18 |
+
self.writer = SummaryWriter("runs/{}".format(self.name))
|
19 |
+
print([k for k in self.running_loss])
|
20 |
+
|
21 |
+
lr = self.scheduler.get_lr().pop()
|
22 |
+
metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()]
|
23 |
+
training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr)
|
24 |
+
metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
|
25 |
+
|
26 |
+
# print the training status
|
27 |
+
print(training_str + metrics_str)
|
28 |
+
|
29 |
+
for key in self.running_loss:
|
30 |
+
val = self.running_loss[key] / SUM_FREQ
|
31 |
+
self.writer.add_scalar(key, val, self.total_steps)
|
32 |
+
self.running_loss[key] = 0.0
|
33 |
+
|
34 |
+
def push(self, metrics):
|
35 |
+
|
36 |
+
for key in metrics:
|
37 |
+
if key not in self.running_loss:
|
38 |
+
self.running_loss[key] = 0.0
|
39 |
+
|
40 |
+
self.running_loss[key] += metrics[key]
|
41 |
+
|
42 |
+
if self.total_steps % SUM_FREQ == SUM_FREQ-1:
|
43 |
+
self._print_training_status()
|
44 |
+
self.running_loss = {}
|
45 |
+
|
46 |
+
self.total_steps += 1
|
47 |
+
|
48 |
+
def write_dict(self, results):
|
49 |
+
if self.writer is None:
|
50 |
+
self.writer = SummaryWriter("runs/{}".format(self.name))
|
51 |
+
print([k for k in self.running_loss])
|
52 |
+
|
53 |
+
for key in results:
|
54 |
+
self.writer.add_scalar(key, results[key], self.total_steps)
|
55 |
+
|
56 |
+
def close(self):
|
57 |
+
self.writer.close()
|
58 |
+
|
mini_dpvo/net.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from collections import OrderedDict
|
6 |
+
|
7 |
+
import torch_scatter
|
8 |
+
from torch_scatter import scatter_sum
|
9 |
+
|
10 |
+
from . import fastba
|
11 |
+
from . import altcorr
|
12 |
+
from . import lietorch
|
13 |
+
from .lietorch import SE3
|
14 |
+
|
15 |
+
from .extractor import BasicEncoder, BasicEncoder4
|
16 |
+
from .blocks import GradientClip, GatedResidual, SoftAgg
|
17 |
+
|
18 |
+
from .utils import *
|
19 |
+
from .ba import BA
|
20 |
+
from . import projective_ops as pops
|
21 |
+
|
22 |
+
autocast = torch.cuda.amp.autocast
|
23 |
+
import matplotlib.pyplot as plt
|
24 |
+
|
25 |
+
DIM = 384
|
26 |
+
|
27 |
+
class Update(nn.Module):
|
28 |
+
def __init__(self, p):
|
29 |
+
super(Update, self).__init__()
|
30 |
+
|
31 |
+
self.c1 = nn.Sequential(
|
32 |
+
nn.Linear(DIM, DIM),
|
33 |
+
nn.ReLU(inplace=True),
|
34 |
+
nn.Linear(DIM, DIM))
|
35 |
+
|
36 |
+
self.c2 = nn.Sequential(
|
37 |
+
nn.Linear(DIM, DIM),
|
38 |
+
nn.ReLU(inplace=True),
|
39 |
+
nn.Linear(DIM, DIM))
|
40 |
+
|
41 |
+
self.norm = nn.LayerNorm(DIM, eps=1e-3)
|
42 |
+
|
43 |
+
self.agg_kk = SoftAgg(DIM)
|
44 |
+
self.agg_ij = SoftAgg(DIM)
|
45 |
+
|
46 |
+
self.gru = nn.Sequential(
|
47 |
+
nn.LayerNorm(DIM, eps=1e-3),
|
48 |
+
GatedResidual(DIM),
|
49 |
+
nn.LayerNorm(DIM, eps=1e-3),
|
50 |
+
GatedResidual(DIM),
|
51 |
+
)
|
52 |
+
|
53 |
+
self.corr = nn.Sequential(
|
54 |
+
nn.Linear(2*49*p*p, DIM),
|
55 |
+
nn.ReLU(inplace=True),
|
56 |
+
nn.Linear(DIM, DIM),
|
57 |
+
nn.LayerNorm(DIM, eps=1e-3),
|
58 |
+
nn.ReLU(inplace=True),
|
59 |
+
nn.Linear(DIM, DIM),
|
60 |
+
)
|
61 |
+
|
62 |
+
self.d = nn.Sequential(
|
63 |
+
nn.ReLU(inplace=False),
|
64 |
+
nn.Linear(DIM, 2),
|
65 |
+
GradientClip())
|
66 |
+
|
67 |
+
self.w = nn.Sequential(
|
68 |
+
nn.ReLU(inplace=False),
|
69 |
+
nn.Linear(DIM, 2),
|
70 |
+
GradientClip(),
|
71 |
+
nn.Sigmoid())
|
72 |
+
|
73 |
+
|
74 |
+
def forward(self, net, inp, corr, flow, ii, jj, kk):
|
75 |
+
""" update operator """
|
76 |
+
|
77 |
+
net = net + inp + self.corr(corr)
|
78 |
+
net = self.norm(net)
|
79 |
+
|
80 |
+
ix, jx = fastba.neighbors(kk, jj)
|
81 |
+
mask_ix = (ix >= 0).float().reshape(1, -1, 1)
|
82 |
+
mask_jx = (jx >= 0).float().reshape(1, -1, 1)
|
83 |
+
|
84 |
+
net = net + self.c1(mask_ix * net[:,ix])
|
85 |
+
net = net + self.c2(mask_jx * net[:,jx])
|
86 |
+
|
87 |
+
net = net + self.agg_kk(net, kk)
|
88 |
+
net = net + self.agg_ij(net, ii*12345 + jj)
|
89 |
+
|
90 |
+
net = self.gru(net)
|
91 |
+
|
92 |
+
return net, (self.d(net), self.w(net), None)
|
93 |
+
|
94 |
+
|
95 |
+
class Patchifier(nn.Module):
|
96 |
+
def __init__(self, patch_size=3):
|
97 |
+
super(Patchifier, self).__init__()
|
98 |
+
self.patch_size = patch_size
|
99 |
+
self.fnet = BasicEncoder4(output_dim=128, norm_fn='instance')
|
100 |
+
self.inet = BasicEncoder4(output_dim=DIM, norm_fn='none')
|
101 |
+
|
102 |
+
def __image_gradient(self, images):
|
103 |
+
gray = ((images + 0.5) * (255.0 / 2)).sum(dim=2)
|
104 |
+
dx = gray[...,:-1,1:] - gray[...,:-1,:-1]
|
105 |
+
dy = gray[...,1:,:-1] - gray[...,:-1,:-1]
|
106 |
+
g = torch.sqrt(dx**2 + dy**2)
|
107 |
+
g = F.avg_pool2d(g, 4, 4)
|
108 |
+
return g
|
109 |
+
|
110 |
+
def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False, return_color=False):
|
111 |
+
""" extract patches from input images """
|
112 |
+
fmap = self.fnet(images) / 4.0
|
113 |
+
imap = self.inet(images) / 4.0
|
114 |
+
|
115 |
+
b, n, c, h, w = fmap.shape
|
116 |
+
P = self.patch_size
|
117 |
+
|
118 |
+
# bias patch selection towards regions with high gradient
|
119 |
+
if gradient_bias:
|
120 |
+
g = self.__image_gradient(images)
|
121 |
+
x = torch.randint(1, w-1, size=[n, 3*patches_per_image], device="cuda")
|
122 |
+
y = torch.randint(1, h-1, size=[n, 3*patches_per_image], device="cuda")
|
123 |
+
|
124 |
+
coords = torch.stack([x, y], dim=-1).float()
|
125 |
+
g = altcorr.patchify(g[0,:,None], coords, 0).view(n, 3 * patches_per_image)
|
126 |
+
|
127 |
+
ix = torch.argsort(g, dim=1)
|
128 |
+
x = torch.gather(x, 1, ix[:, -patches_per_image:])
|
129 |
+
y = torch.gather(y, 1, ix[:, -patches_per_image:])
|
130 |
+
|
131 |
+
else:
|
132 |
+
x = torch.randint(1, w-1, size=[n, patches_per_image], device="cuda")
|
133 |
+
y = torch.randint(1, h-1, size=[n, patches_per_image], device="cuda")
|
134 |
+
|
135 |
+
coords = torch.stack([x, y], dim=-1).float()
|
136 |
+
imap = altcorr.patchify(imap[0], coords, 0).view(b, -1, DIM, 1, 1)
|
137 |
+
gmap = altcorr.patchify(fmap[0], coords, P//2).view(b, -1, 128, P, P)
|
138 |
+
|
139 |
+
if return_color:
|
140 |
+
clr = altcorr.patchify(images[0], 4*(coords + 0.5), 0).view(b, -1, 3)
|
141 |
+
|
142 |
+
if disps is None:
|
143 |
+
disps = torch.ones(b, n, h, w, device="cuda")
|
144 |
+
|
145 |
+
grid, _ = coords_grid_with_index(disps, device=fmap.device)
|
146 |
+
patches = altcorr.patchify(grid[0], coords, P//2).view(b, -1, 3, P, P)
|
147 |
+
|
148 |
+
index = torch.arange(n, device="cuda").view(n, 1)
|
149 |
+
index = index.repeat(1, patches_per_image).reshape(-1)
|
150 |
+
|
151 |
+
if return_color:
|
152 |
+
return fmap, gmap, imap, patches, index, clr
|
153 |
+
|
154 |
+
return fmap, gmap, imap, patches, index
|
155 |
+
|
156 |
+
|
157 |
+
class CorrBlock:
|
158 |
+
def __init__(self, fmap, gmap, radius=3, dropout=0.2, levels=[1,4]):
|
159 |
+
self.dropout = dropout
|
160 |
+
self.radius = radius
|
161 |
+
self.levels = levels
|
162 |
+
|
163 |
+
self.gmap = gmap
|
164 |
+
self.pyramid = pyramidify(fmap, lvls=levels)
|
165 |
+
|
166 |
+
def __call__(self, ii, jj, coords):
|
167 |
+
corrs = []
|
168 |
+
for i in range(len(self.levels)):
|
169 |
+
corrs += [ altcorr.corr(self.gmap, self.pyramid[i], coords / self.levels[i], ii, jj, self.radius, self.dropout) ]
|
170 |
+
return torch.stack(corrs, -1).view(1, len(ii), -1)
|
171 |
+
|
172 |
+
|
173 |
+
class VONet(nn.Module):
|
174 |
+
def __init__(self, use_viewer=False):
|
175 |
+
super(VONet, self).__init__()
|
176 |
+
self.P = 3
|
177 |
+
self.patchify = Patchifier(self.P)
|
178 |
+
self.update = Update(self.P)
|
179 |
+
|
180 |
+
self.DIM = DIM
|
181 |
+
self.RES = 4
|
182 |
+
|
183 |
+
|
184 |
+
@autocast(enabled=False)
|
185 |
+
def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, structure_only=False, rescale=False):
|
186 |
+
""" Estimates SE3 or Sim3 between pair of frames """
|
187 |
+
|
188 |
+
images = 2 * (images / 255.0) - 0.5
|
189 |
+
intrinsics = intrinsics / 4.0
|
190 |
+
disps = disps[:, :, 1::4, 1::4].float()
|
191 |
+
|
192 |
+
fmap, gmap, imap, patches, ix = self.patchify(images, disps=disps)
|
193 |
+
|
194 |
+
corr_fn = CorrBlock(fmap, gmap)
|
195 |
+
|
196 |
+
b, N, c, h, w = fmap.shape
|
197 |
+
p = self.P
|
198 |
+
|
199 |
+
patches_gt = patches.clone()
|
200 |
+
Ps = poses
|
201 |
+
|
202 |
+
d = patches[..., 2, p//2, p//2]
|
203 |
+
patches = set_depth(patches, torch.rand_like(d))
|
204 |
+
|
205 |
+
kk, jj = flatmeshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda"))
|
206 |
+
ii = ix[kk]
|
207 |
+
|
208 |
+
imap = imap.view(b, -1, DIM)
|
209 |
+
net = torch.zeros(b, len(kk), DIM, device="cuda", dtype=torch.float)
|
210 |
+
|
211 |
+
Gs = SE3.IdentityLike(poses)
|
212 |
+
|
213 |
+
if structure_only:
|
214 |
+
Gs.data[:] = poses.data[:]
|
215 |
+
|
216 |
+
traj = []
|
217 |
+
bounds = [-64, -64, w + 64, h + 64]
|
218 |
+
|
219 |
+
while len(traj) < STEPS:
|
220 |
+
Gs = Gs.detach()
|
221 |
+
patches = patches.detach()
|
222 |
+
|
223 |
+
n = ii.max() + 1
|
224 |
+
if len(traj) >= 8 and n < images.shape[1]:
|
225 |
+
if not structure_only: Gs.data[:,n] = Gs.data[:,n-1]
|
226 |
+
kk1, jj1 = flatmeshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda"))
|
227 |
+
kk2, jj2 = flatmeshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda"))
|
228 |
+
|
229 |
+
ii = torch.cat([ix[kk1], ix[kk2], ii])
|
230 |
+
jj = torch.cat([jj1, jj2, jj])
|
231 |
+
kk = torch.cat([kk1, kk2, kk])
|
232 |
+
|
233 |
+
net1 = torch.zeros(b, len(kk1) + len(kk2), DIM, device="cuda")
|
234 |
+
net = torch.cat([net1, net], dim=1)
|
235 |
+
|
236 |
+
if np.random.rand() < 0.1:
|
237 |
+
k = (ii != (n - 4)) & (jj != (n - 4))
|
238 |
+
ii = ii[k]
|
239 |
+
jj = jj[k]
|
240 |
+
kk = kk[k]
|
241 |
+
net = net[:,k]
|
242 |
+
|
243 |
+
patches[:,ix==n,2] = torch.median(patches[:,(ix == n-1) | (ix == n-2),2])
|
244 |
+
n = ii.max() + 1
|
245 |
+
|
246 |
+
coords = pops.transform(Gs, patches, intrinsics, ii, jj, kk)
|
247 |
+
coords1 = coords.permute(0, 1, 4, 2, 3).contiguous()
|
248 |
+
|
249 |
+
corr = corr_fn(kk, jj, coords1)
|
250 |
+
net, (delta, weight, _) = self.update(net, imap[:,kk], corr, None, ii, jj, kk)
|
251 |
+
|
252 |
+
lmbda = 1e-4
|
253 |
+
target = coords[...,p//2,p//2,:] + delta
|
254 |
+
|
255 |
+
ep = 10
|
256 |
+
for itr in range(2):
|
257 |
+
Gs, patches = BA(Gs, patches, intrinsics, target, weight, lmbda, ii, jj, kk,
|
258 |
+
bounds, ep=ep, fixedp=1, structure_only=structure_only)
|
259 |
+
|
260 |
+
kl = torch.as_tensor(0)
|
261 |
+
dij = (ii - jj).abs()
|
262 |
+
k = (dij > 0) & (dij <= 2)
|
263 |
+
|
264 |
+
coords = pops.transform(Gs, patches, intrinsics, ii[k], jj[k], kk[k])
|
265 |
+
coords_gt, valid, _ = pops.transform(Ps, patches_gt, intrinsics, ii[k], jj[k], kk[k], jacobian=True)
|
266 |
+
|
267 |
+
traj.append((valid, coords, coords_gt, Gs[:,:n], Ps[:,:n], kl))
|
268 |
+
|
269 |
+
return traj
|
270 |
+
|
mini_dpvo/plot_utils.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy
|
2 |
+
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
from evo.core import sync
|
6 |
+
from evo.core.trajectory import PoseTrajectory3D
|
7 |
+
# from evo.tools import plot
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
|
11 |
+
def make_traj(args) -> PoseTrajectory3D:
|
12 |
+
if isinstance(args, tuple):
|
13 |
+
traj, tstamps = args
|
14 |
+
return PoseTrajectory3D(positions_xyz=traj[:,:3], orientations_quat_wxyz=traj[:,3:], timestamps=tstamps)
|
15 |
+
assert isinstance(args, PoseTrajectory3D), type(args)
|
16 |
+
return deepcopy(args)
|
17 |
+
|
18 |
+
def best_plotmode(traj):
|
19 |
+
_, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0))
|
20 |
+
plot_axes = "xyz"[i2] + "xyz"[i1]
|
21 |
+
return getattr(plot.PlotMode, plot_axes)
|
22 |
+
|
23 |
+
def plot_trajectory(pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True):
|
24 |
+
pred_traj = make_traj(pred_traj)
|
25 |
+
|
26 |
+
if gt_traj is not None:
|
27 |
+
gt_traj = make_traj(gt_traj)
|
28 |
+
gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
|
29 |
+
|
30 |
+
if align:
|
31 |
+
pred_traj.align(gt_traj, correct_scale=correct_scale)
|
32 |
+
|
33 |
+
plot_collection = plot.PlotCollection("PlotCol")
|
34 |
+
fig = plt.figure(figsize=(8, 8))
|
35 |
+
plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj)
|
36 |
+
ax = plot.prepare_axis(fig, plot_mode)
|
37 |
+
ax.set_title(title)
|
38 |
+
if gt_traj is not None:
|
39 |
+
plot.traj(ax, plot_mode, gt_traj, '--', 'gray', "Ground Truth")
|
40 |
+
plot.traj(ax, plot_mode, pred_traj, '-', 'blue', "Predicted")
|
41 |
+
plot_collection.add_figure("traj (error)", fig)
|
42 |
+
plot_collection.export(filename, confirm_overwrite=False)
|
43 |
+
plt.close(fig=fig)
|
44 |
+
print(f"Saved {filename}")
|
45 |
+
|
46 |
+
def save_trajectory_tum_format(traj, filename):
|
47 |
+
traj = make_traj(traj)
|
48 |
+
tostr = lambda a: ' '.join(map(str, a))
|
49 |
+
with Path(filename).open('w') as f:
|
50 |
+
for i in range(traj.num_poses):
|
51 |
+
f.write(f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[1,2,3,0]])}\n")
|
52 |
+
print(f"Saved {filename}")
|
mini_dpvo/projective_ops.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from .lietorch import SE3, Sim3
|
5 |
+
|
6 |
+
MIN_DEPTH = 0.2
|
7 |
+
|
8 |
+
def extract_intrinsics(intrinsics):
|
9 |
+
return intrinsics[...,None,None,:].unbind(dim=-1)
|
10 |
+
|
11 |
+
def coords_grid(ht, wd, **kwargs):
|
12 |
+
y, x = torch.meshgrid(
|
13 |
+
torch.arange(ht).to(**kwargs).float(),
|
14 |
+
torch.arange(wd).to(**kwargs).float())
|
15 |
+
|
16 |
+
return torch.stack([x, y], dim=-1)
|
17 |
+
|
18 |
+
|
19 |
+
def iproj(patches, intrinsics):
|
20 |
+
""" inverse projection """
|
21 |
+
x, y, d = patches.unbind(dim=2)
|
22 |
+
fx, fy, cx, cy = intrinsics[...,None,None].unbind(dim=2)
|
23 |
+
|
24 |
+
i = torch.ones_like(d)
|
25 |
+
xn = (x - cx) / fx
|
26 |
+
yn = (y - cy) / fy
|
27 |
+
|
28 |
+
X = torch.stack([xn, yn, i, d], dim=-1)
|
29 |
+
return X
|
30 |
+
|
31 |
+
|
32 |
+
def proj(X, intrinsics, depth=False):
|
33 |
+
""" projection """
|
34 |
+
|
35 |
+
X, Y, Z, W = X.unbind(dim=-1)
|
36 |
+
fx, fy, cx, cy = intrinsics[...,None,None].unbind(dim=2)
|
37 |
+
|
38 |
+
# d = 0.01 * torch.ones_like(Z)
|
39 |
+
# d[Z > 0.01] = 1.0 / Z[Z > 0.01]
|
40 |
+
# d = torch.ones_like(Z)
|
41 |
+
# d[Z.abs() > 0.1] = 1.0 / Z[Z.abs() > 0.1]
|
42 |
+
|
43 |
+
d = 1.0 / Z.clamp(min=0.1)
|
44 |
+
x = fx * (d * X) + cx
|
45 |
+
y = fy * (d * Y) + cy
|
46 |
+
|
47 |
+
if depth:
|
48 |
+
return torch.stack([x, y, d], dim=-1)
|
49 |
+
|
50 |
+
return torch.stack([x, y], dim=-1)
|
51 |
+
|
52 |
+
|
53 |
+
def transform(poses, patches, intrinsics, ii, jj, kk, depth=False, valid=False, jacobian=False, tonly=False):
|
54 |
+
""" projective transform """
|
55 |
+
|
56 |
+
# backproject
|
57 |
+
X0 = iproj(patches[:,kk], intrinsics[:,ii])
|
58 |
+
|
59 |
+
# transform
|
60 |
+
Gij = poses[:, jj] * poses[:, ii].inv()
|
61 |
+
|
62 |
+
if tonly:
|
63 |
+
Gij[...,3:] = torch.as_tensor([0,0,0,1], device=Gij.device)
|
64 |
+
|
65 |
+
X1 = Gij[:,:,None,None] * X0
|
66 |
+
|
67 |
+
# project
|
68 |
+
x1 = proj(X1, intrinsics[:,jj], depth)
|
69 |
+
|
70 |
+
|
71 |
+
if jacobian:
|
72 |
+
p = X1.shape[2]
|
73 |
+
X, Y, Z, H = X1[...,p//2,p//2,:].unbind(dim=-1)
|
74 |
+
o = torch.zeros_like(H)
|
75 |
+
i = torch.zeros_like(H)
|
76 |
+
|
77 |
+
fx, fy, cx, cy = intrinsics[:,jj].unbind(dim=-1)
|
78 |
+
|
79 |
+
d = torch.zeros_like(Z)
|
80 |
+
d[Z.abs() > 0.2] = 1.0 / Z[Z.abs() > 0.2]
|
81 |
+
|
82 |
+
Ja = torch.stack([
|
83 |
+
H, o, o, o, Z, -Y,
|
84 |
+
o, H, o, -Z, o, X,
|
85 |
+
o, o, H, Y, -X, o,
|
86 |
+
o, o, o, o, o, o,
|
87 |
+
], dim=-1).view(1, len(ii), 4, 6)
|
88 |
+
|
89 |
+
Jp = torch.stack([
|
90 |
+
fx*d, o, -fx*X*d*d, o,
|
91 |
+
o, fy*d, -fy*Y*d*d, o,
|
92 |
+
], dim=-1).view(1, len(ii), 2, 4)
|
93 |
+
|
94 |
+
Jj = torch.matmul(Jp, Ja)
|
95 |
+
Ji = -Gij[:,:,None].adjT(Jj)
|
96 |
+
|
97 |
+
Jz = torch.matmul(Jp, Gij.matrix()[...,:,3:])
|
98 |
+
|
99 |
+
return x1, (Z > 0.2).float(), (Ji, Jj, Jz)
|
100 |
+
|
101 |
+
if valid:
|
102 |
+
return x1, (X1[...,2] > 0.2).float()
|
103 |
+
|
104 |
+
return x1
|
105 |
+
|
106 |
+
def point_cloud(poses, patches, intrinsics, ix):
|
107 |
+
""" generate point cloud from patches """
|
108 |
+
return poses[:,ix,None,None].inv() * iproj(patches, intrinsics[:,ix])
|
109 |
+
|
110 |
+
|
111 |
+
def flow_mag(poses, patches, intrinsics, ii, jj, kk, beta=0.3):
|
112 |
+
""" projective transform """
|
113 |
+
|
114 |
+
coords0 = transform(poses, patches, intrinsics, ii, ii, kk)
|
115 |
+
coords1 = transform(poses, patches, intrinsics, ii, jj, kk, tonly=False)
|
116 |
+
coords2 = transform(poses, patches, intrinsics, ii, jj, kk, tonly=True)
|
117 |
+
|
118 |
+
flow1 = (coords1 - coords0).norm(dim=-1)
|
119 |
+
flow2 = (coords2 - coords0).norm(dim=-1)
|
120 |
+
|
121 |
+
return beta * flow1 + (1-beta) * flow2
|
mini_dpvo/stream.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from pathlib import Path
|
4 |
+
from itertools import chain
|
5 |
+
from multiprocessing import Queue
|
6 |
+
|
7 |
+
|
8 |
+
def image_stream(
|
9 |
+
queue: Queue, imagedir: str, calib: str, stride: int, skip: int = 0
|
10 |
+
) -> None:
|
11 |
+
"""image generator"""
|
12 |
+
|
13 |
+
calib = np.loadtxt(calib, delimiter=" ")
|
14 |
+
fx, fy, cx, cy = calib[:4]
|
15 |
+
|
16 |
+
K = np.eye(3)
|
17 |
+
K[0, 0] = fx
|
18 |
+
K[0, 2] = cx
|
19 |
+
K[1, 1] = fy
|
20 |
+
K[1, 2] = cy
|
21 |
+
|
22 |
+
img_exts = ["*.png", "*.jpeg", "*.jpg"]
|
23 |
+
image_list = sorted(chain.from_iterable(Path(imagedir).glob(e) for e in img_exts))[
|
24 |
+
skip::stride
|
25 |
+
]
|
26 |
+
|
27 |
+
for t, imfile in enumerate(image_list):
|
28 |
+
image = cv2.imread(str(imfile))
|
29 |
+
if len(calib) > 4:
|
30 |
+
image = cv2.undistort(image, K, calib[4:])
|
31 |
+
|
32 |
+
if 0:
|
33 |
+
image = cv2.resize(image, None, fx=0.5, fy=0.5)
|
34 |
+
intrinsics = np.array([fx / 2, fy / 2, cx / 2, cy / 2])
|
35 |
+
|
36 |
+
else:
|
37 |
+
intrinsics = np.array([fx, fy, cx, cy])
|
38 |
+
|
39 |
+
h, w, _ = image.shape
|
40 |
+
image = image[: h - h % 16, : w - w % 16]
|
41 |
+
|
42 |
+
queue.put((t, image, intrinsics))
|
43 |
+
|
44 |
+
queue.put((-1, image, intrinsics))
|
45 |
+
|
46 |
+
|
47 |
+
def video_stream(
|
48 |
+
queue: Queue, imagedir: str, calib: str, stride: int, skip: int = 0
|
49 |
+
) -> None:
|
50 |
+
"""video generator"""
|
51 |
+
|
52 |
+
calib = np.loadtxt(calib, delimiter=" ")
|
53 |
+
fx, fy, cx, cy = calib[:4]
|
54 |
+
|
55 |
+
K = np.eye(3)
|
56 |
+
K[0, 0] = fx
|
57 |
+
K[0, 2] = cx
|
58 |
+
K[1, 1] = fy
|
59 |
+
K[1, 2] = cy
|
60 |
+
|
61 |
+
cap = cv2.VideoCapture(imagedir)
|
62 |
+
|
63 |
+
t = 0
|
64 |
+
|
65 |
+
for _ in range(skip):
|
66 |
+
ret, image = cap.read()
|
67 |
+
|
68 |
+
while True:
|
69 |
+
# Capture frame-by-frame
|
70 |
+
for _ in range(stride):
|
71 |
+
ret, image = cap.read()
|
72 |
+
# if frame is read correctly ret is True
|
73 |
+
if not ret:
|
74 |
+
break
|
75 |
+
|
76 |
+
if not ret:
|
77 |
+
break
|
78 |
+
|
79 |
+
if len(calib) > 4:
|
80 |
+
image = cv2.undistort(image, K, calib[4:])
|
81 |
+
|
82 |
+
image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)
|
83 |
+
h, w, _ = image.shape
|
84 |
+
image = image[: h - h % 16, : w - w % 16]
|
85 |
+
|
86 |
+
intrinsics = np.array([fx * 0.5, fy * 0.5, cx * 0.5, cy * 0.5])
|
87 |
+
queue.put((t, image, intrinsics))
|
88 |
+
|
89 |
+
t += 1
|
90 |
+
|
91 |
+
queue.put((-1, image, intrinsics))
|
92 |
+
cap.release()
|
mini_dpvo/utils.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
|
5 |
+
all_times = []
|
6 |
+
|
7 |
+
|
8 |
+
class Timer:
|
9 |
+
def __init__(self, name: str, enabled: bool = True):
|
10 |
+
self.name = name
|
11 |
+
self.enabled = enabled
|
12 |
+
|
13 |
+
if self.enabled:
|
14 |
+
self.start = torch.cuda.Event(enable_timing=True)
|
15 |
+
self.end = torch.cuda.Event(enable_timing=True)
|
16 |
+
|
17 |
+
def __enter__(self):
|
18 |
+
if self.enabled:
|
19 |
+
self.start.record()
|
20 |
+
|
21 |
+
def __exit__(self, type, value, traceback):
|
22 |
+
global all_times
|
23 |
+
if self.enabled:
|
24 |
+
self.end.record()
|
25 |
+
torch.cuda.synchronize()
|
26 |
+
|
27 |
+
elapsed = self.start.elapsed_time(self.end)
|
28 |
+
all_times.append(elapsed)
|
29 |
+
print(f"{self.name}: {elapsed:.2f}ms")
|
30 |
+
|
31 |
+
|
32 |
+
def coords_grid(b, n, h, w, **kwargs):
|
33 |
+
"""coordinate grid"""
|
34 |
+
x = torch.arange(0, w, dtype=torch.float, **kwargs)
|
35 |
+
y = torch.arange(0, h, dtype=torch.float, **kwargs)
|
36 |
+
coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
|
37 |
+
return coords[[1, 0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)
|
38 |
+
|
39 |
+
|
40 |
+
def coords_grid_with_index(d, **kwargs):
|
41 |
+
"""coordinate grid with frame index"""
|
42 |
+
b, n, h, w = d.shape
|
43 |
+
i = torch.ones_like(d)
|
44 |
+
x = torch.arange(0, w, dtype=torch.float, **kwargs)
|
45 |
+
y = torch.arange(0, h, dtype=torch.float, **kwargs)
|
46 |
+
|
47 |
+
y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
|
48 |
+
y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
|
49 |
+
x = x.view(1, 1, h, w).repeat(b, n, 1, 1)
|
50 |
+
|
51 |
+
coords = torch.stack([x, y, d], dim=2)
|
52 |
+
index = torch.arange(0, n, dtype=torch.float, **kwargs)
|
53 |
+
index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w)
|
54 |
+
|
55 |
+
return coords, index
|
56 |
+
|
57 |
+
|
58 |
+
def patchify(x, patch_size=3):
|
59 |
+
"""extract patches from video"""
|
60 |
+
b, n, c, h, w = x.shape
|
61 |
+
x = x.view(b * n, c, h, w)
|
62 |
+
y = F.unfold(x, patch_size)
|
63 |
+
y = y.transpose(1, 2)
|
64 |
+
return y.reshape(b, -1, c, patch_size, patch_size)
|
65 |
+
|
66 |
+
|
67 |
+
def pyramidify(fmap, lvls=[1]):
|
68 |
+
"""turn fmap into a pyramid"""
|
69 |
+
b, n, c, h, w = fmap.shape
|
70 |
+
|
71 |
+
pyramid = []
|
72 |
+
for lvl in lvls:
|
73 |
+
gmap = F.avg_pool2d(fmap.view(b * n, c, h, w), lvl, stride=lvl)
|
74 |
+
pyramid += [gmap.view(b, n, c, h // lvl, w // lvl)]
|
75 |
+
|
76 |
+
return pyramid
|
77 |
+
|
78 |
+
|
79 |
+
def all_pairs_exclusive(n, **kwargs):
|
80 |
+
ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
|
81 |
+
k = ii != jj
|
82 |
+
return ii[k].reshape(-1), jj[k].reshape(-1)
|
83 |
+
|
84 |
+
|
85 |
+
def set_depth(patches, depth):
|
86 |
+
patches[..., 2, :, :] = depth[..., None, None]
|
87 |
+
return patches
|
88 |
+
|
89 |
+
|
90 |
+
def flatmeshgrid(*args, **kwargs):
|
91 |
+
grid = torch.meshgrid(*args, **kwargs)
|
92 |
+
return (x.reshape(-1) for x in grid)
|