pablovela5620 commited on
Commit
899c526
1 Parent(s): e32c92e

initial commit with working dpvo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +166 -1
  2. config/default.yaml +19 -0
  3. config/fast.yaml +20 -0
  4. mini_dpvo/__init__.py +0 -0
  5. mini_dpvo/altcorr/__init__.py +1 -0
  6. mini_dpvo/altcorr/correlation.cpp +63 -0
  7. mini_dpvo/altcorr/correlation.py +74 -0
  8. mini_dpvo/altcorr/correlation_kernel.cu +333 -0
  9. mini_dpvo/api/__init__.py +0 -0
  10. mini_dpvo/api/inference.py +190 -0
  11. mini_dpvo/ba.py +182 -0
  12. mini_dpvo/blocks.py +118 -0
  13. mini_dpvo/config.py +27 -0
  14. mini_dpvo/data_readers/__init__.py +1 -0
  15. mini_dpvo/data_readers/augmentation.py +66 -0
  16. mini_dpvo/data_readers/base.py +176 -0
  17. mini_dpvo/data_readers/factory.py +26 -0
  18. mini_dpvo/data_readers/frame_utils.py +164 -0
  19. mini_dpvo/data_readers/rgbd_utils.py +188 -0
  20. mini_dpvo/data_readers/tartan.py +110 -0
  21. mini_dpvo/data_readers/tartan_test.txt +32 -0
  22. mini_dpvo/dpvo.py +410 -0
  23. mini_dpvo/extractor.py +264 -0
  24. mini_dpvo/fastba/__init__.py +1 -0
  25. mini_dpvo/fastba/ba.cpp +157 -0
  26. mini_dpvo/fastba/ba.py +8 -0
  27. mini_dpvo/fastba/ba_cuda.cu +575 -0
  28. mini_dpvo/lietorch/__init__.py +2 -0
  29. mini_dpvo/lietorch/broadcasting.py +31 -0
  30. mini_dpvo/lietorch/gradcheck.py +592 -0
  31. mini_dpvo/lietorch/group_ops.py +102 -0
  32. mini_dpvo/lietorch/groups.py +322 -0
  33. mini_dpvo/lietorch/include/common.h +12 -0
  34. mini_dpvo/lietorch/include/dispatch.h +48 -0
  35. mini_dpvo/lietorch/include/lietorch_cpu.h +51 -0
  36. mini_dpvo/lietorch/include/lietorch_gpu.h +51 -0
  37. mini_dpvo/lietorch/include/rxso3.h +324 -0
  38. mini_dpvo/lietorch/include/se3.h +229 -0
  39. mini_dpvo/lietorch/include/sim3.h +217 -0
  40. mini_dpvo/lietorch/include/so3.h +229 -0
  41. mini_dpvo/lietorch/run_tests.py +302 -0
  42. mini_dpvo/lietorch/src/lietorch.cpp +317 -0
  43. mini_dpvo/lietorch/src/lietorch_cpu.cpp +657 -0
  44. mini_dpvo/lietorch/src/lietorch_gpu.cu +601 -0
  45. mini_dpvo/logger.py +58 -0
  46. mini_dpvo/net.py +270 -0
  47. mini_dpvo/plot_utils.py +52 -0
  48. mini_dpvo/projective_ops.py +121 -0
  49. mini_dpvo/stream.py +92 -0
  50. mini_dpvo/utils.py +92 -0
.gitignore CHANGED
@@ -1,4 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # pixi environments
2
  .pixi
3
  *.egg-info
4
-
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
110
+ .pdm.toml
111
+ .pdm-python
112
+ .pdm-build/
113
+
114
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
115
+ __pypackages__/
116
+
117
+ # Celery stuff
118
+ celerybeat-schedule
119
+ celerybeat.pid
120
+
121
+ # SageMath parsed files
122
+ *.sage.py
123
+
124
+ # Environments
125
+ .env
126
+ .venv
127
+ env/
128
+ venv/
129
+ ENV/
130
+ env.bak/
131
+ venv.bak/
132
+
133
+ # Spyder project settings
134
+ .spyderproject
135
+ .spyproject
136
+
137
+ # Rope project settings
138
+ .ropeproject
139
+
140
+ # mkdocs documentation
141
+ /site
142
+
143
+ # mypy
144
+ .mypy_cache/
145
+ .dmypy.json
146
+ dmypy.json
147
+
148
+ # Pyre type checker
149
+ .pyre/
150
+
151
+ # pytype static type analyzer
152
+ .pytype/
153
+
154
+ # Cython debug symbols
155
+ cython_debug/
156
+
157
+ # PyCharm
158
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
159
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
160
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
161
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
+ #.idea/
163
+
164
  # pixi environments
165
  .pixi
166
  *.egg-info
167
+ thirdparty/*
168
+ data/*
169
+ checkpoints/*
config/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### DPVO Config File ###
2
+
3
+ # VO config (increase for better accuracy)
4
+ PATCHES_PER_FRAME: 96
5
+ REMOVAL_WINDOW: 22
6
+ OPTIMIZATION_WINDOW: 10
7
+ PATCH_LIFETIME: 13
8
+
9
+ # threshold for keyframe removal
10
+ KEYFRAME_THRESH: 15.0
11
+
12
+ # camera motion model
13
+ MOTION_MODEL: 'DAMPED_LINEAR'
14
+ MOTION_DAMPING: 0.5
15
+
16
+ # maybe use mixed precision for inference
17
+ MIXED_PRECISION: True
18
+
19
+ GRADIENT_BIAS: False
config/fast.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ ### DPVO Config File ###
3
+
4
+ # VO config (increase for better accuracy)
5
+ PATCHES_PER_FRAME: 48
6
+ REMOVAL_WINDOW: 16
7
+ OPTIMIZATION_WINDOW: 7
8
+ PATCH_LIFETIME: 11
9
+
10
+ # threshold for keyframe removal
11
+ KEYFRAME_THRESH: 15.0
12
+
13
+ # camera motion model
14
+ MOTION_MODEL: 'DAMPED_LINEAR'
15
+ MOTION_DAMPING: 0.5
16
+
17
+ # maybe use mixed precision for inference
18
+ MIXED_PRECISION: True
19
+
20
+ GRADIENT_BIAS: False
mini_dpvo/__init__.py ADDED
File without changes
mini_dpvo/altcorr/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .correlation import corr, patchify
mini_dpvo/altcorr/correlation.cpp ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+
4
+ // CUDA forward declarations
5
+ std::vector<torch::Tensor> corr_cuda_forward(
6
+ torch::Tensor fmap1,
7
+ torch::Tensor fmap2,
8
+ torch::Tensor coords,
9
+ torch::Tensor ii,
10
+ torch::Tensor jj,
11
+ int radius);
12
+
13
+ std::vector<torch::Tensor> corr_cuda_backward(
14
+ torch::Tensor fmap1,
15
+ torch::Tensor fmap2,
16
+ torch::Tensor coords,
17
+ torch::Tensor ii,
18
+ torch::Tensor jj,
19
+ torch::Tensor corr_grad,
20
+ int radius);
21
+
22
+ std::vector<torch::Tensor> patchify_cuda_forward(
23
+ torch::Tensor net, torch::Tensor coords, int radius);
24
+
25
+ std::vector<torch::Tensor> patchify_cuda_backward(
26
+ torch::Tensor net, torch::Tensor coords, torch::Tensor gradient, int radius);
27
+
28
+ std::vector<torch::Tensor> corr_forward(
29
+ torch::Tensor fmap1,
30
+ torch::Tensor fmap2,
31
+ torch::Tensor coords,
32
+ torch::Tensor ii,
33
+ torch::Tensor jj, int radius) {
34
+ return corr_cuda_forward(fmap1, fmap2, coords, ii, jj, radius);
35
+ }
36
+
37
+ std::vector<torch::Tensor> corr_backward(
38
+ torch::Tensor fmap1,
39
+ torch::Tensor fmap2,
40
+ torch::Tensor coords,
41
+ torch::Tensor ii,
42
+ torch::Tensor jj,
43
+ torch::Tensor corr_grad, int radius) {
44
+ return corr_cuda_backward(fmap1, fmap2, coords, ii, jj, corr_grad, radius);
45
+ }
46
+
47
+ std::vector<torch::Tensor> patchify_forward(
48
+ torch::Tensor net, torch::Tensor coords, int radius) {
49
+ return patchify_cuda_forward(net, coords, radius);
50
+ }
51
+
52
+ std::vector<torch::Tensor> patchify_backward(
53
+ torch::Tensor net, torch::Tensor coords, torch::Tensor gradient, int radius) {
54
+ return patchify_cuda_backward(net, coords, gradient, radius);
55
+ }
56
+
57
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
58
+ m.def("forward", &corr_forward, "CORR forward");
59
+ m.def("backward", &corr_backward, "CORR backward");
60
+
61
+ m.def("patchify_forward", &patchify_forward, "PATCHIFY forward");
62
+ m.def("patchify_backward", &patchify_backward, "PATCHIFY backward");
63
+ }
mini_dpvo/altcorr/correlation.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cuda_corr
3
+
4
+ class CorrLayer(torch.autograd.Function):
5
+ @staticmethod
6
+ def forward(ctx, fmap1, fmap2, coords, ii, jj, radius, dropout):
7
+ """ forward correlation """
8
+ ctx.save_for_backward(fmap1, fmap2, coords, ii, jj)
9
+ ctx.radius = radius
10
+ ctx.dropout = dropout
11
+ corr, = cuda_corr.forward(fmap1, fmap2, coords, ii, jj, radius)
12
+
13
+ return corr
14
+
15
+ @staticmethod
16
+ def backward(ctx, grad):
17
+ """ backward correlation """
18
+ fmap1, fmap2, coords, ii, jj = ctx.saved_tensors
19
+
20
+ if ctx.dropout < 1:
21
+ perm = torch.rand(len(ii), device="cuda") < ctx.dropout
22
+ coords = coords[:,perm]
23
+ grad = grad[:,perm]
24
+ ii = ii[perm]
25
+ jj = jj[perm]
26
+
27
+ fmap1_grad, fmap2_grad = \
28
+ cuda_corr.backward(fmap1, fmap2, coords, ii, jj, grad, ctx.radius)
29
+
30
+ return fmap1_grad, fmap2_grad, None, None, None, None, None
31
+
32
+
33
+ class PatchLayer(torch.autograd.Function):
34
+ @staticmethod
35
+ def forward(ctx, net, coords, radius):
36
+ """ forward patchify """
37
+ ctx.radius = radius
38
+ ctx.save_for_backward(net, coords)
39
+
40
+ patches, = cuda_corr.patchify_forward(net, coords, radius)
41
+ return patches
42
+
43
+ @staticmethod
44
+ def backward(ctx, grad):
45
+ """ backward patchify """
46
+ net, coords = ctx.saved_tensors
47
+ grad, = cuda_corr.patchify_backward(net, coords, grad, ctx.radius)
48
+
49
+ return grad, None, None
50
+
51
+ def patchify(net, coords, radius, mode='bilinear'):
52
+ """ extract patches """
53
+
54
+ patches = PatchLayer.apply(net, coords, radius)
55
+
56
+ if mode == 'bilinear':
57
+ offset = (coords - coords.floor()).to(net.device)
58
+ dx, dy = offset[:,:,None,None,None].unbind(dim=-1)
59
+
60
+ d = 2 * radius + 1
61
+ x00 = (1-dy) * (1-dx) * patches[...,:d,:d]
62
+ x01 = (1-dy) * ( dx) * patches[...,:d,1:]
63
+ x10 = ( dy) * (1-dx) * patches[...,1:,:d]
64
+ x11 = ( dy) * ( dx) * patches[...,1:,1:]
65
+
66
+ return x00 + x01 + x10 + x11
67
+
68
+ return patches
69
+
70
+
71
+ def corr(fmap1, fmap2, coords, ii, jj, radius=1, dropout=1):
72
+ return CorrLayer.apply(fmap1, fmap2, coords, ii, jj, radius, dropout)
73
+
74
+
mini_dpvo/altcorr/correlation_kernel.cu ADDED
@@ -0,0 +1,333 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <THC/THCAtomics.cuh>
3
+ #include <vector>
4
+ #include <iostream>
5
+
6
+ using namespace torch::indexing;
7
+
8
+ #define THREADS 256
9
+ #define BLOCKS(n) (n + THREADS - 1) / THREADS
10
+
11
+ __forceinline__ __device__
12
+ bool within_bounds(int h, int w, int H, int W) {
13
+ return h >= 0 && h < H && w >= 0 && w < W;
14
+ }
15
+
16
+ template <typename scalar_t>
17
+ __global__ void patchify_forward_kernel(int R,
18
+ const torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> net,
19
+ const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords,
20
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patches)
21
+ {
22
+ // diameter
23
+ const int D = 2*R + 2;
24
+
25
+ const int B = coords.size(0);
26
+ const int M = coords.size(1);
27
+ const int C = net.size(1);
28
+ const int H = net.size(2);
29
+ const int W = net.size(3);
30
+
31
+ int n = blockIdx.x * blockDim.x + threadIdx.x;
32
+ if (n < B * M * D * D) {
33
+ const int ii = n % D; n /= D;
34
+ const int jj = n % D; n /= D;
35
+ const int m = n % M; n /= M;
36
+
37
+ const float x = coords[n][m][0];
38
+ const float y = coords[n][m][1];
39
+ const int i = static_cast<int>(floor(y)) + (ii - R);
40
+ const int j = static_cast<int>(floor(x)) + (jj - R);
41
+
42
+ if (within_bounds(i, j, H, W)) {
43
+ for (int k=0; k<C; k++)
44
+ patches[n][m][k][ii][jj] = net[n][k][i][j];
45
+ }
46
+ }
47
+ }
48
+
49
+ template <typename scalar_t>
50
+ __global__ void patchify_backward_kernel(int R,
51
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> patch_gradient,
52
+ const torch::PackedTensorAccessor32<float,3,torch::RestrictPtrTraits> coords,
53
+ torch::PackedTensorAccessor32<scalar_t,4,torch::RestrictPtrTraits> gradient)
54
+ {
55
+ // diameter
56
+ const int D = 2*R + 2;
57
+
58
+ const int B = coords.size(0);
59
+ const int M = coords.size(1);
60
+ const int C = gradient.size(1);
61
+ const int H = gradient.size(2);
62
+ const int W = gradient.size(3);
63
+
64
+ int n = blockIdx.x * blockDim.x + threadIdx.x;
65
+ if (n < B * M * D * D) {
66
+ const int ii = n % D; n /= D;
67
+ const int jj = n % D; n /= D;
68
+ const int m = n % M; n /= M;
69
+
70
+ const float x = coords[n][m][0];
71
+ const float y = coords[n][m][1];
72
+ const int i = static_cast<int>(floor(y)) + (ii - R);
73
+ const int j = static_cast<int>(floor(x)) + (jj - R);
74
+
75
+ if (within_bounds(i, j, H, W)) {
76
+ for (int k=0; k<C; k++)
77
+ atomicAdd(&gradient[n][k][i][j], patch_gradient[n][m][k][ii][jj]);
78
+ }
79
+ }
80
+ }
81
+
82
+ template <typename scalar_t>
83
+ __global__ void corr_forward_kernel(int R,
84
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1,
85
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2,
86
+ const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords,
87
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us,
88
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs,
89
+ torch::PackedTensorAccessor32<scalar_t,6,torch::RestrictPtrTraits> corr)
90
+ {
91
+ // diameter
92
+ const int D = 2*R + 2;
93
+
94
+ const int B = coords.size(0);
95
+ const int M = coords.size(1);
96
+ const int H = coords.size(3);
97
+ const int W = coords.size(4);
98
+
99
+ const int C = fmap1.size(2);
100
+ const int H2 = fmap2.size(3);
101
+ const int W2 = fmap2.size(4);
102
+
103
+ int n = blockIdx.x * blockDim.x + threadIdx.x;
104
+
105
+ if (n < B * M * H * W * D * D) {
106
+ const int jj = n % D; n /= D;
107
+ const int ii = n % D; n /= D;
108
+ const int j0 = n % W; n /= W;
109
+ const int i0 = n % H; n /= H;
110
+ const int m = n % M; n /= M;
111
+
112
+ const int ix = us[m];
113
+ const int jx = vs[m];
114
+
115
+ const float x = coords[n][m][0][i0][j0];
116
+ const float y = coords[n][m][1][i0][j0];
117
+
118
+ const int i1 = static_cast<int>(floor(y)) + (ii - R);
119
+ const int j1 = static_cast<int>(floor(x)) + (jj - R);
120
+
121
+ scalar_t s = 0;
122
+ if (within_bounds(i1, j1, H2, W2)) {
123
+
124
+ #pragma unroll 8
125
+ for (int i=0; i<C; i+=8) {
126
+ scalar_t f1[8]; for (int j=0; j<8; j++) f1[j] = fmap1[n][ix][i+j][i0][j0];
127
+ scalar_t f2[8]; for (int j=0; j<8; j++) f2[j] = fmap2[n][jx][i+j][i1][j1];
128
+
129
+ #pragma unroll
130
+ for (int j=0; j<8; j++) s += f1[j] * f2[j];
131
+ }
132
+ }
133
+
134
+ corr[n][m][ii][jj][i0][j0] = s;
135
+ }
136
+ }
137
+
138
+
139
+ template <typename scalar_t>
140
+ __global__ void corr_backward_kernel(int R,
141
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1,
142
+ const torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2,
143
+ const torch::PackedTensorAccessor32<float,5,torch::RestrictPtrTraits> coords,
144
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> us,
145
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> vs,
146
+ const torch::PackedTensorAccessor32<float,6,torch::RestrictPtrTraits> corr_grad,
147
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap1_grad,
148
+ torch::PackedTensorAccessor32<scalar_t,5,torch::RestrictPtrTraits> fmap2_grad)
149
+ {
150
+ // diameter
151
+ const int D = 2*R + 2;
152
+
153
+ const int B = coords.size(0);
154
+ const int M = coords.size(1);
155
+ const int H = coords.size(3);
156
+ const int W = coords.size(4);
157
+
158
+ const int C = fmap1.size(2);
159
+ const int H2 = fmap2.size(3);
160
+ const int W2 = fmap2.size(4);
161
+
162
+ int n = blockIdx.x * blockDim.x + threadIdx.x;
163
+
164
+ if (n < B * M * H * W * D * D) {
165
+ const int jj = n % D; n /= D;
166
+ const int ii = n % D; n /= D;
167
+ const int j0 = n % W; n /= W;
168
+ const int i0 = n % H; n /= H;
169
+ const int m = n % M; n /= M;
170
+
171
+ const int ix = us[m];
172
+ const int jx = vs[m];
173
+
174
+ const float x = coords[n][m][0][i0][j0];
175
+ const float y = coords[n][m][1][i0][j0];
176
+
177
+ const int i1 = static_cast<int>(floor(y)) + (ii - R);
178
+ const int j1 = static_cast<int>(floor(x)) + (jj - R);
179
+
180
+ const scalar_t g = (scalar_t) corr_grad[n][m][ii][jj][i0][j0];
181
+
182
+ if (within_bounds(i1, j1, H2, W2)) {
183
+ #pragma unroll 32
184
+ for (int i=0; i<C; i++) {
185
+ atomicAdd(&fmap1_grad[n][ix][i][i0][j0], g * fmap2[n][jx][i][i1][j1]);
186
+ atomicAdd(&fmap2_grad[n][jx][i][i1][j1], g * fmap1[n][ix][i][i0][j0]);
187
+ }
188
+ }
189
+ }
190
+ }
191
+
192
+
193
+ std::vector<torch::Tensor> corr_cuda_forward(
194
+ torch::Tensor fmap1,
195
+ torch::Tensor fmap2,
196
+ torch::Tensor coords,
197
+ torch::Tensor ii,
198
+ torch::Tensor jj,
199
+ int radius)
200
+ {
201
+ const int B = coords.size(0);
202
+ const int M = coords.size(1);
203
+
204
+ const int H = coords.size(3);
205
+ const int W = coords.size(4);
206
+ const int D = 2 * radius + 2;
207
+
208
+ auto opts = fmap1.options();
209
+ auto corr = torch::empty({B, M, D, D, H, W}, opts);
210
+
211
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_forward_kernel", ([&] {
212
+ corr_forward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius,
213
+ fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
214
+ fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
215
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
216
+ ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
217
+ jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
218
+ corr.packed_accessor32<scalar_t,6,torch::RestrictPtrTraits>());
219
+ }));
220
+
221
+ torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None});
222
+ torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None});
223
+ torch::Tensor dx = x - x.floor(); dx = dx.to(fmap1.dtype());
224
+ torch::Tensor dy = y - y.floor(); dy = dy.to(fmap2.dtype());
225
+
226
+ torch::Tensor out;
227
+ out = (1 - dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)});
228
+ out += (dx) * (1 - dy) * corr.index({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)});
229
+ out += (1 - dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)});
230
+ out += (dx) * (dy) * corr.index({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)});
231
+
232
+ return { out.permute({0,1,3,2,4,5}) };
233
+ }
234
+
235
+
236
+ std::vector<torch::Tensor> corr_cuda_backward(
237
+ torch::Tensor fmap1,
238
+ torch::Tensor fmap2,
239
+ torch::Tensor coords,
240
+ torch::Tensor ii,
241
+ torch::Tensor jj,
242
+ torch::Tensor grad,
243
+ int radius)
244
+ {
245
+ const int B = coords.size(0);
246
+ const int M = coords.size(1);
247
+
248
+ const int H = coords.size(3);
249
+ const int W = coords.size(4);
250
+ const int D = 2 * radius + 2;
251
+
252
+ grad = grad.permute({0,1,3,2,4,5}).contiguous();
253
+ torch::Tensor x = coords.index({Slice(), Slice(), 0, None, None});
254
+ torch::Tensor y = coords.index({Slice(), Slice(), 1, None, None});
255
+ torch::Tensor dx = x - x.floor();
256
+ torch::Tensor dy = y - y.floor();
257
+
258
+ auto opts = torch::TensorOptions().dtype(torch::kFloat).device(torch::kCUDA);
259
+ torch::Tensor g1 = torch::zeros({B, M, D, D, H, W}, grad.options());
260
+ torch::Tensor g2 = torch::zeros({B, M, D, D, H, W}, grad.options());
261
+ torch::Tensor g3 = torch::zeros({B, M, D, D, H, W}, grad.options());
262
+ torch::Tensor g4 = torch::zeros({B, M, D, D, H, W}, grad.options());
263
+
264
+ g1.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(0, D-1)}, (1 - dx) * (1 - dy) * grad);
265
+ g2.index_put_({Slice(), Slice(), Slice(0, D-1), Slice(1, D-0)}, (dx) * (1 - dy) * grad);
266
+ g3.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(0, D-1)}, (1 - dx) * (dy) * grad);
267
+ g4.index_put_({Slice(), Slice(), Slice(1, D-0), Slice(1, D-0)}, (dx) * (dy) * grad);
268
+
269
+ torch::Tensor corr_grad = g1 + g2 + g3 + g4;
270
+ auto fmap1_grad = torch::zeros_like(fmap1);
271
+ auto fmap2_grad = torch::zeros_like(fmap2);
272
+
273
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(fmap1.type(), "corr_backward_kernel", ([&] {
274
+ corr_backward_kernel<scalar_t><<<BLOCKS(B * M * H * W * D * D), THREADS>>>(radius,
275
+ fmap1.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
276
+ fmap2.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
277
+ coords.packed_accessor32<float,5,torch::RestrictPtrTraits>(),
278
+ ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
279
+ jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
280
+ corr_grad.packed_accessor32<float,6,torch::RestrictPtrTraits>(),
281
+ fmap1_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
282
+ fmap2_grad.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());
283
+ }));
284
+
285
+ return {fmap1_grad, fmap2_grad};
286
+ }
287
+
288
+ std::vector<torch::Tensor> patchify_cuda_forward(
289
+ torch::Tensor net, torch::Tensor coords, int radius)
290
+ {
291
+ const int B = coords.size(0);
292
+ const int M = coords.size(1);
293
+ const int C = net.size(1);
294
+ const int D = 2 * radius + 2;
295
+
296
+ auto opts = net.options();
297
+ auto patches = torch::zeros({B, M, C, D, D}, opts);
298
+
299
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_forward_kernel", ([&] {
300
+ patchify_forward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius,
301
+ net.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>(),
302
+ coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
303
+ patches.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>());
304
+ }));
305
+
306
+ return { patches };
307
+ }
308
+
309
+
310
+ std::vector<torch::Tensor> patchify_cuda_backward(
311
+ torch::Tensor net,
312
+ torch::Tensor coords,
313
+ torch::Tensor gradient,
314
+ int radius)
315
+ {
316
+ const int B = coords.size(0);
317
+ const int M = coords.size(1);
318
+ const int C = net.size(1);
319
+ const int H = net.size(2);
320
+ const int W = net.size(3);
321
+ const int D = 2 * radius + 2;
322
+
323
+ torch::Tensor net_gradient = torch::zeros_like(net);
324
+
325
+ AT_DISPATCH_FLOATING_TYPES_AND_HALF(net.type(), "patchify_backward_kernel", ([&] {
326
+ patchify_backward_kernel<scalar_t><<<BLOCKS(B * M * D * D), THREADS>>>(radius,
327
+ gradient.packed_accessor32<scalar_t,5,torch::RestrictPtrTraits>(),
328
+ coords.packed_accessor32<float,3,torch::RestrictPtrTraits>(),
329
+ net_gradient.packed_accessor32<scalar_t,4,torch::RestrictPtrTraits>());
330
+ }));
331
+
332
+ return { net_gradient };
333
+ }
mini_dpvo/api/__init__.py ADDED
File without changes
mini_dpvo/api/inference.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os
3
+ import torch
4
+ from pathlib import Path
5
+ from multiprocessing import Process, Queue
6
+ from yacs.config import CfgNode
7
+
8
+ from mini_dpvo.utils import Timer
9
+ from mini_dpvo.dpvo import DPVO
10
+ from mini_dpvo.stream import image_stream, video_stream
11
+
12
+ import rerun as rr
13
+ from jaxtyping import UInt8, Float64, Float32
14
+ from scipy.spatial.transform import Rotation
15
+ from dataclasses import dataclass
16
+
17
+ from timeit import default_timer as timer
18
+
19
+
20
+ @dataclass
21
+ class DPVOPrediction:
22
+ final_poses: Float32[torch.Tensor, "num_keyframes 7"] # noqa: F722
23
+ tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
24
+ final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] # noqa: F722
25
+ final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] # noqa: F722
26
+
27
+
28
+ def log_trajectory(
29
+ parent_log_path: Path,
30
+ poses: Float32[torch.Tensor, "buffer_size 7"],
31
+ points: Float32[torch.Tensor, "buffer_size*num_patches 3"],
32
+ colors: UInt8[torch.Tensor, "buffer_size num_patches 3"],
33
+ intri_np: Float64[np.ndarray, "4"],
34
+ bgr_hw3: UInt8[np.ndarray, "h w 3"],
35
+ ):
36
+ cam_log_path = f"{parent_log_path}/camera"
37
+ rr.log(f"{cam_log_path}/pinhole/image", rr.Image(bgr_hw3[..., ::-1]))
38
+ rr.log(
39
+ f"{cam_log_path}/pinhole",
40
+ rr.Pinhole(
41
+ height=bgr_hw3.shape[0],
42
+ width=bgr_hw3.shape[1],
43
+ focal_length=[intri_np[0], intri_np[1]],
44
+ principal_point=[intri_np[2], intri_np[3]],
45
+ ),
46
+ )
47
+
48
+ poses_mask = ~(poses[:, :6] == 0).all(dim=1)
49
+ points_mask = ~(points == 0).all(dim=1)
50
+
51
+ nonzero_poses = poses[poses_mask]
52
+ nonzero_points = points[points_mask]
53
+
54
+ last_index = nonzero_poses.shape[0] - 1
55
+ # get last non-zero pose, and the index of the last non-zero pose
56
+ quat_pose = nonzero_poses[last_index].numpy(force=True)
57
+ trans_quat = quat_pose[:3]
58
+ rotation_quat = Rotation.from_quat(quat_pose[3:])
59
+
60
+ mat3x3 = rotation_quat.as_matrix()
61
+ rr.log(
62
+ f"{cam_log_path}",
63
+ rr.Transform3D(translation=trans_quat, mat3x3=mat3x3, from_parent=True),
64
+ )
65
+
66
+ # outlier removal
67
+ trajectory_center = np.median(nonzero_poses[:, :3].numpy(force=True), axis=0)
68
+ radii = lambda a: np.linalg.norm(a - trajectory_center, axis=1)
69
+ points_np = nonzero_points.view(-1, 3).numpy(force=True)
70
+ colors_np = colors.view(-1, 3)[points_mask].numpy(force=True)
71
+ inlier_mask = (
72
+ radii(points_np) < radii(nonzero_poses[:, :3].numpy(force=True)).max() * 5
73
+ )
74
+ points_filtered = points_np[inlier_mask]
75
+ colors_filtered = colors_np[inlier_mask]
76
+
77
+ # log all points and colors at the same time
78
+ rr.log(
79
+ f"{parent_log_path}/pointcloud",
80
+ rr.Points3D(
81
+ positions=points_filtered,
82
+ colors=colors_filtered,
83
+ ),
84
+ )
85
+
86
+
87
+ def log_final(
88
+ parent_log_path: Path,
89
+ final_poses: Float32[torch.Tensor, "num_keyframes 7"],
90
+ tstamps: Float64[torch.Tensor, "num_keyframes"], # noqa: F821
91
+ final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"],
92
+ final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"],
93
+ ):
94
+ for idx, (pose_quat, tstamp) in enumerate(zip(final_poses, tstamps)):
95
+ cam_log_path = f"{parent_log_path}/camera_{idx}"
96
+ trans_quat = pose_quat[:3]
97
+ R_33 = Rotation.from_quat(pose_quat[3:]).as_matrix()
98
+ rr.log(
99
+ f"{cam_log_path}",
100
+ rr.Transform3D(translation=trans_quat, mat3x3=R_33, from_parent=False),
101
+ )
102
+
103
+
104
+ def create_reader(
105
+ imagedir: str, calib: str, stride: int, skip: int, queue: Queue
106
+ ) -> Process:
107
+ if os.path.isdir(imagedir):
108
+ reader = Process(
109
+ target=image_stream, args=(queue, imagedir, calib, stride, skip)
110
+ )
111
+ else:
112
+ reader = Process(
113
+ target=video_stream, args=(queue, imagedir, calib, stride, skip)
114
+ )
115
+
116
+ return reader
117
+
118
+
119
+ @torch.no_grad()
120
+ def run(
121
+ cfg: CfgNode,
122
+ network_path: str,
123
+ imagedir: str,
124
+ calib: str,
125
+ stride: int = 1,
126
+ skip: int = 0,
127
+ vis_during: bool = True,
128
+ timeit: bool = False,
129
+ ) -> tuple[DPVOPrediction, float]:
130
+ slam = None
131
+ queue = Queue(maxsize=8)
132
+ reader: Process = create_reader(imagedir, calib, stride, skip, queue)
133
+ reader.start()
134
+
135
+ if vis_during:
136
+ parent_log_path = Path("world")
137
+ rr.log(f"{parent_log_path}", rr.ViewCoordinates.RDF, timeless=True)
138
+
139
+ start = timer()
140
+
141
+ while True:
142
+ t: int
143
+ bgr_hw3: UInt8[np.ndarray, "h w 3"]
144
+ intri_np: Float64[np.ndarray, "4"]
145
+ (t, bgr_hw3, intri_np) = queue.get()
146
+ # queue will have a (-1, image, intrinsics) tuple when the reader is done
147
+ if t < 0:
148
+ break
149
+
150
+ if vis_during:
151
+ rr.set_time_sequence(timeline="timestep", sequence=t)
152
+
153
+ bgr_3hw: UInt8[torch.Tensor, "h w 3"] = (
154
+ torch.from_numpy(bgr_hw3).permute(2, 0, 1).cuda()
155
+ )
156
+ intri_torch: Float64[torch.Tensor, "4"] = torch.from_numpy(intri_np).cuda()
157
+
158
+ if slam is None:
159
+ slam = DPVO(cfg, network_path, ht=bgr_3hw.shape[1], wd=bgr_3hw.shape[2])
160
+
161
+ with Timer("SLAM", enabled=timeit):
162
+ slam(t, bgr_3hw, intri_torch)
163
+
164
+ if slam.is_initialized and vis_during:
165
+ poses: Float32[torch.Tensor, "buffer_size 7"] = slam.poses_
166
+ points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_
167
+ colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
168
+ log_trajectory(parent_log_path, poses, points, colors, intri_np, bgr_hw3)
169
+
170
+ for _ in range(12):
171
+ slam.update()
172
+
173
+ total_time: float = timer() - start
174
+ print(f"Total time: {total_time:.2f}s")
175
+
176
+ reader.join()
177
+
178
+ final_poses: Float32[torch.Tensor, "num_keyframes 7"]
179
+ tstamps: Float64[torch.Tensor, "num_keyframes"] # noqa: F821
180
+
181
+ final_poses, tstamps = slam.terminate()
182
+ final_points: Float32[torch.Tensor, "buffer_size*num_patches 3"] = slam.points_
183
+ final_colors: UInt8[torch.Tensor, "buffer_size num_patches 3"] = slam.colors_
184
+ dpvo_pred = DPVOPrediction(
185
+ final_poses=final_poses,
186
+ tstamps=tstamps,
187
+ final_points=final_points,
188
+ final_colors=final_colors,
189
+ )
190
+ return dpvo_pred, total_time
mini_dpvo/ba.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch_scatter import scatter_sum
3
+
4
+ from . import fastba
5
+ from . import lietorch
6
+ from .lietorch import SE3
7
+
8
+ from .utils import Timer
9
+
10
+ from . import projective_ops as pops
11
+
12
+ class CholeskySolver(torch.autograd.Function):
13
+ @staticmethod
14
+ def forward(ctx, H, b):
15
+ # don't crash training if cholesky decomp fails
16
+ U, info = torch.linalg.cholesky_ex(H)
17
+
18
+ if torch.any(info):
19
+ ctx.failed = True
20
+ return torch.zeros_like(b)
21
+
22
+ xs = torch.cholesky_solve(b, U)
23
+ ctx.save_for_backward(U, xs)
24
+ ctx.failed = False
25
+
26
+ return xs
27
+
28
+ @staticmethod
29
+ def backward(ctx, grad_x):
30
+ if ctx.failed:
31
+ return None, None
32
+
33
+ U, xs = ctx.saved_tensors
34
+ dz = torch.cholesky_solve(grad_x, U)
35
+ dH = -torch.matmul(xs, dz.transpose(-1,-2))
36
+
37
+ return dH, dz
38
+
39
+ # utility functions for scattering ops
40
+ def safe_scatter_add_mat(A, ii, jj, n, m):
41
+ v = (ii >= 0) & (jj >= 0) & (ii < n) & (jj < m)
42
+ return scatter_sum(A[:,v], ii[v]*m + jj[v], dim=1, dim_size=n*m)
43
+
44
+ def safe_scatter_add_vec(b, ii, n):
45
+ v = (ii >= 0) & (ii < n)
46
+ return scatter_sum(b[:,v], ii[v], dim=1, dim_size=n)
47
+
48
+ # apply retraction operator to inv-depth maps
49
+ def disp_retr(disps, dz, ii):
50
+ ii = ii.to(device=dz.device)
51
+ return disps + scatter_sum(dz, ii, dim=1, dim_size=disps.shape[1])
52
+
53
+ # apply retraction operator to poses
54
+ def pose_retr(poses, dx, ii):
55
+ ii = ii.to(device=dx.device)
56
+ return poses.retr(scatter_sum(dx, ii, dim=1, dim_size=poses.shape[1]))
57
+
58
+ def block_matmul(A, B):
59
+ """ block matrix multiply """
60
+ b, n1, m1, p1, q1 = A.shape
61
+ b, n2, m2, p2, q2 = B.shape
62
+ A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
63
+ B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
64
+ return torch.matmul(A, B).reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
65
+
66
+ def block_solve(A, B, ep=1.0, lm=1e-4):
67
+ """ block matrix solve """
68
+ b, n1, m1, p1, q1 = A.shape
69
+ b, n2, m2, p2, q2 = B.shape
70
+ A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
71
+ B = B.permute(0, 1, 3, 2, 4).reshape(b, n2*p2, m2*q2)
72
+
73
+ A = A + (ep + lm * A) * torch.eye(n1*p1, device=A.device)
74
+
75
+ X = CholeskySolver.apply(A, B)
76
+ return X.reshape(b, n1, p1, m2, q2).permute(0, 1, 3, 2, 4)
77
+
78
+
79
+ def block_show(A):
80
+ import matplotlib.pyplot as plt
81
+ b, n1, m1, p1, q1 = A.shape
82
+ A = A.permute(0, 1, 3, 2, 4).reshape(b, n1*p1, m1*q1)
83
+ plt.imshow(A[0].detach().cpu().numpy())
84
+ plt.show()
85
+
86
+ def BA(poses, patches, intrinsics, targets, weights, lmbda, ii, jj, kk, bounds, ep=100.0, PRINT=False, fixedp=1, structure_only=False):
87
+ """ bundle adjustment """
88
+
89
+ b = 1
90
+ n = max(ii.max().item(), jj.max().item()) + 1
91
+
92
+ coords, v, (Ji, Jj, Jz) = \
93
+ pops.transform(poses, patches, intrinsics, ii, jj, kk, jacobian=True)
94
+
95
+ p = coords.shape[3]
96
+ r = targets - coords[...,p//2,p//2,:]
97
+
98
+ v *= (r.norm(dim=-1) < 250).float()
99
+
100
+ in_bounds = \
101
+ (coords[...,p//2,p//2,0] > bounds[0]) & \
102
+ (coords[...,p//2,p//2,1] > bounds[1]) & \
103
+ (coords[...,p//2,p//2,0] < bounds[2]) & \
104
+ (coords[...,p//2,p//2,1] < bounds[3])
105
+
106
+ v *= in_bounds.float()
107
+
108
+ if PRINT:
109
+ print((r * v[...,None]).norm(dim=-1).mean().item())
110
+
111
+ r = (v[...,None] * r).unsqueeze(dim=-1)
112
+ weights = (v[...,None] * weights).unsqueeze(dim=-1)
113
+
114
+ wJiT = (weights * Ji).transpose(2,3)
115
+ wJjT = (weights * Jj).transpose(2,3)
116
+ wJzT = (weights * Jz).transpose(2,3)
117
+
118
+ Bii = torch.matmul(wJiT, Ji)
119
+ Bij = torch.matmul(wJiT, Jj)
120
+ Bji = torch.matmul(wJjT, Ji)
121
+ Bjj = torch.matmul(wJjT, Jj)
122
+
123
+ Eik = torch.matmul(wJiT, Jz)
124
+ Ejk = torch.matmul(wJjT, Jz)
125
+
126
+ vi = torch.matmul(wJiT, r)
127
+ vj = torch.matmul(wJjT, r)
128
+
129
+ # fix first pose
130
+ ii = ii.clone()
131
+ jj = jj.clone()
132
+
133
+ n = n - fixedp
134
+ ii = ii - fixedp
135
+ jj = jj - fixedp
136
+
137
+ kx, kk = torch.unique(kk, return_inverse=True, sorted=True)
138
+ m = len(kx)
139
+
140
+ B = safe_scatter_add_mat(Bii, ii, ii, n, n).view(b, n, n, 6, 6) + \
141
+ safe_scatter_add_mat(Bij, ii, jj, n, n).view(b, n, n, 6, 6) + \
142
+ safe_scatter_add_mat(Bji, jj, ii, n, n).view(b, n, n, 6, 6) + \
143
+ safe_scatter_add_mat(Bjj, jj, jj, n, n).view(b, n, n, 6, 6)
144
+
145
+ E = safe_scatter_add_mat(Eik, ii, kk, n, m).view(b, n, m, 6, 1) + \
146
+ safe_scatter_add_mat(Ejk, jj, kk, n, m).view(b, n, m, 6, 1)
147
+
148
+ C = safe_scatter_add_vec(torch.matmul(wJzT, Jz), kk, m)
149
+
150
+ v = safe_scatter_add_vec(vi, ii, n).view(b, n, 1, 6, 1) + \
151
+ safe_scatter_add_vec(vj, jj, n).view(b, n, 1, 6, 1)
152
+
153
+ w = safe_scatter_add_vec(torch.matmul(wJzT, r), kk, m)
154
+
155
+ if isinstance(lmbda, torch.Tensor):
156
+ lmbda = lmbda.reshape(*C.shape)
157
+
158
+ Q = 1.0 / (C + lmbda)
159
+
160
+ ### solve w/ schur complement ###
161
+ EQ = E * Q[:,None]
162
+
163
+ if structure_only or n == 0:
164
+ dZ = (Q * w).view(b, -1, 1, 1)
165
+
166
+ else:
167
+ S = B - block_matmul(EQ, E.permute(0,2,1,4,3))
168
+ y = v - block_matmul(EQ, w.unsqueeze(dim=2))
169
+ dX = block_solve(S, y, ep=ep, lm=1e-4)
170
+
171
+ dZ = Q * (w - block_matmul(E.permute(0,2,1,4,3), dX).squeeze(dim=-1))
172
+ dX = dX.view(b, -1, 6)
173
+ dZ = dZ.view(b, -1, 1, 1)
174
+
175
+ x, y, disps = patches.unbind(dim=2)
176
+ disps = disp_retr(disps, dZ, kx).clamp(min=1e-3, max=10.0)
177
+ patches = torch.stack([x, y, disps], dim=2)
178
+
179
+ if not structure_only and n > 0:
180
+ poses = pose_retr(poses, dX, fixedp + torch.arange(n))
181
+
182
+ return poses, patches
mini_dpvo/blocks.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import torch_scatter
6
+
7
+ class LayerNorm1D(nn.Module):
8
+ def __init__(self, dim):
9
+ super(LayerNorm1D, self).__init__()
10
+ self.norm = nn.LayerNorm(dim, eps=1e-4)
11
+
12
+ def forward(self, x):
13
+ return self.norm(x.transpose(1,2)).transpose(1,2)
14
+
15
+ class GatedResidual(nn.Module):
16
+ def __init__(self, dim):
17
+ super().__init__()
18
+
19
+ self.gate = nn.Sequential(
20
+ nn.Linear(dim, dim),
21
+ nn.Sigmoid())
22
+
23
+ self.res = nn.Sequential(
24
+ nn.Linear(dim, dim),
25
+ nn.ReLU(inplace=True),
26
+ nn.Linear(dim, dim))
27
+
28
+ def forward(self, x):
29
+ return x + self.gate(x) * self.res(x)
30
+
31
+ class SoftAgg(nn.Module):
32
+ def __init__(self, dim=512, expand=True):
33
+ super(SoftAgg, self).__init__()
34
+ self.dim = dim
35
+ self.expand = expand
36
+ self.f = nn.Linear(self.dim, self.dim)
37
+ self.g = nn.Linear(self.dim, self.dim)
38
+ self.h = nn.Linear(self.dim, self.dim)
39
+
40
+ def forward(self, x, ix):
41
+ _, jx = torch.unique(ix, return_inverse=True)
42
+ w = torch_scatter.scatter_softmax(self.g(x), jx, dim=1)
43
+ y = torch_scatter.scatter_sum(self.f(x) * w, jx, dim=1)
44
+
45
+ if self.expand:
46
+ return self.h(y)[:,jx]
47
+
48
+ return self.h(y)
49
+
50
+ class SoftAggBasic(nn.Module):
51
+ def __init__(self, dim=512, expand=True):
52
+ super(SoftAggBasic, self).__init__()
53
+ self.dim = dim
54
+ self.expand = expand
55
+ self.f = nn.Linear(self.dim, self.dim)
56
+ self.g = nn.Linear(self.dim, 1)
57
+ self.h = nn.Linear(self.dim, self.dim)
58
+
59
+ def forward(self, x, ix):
60
+ _, jx = torch.unique(ix, return_inverse=True)
61
+ w = torch_scatter.scatter_softmax(self.g(x), jx, dim=1)
62
+ y = torch_scatter.scatter_sum(self.f(x) * w, jx, dim=1)
63
+
64
+ if self.expand:
65
+ return self.h(y)[:,jx]
66
+
67
+ return self.h(y)
68
+
69
+
70
+ ### Gradient Clipping and Zeroing Operations ###
71
+
72
+ GRAD_CLIP = 0.1
73
+
74
+ class GradClip(torch.autograd.Function):
75
+ @staticmethod
76
+ def forward(ctx, x):
77
+ return x
78
+
79
+ @staticmethod
80
+ def backward(ctx, grad_x):
81
+ grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x)
82
+ return grad_x.clamp(min=-0.01, max=0.01)
83
+
84
+ class GradientClip(nn.Module):
85
+ def __init__(self):
86
+ super(GradientClip, self).__init__()
87
+
88
+ def forward(self, x):
89
+ return GradClip.apply(x)
90
+
91
+ class GradZero(torch.autograd.Function):
92
+ @staticmethod
93
+ def forward(ctx, x):
94
+ return x
95
+
96
+ @staticmethod
97
+ def backward(ctx, grad_x):
98
+ grad_x = torch.where(torch.isnan(grad_x), torch.zeros_like(grad_x), grad_x)
99
+ grad_x = torch.where(torch.abs(grad_x) > GRAD_CLIP, torch.zeros_like(grad_x), grad_x)
100
+ return grad_x
101
+
102
+ class GradientZero(nn.Module):
103
+ def __init__(self):
104
+ super(GradientZero, self).__init__()
105
+
106
+ def forward(self, x):
107
+ return GradZero.apply(x)
108
+
109
+
110
+ class GradMag(torch.autograd.Function):
111
+ @staticmethod
112
+ def forward(ctx, x):
113
+ return x
114
+
115
+ @staticmethod
116
+ def backward(ctx, grad_x):
117
+ print(grad_x.abs().mean())
118
+ return grad_x
mini_dpvo/config.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _C = CN()
4
+
5
+ # max number of keyframes
6
+ _C.BUFFER_SIZE = 2048
7
+
8
+ # bias patch selection towards high gradient regions?
9
+ _C.GRADIENT_BIAS = True
10
+
11
+ # VO config (increase for better accuracy)
12
+ _C.PATCHES_PER_FRAME = 80
13
+ _C.REMOVAL_WINDOW = 20
14
+ _C.OPTIMIZATION_WINDOW = 12
15
+ _C.PATCH_LIFETIME = 12
16
+
17
+ # threshold for keyframe removal
18
+ _C.KEYFRAME_INDEX = 4
19
+ _C.KEYFRAME_THRESH = 12.5
20
+
21
+ # camera motion model
22
+ _C.MOTION_MODEL = 'DAMPED_LINEAR'
23
+ _C.MOTION_DAMPING = 0.5
24
+
25
+ _C.MIXED_PRECISION = True
26
+
27
+ cfg = _C
mini_dpvo/data_readers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
mini_dpvo/data_readers/augmentation.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ import numpy as np
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class RGBDAugmentor:
8
+ """ perform augmentation on RGB-D video """
9
+
10
+ def __init__(self, crop_size):
11
+ self.crop_size = crop_size
12
+ self.augcolor = transforms.Compose([
13
+ transforms.ToPILImage(),
14
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2/3.14),
15
+ transforms.RandomGrayscale(p=0.1),
16
+ transforms.RandomInvert(p=0.1),
17
+ transforms.ToTensor()])
18
+
19
+ self.max_scale = 0.5
20
+
21
+ def spatial_transform(self, images, depths, poses, intrinsics):
22
+ """ cropping and resizing """
23
+ ht, wd = images.shape[2:]
24
+
25
+ max_scale = self.max_scale
26
+ min_scale = np.log2(np.maximum(
27
+ (self.crop_size[0] + 1) / float(ht),
28
+ (self.crop_size[1] + 1) / float(wd)))
29
+
30
+ scale = 1
31
+ if np.random.rand() < 0.8:
32
+ scale = 2 ** np.random.uniform(0.0, max_scale)
33
+
34
+ intrinsics = scale * intrinsics
35
+
36
+ ht1 = int(scale * ht)
37
+ wd1 = int(scale * wd)
38
+
39
+ depths = depths.unsqueeze(dim=1)
40
+
41
+ images = F.interpolate(images, (ht1, wd1), mode='bicubic', align_corners=False)
42
+ depths = F.interpolate(depths, (ht1, wd1), recompute_scale_factor=False)
43
+
44
+ # always perform center crop (TODO: try non-center crops)
45
+ y0 = (images.shape[2] - self.crop_size[0]) // 2
46
+ x0 = (images.shape[3] - self.crop_size[1]) // 2
47
+
48
+ intrinsics = intrinsics - torch.tensor([0.0, 0.0, x0, y0])
49
+ images = images[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
50
+ depths = depths[:, :, y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
51
+
52
+ depths = depths.squeeze(dim=1)
53
+ return images, poses, depths, intrinsics
54
+
55
+ def color_transform(self, images):
56
+ """ color jittering """
57
+ num, ch, ht, wd = images.shape
58
+ images = images.permute(1, 2, 3, 0).reshape(ch, ht, wd*num)
59
+ images = 255 * self.augcolor(images[[2,1,0]] / 255.0)
60
+ return images[[2,1,0]].reshape(ch, ht, wd, num).permute(3,0,1,2).contiguous()
61
+
62
+ def __call__(self, images, poses, depths, intrinsics):
63
+ if np.random.rand() < 0.5:
64
+ images = self.color_transform(images)
65
+
66
+ return self.spatial_transform(images, depths, poses, intrinsics)
mini_dpvo/data_readers/base.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.utils.data as data
4
+ import torch.nn.functional as F
5
+
6
+ import csv
7
+ import os
8
+ import cv2
9
+ import math
10
+ import random
11
+ import json
12
+ import pickle
13
+ import os.path as osp
14
+
15
+ from .augmentation import RGBDAugmentor
16
+ from .rgbd_utils import *
17
+
18
+ class RGBDDataset(data.Dataset):
19
+ def __init__(self, name, datapath, n_frames=4, crop_size=[480,640], fmin=10.0, fmax=75.0, aug=True, sample=True):
20
+ """ Base class for RGBD dataset """
21
+ self.aug = None
22
+ self.root = datapath
23
+ self.name = name
24
+
25
+ self.aug = aug
26
+ self.sample = sample
27
+
28
+ self.n_frames = n_frames
29
+ self.fmin = fmin # exclude very easy examples
30
+ self.fmax = fmax # exclude very hard examples
31
+
32
+ if self.aug:
33
+ self.aug = RGBDAugmentor(crop_size=crop_size)
34
+
35
+ # building dataset is expensive, cache so only needs to be performed once
36
+ cur_path = osp.dirname(osp.abspath(__file__))
37
+ if not os.path.isdir(osp.join(cur_path, 'cache')):
38
+ os.mkdir(osp.join(cur_path, 'cache'))
39
+
40
+ self.scene_info = \
41
+ pickle.load(open('datasets/TartanAir.pickle', 'rb'))[0]
42
+
43
+ self._build_dataset_index()
44
+
45
+ def _build_dataset_index(self):
46
+ self.dataset_index = []
47
+ for scene in self.scene_info:
48
+ if not self.__class__.is_test_scene(scene):
49
+ graph = self.scene_info[scene]['graph']
50
+ for i in graph:
51
+ if i < len(graph) - 65:
52
+ self.dataset_index.append((scene, i))
53
+ else:
54
+ print("Reserving {} for validation".format(scene))
55
+
56
+ @staticmethod
57
+ def image_read(image_file):
58
+ return cv2.imread(image_file)
59
+
60
+ @staticmethod
61
+ def depth_read(depth_file):
62
+ return np.load(depth_file)
63
+
64
+ def build_frame_graph(self, poses, depths, intrinsics, f=16, max_flow=256):
65
+ """ compute optical flow distance between all pairs of frames """
66
+ def read_disp(fn):
67
+ depth = self.__class__.depth_read(fn)[f//2::f, f//2::f]
68
+ depth[depth < 0.01] = np.mean(depth)
69
+ return 1.0 / depth
70
+
71
+ poses = np.array(poses)
72
+ intrinsics = np.array(intrinsics) / f
73
+
74
+ disps = np.stack(list(map(read_disp, depths)), 0)
75
+ d = f * compute_distance_matrix_flow(poses, disps, intrinsics)
76
+
77
+ graph = {}
78
+ for i in range(d.shape[0]):
79
+ j, = np.where(d[i] < max_flow)
80
+ graph[i] = (j, d[i,j])
81
+
82
+ return graph
83
+
84
+ def __getitem__(self, index):
85
+ """ return training video """
86
+
87
+ index = index % len(self.dataset_index)
88
+ scene_id, ix = self.dataset_index[index]
89
+
90
+ frame_graph = self.scene_info[scene_id]['graph']
91
+ images_list = self.scene_info[scene_id]['images']
92
+ depths_list = self.scene_info[scene_id]['depths']
93
+ poses_list = self.scene_info[scene_id]['poses']
94
+ intrinsics_list = self.scene_info[scene_id]['intrinsics']
95
+
96
+ # stride = np.random.choice([1,2,3])
97
+
98
+ d = np.random.uniform(self.fmin, self.fmax)
99
+ s = 1
100
+
101
+ inds = [ ix ]
102
+
103
+ while len(inds) < self.n_frames:
104
+ # get other frames within flow threshold
105
+
106
+ if self.sample:
107
+ k = (frame_graph[ix][1] > self.fmin) & (frame_graph[ix][1] < self.fmax)
108
+ frames = frame_graph[ix][0][k]
109
+
110
+ # prefer frames forward in time
111
+ if np.count_nonzero(frames[frames > ix]):
112
+ ix = np.random.choice(frames[frames > ix])
113
+
114
+ elif ix + 1 < len(images_list):
115
+ ix = ix + 1
116
+
117
+ elif np.count_nonzero(frames):
118
+ ix = np.random.choice(frames)
119
+
120
+ else:
121
+ i = frame_graph[ix][0].copy()
122
+ g = frame_graph[ix][1].copy()
123
+
124
+ g[g > d] = -1
125
+ if s > 0:
126
+ g[i <= ix] = -1
127
+ else:
128
+ g[i >= ix] = -1
129
+
130
+ if len(g) > 0 and np.max(g) > 0:
131
+ ix = i[np.argmax(g)]
132
+ else:
133
+ if ix + s >= len(images_list) or ix + s < 0:
134
+ s *= -1
135
+
136
+ ix = ix + s
137
+
138
+ inds += [ ix ]
139
+
140
+
141
+ images, depths, poses, intrinsics = [], [], [], []
142
+ for i in inds:
143
+ images.append(self.__class__.image_read(images_list[i]))
144
+ depths.append(self.__class__.depth_read(depths_list[i]))
145
+ poses.append(poses_list[i])
146
+ intrinsics.append(intrinsics_list[i])
147
+
148
+ images = np.stack(images).astype(np.float32)
149
+ depths = np.stack(depths).astype(np.float32)
150
+ poses = np.stack(poses).astype(np.float32)
151
+ intrinsics = np.stack(intrinsics).astype(np.float32)
152
+
153
+ images = torch.from_numpy(images).float()
154
+ images = images.permute(0, 3, 1, 2)
155
+
156
+ disps = torch.from_numpy(1.0 / depths)
157
+ poses = torch.from_numpy(poses)
158
+ intrinsics = torch.from_numpy(intrinsics)
159
+
160
+ if self.aug:
161
+ images, poses, disps, intrinsics = \
162
+ self.aug(images, poses, disps, intrinsics)
163
+
164
+ # normalize depth
165
+ s = .7 * torch.quantile(disps, .98)
166
+ disps = disps / s
167
+ poses[...,:3] *= s
168
+
169
+ return images, poses, disps, intrinsics
170
+
171
+ def __len__(self):
172
+ return len(self.dataset_index)
173
+
174
+ def __imul__(self, x):
175
+ self.dataset_index *= x
176
+ return self
mini_dpvo/data_readers/factory.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import pickle
3
+ import os
4
+ import os.path as osp
5
+
6
+ # RGBD-Dataset
7
+ from .tartan import TartanAir
8
+
9
+ def dataset_factory(dataset_list, **kwargs):
10
+ """ create a combined dataset """
11
+
12
+ from torch.utils.data import ConcatDataset
13
+
14
+ dataset_map = {
15
+ 'tartan': (TartanAir, ),
16
+ }
17
+
18
+ db_list = []
19
+ for key in dataset_list:
20
+ # cache datasets for faster future loading
21
+ db = dataset_map[key][0](**kwargs)
22
+
23
+ print("Dataset {} has {} images".format(key, len(db)))
24
+ db_list.append(db)
25
+
26
+ return ConcatDataset(db_list)
mini_dpvo/data_readers/frame_utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ from os.path import *
4
+ import re
5
+ import cv2
6
+ cv2.setNumThreads(0)
7
+
8
+
9
+ TAG_CHAR = np.array([202021.25], np.float32)
10
+
11
+ def readFlowKITTI(filename):
12
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
13
+ flow = flow[:,:,::-1].astype(np.float32)
14
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
15
+ flow = (flow - 2**15) / 64.0
16
+ return flow, valid
17
+
18
+ def readFlow(fn):
19
+ """ Read .flo file in Middlebury format"""
20
+ # Code adapted from:
21
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
22
+
23
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
24
+ # print 'fn = %s'%(fn)
25
+ with open(fn, 'rb') as f:
26
+ magic = np.fromfile(f, np.float32, count=1)
27
+ if 202021.25 != magic:
28
+ print('Magic number incorrect. Invalid .flo file')
29
+ return None
30
+ else:
31
+ w = np.fromfile(f, np.int32, count=1)
32
+ h = np.fromfile(f, np.int32, count=1)
33
+ # print 'Reading %d x %d flo file\n' % (w, h)
34
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
35
+ # Reshape data into 3D array (columns, rows, bands)
36
+ # The reshape here is for visualization, the original code is (w,h,2)
37
+ return np.resize(data, (int(h), int(w), 2))
38
+
39
+ def readPFM(file):
40
+ file = open(file, 'rb')
41
+
42
+ color = None
43
+ width = None
44
+ height = None
45
+ scale = None
46
+ endian = None
47
+
48
+ header = file.readline().rstrip()
49
+ if header == b'PF':
50
+ color = True
51
+ elif header == b'Pf':
52
+ color = False
53
+ else:
54
+ raise Exception('Not a PFM file.')
55
+
56
+ try:
57
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
58
+ except:
59
+ dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
60
+
61
+ if dim_match:
62
+ width, height = map(int, dim_match.groups())
63
+ else:
64
+ raise Exception('Malformed PFM header.')
65
+
66
+ scale = float(file.readline().rstrip())
67
+ if scale < 0: # little-endian
68
+ endian = '<'
69
+ scale = -scale
70
+ else:
71
+ endian = '>' # big-endian
72
+
73
+ data = np.fromfile(file, endian + 'f')
74
+ shape = (height, width, 3) if color else (height, width)
75
+
76
+ data = np.reshape(data, shape)
77
+ data = np.flipud(data)
78
+ return data
79
+
80
+
81
+ def writeFlow(filename,uv,v=None):
82
+ """ Write optical flow to file.
83
+
84
+ If v is None, uv is assumed to contain both u and v channels,
85
+ stacked in depth.
86
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
87
+ """
88
+ nBands = 2
89
+
90
+ if v is None:
91
+ assert(uv.ndim == 3)
92
+ assert(uv.shape[2] == 2)
93
+ u = uv[:,:,0]
94
+ v = uv[:,:,1]
95
+ else:
96
+ u = uv
97
+
98
+ assert(u.shape == v.shape)
99
+ height,width = u.shape
100
+ f = open(filename,'wb')
101
+ # write the header
102
+ f.write(TAG_CHAR)
103
+ np.array(width).astype(np.int32).tofile(f)
104
+ np.array(height).astype(np.int32).tofile(f)
105
+ # arrange into matrix form
106
+ tmp = np.zeros((height, width*nBands))
107
+ tmp[:,np.arange(width)*2] = u
108
+ tmp[:,np.arange(width)*2 + 1] = v
109
+ tmp.astype(np.float32).tofile(f)
110
+ f.close()
111
+
112
+
113
+ def readDPT(filename):
114
+ """ Read depth data from file, return as numpy array. """
115
+ f = open(filename,'rb')
116
+ check = np.fromfile(f,dtype=np.float32,count=1)[0]
117
+ TAG_FLOAT = 202021.25
118
+ TAG_CHAR = 'PIEH'
119
+ assert check == TAG_FLOAT, ' depth_read:: Wrong tag in flow file (should be: {0}, is: {1}). Big-endian machine? '.format(TAG_FLOAT,check)
120
+ width = np.fromfile(f,dtype=np.int32,count=1)[0]
121
+ height = np.fromfile(f,dtype=np.int32,count=1)[0]
122
+ size = width*height
123
+ assert width > 0 and height > 0 and size > 1 and size < 100000000, ' depth_read:: Wrong input size (width = {0}, height = {1}).'.format(width,height)
124
+ depth = np.fromfile(f,dtype=np.float32,count=-1).reshape((height,width))
125
+ return depth
126
+
127
+ def cam_read(filename):
128
+ """ Read camera data, return (M,N) tuple.
129
+ M is the intrinsic matrix, N is the extrinsic matrix, so that
130
+ x = M*N*X,
131
+ with x being a point in homogeneous image pixel coordinates, X being a
132
+ point in homogeneous world coordinates."""
133
+ f = open(filename,'rb')
134
+ check = np.fromfile(f,dtype=np.float32,count=1)[0]
135
+ M = np.fromfile(f,dtype='float64',count=9).reshape((3,3))
136
+ N = np.fromfile(f,dtype='float64',count=12).reshape((3,4))
137
+
138
+ E = np.eye(4)
139
+ E[0:3,:] = N
140
+
141
+ fx, fy, cx, cy = M[0,0], M[1,1], M[0,2], M[1,2]
142
+ kvec = np.array([fx, fy, cx, cy])
143
+
144
+ q = Rotation.from_matrix(E[:3,:3]).as_quat()
145
+ pvec = np.concatenate([E[:3,3], q], 0)
146
+
147
+ return pvec, kvec
148
+
149
+
150
+ def read_gen(file_name, pil=False):
151
+ ext = splitext(file_name)[-1]
152
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
153
+ return Image.open(file_name)
154
+ elif ext == '.bin' or ext == '.raw':
155
+ return np.load(file_name)
156
+ elif ext == '.flo':
157
+ return readFlow(file_name).astype(np.float32)
158
+ elif ext == '.pfm':
159
+ return readPFM(file_name).astype(np.float32)
160
+ elif ext == '.dpt':
161
+ return readDPT(file_name).astype(np.float32)
162
+ elif ext == '.cam':
163
+ return cam_read(file_name)
164
+ return []
mini_dpvo/data_readers/rgbd_utils.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os.path as osp
3
+
4
+ import torch
5
+ from ..lietorch import SE3
6
+
7
+ from scipy.spatial.transform import Rotation
8
+
9
+ def parse_list(filepath, skiprows=0):
10
+ """ read list data """
11
+ data = np.loadtxt(filepath, delimiter=' ', dtype=np.unicode_, skiprows=skiprows)
12
+ return data
13
+
14
+ def associate_frames(tstamp_image, tstamp_depth, tstamp_pose, max_dt=1.0):
15
+ """ pair images, depths, and poses """
16
+ associations = []
17
+ for i, t in enumerate(tstamp_image):
18
+ if tstamp_pose is None:
19
+ j = np.argmin(np.abs(tstamp_depth - t))
20
+ if (np.abs(tstamp_depth[j] - t) < max_dt):
21
+ associations.append((i, j))
22
+
23
+ else:
24
+ j = np.argmin(np.abs(tstamp_depth - t))
25
+ k = np.argmin(np.abs(tstamp_pose - t))
26
+
27
+ if (np.abs(tstamp_depth[j] - t) < max_dt) and \
28
+ (np.abs(tstamp_pose[k] - t) < max_dt):
29
+ associations.append((i, j, k))
30
+
31
+ return associations
32
+
33
+ def loadtum(datapath, frame_rate=-1):
34
+ """ read video data in tum-rgbd format """
35
+ if osp.isfile(osp.join(datapath, 'groundtruth.txt')):
36
+ pose_list = osp.join(datapath, 'groundtruth.txt')
37
+
38
+ elif osp.isfile(osp.join(datapath, 'pose.txt')):
39
+ pose_list = osp.join(datapath, 'pose.txt')
40
+
41
+ else:
42
+ return None, None, None, None
43
+
44
+ image_list = osp.join(datapath, 'rgb.txt')
45
+ depth_list = osp.join(datapath, 'depth.txt')
46
+
47
+ calib_path = osp.join(datapath, 'calibration.txt')
48
+ intrinsic = None
49
+ if osp.isfile(calib_path):
50
+ intrinsic = np.loadtxt(calib_path, delimiter=' ')
51
+ intrinsic = intrinsic.astype(np.float64)
52
+
53
+ image_data = parse_list(image_list)
54
+ depth_data = parse_list(depth_list)
55
+ pose_data = parse_list(pose_list, skiprows=1)
56
+ pose_vecs = pose_data[:,1:].astype(np.float64)
57
+
58
+ tstamp_image = image_data[:,0].astype(np.float64)
59
+ tstamp_depth = depth_data[:,0].astype(np.float64)
60
+ tstamp_pose = pose_data[:,0].astype(np.float64)
61
+ associations = associate_frames(tstamp_image, tstamp_depth, tstamp_pose)
62
+
63
+ # print(len(tstamp_image))
64
+ # print(len(associations))
65
+
66
+ indicies = range(len(associations))[::5]
67
+
68
+ # indicies = [ 0 ]
69
+ # for i in range(1, len(associations)):
70
+ # t0 = tstamp_image[associations[indicies[-1]][0]]
71
+ # t1 = tstamp_image[associations[i][0]]
72
+ # if t1 - t0 > 1.0 / frame_rate:
73
+ # indicies += [ i ]
74
+
75
+ images, poses, depths, intrinsics, tstamps = [], [], [], [], []
76
+ for ix in indicies:
77
+ (i, j, k) = associations[ix]
78
+ images += [ osp.join(datapath, image_data[i,1]) ]
79
+ depths += [ osp.join(datapath, depth_data[j,1]) ]
80
+ poses += [ pose_vecs[k] ]
81
+ tstamps += [ tstamp_image[i] ]
82
+
83
+ if intrinsic is not None:
84
+ intrinsics += [ intrinsic ]
85
+
86
+ return images, depths, poses, intrinsics, tstamps
87
+
88
+
89
+ def all_pairs_distance_matrix(poses, beta=2.5):
90
+ """ compute distance matrix between all pairs of poses """
91
+ poses = np.array(poses, dtype=np.float32)
92
+ poses[:,:3] *= beta # scale to balence rot + trans
93
+ poses = SE3(torch.from_numpy(poses))
94
+
95
+ r = (poses[:,None].inv() * poses[None,:]).log()
96
+ return r.norm(dim=-1).cpu().numpy()
97
+
98
+ def pose_matrix_to_quaternion(pose):
99
+ """ convert 4x4 pose matrix to (t, q) """
100
+ q = Rotation.from_matrix(pose[:3, :3]).as_quat()
101
+ return np.concatenate([pose[:3, 3], q], axis=0)
102
+
103
+ def compute_distance_matrix_flow(poses, disps, intrinsics):
104
+ """ compute flow magnitude between all pairs of frames """
105
+ if not isinstance(poses, SE3):
106
+ poses = torch.from_numpy(poses).float().cuda()[None]
107
+ poses = SE3(poses).inv()
108
+
109
+ disps = torch.from_numpy(disps).float().cuda()[None]
110
+ intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
111
+
112
+ N = poses.shape[1]
113
+
114
+ ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
115
+ ii = ii.reshape(-1).cuda()
116
+ jj = jj.reshape(-1).cuda()
117
+
118
+ MAX_FLOW = 100.0
119
+ matrix = np.zeros((N, N), dtype=np.float32)
120
+
121
+ s = 2048
122
+ for i in range(0, ii.shape[0], s):
123
+ flow1, val1 = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
124
+ flow2, val2 = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s])
125
+
126
+ flow = torch.stack([flow1, flow2], dim=2)
127
+ val = torch.stack([val1, val2], dim=2)
128
+
129
+ mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
130
+ mag = mag.view(mag.shape[1], -1)
131
+ val = val.view(val.shape[1], -1)
132
+
133
+ mag = (mag * val).mean(-1) / val.mean(-1)
134
+ mag[val.mean(-1) < 0.7] = np.inf
135
+
136
+ i1 = ii[i:i+s].cpu().numpy()
137
+ j1 = jj[i:i+s].cpu().numpy()
138
+ matrix[i1, j1] = mag.cpu().numpy()
139
+
140
+ return matrix
141
+
142
+
143
+ def compute_distance_matrix_flow2(poses, disps, intrinsics, beta=0.4):
144
+ """ compute flow magnitude between all pairs of frames """
145
+ # if not isinstance(poses, SE3):
146
+ # poses = torch.from_numpy(poses).float().cuda()[None]
147
+ # poses = SE3(poses).inv()
148
+
149
+ # disps = torch.from_numpy(disps).float().cuda()[None]
150
+ # intrinsics = torch.from_numpy(intrinsics).float().cuda()[None]
151
+
152
+ N = poses.shape[1]
153
+
154
+ ii, jj = torch.meshgrid(torch.arange(N), torch.arange(N))
155
+ ii = ii.reshape(-1)
156
+ jj = jj.reshape(-1)
157
+
158
+ MAX_FLOW = 128.0
159
+ matrix = np.zeros((N, N), dtype=np.float32)
160
+
161
+ s = 2048
162
+ for i in range(0, ii.shape[0], s):
163
+ flow1a, val1a = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s], tonly=True)
164
+ flow1b, val1b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
165
+ flow2a, val2a = pops.induced_flow(poses, disps, intrinsics, jj[i:i+s], ii[i:i+s], tonly=True)
166
+ flow2b, val2b = pops.induced_flow(poses, disps, intrinsics, ii[i:i+s], jj[i:i+s])
167
+
168
+ flow1 = flow1a + beta * flow1b
169
+ val1 = val1a * val2b
170
+
171
+ flow2 = flow2a + beta * flow2b
172
+ val2 = val2a * val2b
173
+
174
+ flow = torch.stack([flow1, flow2], dim=2)
175
+ val = torch.stack([val1, val2], dim=2)
176
+
177
+ mag = flow.norm(dim=-1).clamp(max=MAX_FLOW)
178
+ mag = mag.view(mag.shape[1], -1)
179
+ val = val.view(val.shape[1], -1)
180
+
181
+ mag = (mag * val).mean(-1) / val.mean(-1)
182
+ mag[val.mean(-1) < 0.8] = np.inf
183
+
184
+ i1 = ii[i:i+s].cpu().numpy()
185
+ j1 = jj[i:i+s].cpu().numpy()
186
+ matrix[i1, j1] = mag.cpu().numpy()
187
+
188
+ return matrix
mini_dpvo/data_readers/tartan.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import torch
4
+ import glob
5
+ import cv2
6
+ import os
7
+ import os.path as osp
8
+
9
+ from ..lietorch import SE3
10
+ from .base import RGBDDataset
11
+
12
+ # cur_path = osp.dirname(osp.abspath(__file__))
13
+ # test_split = osp.join(cur_path, 'tartan_test.txt')
14
+ # test_split = open(test_split).read().split()
15
+
16
+
17
+ test_split = [
18
+ "abandonedfactory/abandonedfactory/Easy/P011",
19
+ "abandonedfactory/abandonedfactory/Hard/P011",
20
+ "abandonedfactory_night/abandonedfactory_night/Easy/P013",
21
+ "abandonedfactory_night/abandonedfactory_night/Hard/P014",
22
+ "amusement/amusement/Easy/P008",
23
+ "amusement/amusement/Hard/P007",
24
+ "carwelding/carwelding/Easy/P007",
25
+ "endofworld/endofworld/Easy/P009",
26
+ "gascola/gascola/Easy/P008",
27
+ "gascola/gascola/Hard/P009",
28
+ "hospital/hospital/Easy/P036",
29
+ "hospital/hospital/Hard/P049",
30
+ "japanesealley/japanesealley/Easy/P007",
31
+ "japanesealley/japanesealley/Hard/P005",
32
+ "neighborhood/neighborhood/Easy/P021",
33
+ "neighborhood/neighborhood/Hard/P017",
34
+ "ocean/ocean/Easy/P013",
35
+ "ocean/ocean/Hard/P009",
36
+ "office2/office2/Easy/P011",
37
+ "office2/office2/Hard/P010",
38
+ "office/office/Hard/P007",
39
+ "oldtown/oldtown/Easy/P007",
40
+ "oldtown/oldtown/Hard/P008",
41
+ "seasidetown/seasidetown/Easy/P009",
42
+ "seasonsforest/seasonsforest/Easy/P011",
43
+ "seasonsforest/seasonsforest/Hard/P006",
44
+ "seasonsforest_winter/seasonsforest_winter/Easy/P009",
45
+ "seasonsforest_winter/seasonsforest_winter/Hard/P018",
46
+ "soulcity/soulcity/Easy/P012",
47
+ "soulcity/soulcity/Hard/P009",
48
+ "westerndesert/westerndesert/Easy/P013",
49
+ "westerndesert/westerndesert/Hard/P007",
50
+ ]
51
+
52
+
53
+ class TartanAir(RGBDDataset):
54
+
55
+ # scale depths to balance rot & trans
56
+ DEPTH_SCALE = 5.0
57
+
58
+ def __init__(self, mode='training', **kwargs):
59
+ self.mode = mode
60
+ self.n_frames = 2
61
+ super(TartanAir, self).__init__(name='TartanAir', **kwargs)
62
+
63
+ @staticmethod
64
+ def is_test_scene(scene):
65
+ # print(scene, any(x in scene for x in test_split))
66
+ return any(x in scene for x in test_split)
67
+
68
+ def _build_dataset(self):
69
+ from tqdm import tqdm
70
+ print("Building TartanAir dataset")
71
+
72
+ scene_info = {}
73
+ scenes = glob.glob(osp.join(self.root, '*/*/*/*'))
74
+ for scene in tqdm(sorted(scenes)):
75
+ images = sorted(glob.glob(osp.join(scene, 'image_left/*.png')))
76
+ depths = sorted(glob.glob(osp.join(scene, 'depth_left/*.npy')))
77
+
78
+ if len(images) != len(depths):
79
+ continue
80
+
81
+ poses = np.loadtxt(osp.join(scene, 'pose_left.txt'), delimiter=' ')
82
+ poses = poses[:, [1, 2, 0, 4, 5, 3, 6]]
83
+ poses[:,:3] /= TartanAir.DEPTH_SCALE
84
+ intrinsics = [TartanAir.calib_read()] * len(images)
85
+
86
+ # graph of co-visible frames based on flow
87
+ graph = self.build_frame_graph(poses, depths, intrinsics)
88
+
89
+ scene = '/'.join(scene.split('/'))
90
+ scene_info[scene] = {'images': images, 'depths': depths,
91
+ 'poses': poses, 'intrinsics': intrinsics, 'graph': graph}
92
+
93
+ return scene_info
94
+
95
+ @staticmethod
96
+ def calib_read():
97
+ return np.array([320.0, 320.0, 320.0, 240.0])
98
+
99
+ @staticmethod
100
+ def image_read(image_file):
101
+ return cv2.imread(image_file)
102
+
103
+ @staticmethod
104
+ def depth_read(depth_file):
105
+ depth = np.load(depth_file) / TartanAir.DEPTH_SCALE
106
+ depth[depth==np.nan] = 1.0
107
+ depth[depth==np.inf] = 1.0
108
+ return depth
109
+
110
+
mini_dpvo/data_readers/tartan_test.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ abandonedfactory/abandonedfactory/Easy/P011
2
+ abandonedfactory/abandonedfactory/Hard/P011
3
+ abandonedfactory_night/abandonedfactory_night/Easy/P013
4
+ abandonedfactory_night/abandonedfactory_night/Hard/P014
5
+ amusement/amusement/Easy/P008
6
+ amusement/amusement/Hard/P007
7
+ carwelding/carwelding/Easy/P007
8
+ endofworld/endofworld/Easy/P009
9
+ gascola/gascola/Easy/P008
10
+ gascola/gascola/Hard/P009
11
+ hospital/hospital/Easy/P036
12
+ hospital/hospital/Hard/P049
13
+ japanesealley/japanesealley/Easy/P007
14
+ japanesealley/japanesealley/Hard/P005
15
+ neighborhood/neighborhood/Easy/P021
16
+ neighborhood/neighborhood/Hard/P017
17
+ ocean/ocean/Easy/P013
18
+ ocean/ocean/Hard/P009
19
+ office2/office2/Easy/P011
20
+ office2/office2/Hard/P010
21
+ office/office/Hard/P007
22
+ oldtown/oldtown/Easy/P007
23
+ oldtown/oldtown/Hard/P008
24
+ seasidetown/seasidetown/Easy/P009
25
+ seasonsforest/seasonsforest/Easy/P011
26
+ seasonsforest/seasonsforest/Hard/P006
27
+ seasonsforest_winter/seasonsforest_winter/Easy/P009
28
+ seasonsforest_winter/seasonsforest_winter/Hard/P018
29
+ soulcity/soulcity/Easy/P012
30
+ soulcity/soulcity/Hard/P009
31
+ westerndesert/westerndesert/Easy/P013
32
+ westerndesert/westerndesert/Hard/P007
mini_dpvo/dpvo.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torch.nn.functional as F
4
+
5
+ from . import fastba
6
+ from . import altcorr
7
+ from . import lietorch
8
+ from .lietorch import SE3
9
+
10
+ from .net import VONet
11
+
12
+ from .utils import Timer, flatmeshgrid
13
+ from . import projective_ops as pops
14
+
15
+ autocast = torch.cuda.amp.autocast
16
+ Id = SE3.Identity(1, device="cuda")
17
+
18
+
19
+ class DPVO:
20
+ def __init__(self, cfg, network, ht=480, wd=640):
21
+ self.cfg = cfg
22
+ self.load_weights(network)
23
+ self.is_initialized = False
24
+ self.enable_timing = False
25
+
26
+ self.n = 0 # number of frames
27
+ self.m = 0 # number of patches
28
+ self.M = self.cfg.PATCHES_PER_FRAME
29
+ self.N = self.cfg.BUFFER_SIZE
30
+
31
+ self.ht = ht # image height
32
+ self.wd = wd # image width
33
+
34
+ DIM = self.DIM
35
+ RES = self.RES
36
+
37
+ ### state attributes ###
38
+ self.tlist = []
39
+ self.counter = 0
40
+
41
+ # dummy image for visualization
42
+ self.image_ = torch.zeros(self.ht, self.wd, 3, dtype=torch.uint8, device="cpu")
43
+
44
+ self.tstamps_ = torch.zeros(self.N, dtype=torch.float64, device="cuda")
45
+ self.poses_ = torch.zeros(self.N, 7, dtype=torch.float32, device="cuda")
46
+ self.patches_ = torch.zeros(
47
+ self.N, self.M, 3, self.P, self.P, dtype=torch.float, device="cuda"
48
+ )
49
+ self.intrinsics_ = torch.zeros(self.N, 4, dtype=torch.float, device="cuda")
50
+
51
+ self.points_ = torch.zeros(self.N * self.M, 3, dtype=torch.float, device="cuda")
52
+ self.colors_ = torch.zeros(self.N, self.M, 3, dtype=torch.uint8, device="cuda")
53
+
54
+ self.index_ = torch.zeros(self.N, self.M, dtype=torch.long, device="cuda")
55
+ self.index_map_ = torch.zeros(self.N, dtype=torch.long, device="cuda")
56
+
57
+ ### network attributes ###
58
+ self.mem = 32
59
+
60
+ if self.cfg.MIXED_PRECISION:
61
+ self.kwargs = kwargs = {"device": "cuda", "dtype": torch.half}
62
+ else:
63
+ self.kwargs = kwargs = {"device": "cuda", "dtype": torch.float}
64
+
65
+ self.imap_ = torch.zeros(self.mem, self.M, DIM, **kwargs)
66
+ self.gmap_ = torch.zeros(self.mem, self.M, 128, self.P, self.P, **kwargs)
67
+
68
+ ht = ht // RES
69
+ wd = wd // RES
70
+
71
+ self.fmap1_ = torch.zeros(1, self.mem, 128, ht // 1, wd // 1, **kwargs)
72
+ self.fmap2_ = torch.zeros(1, self.mem, 128, ht // 4, wd // 4, **kwargs)
73
+
74
+ # feature pyramid
75
+ self.pyramid = (self.fmap1_, self.fmap2_)
76
+
77
+ self.net = torch.zeros(1, 0, DIM, **kwargs)
78
+ self.ii = torch.as_tensor([], dtype=torch.long, device="cuda")
79
+ self.jj = torch.as_tensor([], dtype=torch.long, device="cuda")
80
+ self.kk = torch.as_tensor([], dtype=torch.long, device="cuda")
81
+
82
+ # initialize poses to identity matrix
83
+ self.poses_[:, 6] = 1.0
84
+
85
+ # store relative poses for removed frames
86
+ self.delta = {}
87
+
88
+ def load_weights(self, network):
89
+ # load network from checkpoint file
90
+ if isinstance(network, str):
91
+ from collections import OrderedDict
92
+
93
+ state_dict = torch.load(network)
94
+ new_state_dict = OrderedDict()
95
+ for k, v in state_dict.items():
96
+ if "update.lmbda" not in k:
97
+ new_state_dict[k.replace("module.", "")] = v
98
+
99
+ self.network = VONet()
100
+ self.network.load_state_dict(new_state_dict)
101
+
102
+ else:
103
+ self.network = network
104
+
105
+ # steal network attributes
106
+ self.DIM = self.network.DIM
107
+ self.RES = self.network.RES
108
+ self.P = self.network.P
109
+
110
+ self.network.cuda()
111
+ self.network.eval()
112
+
113
+ # if self.cfg.MIXED_PRECISION:
114
+ # self.network.half()
115
+
116
+ @property
117
+ def poses(self):
118
+ return self.poses_.view(1, self.N, 7)
119
+
120
+ @property
121
+ def patches(self):
122
+ return self.patches_.view(1, self.N * self.M, 3, 3, 3)
123
+
124
+ @property
125
+ def intrinsics(self):
126
+ return self.intrinsics_.view(1, self.N, 4)
127
+
128
+ @property
129
+ def ix(self):
130
+ return self.index_.view(-1)
131
+
132
+ @property
133
+ def imap(self):
134
+ return self.imap_.view(1, self.mem * self.M, self.DIM)
135
+
136
+ @property
137
+ def gmap(self):
138
+ return self.gmap_.view(1, self.mem * self.M, 128, 3, 3)
139
+
140
+ def get_pose(self, t):
141
+ if t in self.traj:
142
+ return SE3(self.traj[t])
143
+
144
+ t0, dP = self.delta[t]
145
+ return dP * self.get_pose(t0)
146
+
147
+ def terminate(self):
148
+ """interpolate missing poses"""
149
+ print("Terminating...")
150
+ self.traj = {}
151
+ for i in range(self.n):
152
+ current_t: int = self.tstamps_[i].item()
153
+ self.traj[current_t] = self.poses_[i]
154
+
155
+ poses = [self.get_pose(t) for t in range(self.counter)]
156
+ poses = lietorch.stack(poses, dim=0)
157
+ poses = poses.inv().data.cpu().numpy()
158
+ tstamps = np.array(self.tlist, dtype=np.float64)
159
+
160
+ return poses, tstamps
161
+
162
+ def corr(self, coords, indicies=None):
163
+ """local correlation volume"""
164
+ ii, jj = indicies if indicies is not None else (self.kk, self.jj)
165
+ ii1 = ii % (self.M * self.mem)
166
+ jj1 = jj % (self.mem)
167
+ corr1 = altcorr.corr(self.gmap, self.pyramid[0], coords / 1, ii1, jj1, 3)
168
+ corr2 = altcorr.corr(self.gmap, self.pyramid[1], coords / 4, ii1, jj1, 3)
169
+ return torch.stack([corr1, corr2], -1).view(1, len(ii), -1)
170
+
171
+ def reproject(self, indicies=None):
172
+ """reproject patch k from i -> j"""
173
+ (ii, jj, kk) = indicies if indicies is not None else (self.ii, self.jj, self.kk)
174
+ coords = pops.transform(
175
+ SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk
176
+ )
177
+ return coords.permute(0, 1, 4, 2, 3).contiguous()
178
+
179
+ def append_factors(self, ii, jj):
180
+ self.jj = torch.cat([self.jj, jj])
181
+ self.kk = torch.cat([self.kk, ii])
182
+ self.ii = torch.cat([self.ii, self.ix[ii]])
183
+
184
+ net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
185
+ self.net = torch.cat([self.net, net], dim=1)
186
+
187
+ def remove_factors(self, m):
188
+ self.ii = self.ii[~m]
189
+ self.jj = self.jj[~m]
190
+ self.kk = self.kk[~m]
191
+ self.net = self.net[:, ~m]
192
+
193
+ def motion_probe(self):
194
+ """kinda hacky way to ensure enough motion for initialization"""
195
+ kk = torch.arange(self.m - self.M, self.m, device="cuda")
196
+ jj = self.n * torch.ones_like(kk)
197
+ ii = self.ix[kk]
198
+
199
+ net = torch.zeros(1, len(ii), self.DIM, **self.kwargs)
200
+ coords = self.reproject(indicies=(ii, jj, kk))
201
+
202
+ with autocast(enabled=self.cfg.MIXED_PRECISION):
203
+ corr = self.corr(coords, indicies=(kk, jj))
204
+ ctx = self.imap[:, kk % (self.M * self.mem)]
205
+ net, (delta, weight, _) = self.network.update(
206
+ net, ctx, corr, None, ii, jj, kk
207
+ )
208
+
209
+ return torch.quantile(delta.norm(dim=-1).float(), 0.5)
210
+
211
+ def motionmag(self, i, j):
212
+ k = (self.ii == i) & (self.jj == j)
213
+ ii = self.ii[k]
214
+ jj = self.jj[k]
215
+ kk = self.kk[k]
216
+
217
+ flow = pops.flow_mag(
218
+ SE3(self.poses), self.patches, self.intrinsics, ii, jj, kk, beta=0.5
219
+ )
220
+ return flow.mean().item()
221
+
222
+ def keyframe(self):
223
+ i = self.n - self.cfg.KEYFRAME_INDEX - 1
224
+ j = self.n - self.cfg.KEYFRAME_INDEX + 1
225
+ m = self.motionmag(i, j) + self.motionmag(j, i)
226
+
227
+ if m / 2 < self.cfg.KEYFRAME_THRESH:
228
+ k = self.n - self.cfg.KEYFRAME_INDEX
229
+ t0 = self.tstamps_[k - 1].item()
230
+ t1 = self.tstamps_[k].item()
231
+
232
+ dP = SE3(self.poses_[k]) * SE3(self.poses_[k - 1]).inv()
233
+ self.delta[t1] = (t0, dP)
234
+
235
+ to_remove = (self.ii == k) | (self.jj == k)
236
+ self.remove_factors(to_remove)
237
+
238
+ self.kk[self.ii > k] -= self.M
239
+ self.ii[self.ii > k] -= 1
240
+ self.jj[self.jj > k] -= 1
241
+
242
+ for i in range(k, self.n - 1):
243
+ self.tstamps_[i] = self.tstamps_[i + 1]
244
+ self.colors_[i] = self.colors_[i + 1]
245
+ self.poses_[i] = self.poses_[i + 1]
246
+ self.patches_[i] = self.patches_[i + 1]
247
+ self.intrinsics_[i] = self.intrinsics_[i + 1]
248
+
249
+ self.imap_[i % self.mem] = self.imap_[(i + 1) % self.mem]
250
+ self.gmap_[i % self.mem] = self.gmap_[(i + 1) % self.mem]
251
+ self.fmap1_[0, i % self.mem] = self.fmap1_[0, (i + 1) % self.mem]
252
+ self.fmap2_[0, i % self.mem] = self.fmap2_[0, (i + 1) % self.mem]
253
+
254
+ self.n -= 1
255
+ self.m -= self.M
256
+
257
+ to_remove = self.ix[self.kk] < self.n - self.cfg.REMOVAL_WINDOW
258
+ self.remove_factors(to_remove)
259
+
260
+ def update(self):
261
+ with Timer("other", enabled=self.enable_timing):
262
+ coords = self.reproject()
263
+
264
+ with autocast(enabled=True):
265
+ corr = self.corr(coords)
266
+ ctx = self.imap[:, self.kk % (self.M * self.mem)]
267
+ self.net, (delta, weight, _) = self.network.update(
268
+ self.net, ctx, corr, None, self.ii, self.jj, self.kk
269
+ )
270
+
271
+ lmbda = torch.as_tensor([1e-4], device="cuda")
272
+ weight = weight.float()
273
+ target = coords[..., self.P // 2, self.P // 2] + delta.float()
274
+
275
+ with Timer("BA", enabled=self.enable_timing):
276
+ t0 = self.n - self.cfg.OPTIMIZATION_WINDOW if self.is_initialized else 1
277
+ t0 = max(t0, 1)
278
+
279
+ try:
280
+ fastba.BA(
281
+ self.poses,
282
+ self.patches,
283
+ self.intrinsics,
284
+ target,
285
+ weight,
286
+ lmbda,
287
+ self.ii,
288
+ self.jj,
289
+ self.kk,
290
+ t0,
291
+ self.n,
292
+ 2,
293
+ )
294
+ except:
295
+ print("Warning BA failed...")
296
+
297
+ points = pops.point_cloud(
298
+ SE3(self.poses),
299
+ self.patches[:, : self.m],
300
+ self.intrinsics,
301
+ self.ix[: self.m],
302
+ )
303
+ points = (points[..., 1, 1, :3] / points[..., 1, 1, 3:]).reshape(-1, 3)
304
+ self.points_[: len(points)] = points[:]
305
+
306
+ def __edges_all(self):
307
+ return flatmeshgrid(
308
+ torch.arange(0, self.m, device="cuda"),
309
+ torch.arange(0, self.n, device="cuda"),
310
+ indexing="ij",
311
+ )
312
+
313
+ def __edges_forw(self):
314
+ r = self.cfg.PATCH_LIFETIME
315
+ t0 = self.M * max((self.n - r), 0)
316
+ t1 = self.M * max((self.n - 1), 0)
317
+ return flatmeshgrid(
318
+ torch.arange(t0, t1, device="cuda"),
319
+ torch.arange(self.n - 1, self.n, device="cuda"),
320
+ indexing="ij",
321
+ )
322
+
323
+ def __edges_back(self):
324
+ r = self.cfg.PATCH_LIFETIME
325
+ t0 = self.M * max((self.n - 1), 0)
326
+ t1 = self.M * max((self.n - 0), 0)
327
+ return flatmeshgrid(
328
+ torch.arange(t0, t1, device="cuda"),
329
+ torch.arange(max(self.n - r, 0), self.n, device="cuda"),
330
+ indexing="ij",
331
+ )
332
+
333
+ def __call__(self, tstamp: int, image, intrinsics) -> None:
334
+ """track new frame"""
335
+
336
+ if (self.n + 1) >= self.N:
337
+ raise Exception(
338
+ f'The buffer size is too small. You can increase it using "--buffer {self.N*2}"'
339
+ )
340
+
341
+ image = 2 * (image[None, None] / 255.0) - 0.5
342
+
343
+ with autocast(enabled=self.cfg.MIXED_PRECISION):
344
+ fmap, gmap, imap, patches, _, clr = self.network.patchify(
345
+ image,
346
+ patches_per_image=self.cfg.PATCHES_PER_FRAME,
347
+ gradient_bias=self.cfg.GRADIENT_BIAS,
348
+ return_color=True,
349
+ )
350
+
351
+ ### update state attributes ###
352
+ self.tlist.append(tstamp)
353
+ self.tstamps_[self.n] = self.counter
354
+ self.intrinsics_[self.n] = intrinsics / self.RES
355
+
356
+ # color info for visualization
357
+ clr = (clr[0, :, [2, 1, 0]] + 0.5) * (255.0 / 2)
358
+ self.colors_[self.n] = clr.to(torch.uint8)
359
+
360
+ self.index_[self.n + 1] = self.n + 1
361
+ self.index_map_[self.n + 1] = self.m + self.M
362
+
363
+ if self.n > 1:
364
+ if self.cfg.MOTION_MODEL == "DAMPED_LINEAR":
365
+ P1 = SE3(self.poses_[self.n - 1])
366
+ P2 = SE3(self.poses_[self.n - 2])
367
+
368
+ xi = self.cfg.MOTION_DAMPING * (P1 * P2.inv()).log()
369
+ tvec_qvec = (SE3.exp(xi) * P1).data
370
+ self.poses_[self.n] = tvec_qvec
371
+ else:
372
+ tvec_qvec = self.poses[self.n - 1]
373
+ self.poses_[self.n] = tvec_qvec
374
+
375
+ # TODO better depth initialization
376
+ patches[:, :, 2] = torch.rand_like(patches[:, :, 2, 0, 0, None, None])
377
+ if self.is_initialized:
378
+ s = torch.median(self.patches_[self.n - 3 : self.n, :, 2])
379
+ patches[:, :, 2] = s
380
+
381
+ self.patches_[self.n] = patches
382
+
383
+ ### update network attributes ###
384
+ self.imap_[self.n % self.mem] = imap.squeeze()
385
+ self.gmap_[self.n % self.mem] = gmap.squeeze()
386
+ self.fmap1_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 1, 1)
387
+ self.fmap2_[:, self.n % self.mem] = F.avg_pool2d(fmap[0], 4, 4)
388
+
389
+ self.counter += 1
390
+ if self.n > 0 and not self.is_initialized:
391
+ if self.motion_probe() < 2.0:
392
+ self.delta[self.counter - 1] = (self.counter - 2, Id[0])
393
+ return
394
+
395
+ self.n += 1
396
+ self.m += self.M
397
+
398
+ # relative pose
399
+ self.append_factors(*self.__edges_forw())
400
+ self.append_factors(*self.__edges_back())
401
+
402
+ if self.n == 8 and not self.is_initialized:
403
+ self.is_initialized = True
404
+
405
+ for itr in range(12):
406
+ self.update()
407
+
408
+ elif self.is_initialized:
409
+ self.update()
410
+ self.keyframe()
mini_dpvo/extractor.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class ResidualBlock(nn.Module):
7
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8
+ super(ResidualBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12
+ self.relu = nn.ReLU(inplace=True)
13
+
14
+ num_groups = planes // 8
15
+
16
+ if norm_fn == 'group':
17
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19
+ if not stride == 1:
20
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21
+
22
+ elif norm_fn == 'batch':
23
+ self.norm1 = nn.BatchNorm2d(planes)
24
+ self.norm2 = nn.BatchNorm2d(planes)
25
+ if not stride == 1:
26
+ self.norm3 = nn.BatchNorm2d(planes)
27
+
28
+ elif norm_fn == 'instance':
29
+ self.norm1 = nn.InstanceNorm2d(planes)
30
+ self.norm2 = nn.InstanceNorm2d(planes)
31
+ if not stride == 1:
32
+ self.norm3 = nn.InstanceNorm2d(planes)
33
+
34
+ elif norm_fn == 'none':
35
+ self.norm1 = nn.Sequential()
36
+ self.norm2 = nn.Sequential()
37
+ if not stride == 1:
38
+ self.norm3 = nn.Sequential()
39
+
40
+ if stride == 1:
41
+ self.downsample = None
42
+
43
+ else:
44
+ self.downsample = nn.Sequential(
45
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46
+
47
+ def forward(self, x):
48
+ y = x
49
+ y = self.relu(self.norm1(self.conv1(y)))
50
+ y = self.relu(self.norm2(self.conv2(y)))
51
+
52
+ if self.downsample is not None:
53
+ x = self.downsample(x)
54
+
55
+ return self.relu(x+y)
56
+
57
+
58
+ class BottleneckBlock(nn.Module):
59
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
60
+ super(BottleneckBlock, self).__init__()
61
+
62
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
63
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
64
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
65
+ self.relu = nn.ReLU(inplace=True)
66
+
67
+ num_groups = planes // 8
68
+
69
+ if norm_fn == 'group':
70
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
71
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
72
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
73
+ if not stride == 1:
74
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75
+
76
+ elif norm_fn == 'batch':
77
+ self.norm1 = nn.BatchNorm2d(planes//4)
78
+ self.norm2 = nn.BatchNorm2d(planes//4)
79
+ self.norm3 = nn.BatchNorm2d(planes)
80
+ if not stride == 1:
81
+ self.norm4 = nn.BatchNorm2d(planes)
82
+
83
+ elif norm_fn == 'instance':
84
+ self.norm1 = nn.InstanceNorm2d(planes//4)
85
+ self.norm2 = nn.InstanceNorm2d(planes//4)
86
+ self.norm3 = nn.InstanceNorm2d(planes)
87
+ if not stride == 1:
88
+ self.norm4 = nn.InstanceNorm2d(planes)
89
+
90
+ elif norm_fn == 'none':
91
+ self.norm1 = nn.Sequential()
92
+ self.norm2 = nn.Sequential()
93
+ self.norm3 = nn.Sequential()
94
+ if not stride == 1:
95
+ self.norm4 = nn.Sequential()
96
+
97
+ if stride == 1:
98
+ self.downsample = None
99
+
100
+ else:
101
+ self.downsample = nn.Sequential(
102
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
103
+
104
+ def forward(self, x):
105
+ y = x
106
+ y = self.relu(self.norm1(self.conv1(y)))
107
+ y = self.relu(self.norm2(self.conv2(y)))
108
+ y = self.relu(self.norm3(self.conv3(y)))
109
+
110
+ if self.downsample is not None:
111
+ x = self.downsample(x)
112
+
113
+ return self.relu(x+y)
114
+
115
+ DIM=32
116
+
117
+ class BasicEncoder(nn.Module):
118
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
119
+ super(BasicEncoder, self).__init__()
120
+ self.norm_fn = norm_fn
121
+ self.multidim = multidim
122
+
123
+ if self.norm_fn == 'group':
124
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
125
+
126
+ elif self.norm_fn == 'batch':
127
+ self.norm1 = nn.BatchNorm2d(DIM)
128
+
129
+ elif self.norm_fn == 'instance':
130
+ self.norm1 = nn.InstanceNorm2d(DIM)
131
+
132
+ elif self.norm_fn == 'none':
133
+ self.norm1 = nn.Sequential()
134
+
135
+ self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
136
+ self.relu1 = nn.ReLU(inplace=True)
137
+
138
+ self.in_planes = DIM
139
+ self.layer1 = self._make_layer(DIM, stride=1)
140
+ self.layer2 = self._make_layer(2*DIM, stride=2)
141
+ self.layer3 = self._make_layer(4*DIM, stride=2)
142
+
143
+ # output convolution
144
+ self.conv2 = nn.Conv2d(4*DIM, output_dim, kernel_size=1)
145
+
146
+ if self.multidim:
147
+ self.layer4 = self._make_layer(256, stride=2)
148
+ self.layer5 = self._make_layer(512, stride=2)
149
+
150
+ self.in_planes = 256
151
+ self.layer6 = self._make_layer(256, stride=1)
152
+
153
+ self.in_planes = 128
154
+ self.layer7 = self._make_layer(128, stride=1)
155
+
156
+ self.up1 = nn.Conv2d(512, 256, 1)
157
+ self.up2 = nn.Conv2d(256, 128, 1)
158
+ self.conv3 = nn.Conv2d(128, output_dim, kernel_size=1)
159
+
160
+ if dropout > 0:
161
+ self.dropout = nn.Dropout2d(p=dropout)
162
+ else:
163
+ self.dropout = None
164
+
165
+ for m in self.modules():
166
+ if isinstance(m, nn.Conv2d):
167
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
168
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
169
+ if m.weight is not None:
170
+ nn.init.constant_(m.weight, 1)
171
+ if m.bias is not None:
172
+ nn.init.constant_(m.bias, 0)
173
+
174
+ def _make_layer(self, dim, stride=1):
175
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
176
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
177
+ layers = (layer1, layer2)
178
+
179
+ self.in_planes = dim
180
+ return nn.Sequential(*layers)
181
+
182
+ def forward(self, x):
183
+ b, n, c1, h1, w1 = x.shape
184
+ x = x.view(b*n, c1, h1, w1)
185
+
186
+ x = self.conv1(x)
187
+ x = self.norm1(x)
188
+ x = self.relu1(x)
189
+
190
+ x = self.layer1(x)
191
+ x = self.layer2(x)
192
+ x = self.layer3(x)
193
+
194
+ x = self.conv2(x)
195
+
196
+ _, c2, h2, w2 = x.shape
197
+ return x.view(b, n, c2, h2, w2)
198
+
199
+
200
+ class BasicEncoder4(nn.Module):
201
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0, multidim=False):
202
+ super(BasicEncoder4, self).__init__()
203
+ self.norm_fn = norm_fn
204
+ self.multidim = multidim
205
+
206
+ if self.norm_fn == 'group':
207
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=DIM)
208
+
209
+ elif self.norm_fn == 'batch':
210
+ self.norm1 = nn.BatchNorm2d(DIM)
211
+
212
+ elif self.norm_fn == 'instance':
213
+ self.norm1 = nn.InstanceNorm2d(DIM)
214
+
215
+ elif self.norm_fn == 'none':
216
+ self.norm1 = nn.Sequential()
217
+
218
+ self.conv1 = nn.Conv2d(3, DIM, kernel_size=7, stride=2, padding=3)
219
+ self.relu1 = nn.ReLU(inplace=True)
220
+
221
+ self.in_planes = DIM
222
+ self.layer1 = self._make_layer(DIM, stride=1)
223
+ self.layer2 = self._make_layer(2*DIM, stride=2)
224
+
225
+ # output convolution
226
+ self.conv2 = nn.Conv2d(2*DIM, output_dim, kernel_size=1)
227
+
228
+ if dropout > 0:
229
+ self.dropout = nn.Dropout2d(p=dropout)
230
+ else:
231
+ self.dropout = None
232
+
233
+ for m in self.modules():
234
+ if isinstance(m, nn.Conv2d):
235
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
236
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
237
+ if m.weight is not None:
238
+ nn.init.constant_(m.weight, 1)
239
+ if m.bias is not None:
240
+ nn.init.constant_(m.bias, 0)
241
+
242
+ def _make_layer(self, dim, stride=1):
243
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
244
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
245
+ layers = (layer1, layer2)
246
+
247
+ self.in_planes = dim
248
+ return nn.Sequential(*layers)
249
+
250
+ def forward(self, x):
251
+ b, n, c1, h1, w1 = x.shape
252
+ x = x.view(b*n, c1, h1, w1)
253
+
254
+ x = self.conv1(x)
255
+ x = self.norm1(x)
256
+ x = self.relu1(x)
257
+
258
+ x = self.layer1(x)
259
+ x = self.layer2(x)
260
+
261
+ x = self.conv2(x)
262
+
263
+ _, c2, h2, w2 = x.shape
264
+ return x.view(b, n, c2, h2, w2)
mini_dpvo/fastba/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .ba import BA, neighbors, reproject
mini_dpvo/fastba/ba.cpp ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #include <unordered_map>
4
+ #include <algorithm>
5
+ #include <iostream>
6
+
7
+
8
+ std::vector<torch::Tensor> cuda_ba(
9
+ torch::Tensor poses,
10
+ torch::Tensor patches,
11
+ torch::Tensor intrinsics,
12
+ torch::Tensor target,
13
+ torch::Tensor weight,
14
+ torch::Tensor lmbda,
15
+ torch::Tensor ii,
16
+ torch::Tensor jj,
17
+ torch::Tensor kk,
18
+ int t0, int t1, int iterations);
19
+
20
+
21
+ torch::Tensor cuda_reproject(
22
+ torch::Tensor poses,
23
+ torch::Tensor patches,
24
+ torch::Tensor intrinsics,
25
+ torch::Tensor ii,
26
+ torch::Tensor jj,
27
+ torch::Tensor kk);
28
+
29
+ std::vector<torch::Tensor> ba(
30
+ torch::Tensor poses,
31
+ torch::Tensor patches,
32
+ torch::Tensor intrinsics,
33
+ torch::Tensor target,
34
+ torch::Tensor weight,
35
+ torch::Tensor lmbda,
36
+ torch::Tensor ii,
37
+ torch::Tensor jj,
38
+ torch::Tensor kk,
39
+ int t0, int t1, int iterations) {
40
+ return cuda_ba(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations);
41
+ }
42
+
43
+
44
+ torch::Tensor reproject(
45
+ torch::Tensor poses,
46
+ torch::Tensor patches,
47
+ torch::Tensor intrinsics,
48
+ torch::Tensor ii,
49
+ torch::Tensor jj,
50
+ torch::Tensor kk) {
51
+ return cuda_reproject(poses, patches, intrinsics, ii, jj, kk);
52
+ }
53
+
54
+ // std::vector<torch::Tensor> neighbors(torch::Tensor ii, torch::Tensor jj)
55
+ // {
56
+ // ii = ii.to(torch::kCPU);
57
+ // jj = jj.to(torch::kCPU);
58
+ // auto ii_data = ii.accessor<long,1>();
59
+ // auto jj_data = jj.accessor<long,1>();
60
+
61
+ // std::unordered_map<long, std::vector<long>> graph;
62
+ // std::unordered_map<long, std::vector<long>> index;
63
+ // for (int i=0; i < ii.size(0); i++) {
64
+ // const long ix = ii_data[i];
65
+ // const long jx = jj_data[i];
66
+ // if (graph.find(ix) == graph.end()) {
67
+ // graph[ix] = std::vector<long>();
68
+ // index[ix] = std::vector<long>();
69
+ // }
70
+ // graph[ix].push_back(jx);
71
+ // index[ix].push_back( i);
72
+ // }
73
+
74
+ // auto opts = torch::TensorOptions().dtype(torch::kInt64);
75
+ // torch::Tensor ix = torch::empty({ii.size(0)}, opts);
76
+ // torch::Tensor jx = torch::empty({jj.size(0)}, opts);
77
+
78
+ // auto ix_data = ix.accessor<long,1>();
79
+ // auto jx_data = jx.accessor<long,1>();
80
+
81
+ // for (std::pair<long, std::vector<long>> element : graph) {
82
+ // std::vector<long>& v = graph[element.first];
83
+ // std::vector<long>& idx = index[element.first];
84
+
85
+ // std::stable_sort(idx.begin(), idx.end(),
86
+ // [&v](size_t i, size_t j) {return v[i] < v[j];});
87
+
88
+ // ix_data[idx.front()] = -1;
89
+ // jx_data[idx.back()] = -1;
90
+
91
+ // for (int i=0; i < idx.size(); i++) {
92
+ // ix_data[idx[i]] = (i > 0) ? idx[i-1] : -1;
93
+ // jx_data[idx[i]] = (i < idx.size() - 1) ? idx[i+1] : -1;
94
+ // }
95
+ // }
96
+
97
+ // ix = ix.to(torch::kCUDA);
98
+ // jx = jx.to(torch::kCUDA);
99
+
100
+ // return {ix, jx};
101
+ // }
102
+
103
+
104
+ std::vector<torch::Tensor> neighbors(torch::Tensor ii, torch::Tensor jj)
105
+ {
106
+
107
+ auto tup = torch::_unique(ii, true, true);
108
+ torch::Tensor uniq = std::get<0>(tup).to(torch::kCPU);
109
+ torch::Tensor perm = std::get<1>(tup).to(torch::kCPU);
110
+
111
+ jj = jj.to(torch::kCPU);
112
+ auto jj_accessor = jj.accessor<long,1>();
113
+
114
+ auto perm_accessor = perm.accessor<long,1>();
115
+ std::vector<std::vector<long>> index(uniq.size(0));
116
+ for (int i=0; i < ii.size(0); i++) {
117
+ index[perm_accessor[i]].push_back(i);
118
+ }
119
+
120
+ auto opts = torch::TensorOptions().dtype(torch::kInt64);
121
+ torch::Tensor ix = torch::empty({ii.size(0)}, opts);
122
+ torch::Tensor jx = torch::empty({ii.size(0)}, opts);
123
+
124
+ auto ix_accessor = ix.accessor<long,1>();
125
+ auto jx_accessor = jx.accessor<long,1>();
126
+
127
+ for (int i=0; i<uniq.size(0); i++) {
128
+ std::vector<long>& idx = index[i];
129
+ std::stable_sort(idx.begin(), idx.end(),
130
+ [&jj_accessor](size_t i, size_t j) {return jj_accessor[i] < jj_accessor[j];});
131
+
132
+ for (int i=0; i < idx.size(); i++) {
133
+ ix_accessor[idx[i]] = (i > 0) ? idx[i-1] : -1;
134
+ jx_accessor[idx[i]] = (i < idx.size() - 1) ? idx[i+1] : -1;
135
+ }
136
+ }
137
+
138
+ // for (int i=0; i<ii.size(0); i++) {
139
+ // std::cout << jj_accessor[i] << " ";
140
+ // if (ix_accessor[i] >= 0) std::cout << jj_accessor[ix_accessor[i]] << " ";
141
+ // if (jx_accessor[i] >= 0) std::cout << jj_accessor[jx_accessor[i]] << " ";
142
+ // std::cout << std::endl;
143
+ // }
144
+
145
+ ix = ix.to(torch::kCUDA);
146
+ jx = jx.to(torch::kCUDA);
147
+
148
+ return {ix, jx};
149
+ }
150
+
151
+
152
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
153
+ m.def("forward", &ba, "BA forward operator");
154
+ m.def("neighbors", &neighbors, "temporal neighboor indicies");
155
+ m.def("reproject", &reproject, "temporal neighboor indicies");
156
+
157
+ }
mini_dpvo/fastba/ba.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cuda_ba
3
+
4
+ neighbors = cuda_ba.neighbors
5
+ reproject = cuda_ba.reproject
6
+
7
+ def BA(poses, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations=2):
8
+ return cuda_ba.forward(poses.data, patches, intrinsics, target, weight, lmbda, ii, jj, kk, t0, t1, iterations)
mini_dpvo/fastba/ba_cuda.cu ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #include <iostream>
4
+
5
+ #include <ATen/ATen.h>
6
+ #include <ATen/NativeFunctions.h>
7
+ #include <ATen/Parallel.h>
8
+
9
+
10
+ #define GPU_1D_KERNEL_LOOP(i, n) \
11
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i<n; i += blockDim.x * gridDim.x)
12
+
13
+
14
+ #define NUM_THREADS 256
15
+ #define NUM_BLOCKS(batch_size) ((batch_size + NUM_THREADS - 1) / NUM_THREADS)
16
+
17
+
18
+ __device__ void
19
+ actSO3(const float *q, const float *X, float *Y) {
20
+ float uv[3];
21
+ uv[0] = 2.0 * (q[1]*X[2] - q[2]*X[1]);
22
+ uv[1] = 2.0 * (q[2]*X[0] - q[0]*X[2]);
23
+ uv[2] = 2.0 * (q[0]*X[1] - q[1]*X[0]);
24
+
25
+ Y[0] = X[0] + q[3]*uv[0] + (q[1]*uv[2] - q[2]*uv[1]);
26
+ Y[1] = X[1] + q[3]*uv[1] + (q[2]*uv[0] - q[0]*uv[2]);
27
+ Y[2] = X[2] + q[3]*uv[2] + (q[0]*uv[1] - q[1]*uv[0]);
28
+ }
29
+
30
+ __device__ void
31
+ actSE3(const float *t, const float *q, const float *X, float *Y) {
32
+ actSO3(q, X, Y);
33
+ Y[3] = X[3];
34
+ Y[0] += X[3] * t[0];
35
+ Y[1] += X[3] * t[1];
36
+ Y[2] += X[3] * t[2];
37
+ }
38
+
39
+ __device__ void
40
+ adjSE3(const float *t, const float *q, const float *X, float *Y) {
41
+ float qinv[4] = {-q[0], -q[1], -q[2], q[3]};
42
+ actSO3(qinv, &X[0], &Y[0]);
43
+ actSO3(qinv, &X[3], &Y[3]);
44
+
45
+ float u[3], v[3];
46
+ u[0] = t[2]*X[1] - t[1]*X[2];
47
+ u[1] = t[0]*X[2] - t[2]*X[0];
48
+ u[2] = t[1]*X[0] - t[0]*X[1];
49
+
50
+ actSO3(qinv, u, v);
51
+ Y[3] += v[0];
52
+ Y[4] += v[1];
53
+ Y[5] += v[2];
54
+ }
55
+
56
+ __device__ void
57
+ relSE3(const float *ti, const float *qi, const float *tj, const float *qj, float *tij, float *qij) {
58
+ qij[0] = -qj[3] * qi[0] + qj[0] * qi[3] - qj[1] * qi[2] + qj[2] * qi[1],
59
+ qij[1] = -qj[3] * qi[1] + qj[1] * qi[3] - qj[2] * qi[0] + qj[0] * qi[2],
60
+ qij[2] = -qj[3] * qi[2] + qj[2] * qi[3] - qj[0] * qi[1] + qj[1] * qi[0],
61
+ qij[3] = qj[3] * qi[3] + qj[0] * qi[0] + qj[1] * qi[1] + qj[2] * qi[2],
62
+
63
+ actSO3(qij, ti, tij);
64
+ tij[0] = tj[0] - tij[0];
65
+ tij[1] = tj[1] - tij[1];
66
+ tij[2] = tj[2] - tij[2];
67
+ }
68
+
69
+
70
+ __device__ void
71
+ expSO3(const float *phi, float* q) {
72
+ // SO3 exponential map
73
+ float theta_sq = phi[0]*phi[0] + phi[1]*phi[1] + phi[2]*phi[2];
74
+ float theta_p4 = theta_sq * theta_sq;
75
+
76
+ float theta = sqrtf(theta_sq);
77
+ float imag, real;
78
+
79
+ if (theta_sq < 1e-8) {
80
+ imag = 0.5 - (1.0/48.0)*theta_sq + (1.0/3840.0)*theta_p4;
81
+ real = 1.0 - (1.0/ 8.0)*theta_sq + (1.0/ 384.0)*theta_p4;
82
+ } else {
83
+ imag = sinf(0.5 * theta) / theta;
84
+ real = cosf(0.5 * theta);
85
+ }
86
+
87
+ q[0] = imag * phi[0];
88
+ q[1] = imag * phi[1];
89
+ q[2] = imag * phi[2];
90
+ q[3] = real;
91
+
92
+ }
93
+
94
+ __device__ void
95
+ crossInplace(const float* a, float *b) {
96
+ float x[3] = {
97
+ a[1]*b[2] - a[2]*b[1],
98
+ a[2]*b[0] - a[0]*b[2],
99
+ a[0]*b[1] - a[1]*b[0],
100
+ };
101
+
102
+ b[0] = x[0];
103
+ b[1] = x[1];
104
+ b[2] = x[2];
105
+ }
106
+
107
+ __device__ void
108
+ expSE3(const float *xi, float* t, float* q) {
109
+ // SE3 exponential map
110
+
111
+ expSO3(xi + 3, q);
112
+ float tau[3] = {xi[0], xi[1], xi[2]};
113
+ float phi[3] = {xi[3], xi[4], xi[5]};
114
+
115
+ float theta_sq = phi[0]*phi[0] + phi[1]*phi[1] + phi[2]*phi[2];
116
+ float theta = sqrtf(theta_sq);
117
+
118
+ t[0] = tau[0];
119
+ t[1] = tau[1];
120
+ t[2] = tau[2];
121
+
122
+ if (theta > 1e-4) {
123
+ float a = (1 - cosf(theta)) / theta_sq;
124
+ crossInplace(phi, tau);
125
+ t[0] += a * tau[0];
126
+ t[1] += a * tau[1];
127
+ t[2] += a * tau[2];
128
+
129
+ float b = (theta - sinf(theta)) / (theta * theta_sq);
130
+ crossInplace(phi, tau);
131
+ t[0] += b * tau[0];
132
+ t[1] += b * tau[1];
133
+ t[2] += b * tau[2];
134
+ }
135
+ }
136
+
137
+
138
+ __device__ void
139
+ retrSE3(const float *xi, const float* t, const float* q, float* t1, float* q1) {
140
+ // retraction on SE3 manifold
141
+
142
+ float dt[3] = {0, 0, 0};
143
+ float dq[4] = {0, 0, 0, 1};
144
+
145
+ expSE3(xi, dt, dq);
146
+
147
+ q1[0] = dq[3] * q[0] + dq[0] * q[3] + dq[1] * q[2] - dq[2] * q[1];
148
+ q1[1] = dq[3] * q[1] + dq[1] * q[3] + dq[2] * q[0] - dq[0] * q[2];
149
+ q1[2] = dq[3] * q[2] + dq[2] * q[3] + dq[0] * q[1] - dq[1] * q[0];
150
+ q1[3] = dq[3] * q[3] - dq[0] * q[0] - dq[1] * q[1] - dq[2] * q[2];
151
+
152
+ actSO3(dq, t, t1);
153
+ t1[0] += dt[0];
154
+ t1[1] += dt[1];
155
+ t1[2] += dt[2];
156
+ }
157
+
158
+
159
+
160
+ __global__ void pose_retr_kernel(const int t0, const int t1,
161
+ torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
162
+ torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> update)
163
+ {
164
+ GPU_1D_KERNEL_LOOP(i, t1 - t0) {
165
+ const float t = t0 + i;
166
+ float t1[3], t0[3] = { poses[t][0], poses[t][1], poses[t][2] };
167
+ float q1[4], q0[4] = { poses[t][3], poses[t][4], poses[t][5], poses[t][6] };
168
+
169
+ float xi[6] = {
170
+ update[i][0],
171
+ update[i][1],
172
+ update[i][2],
173
+ update[i][3],
174
+ update[i][4],
175
+ update[i][5],
176
+ };
177
+
178
+ retrSE3(xi, t0, q0, t1, q1);
179
+
180
+ poses[t][0] = t1[0];
181
+ poses[t][1] = t1[1];
182
+ poses[t][2] = t1[2];
183
+ poses[t][3] = q1[0];
184
+ poses[t][4] = q1[1];
185
+ poses[t][5] = q1[2];
186
+ poses[t][6] = q1[3];
187
+ }
188
+ }
189
+
190
+
191
+ __global__ void patch_retr_kernel(
192
+ torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> index,
193
+ torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
194
+ torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> update)
195
+ {
196
+ GPU_1D_KERNEL_LOOP(n, index.size(0)) {
197
+ const int p = patches.size(2);
198
+ const int ix = index[n];
199
+
200
+ float d = patches[ix][2][0][0];
201
+ d = d + update[n];
202
+ d = (d > 20) ? 1.0 : d;
203
+ d = max(d, 1e-4);
204
+
205
+ for (int i=0; i<p; i++) {
206
+ for (int j=0; j<p; j++) {
207
+ patches[ix][2][i][j] = d;
208
+ }
209
+ }
210
+ }
211
+ }
212
+
213
+
214
+ __global__ void reprojection_residuals_and_hessian(
215
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
216
+ const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
217
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
218
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> target,
219
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> weight,
220
+ const torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> lmbda,
221
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ii,
222
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> jj,
223
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> kk,
224
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ku,
225
+ torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> B,
226
+ torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> E,
227
+ torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> C,
228
+ torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> v,
229
+ torch::PackedTensorAccessor32<float,1,torch::RestrictPtrTraits> u, const int t0)
230
+ {
231
+
232
+ __shared__ float fx, fy, cx, cy;
233
+ if (threadIdx.x == 0) {
234
+ fx = intrinsics[0][0];
235
+ fy = intrinsics[0][1];
236
+ cx = intrinsics[0][2];
237
+ cy = intrinsics[0][3];
238
+ }
239
+
240
+ __syncthreads();
241
+
242
+ GPU_1D_KERNEL_LOOP(n, ii.size(0)) {
243
+ int k = ku[n];
244
+ int ix = ii[n];
245
+ int jx = jj[n];
246
+ int kx = kk[n];
247
+
248
+ float ti[3] = { poses[ix][0], poses[ix][1], poses[ix][2] };
249
+ float tj[3] = { poses[jx][0], poses[jx][1], poses[jx][2] };
250
+ float qi[4] = { poses[ix][3], poses[ix][4], poses[ix][5], poses[ix][6] };
251
+ float qj[4] = { poses[jx][3], poses[jx][4], poses[jx][5], poses[jx][6] };
252
+
253
+ float Xi[4], Xj[4];
254
+ Xi[0] = (patches[kx][0][1][1] - cx) / fx;
255
+ Xi[1] = (patches[kx][1][1][1] - cy) / fy;
256
+ Xi[2] = 1.0;
257
+ Xi[3] = patches[kx][2][1][1];
258
+
259
+ float tij[3], qij[4];
260
+ relSE3(ti, qi, tj, qj, tij, qij);
261
+ actSE3(tij, qij, Xi, Xj);
262
+
263
+ const float X = Xj[0];
264
+ const float Y = Xj[1];
265
+ const float Z = Xj[2];
266
+ const float W = Xj[3];
267
+
268
+ const float d = (Z >= 0.2) ? 1.0 / Z : 0.0;
269
+ const float d2 = d * d;
270
+
271
+ const float x1 = fx * (X / Z) + cx;
272
+ const float y1 = fy * (Y / Z) + cy;
273
+
274
+ const float rx = target[n][0] - x1;
275
+ const float ry = target[n][1] - y1;
276
+
277
+ const bool in_bounds = (sqrt(rx*rx + ry*ry) < 128) && (Z > 0.2) &&
278
+ (x1 > -64) && (y1 > -64) && (x1 < 2*cx + 64) && (y1 < 2*cy + 64);
279
+
280
+ const float mask = in_bounds ? 1.0 : 0.0;
281
+
282
+ ix = ix - t0;
283
+ jx = jx - t0;
284
+
285
+ {
286
+ const float r = target[n][0] - x1;
287
+ const float w = mask * weight[n][0];
288
+
289
+ float Jz = fx * (tij[0] * d - tij[2] * (X * d2));
290
+ float Ji[6], Jj[6] = {fx*W*d, 0, fx*-X*W*d2, fx*-X*Y*d2, fx*(1+X*X*d2), fx*-Y*d};
291
+
292
+ adjSE3(tij, qij, Jj, Ji);
293
+
294
+ for (int i=0; i<6; i++) {
295
+ for (int j=0; j<6; j++) {
296
+ if (ix >= 0)
297
+ atomicAdd(&B[6*ix+i][6*ix+j], w * Ji[i] * Ji[j]);
298
+ if (jx >= 0)
299
+ atomicAdd(&B[6*jx+i][6*jx+j], w * Jj[i] * Jj[j]);
300
+ if (ix >= 0 && jx >= 0) {
301
+ atomicAdd(&B[6*ix+i][6*jx+j], -w * Ji[i] * Jj[j]);
302
+ atomicAdd(&B[6*jx+i][6*ix+j], -w * Jj[i] * Ji[j]);
303
+ }
304
+ }
305
+ }
306
+
307
+ for (int i=0; i<6; i++) {
308
+ if (ix >= 0)
309
+ atomicAdd(&E[6*ix+i][k], -w * Jz * Ji[i]);
310
+ if (jx >= 0)
311
+ atomicAdd(&E[6*jx+i][k], w * Jz * Jj[i]);
312
+ }
313
+
314
+ for (int i=0; i<6; i++) {
315
+ if (ix >= 0)
316
+ atomicAdd(&v[6*ix+i], -w * r * Ji[i]);
317
+ if (jx >= 0)
318
+ atomicAdd(&v[6*jx+i], w * r * Jj[i]);
319
+ }
320
+
321
+ atomicAdd(&C[k], w * Jz * Jz);
322
+ atomicAdd(&u[k], w * r * Jz);
323
+ }
324
+
325
+ {
326
+ const float r = target[n][1] - y1;
327
+ const float w = mask * weight[n][1];
328
+
329
+ float Jz = fy * (tij[1] * d - tij[2] * (Y * d2));
330
+ float Ji[6], Jj[6] = {0, fy*W*d, fy*-Y*W*d2, fy*(-1-Y*Y*d2), fy*(X*Y*d2), fy*X*d};
331
+
332
+ adjSE3(tij, qij, Jj, Ji);
333
+
334
+ for (int i=0; i<6; i++) {
335
+ for (int j=0; j<6; j++) {
336
+ if (ix >= 0)
337
+ atomicAdd(&B[6*ix+i][6*ix+j], w * Ji[i] * Ji[j]);
338
+ if (jx >= 0)
339
+ atomicAdd(&B[6*jx+i][6*jx+j], w * Jj[i] * Jj[j]);
340
+ if (ix >= 0 && jx >= 0) {
341
+ atomicAdd(&B[6*ix+i][6*jx+j], -w * Ji[i] * Jj[j]);
342
+ atomicAdd(&B[6*jx+i][6*ix+j], -w * Jj[i] * Ji[j]);
343
+ }
344
+ }
345
+ }
346
+
347
+ for (int i=0; i<6; i++) {
348
+ if (ix >= 0)
349
+ atomicAdd(&E[6*ix+i][k], -w * Jz * Ji[i]);
350
+ if (jx >= 0)
351
+ atomicAdd(&E[6*jx+i][k], w * Jz * Jj[i]);
352
+ }
353
+
354
+ for (int i=0; i<6; i++) {
355
+ if (ix >= 0)
356
+ atomicAdd(&v[6*ix+i], -w * r * Ji[i]);
357
+ if (jx >= 0)
358
+ atomicAdd(&v[6*jx+i], w * r * Jj[i]);
359
+ }
360
+
361
+ atomicAdd(&C[k], w * Jz * Jz);
362
+ atomicAdd(&u[k], w * r * Jz);
363
+ }
364
+ }
365
+ }
366
+
367
+
368
+ __global__ void reproject(
369
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> poses,
370
+ const torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> patches,
371
+ const torch::PackedTensorAccessor32<float,2,torch::RestrictPtrTraits> intrinsics,
372
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> ii,
373
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> jj,
374
+ const torch::PackedTensorAccessor32<long,1,torch::RestrictPtrTraits> kk,
375
+ torch::PackedTensorAccessor32<float,4,torch::RestrictPtrTraits> coords) {
376
+
377
+ __shared__ float fx, fy, cx, cy;
378
+ if (threadIdx.x == 0) {
379
+ fx = intrinsics[0][0];
380
+ fy = intrinsics[0][1];
381
+ cx = intrinsics[0][2];
382
+ cy = intrinsics[0][3];
383
+ }
384
+
385
+ __syncthreads();
386
+
387
+ GPU_1D_KERNEL_LOOP(n, ii.size(0)) {
388
+ int ix = ii[n];
389
+ int jx = jj[n];
390
+ int kx = kk[n];
391
+
392
+ float ti[3] = { poses[ix][0], poses[ix][1], poses[ix][2] };
393
+ float tj[3] = { poses[jx][0], poses[jx][1], poses[jx][2] };
394
+ float qi[4] = { poses[ix][3], poses[ix][4], poses[ix][5], poses[ix][6] };
395
+ float qj[4] = { poses[jx][3], poses[jx][4], poses[jx][5], poses[jx][6] };
396
+
397
+ float tij[3], qij[4];
398
+ relSE3(ti, qi, tj, qj, tij, qij);
399
+
400
+ float Xi[4], Xj[4];
401
+ for (int i=0; i<patches.size(2); i++) {
402
+ for (int j=0; j<patches.size(3); j++) {
403
+
404
+ Xi[0] = (patches[kx][0][i][j] - cx) / fx;
405
+ Xi[1] = (patches[kx][1][i][j] - cy) / fy;
406
+ Xi[2] = 1.0;
407
+ Xi[3] = patches[kx][2][i][j];
408
+
409
+ actSE3(tij, qij, Xi, Xj);
410
+
411
+ coords[n][0][i][j] = fx * (Xj[0] / Xj[2]) + cx;
412
+ coords[n][1][i][j] = fy * (Xj[1] / Xj[2]) + cy;
413
+ // coords[n][2][i][j] = 1.0 / Xj[2];
414
+
415
+ }
416
+ }
417
+ }
418
+ }
419
+
420
+
421
+
422
+ std::vector<torch::Tensor> cuda_ba(
423
+ torch::Tensor poses,
424
+ torch::Tensor patches,
425
+ torch::Tensor intrinsics,
426
+ torch::Tensor target,
427
+ torch::Tensor weight,
428
+ torch::Tensor lmbda,
429
+ torch::Tensor ii,
430
+ torch::Tensor jj,
431
+ torch::Tensor kk,
432
+ const int t0, const int t1, const int iterations)
433
+ {
434
+
435
+ auto ktuple = torch::_unique(kk, true, true);
436
+ torch::Tensor kx = std::get<0>(ktuple);
437
+ torch::Tensor ku = std::get<1>(ktuple);
438
+
439
+ const int N = t1 - t0; // number of poses
440
+ const int M = kx.size(0); // number of patches
441
+ const int P = patches.size(3); // patch size
442
+
443
+ auto opts = torch::TensorOptions()
444
+ .dtype(torch::kFloat32).device(torch::kCUDA);
445
+
446
+ poses = poses.view({-1, 7});
447
+ patches = patches.view({-1,3,P,P});
448
+ intrinsics = intrinsics.view({-1, 4});
449
+
450
+ target = target.view({-1, 2});
451
+ weight = weight.view({-1, 2});
452
+
453
+ const int num = ii.size(0);
454
+ torch::Tensor B = torch::empty({6*N, 6*N}, opts);
455
+ torch::Tensor E = torch::empty({6*N, 1*M}, opts);
456
+ torch::Tensor C = torch::empty({M}, opts);
457
+
458
+ torch::Tensor v = torch::empty({6*N}, opts);
459
+ torch::Tensor u = torch::empty({1*M}, opts);
460
+
461
+ for (int itr=0; itr < iterations; itr++) {
462
+
463
+ B.zero_();
464
+ E.zero_();
465
+ C.zero_();
466
+ v.zero_();
467
+ u.zero_();
468
+
469
+ v = v.view({6*N});
470
+ u = u.view({1*M});
471
+
472
+ reprojection_residuals_and_hessian<<<NUM_BLOCKS(ii.size(0)), NUM_THREADS>>>(
473
+ poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
474
+ patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
475
+ intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
476
+ target.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
477
+ weight.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
478
+ lmbda.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
479
+ ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
480
+ jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
481
+ kk.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
482
+ ku.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
483
+ B.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
484
+ E.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
485
+ C.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
486
+ v.packed_accessor32<float,1,torch::RestrictPtrTraits>(),
487
+ u.packed_accessor32<float,1,torch::RestrictPtrTraits>(), t0);
488
+
489
+ v = v.view({6*N, 1});
490
+ u = u.view({1*M, 1});
491
+
492
+ torch::Tensor Q = 1.0 / (C + lmbda).view({1, M});
493
+
494
+ if (t1 - t0 == 0) {
495
+
496
+ torch::Tensor Qt = torch::transpose(Q, 0, 1);
497
+ torch::Tensor dZ = Qt * u;
498
+
499
+ dZ = dZ.view({M});
500
+
501
+ patch_retr_kernel<<<NUM_BLOCKS(M), NUM_THREADS>>>(
502
+ kx.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
503
+ patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
504
+ dZ.packed_accessor32<float,1,torch::RestrictPtrTraits>());
505
+
506
+ }
507
+
508
+ else {
509
+
510
+ torch::Tensor EQ = E * Q;
511
+ torch::Tensor Et = torch::transpose(E, 0, 1);
512
+ torch::Tensor Qt = torch::transpose(Q, 0, 1);
513
+
514
+ torch::Tensor S = B - torch::matmul(EQ, Et);
515
+ torch::Tensor y = v - torch::matmul(EQ, u);
516
+
517
+ torch::Tensor I = torch::eye(6*N, opts);
518
+ S += I * (1e-4 * S + 1.0);
519
+
520
+
521
+ torch::Tensor U = torch::linalg::cholesky(S);
522
+ torch::Tensor dX = torch::cholesky_solve(y, U);
523
+ torch::Tensor dZ = Qt * (u - torch::matmul(Et, dX));
524
+
525
+ dX = dX.view({N, 6});
526
+ dZ = dZ.view({M});
527
+
528
+ pose_retr_kernel<<<NUM_BLOCKS(N), NUM_THREADS>>>(t0, t1,
529
+ poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
530
+ dX.packed_accessor32<float,2,torch::RestrictPtrTraits>());
531
+
532
+ patch_retr_kernel<<<NUM_BLOCKS(M), NUM_THREADS>>>(
533
+ kx.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
534
+ patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
535
+ dZ.packed_accessor32<float,1,torch::RestrictPtrTraits>());
536
+ }
537
+ }
538
+
539
+ return {};
540
+ }
541
+
542
+
543
+ torch::Tensor cuda_reproject(
544
+ torch::Tensor poses,
545
+ torch::Tensor patches,
546
+ torch::Tensor intrinsics,
547
+ torch::Tensor ii,
548
+ torch::Tensor jj,
549
+ torch::Tensor kk)
550
+ {
551
+
552
+ const int N = ii.size(0);
553
+ const int P = patches.size(3); // patch size
554
+
555
+ poses = poses.view({-1, 7});
556
+ patches = patches.view({-1,3,P,P});
557
+ intrinsics = intrinsics.view({-1, 4});
558
+
559
+ auto opts = torch::TensorOptions()
560
+ .dtype(torch::kFloat32).device(torch::kCUDA);
561
+
562
+ torch::Tensor coords = torch::empty({N, 2, P, P}, opts);
563
+
564
+ reproject<<<NUM_BLOCKS(N), NUM_THREADS>>>(
565
+ poses.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
566
+ patches.packed_accessor32<float,4,torch::RestrictPtrTraits>(),
567
+ intrinsics.packed_accessor32<float,2,torch::RestrictPtrTraits>(),
568
+ ii.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
569
+ jj.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
570
+ kk.packed_accessor32<long,1,torch::RestrictPtrTraits>(),
571
+ coords.packed_accessor32<float,4,torch::RestrictPtrTraits>());
572
+
573
+ return coords.view({1, N, 2, P, P});
574
+
575
+ }
mini_dpvo/lietorch/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __all__ = ['groups']
2
+ from .groups import LieGroupParameter, SO3, RxSO3, SE3, Sim3, cat, stack
mini_dpvo/lietorch/broadcasting.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def check_broadcastable(x, y):
5
+ assert len(x.shape) == len(y.shape)
6
+ for (n, m) in zip(x.shape[:-1], y.shape[:-1]):
7
+ assert n==m or n==1 or m==1
8
+
9
+ def broadcast_inputs(x, y):
10
+ """ Automatic broadcasting of missing dimensions """
11
+ if y is None:
12
+ xs, xd = x.shape[:-1], x.shape[-1]
13
+ return (x.view(-1, xd).contiguous(), ), x.shape[:-1]
14
+
15
+ check_broadcastable(x, y)
16
+
17
+ xs, xd = x.shape[:-1], x.shape[-1]
18
+ ys, yd = y.shape[:-1], y.shape[-1]
19
+ out_shape = [max(n,m) for (n,m) in zip(xs,ys)]
20
+
21
+ if x.shape[:-1] == y.shape[-1]:
22
+ x1 = x.view(-1, xd)
23
+ y1 = y.view(-1, yd)
24
+
25
+ else:
26
+ x_expand = [m if n==1 else 1 for (n,m) in zip(xs, ys)]
27
+ y_expand = [n if m==1 else 1 for (n,m) in zip(xs, ys)]
28
+ x1 = x.repeat(x_expand + [1]).reshape(-1, xd).contiguous()
29
+ y1 = y.repeat(y_expand + [1]).reshape(-1, yd).contiguous()
30
+
31
+ return (x1, y1), tuple(out_shape)
mini_dpvo/lietorch/gradcheck.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ TORCH_MAJOR = int(torch.__version__.split('.')[0])
4
+ TORCH_MINOR = int(torch.__version__.split('.')[1])
5
+
6
+ from torch.types import _TensorOrTensors
7
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
8
+ from torch._six import container_abcs, istuple
9
+ else:
10
+ import collections.abc as container_abcs
11
+
12
+ import torch.testing
13
+ from torch.overrides import is_tensor_like
14
+ from itertools import product
15
+ import warnings
16
+ from typing import Callable, Union, Optional, Iterable, List
17
+
18
+ def zero_gradients(x):
19
+ if isinstance(x, torch.Tensor):
20
+ if x.grad is not None:
21
+ x.grad.detach_()
22
+ x.grad.zero_()
23
+ elif isinstance(x, container_abcs.Iterable):
24
+ for elem in x:
25
+ zero_gradients(elem)
26
+
27
+
28
+ def make_jacobian(input, num_out):
29
+ if is_tensor_like(input):
30
+ if not input.is_floating_point() and not input.is_complex():
31
+ return None
32
+ if not input.requires_grad:
33
+ return None
34
+ return input.new_zeros((input.nelement(), num_out), dtype=input.dtype, layout=torch.strided)
35
+ elif isinstance(input, container_abcs.Iterable) and not isinstance(input, str):
36
+ jacobians = list(filter(
37
+ lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
38
+ if not jacobians:
39
+ return None
40
+ return type(input)(jacobians) # type: ignore
41
+ else:
42
+ return None
43
+
44
+
45
+ def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
46
+ if is_tensor_like(x):
47
+ # mypy doesn't narrow type of `x` to torch.Tensor
48
+ if x.requires_grad or not only_requiring_grad: # type: ignore
49
+ yield x # type: ignore
50
+ elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
51
+ for elem in x:
52
+ for result in iter_tensors(elem, only_requiring_grad):
53
+ yield result
54
+
55
+ def get_numerical_jacobian(fn, input, target=None, eps=1e-3, grad_out=1.0):
56
+ """
57
+ input: input to `fn`
58
+ target: the Tensors wrt whom Jacobians are calculated (default=`input`)
59
+ grad_out: grad output value used to calculate gradients.
60
+
61
+ Note that `target` may not even be part of `input` to `fn`, so please be
62
+ **very careful** in this to not clone `target`.
63
+ """
64
+ if target is None:
65
+ target = input
66
+ output_size = fn(input).numel()
67
+ jacobian = make_jacobian(target, output_size)
68
+
69
+ # It's much easier to iterate over flattened lists of tensors.
70
+ # These are reference to the same objects in jacobian, so any changes
71
+ # will be reflected in it as well.
72
+ x_tensors = iter_tensors(target, True)
73
+ j_tensors = iter_tensors(jacobian)
74
+
75
+ def update_jacobians(x, idx, d, d_idx, is_mkldnn=False):
76
+
77
+ # compute_jacobian only works for pure real
78
+ # or pure imaginary delta
79
+ def compute_gradient(delta):
80
+ # we currently assume that the norm of delta equals eps
81
+ assert(delta == eps or delta == (eps * 1j))
82
+
83
+ def fn_out():
84
+ if not is_mkldnn:
85
+ # x is a view into input and so this works
86
+ return fn(input).clone()
87
+ else:
88
+ # convert the dense tensor back to have mkldnn layout
89
+ return fn([x.to_mkldnn()])
90
+
91
+ orig = x[idx].item()
92
+ x[idx] = orig - delta
93
+ outa = fn_out()
94
+ x[idx] = orig + delta
95
+ outb = fn_out()
96
+ x[idx] = orig
97
+ r = (outb - outa) / (2 * eps)
98
+ return r.detach().reshape(-1)
99
+
100
+ # for details on the algorithm used here, refer:
101
+ # Section 3.5.3 https://arxiv.org/pdf/1701.00392.pdf
102
+ # s = fn(z) where z = x for real valued input
103
+ # and z = x + yj for complex valued input
104
+ ds_dx = compute_gradient(eps)
105
+ if x.is_complex(): # C -> C, C -> R
106
+ ds_dy = compute_gradient(eps * 1j)
107
+ # conjugate wirtinger derivative
108
+ conj_w_d = 0.5 * (ds_dx + ds_dy * 1j)
109
+ # wirtinger derivative
110
+ w_d = 0.5 * (ds_dx - ds_dy * 1j)
111
+ d[d_idx] = grad_out.conjugate() * conj_w_d + grad_out * w_d.conj()
112
+ elif ds_dx.is_complex(): # R -> C
113
+ # w_d = conj_w_d = 0.5 * ds_dx
114
+ # dL_dz_conj = 0.5 * [grad_out.conj() * ds_dx + grad_out * ds_dx.conj()]
115
+ # = 0.5 * [grad_out.conj() * ds_dx + (grad_out.conj() * ds_dx).conj()]
116
+ # = 0.5 * 2 * real(grad_out.conj() * ds_dx)
117
+ # = real(grad_out.conj() * ds_dx)
118
+ d[d_idx] = torch.real(grad_out.conjugate() * ds_dx)
119
+ else: # R -> R
120
+ d[d_idx] = ds_dx * grad_out
121
+
122
+ # TODO: compare structure
123
+ for x_tensor, d_tensor in zip(x_tensors, j_tensors):
124
+ if x_tensor.is_sparse:
125
+ def get_stride(size):
126
+ dim = len(size)
127
+ tmp = 1
128
+ stride = [0] * dim
129
+ for i in reversed(range(dim)):
130
+ stride[i] = tmp
131
+ tmp *= size[i]
132
+ return stride
133
+
134
+ x_nnz = x_tensor._nnz()
135
+ x_size = list(x_tensor.size())
136
+ x_indices = x_tensor._indices().t()
137
+ x_values = x_tensor._values()
138
+ x_stride = get_stride(x_size)
139
+
140
+ # Use .data here to get around the version check
141
+ x_values = x_values.data
142
+
143
+ for i in range(x_nnz):
144
+ x_value = x_values[i]
145
+ for x_idx in product(*[range(m) for m in x_values.size()[1:]]):
146
+ indices = x_indices[i].tolist() + list(x_idx)
147
+ d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
148
+ update_jacobians(x_value, x_idx, d_tensor, d_idx)
149
+ elif x_tensor.layout == torch._mkldnn: # type: ignore
150
+ # Use .data here to get around the version check
151
+ x_tensor = x_tensor.data
152
+ if len(input) != 1:
153
+ raise ValueError('gradcheck currently only supports functions with 1 input, but got: ',
154
+ len(input))
155
+ for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
156
+ # this is really inefficient, but without indexing implemented, there's
157
+ # not really a better way than converting back and forth
158
+ x_tensor_dense = x_tensor.to_dense()
159
+ update_jacobians(x_tensor_dense, x_idx, d_tensor, d_idx, is_mkldnn=True)
160
+ else:
161
+ # Use .data here to get around the version check
162
+ x_tensor = x_tensor.data
163
+ for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])):
164
+ update_jacobians(x_tensor, x_idx, d_tensor, d_idx)
165
+
166
+ return jacobian
167
+
168
+
169
+ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
170
+ # it is easier to call to_dense() on the sparse output than
171
+ # to modify analytical jacobian
172
+ if output.is_sparse:
173
+ raise ValueError('Sparse output is not supported at gradcheck yet. '
174
+ 'Please call to_dense() on the output of fn for gradcheck.')
175
+ if output.layout == torch._mkldnn: # type: ignore
176
+ raise ValueError('MKLDNN output is not supported at gradcheck yet. '
177
+ 'Please call to_dense() on the output of fn for gradcheck.')
178
+ diff_input_list = list(iter_tensors(input, True))
179
+ jacobian = make_jacobian(input, output.numel())
180
+ jacobian_reentrant = make_jacobian(input, output.numel())
181
+ grad_output = torch.zeros_like(output, memory_format=torch.legacy_contiguous_format)
182
+ flat_grad_output = grad_output.view(-1)
183
+ reentrant = True
184
+ correct_grad_sizes = True
185
+ correct_grad_types = True
186
+
187
+ for i in range(flat_grad_output.numel()):
188
+ flat_grad_output.zero_()
189
+ flat_grad_output[i] = grad_out
190
+ for jacobian_c in (jacobian, jacobian_reentrant):
191
+ grads_input = torch.autograd.grad(output, diff_input_list, grad_output,
192
+ retain_graph=True, allow_unused=True)
193
+ for jacobian_x, d_x, x in zip(jacobian_c, grads_input, diff_input_list):
194
+ if d_x is not None and d_x.size() != x.size():
195
+ correct_grad_sizes = False
196
+ elif d_x is not None and d_x.dtype != x.dtype:
197
+ correct_grad_types = False
198
+ elif jacobian_x.numel() != 0:
199
+ if d_x is None:
200
+ jacobian_x[:, i].zero_()
201
+ else:
202
+ d_x_dense = d_x.to_dense() if not d_x.layout == torch.strided else d_x
203
+ assert jacobian_x[:, i].numel() == d_x_dense.numel()
204
+ jacobian_x[:, i] = d_x_dense.contiguous().view(-1)
205
+
206
+ for jacobian_x, jacobian_reentrant_x in zip(jacobian, jacobian_reentrant):
207
+ if jacobian_x.numel() != 0 and (jacobian_x - jacobian_reentrant_x).abs().max() > nondet_tol:
208
+ reentrant = False
209
+
210
+ return jacobian, reentrant, correct_grad_sizes, correct_grad_types
211
+
212
+
213
+ def _as_tuple(x):
214
+ if TORCH_MAJOR == 1 and TORCH_MINOR < 8:
215
+ b_tuple = istuple(x)
216
+ else:
217
+ b_tuple = isinstance(x, tuple)
218
+
219
+ if b_tuple:
220
+ return x
221
+ elif isinstance(x, list):
222
+ return tuple(x)
223
+ else:
224
+ return x,
225
+
226
+
227
+
228
+ def _differentiable_outputs(x):
229
+ return tuple(o for o in _as_tuple(x) if o.requires_grad)
230
+
231
+
232
+ # Note [VarArg of Tensors]
233
+ # ~~~~~~~~~~~~~~~~~~~~~~~~
234
+ # 'func' accepts a vararg of tensors, which isn't expressable in the type system at the moment.
235
+ # If https://mypy.readthedocs.io/en/latest/additional_features.html?highlight=callable#extended-callable-types is accepted,
236
+ # the '...' first argument of Callable can be replaced with VarArg(Tensor).
237
+ # For now, we permit any input.
238
+ # the '...' first argument of Callable can be replaced with VarArg(Tensor).
239
+ # For now, we permit any input.
240
+
241
+ def gradcheck(
242
+ func: Callable[..., Union[_TensorOrTensors]], # See Note [VarArg of Tensors]
243
+ inputs: _TensorOrTensors,
244
+ eps: float = 1e-6,
245
+ atol: float = 1e-5,
246
+ rtol: float = 1e-3,
247
+ raise_exception: bool = True,
248
+ check_sparse_nnz: bool = False,
249
+ nondet_tol: float = 0.0,
250
+ check_undefined_grad: bool = True,
251
+ check_grad_dtypes: bool = False
252
+ ) -> bool:
253
+ r"""Check gradients computed via small finite differences against analytical
254
+ gradients w.r.t. tensors in :attr:`inputs` that are of floating point or complex type
255
+ and with ``requires_grad=True``.
256
+
257
+ The check between numerical and analytical gradients uses :func:`~torch.allclose`.
258
+
259
+ For complex functions, no notion of Jacobian exists. Gradcheck verifies if the numerical and
260
+ analytical values of Wirtinger and Conjugate Wirtinger derivative are consistent. The gradient
261
+ computation is done under the assumption that the overall function has a real valued output.
262
+ For functions with complex output, gradcheck compares the numerical and analytical gradients
263
+ for two values of :attr:`grad_output`: 1 and 1j. For more details, check out
264
+ :ref:`complex_autograd-doc`.
265
+
266
+ .. note::
267
+ The default values are designed for :attr:`input` of double precision.
268
+ This check will likely fail if :attr:`input` is of less precision, e.g.,
269
+ ``FloatTensor``.
270
+
271
+ .. warning::
272
+ If any checked tensor in :attr:`input` has overlapping memory, i.e.,
273
+ different indices pointing to the same memory address (e.g., from
274
+ :func:`torch.expand`), this check will likely fail because the numerical
275
+ gradients computed by point perturbation at such indices will change
276
+ values at all other indices that share the same memory address.
277
+
278
+ Args:
279
+ func (function): a Python function that takes Tensor inputs and returns
280
+ a Tensor or a tuple of Tensors
281
+ inputs (tuple of Tensor or Tensor): inputs to the function
282
+ eps (float, optional): perturbation for finite differences
283
+ atol (float, optional): absolute tolerance
284
+ rtol (float, optional): relative tolerance
285
+ raise_exception (bool, optional): indicating whether to raise an exception if
286
+ the check fails. The exception gives more information about the
287
+ exact nature of the failure. This is helpful when debugging gradchecks.
288
+ check_sparse_nnz (bool, optional): if True, gradcheck allows for SparseTensor input,
289
+ and for any SparseTensor at input, gradcheck will perform check at nnz positions only.
290
+ nondet_tol (float, optional): tolerance for non-determinism. When running
291
+ identical inputs through the differentiation, the results must either match
292
+ exactly (default, 0.0) or be within this tolerance.
293
+ check_undefined_grad (bool, options): if True, check if undefined output grads
294
+ are supported and treated as zeros, for ``Tensor`` outputs.
295
+
296
+ Returns:
297
+ True if all differences satisfy allclose condition
298
+ """
299
+ def fail_test(msg):
300
+ if raise_exception:
301
+ raise RuntimeError(msg)
302
+ return False
303
+
304
+ tupled_inputs = _as_tuple(inputs)
305
+ if not check_sparse_nnz and any(t.is_sparse for t in tupled_inputs if isinstance(t, torch.Tensor)):
306
+ return fail_test('gradcheck expects all tensor inputs are dense when check_sparse_nnz is set to False.')
307
+
308
+ # Make sure that gradients are saved for at least one input
309
+ any_input_requiring_grad = False
310
+ for idx, inp in enumerate(tupled_inputs):
311
+ if is_tensor_like(inp) and inp.requires_grad:
312
+ if not (inp.dtype == torch.float64 or inp.dtype == torch.complex128):
313
+ warnings.warn(
314
+ f'Input #{idx} requires gradient and '
315
+ 'is not a double precision floating point or complex. '
316
+ 'This check will likely fail if all the inputs are '
317
+ 'not of double precision floating point or complex. ')
318
+ content = inp._values() if inp.is_sparse else inp
319
+ # TODO: To cover more problematic cases, replace stride = 0 check with
320
+ # "any overlap in memory" once we have a proper function to check it.
321
+ if content.layout is not torch._mkldnn: # type: ignore
322
+ if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
323
+ raise RuntimeError(
324
+ 'The {}th input has a dimension with stride 0. gradcheck only '
325
+ 'supports inputs that are non-overlapping to be able to '
326
+ 'compute the numerical gradients correctly. You should call '
327
+ '.contiguous on the input before passing it to gradcheck.')
328
+ any_input_requiring_grad = True
329
+ inp.retain_grad()
330
+ if not any_input_requiring_grad:
331
+ raise ValueError(
332
+ 'gradcheck expects at least one input tensor to require gradient, '
333
+ 'but none of the them have requires_grad=True.')
334
+
335
+ func_out = func(*tupled_inputs)
336
+ output = _differentiable_outputs(func_out)
337
+
338
+ if not output:
339
+ for i, o in enumerate(func_out):
340
+ def fn(input):
341
+ return _as_tuple(func(*input))[i]
342
+ numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
343
+ for n in numerical:
344
+ if torch.ne(n, 0).sum() > 0:
345
+ return fail_test('Numerical gradient for function expected to be zero')
346
+ return True
347
+
348
+ for i, o in enumerate(output):
349
+ if not o.requires_grad:
350
+ continue
351
+
352
+ def fn(input):
353
+ return _as_tuple(func(*input))[i]
354
+
355
+ analytical, reentrant, correct_grad_sizes, correct_grad_types = get_analytical_jacobian(tupled_inputs,
356
+ o,
357
+ nondet_tol=nondet_tol)
358
+ numerical = get_numerical_jacobian(fn, tupled_inputs, eps=eps)
359
+
360
+ return analytical, numerical
361
+
362
+ out_is_complex = o.is_complex()
363
+
364
+ if out_is_complex:
365
+ # analytical vjp with grad_out = 1.0j
366
+ analytical_with_imag_grad_out, reentrant_with_imag_grad_out, \
367
+ correct_grad_sizes_with_imag_grad_out, correct_grad_types_with_imag_grad_out \
368
+ = get_analytical_jacobian(tupled_inputs, o, nondet_tol=nondet_tol, grad_out=1j)
369
+ numerical_with_imag_grad_out = get_numerical_jacobian(fn, tupled_inputs, eps=eps, grad_out=1j)
370
+
371
+ if not correct_grad_types and check_grad_dtypes:
372
+ return fail_test('Gradient has dtype mismatch')
373
+
374
+ if out_is_complex and not correct_grad_types_with_imag_grad_out and check_grad_dtypes:
375
+ return fail_test('Gradient (calculated using complex valued grad output) has dtype mismatch')
376
+
377
+ if not correct_grad_sizes:
378
+ return fail_test('Analytical gradient has incorrect size')
379
+
380
+ if out_is_complex and not correct_grad_sizes_with_imag_grad_out:
381
+ return fail_test('Analytical gradient (calculated using complex valued grad output) has incorrect size')
382
+
383
+ def checkIfNumericalAnalyticAreClose(a, n, j, error_str=''):
384
+ if not torch.allclose(a, n, rtol, atol):
385
+ return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
386
+ 'numerical:%s\nanalytical:%s\n' % (i, j, n, a))
387
+
388
+ inp_tensors = iter_tensors(tupled_inputs, True)
389
+
390
+ for j, (a, n, inp) in enumerate(zip(analytical, numerical, inp_tensors)):
391
+ if a.numel() != 0 or n.numel() != 0:
392
+ if o.is_complex():
393
+ # C -> C, R -> C
394
+ a_with_imag_grad_out = analytical_with_imag_grad_out[j]
395
+ n_with_imag_grad_out = numerical_with_imag_grad_out[j]
396
+ checkIfNumericalAnalyticAreClose(a_with_imag_grad_out, n_with_imag_grad_out, j,
397
+ "Gradients failed to compare equal for grad output = 1j. ")
398
+ if inp.is_complex():
399
+ # C -> R, C -> C
400
+ checkIfNumericalAnalyticAreClose(a, n, j,
401
+ "Gradients failed to compare equal for grad output = 1. ")
402
+ else:
403
+ # R -> R, R -> C
404
+ checkIfNumericalAnalyticAreClose(a, n, j)
405
+
406
+
407
+ def not_reentrant_error(error_str=''):
408
+ error_msg = "Backward" + error_str + " is not reentrant, i.e., running backward with same \
409
+ input and grad_output multiple times gives different values, \
410
+ although analytical gradient matches numerical gradient. \
411
+ The tolerance for nondeterminism was {}.".format(nondet_tol)
412
+ return fail_test(error_msg)
413
+
414
+ if not reentrant:
415
+ return not_reentrant_error()
416
+
417
+ if out_is_complex and not reentrant_with_imag_grad_out:
418
+ return not_reentrant_error(' (calculated using complex valued grad output)')
419
+
420
+ # check if the backward multiplies by grad_output
421
+ output = _differentiable_outputs(func(*tupled_inputs))
422
+ if any([o.requires_grad for o in output]):
423
+ diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True))
424
+ if not diff_input_list:
425
+ raise RuntimeError("no Tensors requiring grad found in input")
426
+ grads_input = torch.autograd.grad(output, diff_input_list,
427
+ [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output],
428
+ allow_unused=True)
429
+ for gi, di in zip(grads_input, diff_input_list):
430
+ if gi is None:
431
+ continue
432
+ if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
433
+ if gi.layout != di.layout:
434
+ return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')')
435
+ if gi.layout == torch.sparse_coo:
436
+ if gi.sparse_dim() != di.sparse_dim():
437
+ return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
438
+ if gi.dense_dim() != di.dense_dim():
439
+ return fail_test('grad is sparse tensor, but has incorrect dense_dim')
440
+ gi = gi.to_dense()
441
+ di = di.to_dense()
442
+ if not gi.eq(0).all():
443
+ return fail_test('backward not multiplied by grad_output')
444
+ if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
445
+ return fail_test("grad is incorrect type")
446
+ if gi.size() != di.size():
447
+ return fail_test('grad is incorrect size')
448
+
449
+ if check_undefined_grad:
450
+ def warn_bc_breaking():
451
+ warnings.warn((
452
+ 'Backwards compatibility: New undefined gradient support checking '
453
+ 'feature is enabled by default, but it may break existing callers '
454
+ 'of this function. If this is true for you, you can call this '
455
+ 'function with "check_undefined_grad=False" to disable the feature'))
456
+
457
+ def check_undefined_grad_support(output_to_check):
458
+ grads_output = [torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output_to_check]
459
+ try:
460
+ grads_input = torch.autograd.grad(output_to_check,
461
+ diff_input_list,
462
+ grads_output,
463
+ allow_unused=True)
464
+ except RuntimeError:
465
+ warn_bc_breaking()
466
+ return fail_test((
467
+ 'Expected backward function to handle undefined output grads. '
468
+ 'Please look at "Notes about undefined output gradients" in '
469
+ '"tools/autograd/derivatives.yaml"'))
470
+
471
+ for gi, i in zip(grads_input, diff_input_list):
472
+ if (gi is not None) and (not gi.eq(0).all()):
473
+ warn_bc_breaking()
474
+ return fail_test((
475
+ 'Expected all input grads to be undefined or zero when all output grads are undefined '
476
+ 'or zero. Please look at "Notes about undefined output gradients" in '
477
+ '"tools/autograd/derivatives.yaml"'))
478
+ return True
479
+
480
+ # All backward functions must work properly if all output grads are undefined
481
+ outputs_to_check = [[
482
+ torch._C._functions.UndefinedGrad()(o) for o in _differentiable_outputs(func(*tupled_inputs))
483
+ # This check filters out Tensor-likes that aren't instances of Tensor.
484
+ if isinstance(o, torch.Tensor)
485
+ ]]
486
+
487
+ # If there are multiple output grads, we should be able to undef one at a time without error
488
+ if len(outputs_to_check[0]) > 1:
489
+ for undef_grad_idx in range(len(output)):
490
+ output_to_check = _differentiable_outputs(func(*tupled_inputs))
491
+ outputs_to_check.append([
492
+ torch._C._functions.UndefinedGrad()(o) if idx == undef_grad_idx else o
493
+ for idx, o in enumerate(output_to_check)])
494
+
495
+ for output_to_check in outputs_to_check:
496
+ if not check_undefined_grad_support(output_to_check):
497
+ return False
498
+
499
+ return True
500
+
501
+
502
+ def gradgradcheck(
503
+ func: Callable[..., _TensorOrTensors], # See Note [VarArg of Tensors]
504
+ inputs: _TensorOrTensors,
505
+ grad_outputs: Optional[_TensorOrTensors] = None,
506
+ eps: float = 1e-6,
507
+ atol: float = 1e-5,
508
+ rtol: float = 1e-3,
509
+ gen_non_contig_grad_outputs: bool = False,
510
+ raise_exception: bool = True,
511
+ nondet_tol: float = 0.0,
512
+ check_undefined_grad: bool = True,
513
+ check_grad_dtypes: bool = False
514
+ ) -> bool:
515
+ r"""Check gradients of gradients computed via small finite differences
516
+ against analytical gradients w.r.t. tensors in :attr:`inputs` and
517
+ :attr:`grad_outputs` that are of floating point or complex type and with
518
+ ``requires_grad=True``.
519
+
520
+ This function checks that backpropagating through the gradients computed
521
+ to the given :attr:`grad_outputs` are correct.
522
+
523
+ The check between numerical and analytical gradients uses :func:`~torch.allclose`.
524
+
525
+ .. note::
526
+ The default values are designed for :attr:`input` and
527
+ :attr:`grad_outputs` of double precision. This check will likely fail if
528
+ they are of less precision, e.g., ``FloatTensor``.
529
+
530
+ .. warning::
531
+ If any checked tensor in :attr:`input` and :attr:`grad_outputs` has
532
+ overlapping memory, i.e., different indices pointing to the same memory
533
+ address (e.g., from :func:`torch.expand`), this check will likely fail
534
+ because the numerical gradients computed by point perturbation at such
535
+ indices will change values at all other indices that share the same
536
+ memory address.
537
+
538
+ Args:
539
+ func (function): a Python function that takes Tensor inputs and returns
540
+ a Tensor or a tuple of Tensors
541
+ inputs (tuple of Tensor or Tensor): inputs to the function
542
+ grad_outputs (tuple of Tensor or Tensor, optional): The gradients with
543
+ respect to the function's outputs.
544
+ eps (float, optional): perturbation for finite differences
545
+ atol (float, optional): absolute tolerance
546
+ rtol (float, optional): relative tolerance
547
+ gen_non_contig_grad_outputs (bool, optional): if :attr:`grad_outputs` is
548
+ ``None`` and :attr:`gen_non_contig_grad_outputs` is ``True``, the
549
+ randomly generated gradient outputs are made to be noncontiguous
550
+ raise_exception (bool, optional): indicating whether to raise an exception if
551
+ the check fails. The exception gives more information about the
552
+ exact nature of the failure. This is helpful when debugging gradchecks.
553
+ nondet_tol (float, optional): tolerance for non-determinism. When running
554
+ identical inputs through the differentiation, the results must either match
555
+ exactly (default, 0.0) or be within this tolerance. Note that a small amount
556
+ of nondeterminism in the gradient will lead to larger inaccuracies in
557
+ the second derivative.
558
+ check_undefined_grad (bool, options): if True, check if undefined output grads
559
+ are supported and treated as zeros
560
+
561
+ Returns:
562
+ True if all differences satisfy allclose condition
563
+ """
564
+ tupled_inputs = _as_tuple(inputs)
565
+
566
+ if grad_outputs is None:
567
+ # If grad_outputs is not specified, create random Tensors of the same
568
+ # shape, type, and device as the outputs
569
+ def randn_like(x):
570
+ y = torch.testing.randn_like(
571
+ x if (x.is_floating_point() or x.is_complex()) else x.double(), memory_format=torch.legacy_contiguous_format)
572
+ if gen_non_contig_grad_outputs:
573
+ y = torch.testing.make_non_contiguous(y)
574
+ return y.requires_grad_()
575
+ outputs = _as_tuple(func(*tupled_inputs))
576
+ tupled_grad_outputs = tuple(randn_like(x) for x in outputs)
577
+ else:
578
+ tupled_grad_outputs = _as_tuple(grad_outputs)
579
+
580
+ num_outputs = len(tupled_grad_outputs)
581
+
582
+ def new_func(*args):
583
+ input_args = args[:-num_outputs]
584
+ grad_outputs = args[-num_outputs:]
585
+ outputs = _differentiable_outputs(func(*input_args))
586
+ input_args = tuple(x for x in input_args if isinstance(x, torch.Tensor) and x.requires_grad)
587
+ grad_inputs = torch.autograd.grad(outputs, input_args, grad_outputs, create_graph=True)
588
+ return grad_inputs
589
+
590
+ return gradcheck(new_func, tupled_inputs + tupled_grad_outputs, eps, atol, rtol, raise_exception,
591
+ nondet_tol=nondet_tol, check_undefined_grad=check_undefined_grad,
592
+ check_grad_dtypes=check_grad_dtypes)
mini_dpvo/lietorch/group_ops.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import lietorch_backends
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+
6
+
7
+ class GroupOp(torch.autograd.Function):
8
+ """ group operation base class """
9
+
10
+ @classmethod
11
+ def forward(cls, ctx, group_id, *inputs):
12
+ ctx.group_id = group_id
13
+ ctx.save_for_backward(*inputs)
14
+ out = cls.forward_op(ctx.group_id, *inputs)
15
+ return out
16
+
17
+ @classmethod
18
+ def backward(cls, ctx, grad):
19
+ error_str = "Backward operation not implemented for {}".format(cls)
20
+ assert cls.backward_op is not None, error_str
21
+
22
+ inputs = ctx.saved_tensors
23
+ grad = grad.contiguous()
24
+ grad_inputs = cls.backward_op(ctx.group_id, grad, *inputs)
25
+ return (None, ) + tuple(grad_inputs)
26
+
27
+
28
+ class Exp(GroupOp):
29
+ """ exponential map """
30
+ forward_op, backward_op = lietorch_backends.expm, lietorch_backends.expm_backward
31
+
32
+ class Log(GroupOp):
33
+ """ logarithm map """
34
+ forward_op, backward_op = lietorch_backends.logm, lietorch_backends.logm_backward
35
+
36
+ class Inv(GroupOp):
37
+ """ group inverse """
38
+ forward_op, backward_op = lietorch_backends.inv, lietorch_backends.inv_backward
39
+
40
+ class Mul(GroupOp):
41
+ """ group multiplication """
42
+ forward_op, backward_op = lietorch_backends.mul, lietorch_backends.mul_backward
43
+
44
+ class Adj(GroupOp):
45
+ """ adjoint operator """
46
+ forward_op, backward_op = lietorch_backends.adj, lietorch_backends.adj_backward
47
+
48
+ class AdjT(GroupOp):
49
+ """ adjoint operator """
50
+ forward_op, backward_op = lietorch_backends.adjT, lietorch_backends.adjT_backward
51
+
52
+ class Act3(GroupOp):
53
+ """ action on point """
54
+ forward_op, backward_op = lietorch_backends.act, lietorch_backends.act_backward
55
+
56
+ class Act4(GroupOp):
57
+ """ action on point """
58
+ forward_op, backward_op = lietorch_backends.act4, lietorch_backends.act4_backward
59
+
60
+ class Jinv(GroupOp):
61
+ """ adjoint operator """
62
+ forward_op, backward_op = lietorch_backends.Jinv, None
63
+
64
+ class ToMatrix(GroupOp):
65
+ """ convert to matrix representation """
66
+ forward_op, backward_op = lietorch_backends.as_matrix, None
67
+
68
+
69
+
70
+
71
+ ### conversion operations to/from Euclidean embeddings ###
72
+
73
+ class FromVec(torch.autograd.Function):
74
+ """ convert vector into group object """
75
+
76
+ @classmethod
77
+ def forward(cls, ctx, group_id, *inputs):
78
+ ctx.group_id = group_id
79
+ ctx.save_for_backward(*inputs)
80
+ return inputs[0]
81
+
82
+ @classmethod
83
+ def backward(cls, ctx, grad):
84
+ inputs = ctx.saved_tensors
85
+ J = lietorch_backends.projector(ctx.group_id, *inputs)
86
+ return None, torch.matmul(grad.unsqueeze(-2), torch.linalg.pinv(J)).squeeze(-2)
87
+
88
+ class ToVec(torch.autograd.Function):
89
+ """ convert group object to vector """
90
+
91
+ @classmethod
92
+ def forward(cls, ctx, group_id, *inputs):
93
+ ctx.group_id = group_id
94
+ ctx.save_for_backward(*inputs)
95
+ return inputs[0]
96
+
97
+ @classmethod
98
+ def backward(cls, ctx, grad):
99
+ inputs = ctx.saved_tensors
100
+ J = lietorch_backends.projector(ctx.group_id, *inputs)
101
+ return None, torch.matmul(grad.unsqueeze(-2), J).squeeze(-2)
102
+
mini_dpvo/lietorch/groups.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ # group operations implemented in cuda
5
+ from .group_ops import Exp, Log, Inv, Mul, Adj, AdjT, Jinv, Act3, Act4, ToMatrix, ToVec, FromVec
6
+ from .broadcasting import broadcast_inputs
7
+
8
+
9
+ class LieGroupParameter(torch.Tensor):
10
+ """ Wrapper class for LieGroup """
11
+
12
+ from torch._C import _disabled_torch_function_impl
13
+ __torch_function__ = _disabled_torch_function_impl
14
+
15
+ def __new__(cls, group, requires_grad=True):
16
+ data = torch.zeros(group.tangent_shape,
17
+ device=group.data.device,
18
+ dtype=group.data.dtype,
19
+ requires_grad=True)
20
+
21
+ return torch.Tensor._make_subclass(cls, data, requires_grad)
22
+
23
+ def __init__(self, group):
24
+ self.group = group
25
+
26
+ def retr(self):
27
+ return self.group.retr(self)
28
+
29
+ def log(self):
30
+ return self.retr().log()
31
+
32
+ def inv(self):
33
+ return self.retr().inv()
34
+
35
+ def adj(self, a):
36
+ return self.retr().adj(a)
37
+
38
+ def __mul__(self, other):
39
+ if isinstance(other, LieGroupParameter):
40
+ return self.retr() * other.retr()
41
+ else:
42
+ return self.retr() * other
43
+
44
+ def add_(self, update, alpha):
45
+ self.group = self.group.exp(alpha*update) * self.group
46
+
47
+ def __getitem__(self, index):
48
+ return self.retr().__getitem__(index)
49
+
50
+
51
+ class LieGroup:
52
+ """ Base class for Lie Group """
53
+
54
+ def __init__(self, data):
55
+ self.data = data
56
+
57
+ def __repr__(self):
58
+ return "{}: size={}, device={}, dtype={}".format(
59
+ self.group_name, self.shape, self.device, self.dtype)
60
+
61
+ @property
62
+ def shape(self):
63
+ return self.data.shape[:-1]
64
+
65
+ @property
66
+ def device(self):
67
+ return self.data.device
68
+
69
+ @property
70
+ def dtype(self):
71
+ return self.data.dtype
72
+
73
+ def vec(self):
74
+ return self.apply_op(ToVec, self.data)
75
+
76
+ @property
77
+ def tangent_shape(self):
78
+ return self.data.shape[:-1] + (self.manifold_dim,)
79
+
80
+ @classmethod
81
+ def Identity(cls, *batch_shape, **kwargs):
82
+ """ Construct identity element with batch shape """
83
+
84
+ if isinstance(batch_shape[0], tuple):
85
+ batch_shape = batch_shape[0]
86
+
87
+ elif isinstance(batch_shape[0], list):
88
+ batch_shape = tuple(batch_shape[0])
89
+
90
+ numel = np.prod(batch_shape)
91
+ data = cls.id_elem.reshape(1,-1)
92
+
93
+ if 'device' in kwargs:
94
+ data = data.to(kwargs['device'])
95
+
96
+ if 'dtype' in kwargs:
97
+ data = data.type(kwargs['dtype'])
98
+
99
+ data = data.repeat(numel, 1)
100
+ return cls(data).view(batch_shape)
101
+
102
+ @classmethod
103
+ def IdentityLike(cls, G):
104
+ return cls.Identity(G.shape, device=G.data.device, dtype=G.data.dtype)
105
+
106
+ @classmethod
107
+ def InitFromVec(cls, data):
108
+ return cls(cls.apply_op(FromVec, data))
109
+
110
+ @classmethod
111
+ def Random(cls, *batch_shape, sigma=1.0, **kwargs):
112
+ """ Construct random element with batch_shape by random sampling in tangent space"""
113
+
114
+ if isinstance(batch_shape[0], tuple):
115
+ batch_shape = batch_shape[0]
116
+
117
+ elif isinstance(batch_shape[0], list):
118
+ batch_shape = tuple(batch_shape[0])
119
+
120
+ tangent_shape = batch_shape + (cls.manifold_dim,)
121
+ xi = torch.randn(tangent_shape, **kwargs)
122
+ return cls.exp(sigma * xi)
123
+
124
+ @classmethod
125
+ def apply_op(cls, op, x, y=None):
126
+ """ Apply group operator """
127
+ inputs, out_shape = broadcast_inputs(x, y)
128
+
129
+ data = op.apply(cls.group_id, *inputs)
130
+ return data.view(out_shape + (-1,))
131
+
132
+ @classmethod
133
+ def exp(cls, x):
134
+ """ exponential map: x -> X """
135
+ return cls(cls.apply_op(Exp, x))
136
+
137
+ def quaternion(self):
138
+ """ extract quaternion """
139
+ return self.apply_op(Quat, self.data)
140
+
141
+ def log(self):
142
+ """ logarithm map """
143
+ return self.apply_op(Log, self.data)
144
+
145
+ def inv(self):
146
+ """ group inverse """
147
+ return self.__class__(self.apply_op(Inv, self.data))
148
+
149
+ def mul(self, other):
150
+ """ group multiplication """
151
+ return self.__class__(self.apply_op(Mul, self.data, other.data))
152
+
153
+ def retr(self, a):
154
+ """ retraction: Exp(a) * X """
155
+ dX = self.__class__.apply_op(Exp, a)
156
+ return self.__class__(self.apply_op(Mul, dX, self.data))
157
+
158
+ def adj(self, a):
159
+ """ adjoint operator: b = A(X) * a """
160
+ return self.apply_op(Adj, self.data, a)
161
+
162
+ def adjT(self, a):
163
+ """ transposed adjoint operator: b = a * A(X) """
164
+ return self.apply_op(AdjT, self.data, a)
165
+
166
+ def Jinv(self, a):
167
+ return self.apply_op(Jinv, self.data, a)
168
+
169
+ def act(self, p):
170
+ """ action on a point cloud """
171
+
172
+ # action on point
173
+ if p.shape[-1] == 3:
174
+ return self.apply_op(Act3, self.data, p)
175
+
176
+ # action on homogeneous point
177
+ elif p.shape[-1] == 4:
178
+ return self.apply_op(Act4, self.data, p)
179
+
180
+ def matrix(self):
181
+ """ convert element to 4x4 matrix """
182
+ I = torch.eye(4, dtype=self.dtype, device=self.device)
183
+ I = I.view([1] * (len(self.data.shape) - 1) + [4, 4])
184
+ return self.__class__(self.data[...,None,:]).act(I).transpose(-1,-2)
185
+
186
+ def translation(self):
187
+ """ extract translation component """
188
+ p = torch.as_tensor([0.0, 0.0, 0.0, 1.0], dtype=self.dtype, device=self.device)
189
+ p = p.view([1] * (len(self.data.shape) - 1) + [4,])
190
+ return self.apply_op(Act4, self.data, p)
191
+
192
+ def detach(self):
193
+ return self.__class__(self.data.detach())
194
+
195
+ def view(self, dims):
196
+ data_reshaped = self.data.view(dims + (self.embedded_dim,))
197
+ return self.__class__(data_reshaped)
198
+
199
+ def __mul__(self, other):
200
+ # group multiplication
201
+
202
+ if isinstance(other, LieGroup):
203
+ return self.mul(other)
204
+
205
+ # action on point
206
+ elif isinstance(other, torch.Tensor):
207
+ return self.act(other)
208
+
209
+ def __getitem__(self, index):
210
+ return self.__class__(self.data[index])
211
+
212
+ def __setitem__(self, index, item):
213
+ self.data[index] = item.data
214
+
215
+ def to(self, *args, **kwargs):
216
+ return self.__class__(self.data.to(*args, **kwargs))
217
+
218
+ def cpu(self):
219
+ return self.__class__(self.data.cpu())
220
+
221
+ def cuda(self):
222
+ return self.__class__(self.data.cuda())
223
+
224
+ def float(self, device):
225
+ return self.__class__(self.data.float())
226
+
227
+ def double(self, device):
228
+ return self.__class__(self.data.double())
229
+
230
+ def unbind(self, dim=0):
231
+ return [self.__class__(x) for x in self.data.unbind(dim=dim)]
232
+
233
+
234
+ class SO3(LieGroup):
235
+ group_name = 'SO3'
236
+ group_id = 1
237
+ manifold_dim = 3
238
+ embedded_dim = 4
239
+
240
+ # unit quaternion
241
+ id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0])
242
+
243
+ def __init__(self, data):
244
+ if isinstance(data, SE3):
245
+ data = data.data[..., 3:7]
246
+
247
+ super(SO3, self).__init__(data)
248
+
249
+
250
+ class RxSO3(LieGroup):
251
+ group_name = 'RxSO3'
252
+ group_id = 2
253
+ manifold_dim = 4
254
+ embedded_dim = 5
255
+
256
+ # unit quaternion
257
+ id_elem = torch.as_tensor([0.0, 0.0, 0.0, 1.0, 1.0])
258
+
259
+ def __init__(self, data):
260
+ if isinstance(data, Sim3):
261
+ data = data.data[..., 3:8]
262
+
263
+ super(RxSO3, self).__init__(data)
264
+
265
+
266
+ class SE3(LieGroup):
267
+ group_name = 'SE3'
268
+ group_id = 3
269
+ manifold_dim = 6
270
+ embedded_dim = 7
271
+
272
+ # translation, unit quaternion
273
+ id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0])
274
+
275
+ def __init__(self, data):
276
+ if isinstance(data, SO3):
277
+ translation = torch.zeros_like(data.data[...,:3])
278
+ data = torch.cat([translation, data.data], -1)
279
+
280
+ super(SE3, self).__init__(data)
281
+
282
+ def scale(self, s):
283
+ t, q = self.data.split([3,4], -1)
284
+ t = t * s.unsqueeze(-1)
285
+ return SE3(torch.cat([t, q], dim=-1))
286
+
287
+
288
+ class Sim3(LieGroup):
289
+ group_name = 'Sim3'
290
+ group_id = 4
291
+ manifold_dim = 7
292
+ embedded_dim = 8
293
+
294
+ # translation, unit quaternion, scale
295
+ id_elem = torch.as_tensor([0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0])
296
+
297
+ def __init__(self, data):
298
+
299
+ if isinstance(data, SO3):
300
+ scale = torch.ones_like(SO3.data[...,:1])
301
+ translation = torch.zeros_like(SO3.data[...,:3])
302
+ data = torch.cat([translation, SO3.data, scale], -1)
303
+
304
+ elif isinstance(data, SE3):
305
+ scale = torch.ones_like(data.data[...,:1])
306
+ data = torch.cat([data.data, scale], -1)
307
+
308
+ elif isinstance(data, Sim3):
309
+ data = data.data
310
+
311
+ super(Sim3, self).__init__(data)
312
+
313
+
314
+ def cat(group_list, dim):
315
+ """ Concatenate groups along dimension """
316
+ data = torch.cat([X.data for X in group_list], dim=dim)
317
+ return group_list[0].__class__(data)
318
+
319
+ def stack(group_list, dim):
320
+ """ Concatenate groups along dimension """
321
+ data = torch.stack([X.data for X in group_list], dim=dim)
322
+ return group_list[0].__class__(data)
mini_dpvo/lietorch/include/common.h ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef COMMON_H
2
+ #define COMMON_H
3
+
4
+ #define EIGEN_DEFAULT_DENSE_INDEX_TYPE int
5
+ #define EIGEN_RUNTIME_NO_MALLOC
6
+
7
+ #define EPS 1e-6
8
+ #define PI 3.14159265358979323846
9
+
10
+
11
+ #endif
12
+
mini_dpvo/lietorch/include/dispatch.h ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #ifndef DISPATCH_H
2
+ #define DISPATCH_H
3
+
4
+ #include <torch/extension.h>
5
+
6
+ #include "so3.h"
7
+ #include "rxso3.h"
8
+ #include "se3.h"
9
+ #include "sim3.h"
10
+
11
+
12
+ #define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...) \
13
+ case enum_type: { \
14
+ using scalar_t = type; \
15
+ switch (group_index) { \
16
+ case 1: { \
17
+ using group_t = SO3<type>; \
18
+ return __VA_ARGS__(); \
19
+ } \
20
+ case 2: { \
21
+ using group_t = RxSO3<type>; \
22
+ return __VA_ARGS__(); \
23
+ } \
24
+ case 3: { \
25
+ using group_t = SE3<type>; \
26
+ return __VA_ARGS__(); \
27
+ } \
28
+ case 4: { \
29
+ using group_t = Sim3<type>; \
30
+ return __VA_ARGS__(); \
31
+ } \
32
+ } \
33
+ } \
34
+
35
+ #define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...) \
36
+ [&] { \
37
+ const auto& the_type = TYPE; \
38
+ /* don't use TYPE again in case it is an expensive or side-effect op */ \
39
+ at::ScalarType _st = ::detail::scalar_type(the_type); \
40
+ switch (_st) { \
41
+ PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__) \
42
+ PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__) \
43
+ default: break; \
44
+ } \
45
+ }()
46
+
47
+ #endif
48
+
mini_dpvo/lietorch/include/lietorch_cpu.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef LIETORCH_CPU_H_
3
+ #define LIETORCH_CPU_H_
4
+
5
+ #include <vector>
6
+ #include <torch/extension.h>
7
+
8
+
9
+ // unary operations
10
+ torch::Tensor exp_forward_cpu(int, torch::Tensor);
11
+ std::vector<torch::Tensor> exp_backward_cpu(int, torch::Tensor, torch::Tensor);
12
+
13
+ torch::Tensor log_forward_cpu(int, torch::Tensor);
14
+ std::vector<torch::Tensor> log_backward_cpu(int, torch::Tensor, torch::Tensor);
15
+
16
+ torch::Tensor inv_forward_cpu(int, torch::Tensor);
17
+ std::vector<torch::Tensor> inv_backward_cpu(int, torch::Tensor, torch::Tensor);
18
+
19
+ // binary operations
20
+ torch::Tensor mul_forward_cpu(int, torch::Tensor, torch::Tensor);
21
+ std::vector<torch::Tensor> mul_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
22
+
23
+ torch::Tensor adj_forward_cpu(int, torch::Tensor, torch::Tensor);
24
+ std::vector<torch::Tensor> adj_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
25
+
26
+ torch::Tensor adjT_forward_cpu(int, torch::Tensor, torch::Tensor);
27
+ std::vector<torch::Tensor> adjT_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
28
+
29
+ torch::Tensor act_forward_cpu(int, torch::Tensor, torch::Tensor);
30
+ std::vector<torch::Tensor> act_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
31
+
32
+ torch::Tensor act4_forward_cpu(int, torch::Tensor, torch::Tensor);
33
+ std::vector<torch::Tensor> act4_backward_cpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
34
+
35
+
36
+ // conversion operations
37
+ // std::vector<torch::Tensor> to_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
38
+ // std::vector<torch::Tensor> from_vec_backward_cpu(int, torch::Tensor, torch::Tensor);
39
+
40
+ // utility operations
41
+ torch::Tensor orthogonal_projector_cpu(int, torch::Tensor);
42
+
43
+ torch::Tensor as_matrix_forward_cpu(int, torch::Tensor);
44
+
45
+ torch::Tensor jleft_forward_cpu(int, torch::Tensor, torch::Tensor);
46
+
47
+
48
+ #endif
49
+
50
+
51
+
mini_dpvo/lietorch/include/lietorch_gpu.h ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef LIETORCH_GPU_H_
3
+ #define LIETORCH_GPU_H_
4
+
5
+ #include <vector>
6
+ #include <torch/extension.h>
7
+ #include <cuda.h>
8
+ #include <cuda_runtime.h>
9
+
10
+
11
+ // unary operations
12
+ torch::Tensor exp_forward_gpu(int, torch::Tensor);
13
+ std::vector<torch::Tensor> exp_backward_gpu(int, torch::Tensor, torch::Tensor);
14
+
15
+ torch::Tensor log_forward_gpu(int, torch::Tensor);
16
+ std::vector<torch::Tensor> log_backward_gpu(int, torch::Tensor, torch::Tensor);
17
+
18
+ torch::Tensor inv_forward_gpu(int, torch::Tensor);
19
+ std::vector<torch::Tensor> inv_backward_gpu(int, torch::Tensor, torch::Tensor);
20
+
21
+ // binary operations
22
+ torch::Tensor mul_forward_gpu(int, torch::Tensor, torch::Tensor);
23
+ std::vector<torch::Tensor> mul_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
24
+
25
+ torch::Tensor adj_forward_gpu(int, torch::Tensor, torch::Tensor);
26
+ std::vector<torch::Tensor> adj_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
27
+
28
+ torch::Tensor adjT_forward_gpu(int, torch::Tensor, torch::Tensor);
29
+ std::vector<torch::Tensor> adjT_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
30
+
31
+ torch::Tensor act_forward_gpu(int, torch::Tensor, torch::Tensor);
32
+ std::vector<torch::Tensor> act_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
33
+
34
+ torch::Tensor act4_forward_gpu(int, torch::Tensor, torch::Tensor);
35
+ std::vector<torch::Tensor> act4_backward_gpu(int, torch::Tensor, torch::Tensor, torch::Tensor);
36
+
37
+ // conversion operations
38
+ // std::vector<torch::Tensor> to_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
39
+ // std::vector<torch::Tensor> from_vec_backward_gpu(int, torch::Tensor, torch::Tensor);
40
+
41
+ // utility operators
42
+ torch::Tensor orthogonal_projector_gpu(int, torch::Tensor);
43
+
44
+ torch::Tensor as_matrix_forward_gpu(int, torch::Tensor);
45
+
46
+ torch::Tensor jleft_forward_gpu(int, torch::Tensor, torch::Tensor);
47
+
48
+ #endif
49
+
50
+
51
+
mini_dpvo/lietorch/include/rxso3.h ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef RxSO3_HEADER
3
+ #define RxSO3_HEADER
4
+
5
+ #include <stdio.h>
6
+ #include <Eigen/Dense>
7
+ #include <Eigen/Geometry>
8
+
9
+ #include "common.h"
10
+
11
+ template <typename Scalar>
12
+ class RxSO3 {
13
+ public:
14
+ const static int constexpr K = 4; // manifold dimension
15
+ const static int constexpr N = 5; // embedding dimension
16
+
17
+ using Vector3 = Eigen::Matrix<Scalar,3,1>;
18
+ using Vector4 = Eigen::Matrix<Scalar,4,1>;
19
+ using Matrix3 = Eigen::Matrix<Scalar,3,3>;
20
+
21
+ using Tangent = Eigen::Matrix<Scalar,K,1>;
22
+ using Data = Eigen::Matrix<Scalar,N,1>;
23
+
24
+ using Point = Eigen::Matrix<Scalar,3,1>;
25
+ using Point4 = Eigen::Matrix<Scalar,4,1>;
26
+
27
+ using Quaternion = Eigen::Quaternion<Scalar>;
28
+ using Transformation = Eigen::Matrix<Scalar,3,3>;
29
+ using Adjoint = Eigen::Matrix<Scalar,4,4>;
30
+
31
+
32
+ EIGEN_DEVICE_FUNC RxSO3(Quaternion const& q, Scalar const s)
33
+ : unit_quaternion(q), scale(s) {
34
+ unit_quaternion.normalize();
35
+ };
36
+
37
+ EIGEN_DEVICE_FUNC RxSO3(const Scalar *data) : unit_quaternion(data), scale(data[4]) {
38
+ unit_quaternion.normalize();
39
+ };
40
+
41
+ EIGEN_DEVICE_FUNC RxSO3() {
42
+ unit_quaternion = Quaternion::Identity();
43
+ scale = Scalar(1.0);
44
+ }
45
+
46
+ EIGEN_DEVICE_FUNC RxSO3<Scalar> inv() {
47
+ return RxSO3<Scalar>(unit_quaternion.conjugate(), 1.0/scale);
48
+ }
49
+
50
+ EIGEN_DEVICE_FUNC Data data() const {
51
+ Data data_vec; data_vec << unit_quaternion.coeffs(), scale;
52
+ return data_vec;
53
+ }
54
+
55
+ EIGEN_DEVICE_FUNC RxSO3<Scalar> operator*(RxSO3<Scalar> const& other) {
56
+ return RxSO3<Scalar>(unit_quaternion * other.unit_quaternion, scale * other.scale);
57
+ }
58
+
59
+ EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
60
+ const Quaternion& q = unit_quaternion;
61
+ Point uv = q.vec().cross(p); uv += uv;
62
+ return scale * (p + q.w()*uv + q.vec().cross(uv));
63
+ }
64
+
65
+ EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
66
+ Point4 p1; p1 << this->operator*(p.template segment<3>(0)), p(3);
67
+ return p1;
68
+ }
69
+
70
+ EIGEN_DEVICE_FUNC Adjoint Adj() const {
71
+ Adjoint Ad = Adjoint::Identity();
72
+ Ad.template block<3,3>(0,0) = unit_quaternion.toRotationMatrix();
73
+ return Ad;
74
+ }
75
+
76
+ EIGEN_DEVICE_FUNC Transformation Matrix() const {
77
+ return scale * unit_quaternion.toRotationMatrix();
78
+ }
79
+
80
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> Matrix4x4() const {
81
+ Eigen::Matrix<Scalar,4,4> T;
82
+ T = Eigen::Matrix<Scalar,4,4>::Identity();
83
+ T.template block<3,3>(0,0) = Matrix();
84
+ return T;
85
+ }
86
+
87
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,5,5> orthogonal_projector() const {
88
+ // jacobian action on a point
89
+ Eigen::Matrix<Scalar,5,5> J = Eigen::Matrix<Scalar,5,5>::Zero();
90
+
91
+ J.template block<3,3>(0,0) = 0.5 * (
92
+ unit_quaternion.w() * Matrix3::Identity() +
93
+ SO3<Scalar>::hat(-unit_quaternion.vec())
94
+ );
95
+
96
+ J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
97
+
98
+ // scale
99
+ J(4,3) = scale;
100
+
101
+ return J;
102
+ }
103
+
104
+ EIGEN_DEVICE_FUNC Transformation Rotation() const {
105
+ return unit_quaternion.toRotationMatrix();
106
+ }
107
+
108
+ EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
109
+ return Adj() * a;
110
+ }
111
+
112
+ EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
113
+ return Adj().transpose() * a;
114
+ }
115
+
116
+ EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& phi_sigma) {
117
+ Vector3 const phi = phi_sigma.template segment<3>(0);
118
+ return SO3<Scalar>::hat(phi) + phi(3) * Transformation::Identity();
119
+ }
120
+
121
+ EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& phi_sigma) {
122
+ Vector3 const phi = phi_sigma.template segment<3>(0);
123
+ Matrix3 const Phi = SO3<Scalar>::hat(phi);
124
+
125
+ Adjoint ad = Adjoint::Zero();
126
+ ad.template block<3,3>(0,0) = Phi;
127
+
128
+ return ad;
129
+ }
130
+
131
+ EIGEN_DEVICE_FUNC Tangent Log() const {
132
+ using std::abs;
133
+ using std::atan;
134
+ using std::sqrt;
135
+
136
+ Scalar squared_n = unit_quaternion.vec().squaredNorm();
137
+ Scalar w = unit_quaternion.w();
138
+ Scalar two_atan_nbyw_by_n;
139
+
140
+ /// Atan-based log thanks to
141
+ ///
142
+ /// C. Hertzberg et al.:
143
+ /// "Integrating Generic Sensor Fusion Algorithms with Sound State
144
+ /// Representation through Encapsulation of Manifolds"
145
+ /// Information Fusion, 2011
146
+
147
+ if (squared_n < EPS * EPS) {
148
+ two_atan_nbyw_by_n = Scalar(2) / w - Scalar(2.0/3.0) * (squared_n) / (w * w * w);
149
+ } else {
150
+ Scalar n = sqrt(squared_n);
151
+ if (abs(w) < EPS) {
152
+ if (w > Scalar(0)) {
153
+ two_atan_nbyw_by_n = PI / n;
154
+ } else {
155
+ two_atan_nbyw_by_n = -PI / n;
156
+ }
157
+ } else {
158
+ two_atan_nbyw_by_n = Scalar(2) * atan(n / w) / n;
159
+ }
160
+ }
161
+
162
+ Tangent phi_sigma;
163
+ phi_sigma << two_atan_nbyw_by_n * unit_quaternion.vec(), log(scale);
164
+
165
+ return phi_sigma;
166
+ }
167
+
168
+ EIGEN_DEVICE_FUNC static RxSO3<Scalar> Exp(Tangent const& phi_sigma) {
169
+ Vector3 phi = phi_sigma.template segment<3>(0);
170
+ Scalar scale = exp(phi_sigma(3));
171
+
172
+ Scalar theta2 = phi.squaredNorm();
173
+ Scalar theta = sqrt(theta2);
174
+ Scalar imag_factor;
175
+ Scalar real_factor;
176
+
177
+ if (theta < EPS) {
178
+ Scalar theta4 = theta2 * theta2;
179
+ imag_factor = Scalar(0.5) - Scalar(1.0/48.0) * theta2 + Scalar(1.0/3840.0) * theta4;
180
+ real_factor = Scalar(1) - Scalar(1.0/8.0) * theta2 + Scalar(1.0/384.0) * theta4;
181
+ } else {
182
+ imag_factor = sin(.5 * theta) / theta;
183
+ real_factor = cos(.5 * theta);
184
+ }
185
+
186
+ Quaternion q(real_factor, imag_factor*phi.x(), imag_factor*phi.y(), imag_factor*phi.z());
187
+ return RxSO3<Scalar>(q, scale);
188
+ }
189
+
190
+ EIGEN_DEVICE_FUNC static Matrix3 calcW(Tangent const& phi_sigma) {
191
+ // left jacobian
192
+ using std::abs;
193
+ Matrix3 const I = Matrix3::Identity();
194
+ Scalar const one(1);
195
+ Scalar const half(0.5);
196
+
197
+ Vector3 const phi = phi_sigma.template segment<3>(0);
198
+ Scalar const sigma = phi_sigma(3);
199
+ Scalar const theta = phi.norm();
200
+
201
+ Matrix3 const Phi = SO3<Scalar>::hat(phi);
202
+ Matrix3 const Phi2 = Phi * Phi;
203
+ Scalar const scale = exp(sigma);
204
+
205
+ Scalar A, B, C;
206
+ if (abs(sigma) < EPS) {
207
+ C = one;
208
+ if (abs(theta) < EPS) {
209
+ A = half;
210
+ B = Scalar(1. / 6.);
211
+ } else {
212
+ Scalar theta_sq = theta * theta;
213
+ A = (one - cos(theta)) / theta_sq;
214
+ B = (theta - sin(theta)) / (theta_sq * theta);
215
+ }
216
+ } else {
217
+ C = (scale - one) / sigma;
218
+ if (abs(theta) < EPS) {
219
+ Scalar sigma_sq = sigma * sigma;
220
+ A = ((sigma - one) * scale + one) / sigma_sq;
221
+ B = (scale * half * sigma_sq + scale - one - sigma * scale) /
222
+ (sigma_sq * sigma);
223
+ } else {
224
+ Scalar theta_sq = theta * theta;
225
+ Scalar a = scale * sin(theta);
226
+ Scalar b = scale * cos(theta);
227
+ Scalar c = theta_sq + sigma * sigma;
228
+ A = (a * sigma + (one - b) * theta) / (theta * c);
229
+ B = (C - ((b - one) * sigma + a * theta) / (c)) * one / (theta_sq);
230
+ }
231
+ }
232
+ return A * Phi + B * Phi2 + C * I;
233
+ }
234
+
235
+ EIGEN_DEVICE_FUNC static Matrix3 calcWInv(Tangent const& phi_sigma) {
236
+ // left jacobian inverse
237
+ Matrix3 const I = Matrix3::Identity();
238
+ Scalar const half(0.5);
239
+ Scalar const one(1);
240
+ Scalar const two(2);
241
+
242
+ Vector3 const phi = phi_sigma.template segment<3>(0);
243
+ Scalar const sigma = phi_sigma(3);
244
+ Scalar const theta = phi.norm();
245
+ Scalar const scale = exp(sigma);
246
+
247
+ Matrix3 const Phi = SO3<Scalar>::hat(phi);
248
+ Matrix3 const Phi2 = Phi * Phi;
249
+ Scalar const scale_sq = scale * scale;
250
+ Scalar const theta_sq = theta * theta;
251
+ Scalar const sin_theta = sin(theta);
252
+ Scalar const cos_theta = cos(theta);
253
+
254
+ Scalar a, b, c;
255
+ if (abs(sigma * sigma) < EPS) {
256
+ c = one - half * sigma;
257
+ a = -half;
258
+ if (abs(theta_sq) < EPS) {
259
+ b = Scalar(1. / 12.);
260
+ } else {
261
+ b = (theta * sin_theta + two * cos_theta - two) /
262
+ (two * theta_sq * (cos_theta - one));
263
+ }
264
+ } else {
265
+ Scalar const scale_cu = scale_sq * scale;
266
+ c = sigma / (scale - one);
267
+ if (abs(theta_sq) < EPS) {
268
+ a = (-sigma * scale + scale - one) / ((scale - one) * (scale - one));
269
+ b = (scale_sq * sigma - two * scale_sq + scale * sigma + two * scale) /
270
+ (two * scale_cu - Scalar(6) * scale_sq + Scalar(6) * scale - two);
271
+ } else {
272
+ Scalar const s_sin_theta = scale * sin_theta;
273
+ Scalar const s_cos_theta = scale * cos_theta;
274
+ a = (theta * s_cos_theta - theta - sigma * s_sin_theta) /
275
+ (theta * (scale_sq - two * s_cos_theta + one));
276
+ b = -scale *
277
+ (theta * s_sin_theta - theta * sin_theta + sigma * s_cos_theta -
278
+ scale * sigma + sigma * cos_theta - sigma) /
279
+ (theta_sq * (scale_cu - two * scale * s_cos_theta - scale_sq +
280
+ two * s_cos_theta + scale - one));
281
+ }
282
+ }
283
+ return a * Phi + b * Phi2 + c * I;
284
+ }
285
+
286
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& phi_sigma) {
287
+ // left jacobian
288
+ Adjoint J = Adjoint::Identity();
289
+ Vector3 phi = phi_sigma.template segment<3>(0);
290
+ J.template block<3,3>(0,0) = SO3<Scalar>::left_jacobian(phi);
291
+ return J;
292
+ }
293
+
294
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi_sigma) {
295
+ // left jacobian inverse
296
+ Adjoint Jinv = Adjoint::Identity();
297
+ Vector3 phi = phi_sigma.template segment<3>(0);
298
+ Jinv.template block<3,3>(0,0) = SO3<Scalar>::left_jacobian_inverse(phi);
299
+ return Jinv;
300
+ }
301
+
302
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,4> act_jacobian(Point const& p) {
303
+ // jacobian action on a point
304
+ Eigen::Matrix<Scalar,3,4> Ja;
305
+ Ja << SO3<Scalar>::hat(-p), p;
306
+ return Ja;
307
+ }
308
+
309
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,4> act4_jacobian(Point4 const& p) {
310
+ // jacobian action on a point
311
+ Eigen::Matrix<Scalar,4,4> J = Eigen::Matrix<Scalar,4,4>::Zero();
312
+ J.template block<3,3>(0,0) = SO3<Scalar>::hat(-p.template segment<3>(0));
313
+ J.template block<3,1>(0,3) = p.template segment<3>(0);
314
+ return J;
315
+ }
316
+
317
+ private:
318
+ Quaternion unit_quaternion;
319
+ Scalar scale;
320
+ };
321
+
322
+ #endif
323
+
324
+
mini_dpvo/lietorch/include/se3.h ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef SE3_HEADER
3
+ #define SE3_HEADER
4
+
5
+ #include <stdio.h>
6
+ #include <Eigen/Dense>
7
+ #include <Eigen/Geometry>
8
+
9
+ #include "common.h"
10
+ #include "so3.h"
11
+
12
+
13
+ template <typename Scalar>
14
+ class SE3 {
15
+ public:
16
+ const static int constexpr K = 6; // manifold dimension
17
+ const static int constexpr N = 7; // embedding dimension
18
+
19
+ using Vector3 = Eigen::Matrix<Scalar,3,1>;
20
+ using Vector4 = Eigen::Matrix<Scalar,4,1>;
21
+ using Matrix3 = Eigen::Matrix<Scalar,3,3>;
22
+
23
+ using Tangent = Eigen::Matrix<Scalar,K,1>;
24
+ using Point = Eigen::Matrix<Scalar,3,1>;
25
+ using Point4 = Eigen::Matrix<Scalar,4,1>;
26
+ using Data = Eigen::Matrix<Scalar,N,1>;
27
+ using Transformation = Eigen::Matrix<Scalar,4,4>;
28
+ using Adjoint = Eigen::Matrix<Scalar,K,K>;
29
+
30
+ EIGEN_DEVICE_FUNC SE3() { translation = Vector3::Zero(); }
31
+
32
+ EIGEN_DEVICE_FUNC SE3(SO3<Scalar> const& so3, Vector3 const& t) : so3(so3), translation(t) {};
33
+
34
+ EIGEN_DEVICE_FUNC SE3(const Scalar *data) : translation(data), so3(data+3) {};
35
+
36
+ EIGEN_DEVICE_FUNC SE3<Scalar> inv() {
37
+ return SE3(so3.inv(), -(so3.inv()*translation));
38
+ }
39
+
40
+ EIGEN_DEVICE_FUNC Data data() const {
41
+ Data data_vec; data_vec << translation, so3.data();
42
+ return data_vec;
43
+ }
44
+
45
+ EIGEN_DEVICE_FUNC SE3<Scalar> operator*(SE3<Scalar> const& other) {
46
+ return SE3(so3 * other.so3, translation + so3 * other.translation);
47
+ }
48
+
49
+ EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
50
+ return so3 * p + translation;
51
+ }
52
+
53
+ EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
54
+ Point4 p1; p1 << so3 * p.template segment<3>(0) + translation * p(3), p(3);
55
+ return p1;
56
+ }
57
+
58
+ EIGEN_DEVICE_FUNC Adjoint Adj() const {
59
+ Matrix3 R = so3.Matrix();
60
+ Matrix3 tx = SO3<Scalar>::hat(translation);
61
+ Matrix3 Zer = Matrix3::Zero();
62
+
63
+ Adjoint Ad;
64
+ Ad << R, tx*R, Zer, R;
65
+
66
+ return Ad;
67
+ }
68
+
69
+ EIGEN_DEVICE_FUNC Transformation Matrix() const {
70
+ Transformation T = Transformation::Identity();
71
+ T.template block<3,3>(0,0) = so3.Matrix();
72
+ T.template block<3,1>(0,3) = translation;
73
+ return T;
74
+ }
75
+
76
+ EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
77
+ return Matrix();
78
+ }
79
+
80
+ EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
81
+ return Adj() * a;
82
+ }
83
+
84
+ EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
85
+ return Adj().transpose() * a;
86
+ }
87
+
88
+
89
+ EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi) {
90
+ Vector3 tau = tau_phi.template segment<3>(0);
91
+ Vector3 phi = tau_phi.template segment<3>(3);
92
+
93
+ Transformation TauPhi = Transformation::Zero();
94
+ TauPhi.template block<3,3>(0,0) = SO3<Scalar>::hat(phi);
95
+ TauPhi.template block<3,1>(0,3) = tau;
96
+
97
+ return TauPhi;
98
+ }
99
+
100
+ EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi) {
101
+ Vector3 tau = tau_phi.template segment<3>(0);
102
+ Vector3 phi = tau_phi.template segment<3>(3);
103
+
104
+ Matrix3 Tau = SO3<Scalar>::hat(tau);
105
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
106
+ Matrix3 Zer = Matrix3::Zero();
107
+
108
+ Adjoint ad;
109
+ ad << Phi, Tau, Zer, Phi;
110
+
111
+ return ad;
112
+ }
113
+
114
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,7,7> orthogonal_projector() const {
115
+ // jacobian action on a point
116
+ Eigen::Matrix<Scalar,7,7> J = Eigen::Matrix<Scalar,7,7>::Zero();
117
+ J.template block<3,3>(0,0) = Matrix3::Identity();
118
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
119
+ J.template block<4,4>(3,3) = so3.orthogonal_projector();
120
+
121
+ return J;
122
+ }
123
+
124
+ EIGEN_DEVICE_FUNC Tangent Log() const {
125
+ Vector3 phi = so3.Log();
126
+ Matrix3 Vinv = SO3<Scalar>::left_jacobian_inverse(phi);
127
+
128
+ Tangent tau_phi;
129
+ tau_phi << Vinv * translation, phi;
130
+
131
+ return tau_phi;
132
+ }
133
+
134
+ EIGEN_DEVICE_FUNC static SE3<Scalar> Exp(Tangent const& tau_phi) {
135
+ Vector3 tau = tau_phi.template segment<3>(0);
136
+ Vector3 phi = tau_phi.template segment<3>(3);
137
+
138
+ SO3<Scalar> so3 = SO3<Scalar>::Exp(phi);
139
+ Vector3 t = SO3<Scalar>::left_jacobian(phi) * tau;
140
+
141
+ return SE3<Scalar>(so3, t);
142
+ }
143
+
144
+ EIGEN_DEVICE_FUNC static Matrix3 calcQ(Tangent const& tau_phi) {
145
+ // Q matrix
146
+ Vector3 tau = tau_phi.template segment<3>(0);
147
+ Vector3 phi = tau_phi.template segment<3>(3);
148
+ Matrix3 Tau = SO3<Scalar>::hat(tau);
149
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
150
+
151
+ Scalar theta = phi.norm();
152
+ Scalar theta_pow2 = theta * theta;
153
+ Scalar theta_pow4 = theta_pow2 * theta_pow2;
154
+
155
+ Scalar coef1 = (theta < EPS) ?
156
+ Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta_pow2 :
157
+ (theta - sin(theta)) / (theta_pow2 * theta);
158
+
159
+ Scalar coef2 = (theta < EPS) ?
160
+ Scalar(1.0/24.0) - Scalar(1.0/720.0) * theta_pow2 :
161
+ (theta_pow2 + 2*cos(theta) - 2) / (2 * theta_pow4);
162
+
163
+ Scalar coef3 = (theta < EPS) ?
164
+ Scalar(1.0/120.0) - Scalar(1.0/2520.0) * theta_pow2 :
165
+ (2*theta - 3*sin(theta) + theta*cos(theta)) / (2 * theta_pow4 * theta);
166
+
167
+ Matrix3 Q = Scalar(0.5) * Tau +
168
+ coef1 * (Phi*Tau + Tau*Phi + Phi*Tau*Phi) +
169
+ coef2 * (Phi*Phi*Tau + Tau*Phi*Phi - 3*Phi*Tau*Phi) +
170
+ coef3 * (Phi*Tau*Phi*Phi + Phi*Phi*Tau*Phi);
171
+
172
+ return Q;
173
+ }
174
+
175
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi) {
176
+ // left jacobian
177
+ Vector3 phi = tau_phi.template segment<3>(3);
178
+ Matrix3 J = SO3<Scalar>::left_jacobian(phi);
179
+ Matrix3 Q = SE3<Scalar>::calcQ(tau_phi);
180
+ Matrix3 Zer = Matrix3::Zero();
181
+
182
+ Adjoint J6x6;
183
+ J6x6 << J, Q, Zer, J;
184
+
185
+ return J6x6;
186
+ }
187
+
188
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi) {
189
+ // left jacobian inverse
190
+ Vector3 tau = tau_phi.template segment<3>(0);
191
+ Vector3 phi = tau_phi.template segment<3>(3);
192
+ Matrix3 Jinv = SO3<Scalar>::left_jacobian_inverse(phi);
193
+ Matrix3 Q = SE3<Scalar>::calcQ(tau_phi);
194
+ Matrix3 Zer = Matrix3::Zero();
195
+
196
+ Adjoint J6x6;
197
+ J6x6 << Jinv, -Jinv * Q * Jinv, Zer, Jinv;
198
+
199
+ return J6x6;
200
+
201
+ }
202
+
203
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,6> act_jacobian(Point const& p) {
204
+ // jacobian action on a point
205
+ Eigen::Matrix<Scalar,3,6> J;
206
+ J.template block<3,3>(0,0) = Matrix3::Identity();
207
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p);
208
+ return J;
209
+ }
210
+
211
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,6> act4_jacobian(Point4 const& p) {
212
+ // jacobian action on a point
213
+ Eigen::Matrix<Scalar,4,6> J = Eigen::Matrix<Scalar,4,6>::Zero();
214
+ J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
215
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p.template segment<3>(0));
216
+ return J;
217
+ }
218
+
219
+
220
+
221
+
222
+ private:
223
+ SO3<Scalar> so3;
224
+ Vector3 translation;
225
+
226
+ };
227
+
228
+ #endif
229
+
mini_dpvo/lietorch/include/sim3.h ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef Sim3_HEADER
3
+ #define Sim3_HEADER
4
+
5
+ #include <stdio.h>
6
+ #include <iostream>
7
+
8
+ #include <Eigen/Dense>
9
+ #include <Eigen/Geometry>
10
+
11
+ #include "common.h"
12
+ #include "so3.h"
13
+ #include "rxso3.h"
14
+
15
+
16
+ template <typename Scalar>
17
+ class Sim3 {
18
+ public:
19
+ const static int constexpr K = 7; // manifold dimension
20
+ const static int constexpr N = 8; // embedding dimension
21
+
22
+ using Vector3 = Eigen::Matrix<Scalar,3,1>;
23
+ using Vector4 = Eigen::Matrix<Scalar,4,1>;
24
+ using Matrix3 = Eigen::Matrix<Scalar,3,3>;
25
+
26
+ using Tangent = Eigen::Matrix<Scalar,K,1>;
27
+ using Point = Eigen::Matrix<Scalar,3,1>;
28
+ using Point4 = Eigen::Matrix<Scalar,4,1>;
29
+ using Data = Eigen::Matrix<Scalar,N,1>;
30
+ using Transformation = Eigen::Matrix<Scalar,4,4>;
31
+ using Adjoint = Eigen::Matrix<Scalar,K,K>;
32
+
33
+ EIGEN_DEVICE_FUNC Sim3() {
34
+ translation = Vector3::Zero();
35
+ }
36
+
37
+ EIGEN_DEVICE_FUNC Sim3(RxSO3<Scalar> const& rxso3, Vector3 const& t)
38
+ : rxso3(rxso3), translation(t) {};
39
+
40
+ EIGEN_DEVICE_FUNC Sim3(const Scalar *data)
41
+ : translation(data), rxso3(data+3) {};
42
+
43
+ EIGEN_DEVICE_FUNC Sim3<Scalar> inv() {
44
+ return Sim3<Scalar>(rxso3.inv(), -(rxso3.inv() * translation));
45
+ }
46
+
47
+ EIGEN_DEVICE_FUNC Data data() const {
48
+ Data data_vec; data_vec << translation, rxso3.data();
49
+ return data_vec;
50
+ }
51
+
52
+ EIGEN_DEVICE_FUNC Sim3<Scalar> operator*(Sim3<Scalar> const& other) {
53
+ return Sim3(rxso3 * other.rxso3, translation + rxso3 * other.translation);
54
+ }
55
+
56
+ EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
57
+ return (rxso3 * p) + translation;
58
+ }
59
+
60
+ EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
61
+ Point4 p1; p1 << rxso3 * p.template segment<3>(0) + p(3) * translation , p(3);
62
+ return p1;
63
+ }
64
+
65
+ EIGEN_DEVICE_FUNC Transformation Matrix() const {
66
+ Transformation T = Transformation::Identity();
67
+ T.template block<3,3>(0,0) = rxso3.Matrix();
68
+ T.template block<3,1>(0,3) = translation;
69
+ return T;
70
+ }
71
+
72
+ EIGEN_DEVICE_FUNC Transformation Matrix4x4() const {
73
+ Transformation T = Transformation::Identity();
74
+ T.template block<3,3>(0,0) = rxso3.Matrix();
75
+ T.template block<3,1>(0,3) = translation;
76
+ return T;
77
+ }
78
+
79
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,8,8> orthogonal_projector() const {
80
+ // jacobian action on a point
81
+ Eigen::Matrix<Scalar,8,8> J = Eigen::Matrix<Scalar,8,8>::Zero();
82
+ J.template block<3,3>(0,0) = Matrix3::Identity();
83
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-translation);
84
+ J.template block<3,1>(0,6) = translation;
85
+ J.template block<5,5>(3,3) = rxso3.orthogonal_projector();
86
+ return J;
87
+ }
88
+
89
+ EIGEN_DEVICE_FUNC Adjoint Adj() const {
90
+ Adjoint Ad = Adjoint::Identity();
91
+ Matrix3 sR = rxso3.Matrix();
92
+ Matrix3 tx = SO3<Scalar>::hat(translation);
93
+ Matrix3 R = rxso3.Rotation();
94
+
95
+ Ad.template block<3,3>(0,0) = sR;
96
+ Ad.template block<3,3>(0,3) = tx * R;
97
+ Ad.template block<3,1>(0,6) = -translation;
98
+ Ad.template block<3,3>(3,3) = R;
99
+
100
+ return Ad;
101
+ }
102
+
103
+ EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
104
+ return Adj() * a;
105
+ }
106
+
107
+ EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
108
+ return Adj().transpose() * a;
109
+ }
110
+
111
+ EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& tau_phi_sigma) {
112
+ Vector3 tau = tau_phi_sigma.template segment<3>(0);
113
+ Vector3 phi = tau_phi_sigma.template segment<3>(3);
114
+ Scalar sigma = tau_phi_sigma(6);
115
+
116
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
117
+ Matrix3 I = Matrix3::Identity();
118
+
119
+ Transformation Omega = Transformation::Zero();
120
+ Omega.template block<3,3>(0,0) = Phi + sigma * I;
121
+ Omega.template block<3,1>(0,3) = tau;
122
+
123
+ return Omega;
124
+ }
125
+
126
+ EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& tau_phi_sigma) {
127
+ Adjoint ad = Adjoint::Zero();
128
+ Vector3 tau = tau_phi_sigma.template segment<3>(0);
129
+ Vector3 phi = tau_phi_sigma.template segment<3>(3);
130
+ Scalar sigma = tau_phi_sigma(6);
131
+
132
+ Matrix3 Tau = SO3<Scalar>::hat(tau);
133
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
134
+ Matrix3 I = Matrix3::Identity();
135
+
136
+ ad.template block<3,3>(0,0) = Phi + sigma * I;
137
+ ad.template block<3,3>(0,3) = Tau;
138
+ ad.template block<3,1>(0,6) = -tau;
139
+ ad.template block<3,3>(3,3) = Phi;
140
+
141
+ return ad;
142
+ }
143
+
144
+
145
+ EIGEN_DEVICE_FUNC Tangent Log() const {
146
+ // logarithm map
147
+ Vector4 phi_sigma = rxso3.Log();
148
+ Matrix3 W = RxSO3<Scalar>::calcW(phi_sigma);
149
+
150
+ Tangent tau_phi_sigma;
151
+ tau_phi_sigma << W.inverse() * translation, phi_sigma;
152
+
153
+ return tau_phi_sigma;
154
+ }
155
+
156
+ EIGEN_DEVICE_FUNC static Sim3<Scalar> Exp(Tangent const& tau_phi_sigma) {
157
+ // exponential map
158
+ Vector3 tau = tau_phi_sigma.template segment<3>(0);
159
+ Vector4 phi_sigma = tau_phi_sigma.template segment<4>(3);
160
+
161
+ RxSO3<Scalar> rxso3 = RxSO3<Scalar>::Exp(phi_sigma);
162
+ Matrix3 W = RxSO3<Scalar>::calcW(phi_sigma);
163
+
164
+ return Sim3<Scalar>(rxso3, W*tau);
165
+ }
166
+
167
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& tau_phi_sigma) {
168
+ // left jacobian
169
+ Adjoint const Xi = adj(tau_phi_sigma);
170
+ Adjoint const Xi2 = Xi * Xi;
171
+ Adjoint const Xi4 = Xi2 * Xi2;
172
+
173
+ return Adjoint::Identity()
174
+ + Scalar(1.0/2.0)*Xi
175
+ + Scalar(1.0/6.0)*Xi2
176
+ + Scalar(1.0/24.0)*Xi*Xi2
177
+ + Scalar(1.0/120.0)*Xi4;
178
+ + Scalar(1.0/720.0)*Xi*Xi4;
179
+ }
180
+
181
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& tau_phi_sigma) {
182
+ // left jacobian inverse
183
+ Adjoint const Xi = adj(tau_phi_sigma);
184
+ Adjoint const Xi2 = Xi * Xi;
185
+ Adjoint const Xi4 = Xi2 * Xi2;
186
+
187
+ return Adjoint::Identity()
188
+ - Scalar(1.0/2.0)*Xi
189
+ + Scalar(1.0/12.0)*Xi2
190
+ - Scalar(1.0/720.0)*Xi4;
191
+ }
192
+
193
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,7> act_jacobian(Point const& p) {
194
+ // jacobian action on a point
195
+ Eigen::Matrix<Scalar,3,7> J;
196
+ J.template block<3,3>(0,0) = Matrix3::Identity();
197
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p);
198
+ J.template block<3,1>(0,6) = p;
199
+ return J;
200
+ }
201
+
202
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,7> act4_jacobian(Point4 const& p) {
203
+ // jacobian action on a point
204
+ Eigen::Matrix<Scalar,4,7> J = Eigen::Matrix<Scalar,4,7>::Zero();
205
+ J.template block<3,3>(0,0) = p(3) * Matrix3::Identity();
206
+ J.template block<3,3>(0,3) = SO3<Scalar>::hat(-p.template segment<3>(0));
207
+ J.template block<3,1>(0,6) = p.template segment<3>(0);
208
+ return J;
209
+ }
210
+
211
+ private:
212
+ Vector3 translation;
213
+ RxSO3<Scalar> rxso3;
214
+ };
215
+
216
+ #endif
217
+
mini_dpvo/lietorch/include/so3.h ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #ifndef SO3_HEADER
3
+ #define SO3_HEADER
4
+
5
+ #include <cuda.h>
6
+ #include <stdio.h>
7
+ #include <Eigen/Dense>
8
+ #include <Eigen/Geometry>
9
+
10
+ #include "common.h"
11
+
12
+ template <typename Scalar>
13
+ class SO3 {
14
+ public:
15
+ const static int constexpr K = 3; // manifold dimension
16
+ const static int constexpr N = 4; // embedding dimension
17
+
18
+ using Vector3 = Eigen::Matrix<Scalar,3,1>;
19
+ using Vector4 = Eigen::Matrix<Scalar,4,1>;
20
+ using Matrix3 = Eigen::Matrix<Scalar,3,3>;
21
+
22
+ using Tangent = Eigen::Matrix<Scalar,K,1>;
23
+ using Data = Eigen::Matrix<Scalar,N,1>;
24
+
25
+ using Point = Eigen::Matrix<Scalar,3,1>;
26
+ using Point4 = Eigen::Matrix<Scalar,4,1>;
27
+ using Transformation = Eigen::Matrix<Scalar,3,3>;
28
+ using Adjoint = Eigen::Matrix<Scalar,K,K>;
29
+ using Quaternion = Eigen::Quaternion<Scalar>;
30
+
31
+ EIGEN_DEVICE_FUNC SO3(Quaternion const& q) : unit_quaternion(q) {
32
+ unit_quaternion.normalize();
33
+ };
34
+
35
+ EIGEN_DEVICE_FUNC SO3(const Scalar *data) : unit_quaternion(data) {
36
+ unit_quaternion.normalize();
37
+ };
38
+
39
+ EIGEN_DEVICE_FUNC SO3() {
40
+ unit_quaternion = Quaternion::Identity();
41
+ }
42
+
43
+ EIGEN_DEVICE_FUNC SO3<Scalar> inv() {
44
+ return SO3<Scalar>(unit_quaternion.conjugate());
45
+ }
46
+
47
+ EIGEN_DEVICE_FUNC Data data() const {
48
+ return unit_quaternion.coeffs();
49
+ }
50
+
51
+ EIGEN_DEVICE_FUNC SO3<Scalar> operator*(SO3<Scalar> const& other) {
52
+ return SO3(unit_quaternion * other.unit_quaternion);
53
+ }
54
+
55
+ EIGEN_DEVICE_FUNC Point operator*(Point const& p) const {
56
+ const Quaternion& q = unit_quaternion;
57
+ Point uv = q.vec().cross(p);
58
+ uv += uv;
59
+ return p + q.w()*uv + q.vec().cross(uv);
60
+ }
61
+
62
+ EIGEN_DEVICE_FUNC Point4 act4(Point4 const& p) const {
63
+ Point4 p1; p1 << this->operator*(p.template segment<3>(0)), p(3);
64
+ return p1;
65
+ }
66
+
67
+ EIGEN_DEVICE_FUNC Adjoint Adj() const {
68
+ return unit_quaternion.toRotationMatrix();
69
+ }
70
+
71
+ EIGEN_DEVICE_FUNC Transformation Matrix() const {
72
+ return unit_quaternion.toRotationMatrix();
73
+ }
74
+
75
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> Matrix4x4() const {
76
+ Eigen::Matrix<Scalar,4,4> T = Eigen::Matrix<Scalar,4,4>::Identity();
77
+ T.template block<3,3>(0,0) = Matrix();
78
+ return T;
79
+ }
80
+
81
+ EIGEN_DEVICE_FUNC Eigen::Matrix<Scalar,4,4> orthogonal_projector() const {
82
+ // jacobian action on a point
83
+ Eigen::Matrix<Scalar,4,4> J = Eigen::Matrix<Scalar,4,4>::Zero();
84
+ J.template block<3,3>(0,0) = 0.5 * (
85
+ unit_quaternion.w() * Matrix3::Identity() +
86
+ SO3<Scalar>::hat(-unit_quaternion.vec())
87
+ );
88
+
89
+ J.template block<1,3>(3,0) = 0.5 * (-unit_quaternion.vec());
90
+ return J;
91
+ }
92
+
93
+ EIGEN_DEVICE_FUNC Tangent Adj(Tangent const& a) const {
94
+ return Adj() * a;
95
+ }
96
+
97
+ EIGEN_DEVICE_FUNC Tangent AdjT(Tangent const& a) const {
98
+ return Adj().transpose() * a;
99
+ }
100
+
101
+ EIGEN_DEVICE_FUNC static Transformation hat(Tangent const& phi) {
102
+ Transformation Phi;
103
+ Phi <<
104
+ 0.0, -phi(2), phi(1),
105
+ phi(2), 0.0, -phi(0),
106
+ -phi(1), phi(0), 0.0;
107
+
108
+ return Phi;
109
+ }
110
+
111
+ EIGEN_DEVICE_FUNC static Adjoint adj(Tangent const& phi) {
112
+ return SO3<Scalar>::hat(phi);
113
+ }
114
+
115
+ EIGEN_DEVICE_FUNC Tangent Log() const {
116
+ using std::abs;
117
+ using std::atan;
118
+ using std::sqrt;
119
+ Scalar squared_n = unit_quaternion.vec().squaredNorm();
120
+ Scalar w = unit_quaternion.w();
121
+
122
+ Scalar two_atan_nbyw_by_n;
123
+
124
+ /// Atan-based log thanks to
125
+ ///
126
+ /// C. Hertzberg et al.:
127
+ /// "Integrating Generic Sensor Fusion Algorithms with Sound State
128
+ /// Representation through Encapsulation of Manifolds"
129
+ /// Information Fusion, 2011
130
+
131
+ if (squared_n < EPS * EPS) {
132
+ // If quaternion is normalized and n=0, then w should be 1;
133
+ // w=0 should never happen here!
134
+ Scalar squared_w = w * w;
135
+ two_atan_nbyw_by_n =
136
+ Scalar(2) / w - Scalar(2.0/3.0) * (squared_n) / (w * squared_w);
137
+ } else {
138
+ Scalar n = sqrt(squared_n);
139
+ if (abs(w) < EPS) {
140
+ if (w > Scalar(0)) {
141
+ two_atan_nbyw_by_n = Scalar(PI) / n;
142
+ } else {
143
+ two_atan_nbyw_by_n = -Scalar(PI) / n;
144
+ }
145
+ } else {
146
+ two_atan_nbyw_by_n = Scalar(2) * atan(n / w) / n;
147
+ }
148
+ }
149
+
150
+ return two_atan_nbyw_by_n * unit_quaternion.vec();
151
+ }
152
+
153
+ EIGEN_DEVICE_FUNC static SO3<Scalar> Exp(Tangent const& phi) {
154
+ Scalar theta2 = phi.squaredNorm();
155
+ Scalar theta = sqrt(theta2);
156
+ Scalar imag_factor;
157
+ Scalar real_factor;
158
+
159
+ if (theta < EPS) {
160
+ Scalar theta4 = theta2 * theta2;
161
+ imag_factor = Scalar(0.5) - Scalar(1.0/48.0) * theta2 + Scalar(1.0/3840.0) * theta4;
162
+ real_factor = Scalar(1) - Scalar(1.0/8.0) * theta2 + Scalar(1.0/384.0) * theta4;
163
+ } else {
164
+ imag_factor = sin(.5 * theta) / theta;
165
+ real_factor = cos(.5 * theta);
166
+ }
167
+
168
+ Quaternion q(real_factor, imag_factor*phi.x(), imag_factor*phi.y(), imag_factor*phi.z());
169
+ return SO3<Scalar>(q);
170
+ }
171
+
172
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian(Tangent const& phi) {
173
+ // left jacobian
174
+ Matrix3 I = Matrix3::Identity();
175
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
176
+ Matrix3 Phi2 = Phi * Phi;
177
+
178
+ Scalar theta2 = phi.squaredNorm();
179
+ Scalar theta = sqrt(theta2);
180
+
181
+ Scalar coef1 = (theta < EPS) ?
182
+ Scalar(1.0/2.0) - Scalar(1.0/24.0) * theta2 :
183
+ (1.0 - cos(theta)) / theta2;
184
+
185
+ Scalar coef2 = (theta < EPS) ?
186
+ Scalar(1.0/6.0) - Scalar(1.0/120.0) * theta2 :
187
+ (theta - sin(theta)) / (theta2 * theta);
188
+
189
+ return I + coef1 * Phi + coef2 * Phi2;
190
+ }
191
+
192
+ EIGEN_DEVICE_FUNC static Adjoint left_jacobian_inverse(Tangent const& phi) {
193
+ // left jacobian inverse
194
+ Matrix3 I = Matrix3::Identity();
195
+ Matrix3 Phi = SO3<Scalar>::hat(phi);
196
+ Matrix3 Phi2 = Phi * Phi;
197
+
198
+ Scalar theta2 = phi.squaredNorm();
199
+ Scalar theta = sqrt(theta2);
200
+ Scalar half_theta = Scalar(.5) * theta ;
201
+
202
+ Scalar coef2 = (theta < EPS) ? Scalar(1.0/12.0) :
203
+ (Scalar(1) -
204
+ theta * cos(half_theta) / (Scalar(2) * sin(half_theta))) /
205
+ (theta * theta);
206
+
207
+ return I + Scalar(-0.5) * Phi + coef2 * Phi2;
208
+ }
209
+
210
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,3,3> act_jacobian(Point const& p) {
211
+ // jacobian action on a point
212
+ return SO3<Scalar>::hat(-p);
213
+ }
214
+
215
+ EIGEN_DEVICE_FUNC static Eigen::Matrix<Scalar,4,3> act4_jacobian(Point4 const& p) {
216
+ // jacobian action on a point
217
+ Eigen::Matrix<Scalar,4,3> J = Eigen::Matrix<Scalar,4,3>::Zero();
218
+ J.template block<3,3>(0,0) = SO3<Scalar>::hat(-p.template segment<3>(0));
219
+ return J;
220
+ }
221
+
222
+ private:
223
+ Quaternion unit_quaternion;
224
+
225
+ };
226
+
227
+ #endif
228
+
229
+
mini_dpvo/lietorch/run_tests.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lietorch
3
+
4
+ from lietorch import SO3, RxSO3, SE3, Sim3
5
+ from gradcheck import gradcheck, get_analytical_jacobian
6
+
7
+
8
+ ### forward tests ###
9
+
10
+ def make_homogeneous(p):
11
+ return torch.cat([p, torch.ones_like(p[...,:1])], dim=-1)
12
+
13
+ def matv(A, b):
14
+ return torch.matmul(A, b[...,None])[..., 0]
15
+
16
+ def test_exp_log(Group, device='cuda'):
17
+ """ check Log(Exp(x)) == x """
18
+ a = .2*torch.randn(2,3,4,5,6,7,Group.manifold_dim, device=device).double()
19
+ b = Group.exp(a).log()
20
+ assert torch.allclose(a,b,atol=1e-8), "should be identity"
21
+ print("\t-", Group, "Passed exp-log test")
22
+
23
+ def test_inv(Group, device='cuda'):
24
+ """ check X * X^{-1} == 0 """
25
+ X = Group.exp(.1*torch.randn(2,3,4,5,Group.manifold_dim, device=device).double())
26
+ a = (X * X.inv()).log()
27
+ assert torch.allclose(a, torch.zeros_like(a), atol=1e-8), "should be 0"
28
+ print("\t-", Group, "Passed inv test")
29
+
30
+ def test_adj(Group, device='cuda'):
31
+ """ check X * Exp(a) == Exp(Adj(X,a)) * X 0 """
32
+ X = Group.exp(torch.randn(2,3,4,5, Group.manifold_dim, device=device).double())
33
+ a = torch.randn(2,3,4,5, Group.manifold_dim, device=device).double()
34
+
35
+ b = X.adj(a)
36
+ Y1 = X * Group.exp(a)
37
+ Y2 = Group.exp(b) * X
38
+
39
+ c = (Y1 * Y2.inv()).log()
40
+ assert torch.allclose(c, torch.zeros_like(c), atol=1e-8), "should be 0"
41
+ print("\t-", Group, "Passed adj test")
42
+
43
+
44
+ def test_act(Group, device='cuda'):
45
+ X = Group.exp(torch.randn(1, Group.manifold_dim, device=device).double())
46
+ p = torch.randn(1,3,device=device).double()
47
+
48
+ p1 = X.act(p)
49
+ p2 = matv(X.matrix(), make_homogeneous(p))
50
+
51
+ assert torch.allclose(p1, p2[...,:3], atol=1e-8), "should be 0"
52
+ print("\t-", Group, "Passed act test")
53
+
54
+
55
+ ### backward tests ###
56
+ def test_exp_log_grad(Group, device='cuda', tol=1e-8):
57
+
58
+ D = Group.manifold_dim
59
+
60
+ def fn(a):
61
+ return Group.exp(a).log()
62
+
63
+ a = torch.zeros(1, Group.manifold_dim, requires_grad=True, device=device).double()
64
+ analytical, reentrant, correct_grad_sizes, correct_grad_types = \
65
+ get_analytical_jacobian((a,), fn(a))
66
+
67
+ assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)
68
+
69
+ a = .2 * torch.randn(1, Group.manifold_dim, requires_grad=True, device=device).double()
70
+ analytical, reentrant, correct_grad_sizes, correct_grad_types = \
71
+ get_analytical_jacobian((a,), fn(a))
72
+
73
+ assert torch.allclose(analytical[0], torch.eye(D, device=device).double(), atol=tol)
74
+
75
+ print("\t-", Group, "Passed eye-grad test")
76
+
77
+
78
+ def test_inv_log_grad(Group, device='cuda', tol=1e-8):
79
+
80
+ D = Group.manifold_dim
81
+ X = Group.exp(.2*torch.randn(1,D,device=device).double())
82
+
83
+ def fn(a):
84
+ return (Group.exp(a) * X).inv().log()
85
+
86
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
87
+ analytical, numerical = gradcheck(fn, [a], eps=1e-4)
88
+
89
+ # assert torch.allclose(analytical[0], numerical[0], atol=tol)
90
+ if not torch.allclose(analytical[0], numerical[0], atol=tol):
91
+ print(analytical[0])
92
+ print(numerical[0])
93
+
94
+ print("\t-", Group, "Passed inv-grad test")
95
+
96
+
97
+ def test_adj_grad(Group, device='cuda'):
98
+ D = Group.manifold_dim
99
+ X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double())
100
+
101
+ def fn(a, b):
102
+ return (Group.exp(a) * X).adj(b)
103
+
104
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
105
+ b = torch.randn(1, D, requires_grad=True, device=device).double()
106
+
107
+ analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
108
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
109
+ assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
110
+
111
+ print("\t-", Group, "Passed adj-grad test")
112
+
113
+
114
+ def test_adjT_grad(Group, device='cuda'):
115
+ D = Group.manifold_dim
116
+ X = Group.exp(.5*torch.randn(1,Group.manifold_dim, device=device).double())
117
+
118
+ def fn(a, b):
119
+ return (Group.exp(a) * X).adjT(b)
120
+
121
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
122
+ b = torch.randn(1, D, requires_grad=True, device=device).double()
123
+
124
+ analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
125
+
126
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
127
+ assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
128
+
129
+ print("\t-", Group, "Passed adjT-grad test")
130
+
131
+
132
+ def test_act_grad(Group, device='cuda'):
133
+ D = Group.manifold_dim
134
+ X = Group.exp(5*torch.randn(1,D, device=device).double())
135
+
136
+ def fn(a, b):
137
+ return (X*Group.exp(a)).act(b)
138
+
139
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
140
+ b = torch.randn(1, 3, requires_grad=True, device=device).double()
141
+
142
+ analytical, numerical = gradcheck(fn, [a, b], eps=1e-4)
143
+
144
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
145
+ assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
146
+
147
+ print("\t-", Group, "Passed act-grad test")
148
+
149
+
150
+ def test_matrix_grad(Group, device='cuda'):
151
+ D = Group.manifold_dim
152
+ X = Group.exp(torch.randn(1, D, device=device).double())
153
+
154
+ def fn(a):
155
+ return (Group.exp(a) * X).matrix()
156
+
157
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
158
+ analytical, numerical = gradcheck(fn, [a], eps=1e-4)
159
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-6)
160
+
161
+ print("\t-", Group, "Passed matrix-grad test")
162
+
163
+
164
+ def extract_translation_grad(Group, device='cuda'):
165
+ """ prototype function """
166
+
167
+ D = Group.manifold_dim
168
+ X = Group.exp(5*torch.randn(1,D, device=device).double())
169
+
170
+ def fn(a):
171
+ return (Group.exp(a)*X).translation()
172
+
173
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
174
+
175
+ analytical, numerical = gradcheck(fn, [a], eps=1e-4)
176
+
177
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
178
+ print("\t-", Group, "Passed translation grad test")
179
+
180
+
181
+ def test_vec_grad(Group, device='cuda', tol=1e-6):
182
+
183
+ D = Group.manifold_dim
184
+ X = Group.exp(5*torch.randn(1,D, device=device).double())
185
+
186
+ def fn(a):
187
+ return (Group.exp(a)*X).vec()
188
+
189
+ a = torch.zeros(1, D, requires_grad=True, device=device).double()
190
+
191
+ analytical, numerical = gradcheck(fn, [a], eps=1e-4)
192
+
193
+ assert torch.allclose(analytical[0], numerical[0], atol=tol)
194
+ print("\t-", Group, "Passed tovec grad test")
195
+
196
+
197
+ def test_fromvec_grad(Group, device='cuda', tol=1e-6):
198
+
199
+ def fn(a):
200
+ if Group == SO3:
201
+ a = a / a.norm(dim=-1, keepdim=True)
202
+
203
+ elif Group == RxSO3:
204
+ q, s = a.split([4, 1], dim=-1)
205
+ q = q / q.norm(dim=-1, keepdim=True)
206
+ a = torch.cat([q, s.exp()], dim=-1)
207
+
208
+ elif Group == SE3:
209
+ t, q = a.split([3, 4], dim=-1)
210
+ q = q / q.norm(dim=-1, keepdim=True)
211
+ a = torch.cat([t, q], dim=-1)
212
+
213
+ elif Group == Sim3:
214
+ t, q, s = a.split([3, 4, 1], dim=-1)
215
+ q = q / q.norm(dim=-1, keepdim=True)
216
+ a = torch.cat([t, q, s.exp()], dim=-1)
217
+
218
+ return Group.InitFromVec(a).vec()
219
+
220
+ D = Group.embedded_dim
221
+ a = torch.randn(1, 2, D, requires_grad=True, device=device).double()
222
+
223
+ analytical, numerical = gradcheck(fn, [a], eps=1e-4)
224
+
225
+ assert torch.allclose(analytical[0], numerical[0], atol=tol)
226
+ print("\t-", Group, "Passed fromvec grad test")
227
+
228
+
229
+
230
+ def scale(device='cuda'):
231
+
232
+ def fn(a, s):
233
+ X = SE3.exp(a)
234
+ X.scale(s)
235
+ return X.log()
236
+
237
+ s = torch.rand(1, requires_grad=True, device=device).double()
238
+ a = torch.randn(1, 6, requires_grad=True, device=device).double()
239
+
240
+ analytical, numerical = gradcheck(fn, [a, s], eps=1e-3)
241
+ print(analytical[1])
242
+ print(numerical[1])
243
+
244
+
245
+ assert torch.allclose(analytical[0], numerical[0], atol=1e-8)
246
+ assert torch.allclose(analytical[1], numerical[1], atol=1e-8)
247
+
248
+ print("\t-", "Passed se3-to-sim3 test")
249
+
250
+
251
+ if __name__ == '__main__':
252
+
253
+
254
+ print("Testing lietorch forward pass (CPU) ...")
255
+ for Group in [SO3, RxSO3, SE3, Sim3]:
256
+ test_exp_log(Group, device='cpu')
257
+ test_inv(Group, device='cpu')
258
+ test_adj(Group, device='cpu')
259
+ test_act(Group, device='cpu')
260
+
261
+ print("Testing lietorch backward pass (CPU)...")
262
+ for Group in [SO3, RxSO3, SE3, Sim3]:
263
+ if Group == Sim3:
264
+ tol = 1e-3
265
+ else:
266
+ tol = 1e-8
267
+
268
+ test_exp_log_grad(Group, device='cpu', tol=tol)
269
+ test_inv_log_grad(Group, device='cpu', tol=tol)
270
+ test_adj_grad(Group, device='cpu')
271
+ test_adjT_grad(Group, device='cpu')
272
+ test_act_grad(Group, device='cpu')
273
+ test_matrix_grad(Group, device='cpu')
274
+ extract_translation_grad(Group, device='cpu')
275
+ test_vec_grad(Group, device='cpu')
276
+ test_fromvec_grad(Group, device='cpu')
277
+
278
+ print("Testing lietorch forward pass (GPU) ...")
279
+ for Group in [SO3, RxSO3, SE3, Sim3]:
280
+ test_exp_log(Group, device='cuda')
281
+ test_inv(Group, device='cuda')
282
+ test_adj(Group, device='cuda')
283
+ test_act(Group, device='cuda')
284
+
285
+ print("Testing lietorch backward pass (GPU)...")
286
+ for Group in [SO3, RxSO3, SE3, Sim3]:
287
+ if Group == Sim3:
288
+ tol = 1e-3
289
+ else:
290
+ tol = 1e-8
291
+
292
+ test_exp_log_grad(Group, device='cuda', tol=tol)
293
+ test_inv_log_grad(Group, device='cuda', tol=tol)
294
+ test_adj_grad(Group, device='cuda')
295
+ test_adjT_grad(Group, device='cuda')
296
+ test_act_grad(Group, device='cuda')
297
+ test_matrix_grad(Group, device='cuda')
298
+ extract_translation_grad(Group, device='cuda')
299
+ test_vec_grad(Group, device='cuda')
300
+ test_fromvec_grad(Group, device='cuda')
301
+
302
+
mini_dpvo/lietorch/src/lietorch.cpp ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/extension.h>
2
+ #include <vector>
3
+ #include "lietorch_gpu.h"
4
+ #include "lietorch_cpu.h"
5
+
6
+
7
+ #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
8
+
9
+
10
+ /* Interface for cuda and c++ group operations
11
+
12
+ enum group_t { SO3=1, SE3=2, Sim3=3 };
13
+ X, Y, Z: (uppercase) Lie Group Elements
14
+ a, b, c: (lowercase) Lie Algebra Elements
15
+ */
16
+
17
+ // Unary operations
18
+ torch::Tensor expm(int group_index, torch::Tensor a) {
19
+ CHECK_CONTIGUOUS(a);
20
+ if (a.device().type() == torch::DeviceType::CPU) {
21
+ return exp_forward_cpu(group_index, a);
22
+
23
+ } else if (a.device().type() == torch::DeviceType::CUDA) {
24
+ return exp_forward_gpu(group_index, a);
25
+ }
26
+
27
+ return a;
28
+ }
29
+
30
+ std::vector<torch::Tensor> expm_backward(int group_index, torch::Tensor grad, torch::Tensor a) {
31
+ CHECK_CONTIGUOUS(a);
32
+ CHECK_CONTIGUOUS(grad);
33
+ if (a.device().type() == torch::DeviceType::CPU) {
34
+ return exp_backward_cpu(group_index, grad, a);
35
+
36
+ } else if (a.device().type() == torch::DeviceType::CUDA) {
37
+ return exp_backward_gpu(group_index, grad, a);
38
+ }
39
+
40
+ return {};
41
+ }
42
+
43
+ torch::Tensor logm(int group_index, torch::Tensor X) {
44
+ CHECK_CONTIGUOUS(X);
45
+ if (X.device().type() == torch::DeviceType::CPU) {
46
+ return log_forward_cpu(group_index, X);
47
+
48
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
49
+ return log_forward_gpu(group_index, X);
50
+ }
51
+
52
+ return X;
53
+ }
54
+
55
+ std::vector<torch::Tensor> logm_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
56
+ CHECK_CONTIGUOUS(X);
57
+ CHECK_CONTIGUOUS(grad);
58
+
59
+ if (X.device().type() == torch::DeviceType::CPU) {
60
+ return log_backward_cpu(group_index, grad, X);
61
+
62
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
63
+ return log_backward_gpu(group_index, grad, X);
64
+ }
65
+
66
+ return {};
67
+ }
68
+
69
+ torch::Tensor inv(int group_index, torch::Tensor X) {
70
+ CHECK_CONTIGUOUS(X);
71
+
72
+ if (X.device().type() == torch::DeviceType::CPU) {
73
+ return inv_forward_cpu(group_index, X);
74
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
75
+ return inv_forward_gpu(group_index, X);
76
+ }
77
+
78
+ return X;
79
+ }
80
+
81
+ std::vector<torch::Tensor> inv_backward(int group_index, torch::Tensor grad, torch::Tensor X) {
82
+ CHECK_CONTIGUOUS(X);
83
+ CHECK_CONTIGUOUS(grad);
84
+
85
+ if (X.device().type() == torch::DeviceType::CPU) {
86
+ return inv_backward_cpu(group_index, grad, X);
87
+
88
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
89
+ return inv_backward_gpu(group_index, grad, X);
90
+ }
91
+
92
+ return {};
93
+ }
94
+
95
+ // Binary operations
96
+
97
+ torch::Tensor mul(int group_index, torch::Tensor X, torch::Tensor Y) {
98
+ CHECK_CONTIGUOUS(X);
99
+ CHECK_CONTIGUOUS(Y);
100
+
101
+ if (X.device().type() == torch::DeviceType::CPU) {
102
+ return mul_forward_cpu(group_index, X, Y);
103
+
104
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
105
+ return mul_forward_gpu(group_index, X, Y);
106
+ }
107
+
108
+ return X;
109
+ }
110
+
111
+ std::vector<torch::Tensor> mul_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
112
+ CHECK_CONTIGUOUS(X);
113
+ CHECK_CONTIGUOUS(Y);
114
+ CHECK_CONTIGUOUS(grad);
115
+
116
+ if (X.device().type() == torch::DeviceType::CPU) {
117
+ return mul_backward_cpu(group_index, grad, X, Y);
118
+
119
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
120
+ return mul_backward_gpu(group_index, grad, X, Y);
121
+ }
122
+
123
+ return {};
124
+ }
125
+
126
+ torch::Tensor adj(int group_index, torch::Tensor X, torch::Tensor a) {
127
+ CHECK_CONTIGUOUS(X);
128
+ CHECK_CONTIGUOUS(a);
129
+
130
+ if (X.device().type() == torch::DeviceType::CPU) {
131
+ return adj_forward_cpu(group_index, X, a);
132
+
133
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
134
+ return adj_forward_gpu(group_index, X, a);
135
+ }
136
+
137
+ return X;
138
+ }
139
+
140
+ std::vector<torch::Tensor> adj_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
141
+ CHECK_CONTIGUOUS(X);
142
+ CHECK_CONTIGUOUS(a);
143
+ CHECK_CONTIGUOUS(grad);
144
+
145
+ if (X.device().type() == torch::DeviceType::CPU) {
146
+ return adj_backward_cpu(group_index, grad, X, a);
147
+
148
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
149
+ return adj_backward_gpu(group_index, grad, X, a);
150
+ }
151
+
152
+ return {};
153
+ }
154
+
155
+ torch::Tensor adjT(int group_index, torch::Tensor X, torch::Tensor a) {
156
+ CHECK_CONTIGUOUS(X);
157
+ CHECK_CONTIGUOUS(a);
158
+
159
+ if (X.device().type() == torch::DeviceType::CPU) {
160
+ return adjT_forward_cpu(group_index, X, a);
161
+
162
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
163
+ return adjT_forward_gpu(group_index, X, a);
164
+ }
165
+
166
+ return X;
167
+ }
168
+
169
+ std::vector<torch::Tensor> adjT_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
170
+ CHECK_CONTIGUOUS(X);
171
+ CHECK_CONTIGUOUS(a);
172
+ CHECK_CONTIGUOUS(grad);
173
+
174
+ if (X.device().type() == torch::DeviceType::CPU) {
175
+ return adjT_backward_cpu(group_index, grad, X, a);
176
+
177
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
178
+ return adjT_backward_gpu(group_index, grad, X, a);
179
+ }
180
+
181
+ return {};
182
+ }
183
+
184
+
185
+ torch::Tensor act(int group_index, torch::Tensor X, torch::Tensor p) {
186
+ CHECK_CONTIGUOUS(X);
187
+ CHECK_CONTIGUOUS(p);
188
+
189
+ if (X.device().type() == torch::DeviceType::CPU) {
190
+ return act_forward_cpu(group_index, X, p);
191
+
192
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
193
+ return act_forward_gpu(group_index, X, p);
194
+ }
195
+
196
+ return X;
197
+ }
198
+
199
+ std::vector<torch::Tensor> act_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
200
+ CHECK_CONTIGUOUS(X);
201
+ CHECK_CONTIGUOUS(p);
202
+ CHECK_CONTIGUOUS(grad);
203
+
204
+ if (X.device().type() == torch::DeviceType::CPU) {
205
+ return act_backward_cpu(group_index, grad, X, p);
206
+
207
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
208
+ return act_backward_gpu(group_index, grad, X, p);
209
+ }
210
+
211
+ return {};
212
+ }
213
+
214
+ torch::Tensor act4(int group_index, torch::Tensor X, torch::Tensor p) {
215
+ CHECK_CONTIGUOUS(X);
216
+ CHECK_CONTIGUOUS(p);
217
+
218
+ if (X.device().type() == torch::DeviceType::CPU) {
219
+ return act4_forward_cpu(group_index, X, p);
220
+
221
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
222
+ return act4_forward_gpu(group_index, X, p);
223
+ }
224
+
225
+ return X;
226
+ }
227
+
228
+ std::vector<torch::Tensor> act4_backward(int group_index, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
229
+ CHECK_CONTIGUOUS(X);
230
+ CHECK_CONTIGUOUS(p);
231
+ CHECK_CONTIGUOUS(grad);
232
+
233
+ if (X.device().type() == torch::DeviceType::CPU) {
234
+ return act4_backward_cpu(group_index, grad, X, p);
235
+
236
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
237
+ return act4_backward_gpu(group_index, grad, X, p);
238
+ }
239
+
240
+ return {};
241
+ }
242
+
243
+
244
+ torch::Tensor projector(int group_index, torch::Tensor X) {
245
+ CHECK_CONTIGUOUS(X);
246
+
247
+ if (X.device().type() == torch::DeviceType::CPU) {
248
+ return orthogonal_projector_cpu(group_index, X);
249
+
250
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
251
+ return orthogonal_projector_gpu(group_index, X);
252
+ }
253
+
254
+ return X;
255
+ }
256
+
257
+
258
+ torch::Tensor as_matrix(int group_index, torch::Tensor X) {
259
+ CHECK_CONTIGUOUS(X);
260
+
261
+ if (X.device().type() == torch::DeviceType::CPU) {
262
+ return as_matrix_forward_cpu(group_index, X);
263
+
264
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
265
+ return as_matrix_forward_gpu(group_index, X);
266
+ }
267
+
268
+ return X;
269
+ }
270
+
271
+ torch::Tensor Jinv(int group_index, torch::Tensor X, torch::Tensor a) {
272
+ CHECK_CONTIGUOUS(X);
273
+ CHECK_CONTIGUOUS(a);
274
+
275
+ if (X.device().type() == torch::DeviceType::CPU) {
276
+ return jleft_forward_cpu(group_index, X, a);
277
+
278
+ } else if (X.device().type() == torch::DeviceType::CUDA) {
279
+ return jleft_forward_gpu(group_index, X, a);
280
+ }
281
+
282
+ return a;
283
+ }
284
+
285
+ // {exp, log, inv, mul, adj, adjT, act, act4} forward/backward bindings
286
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
287
+ m.def("expm", &expm, "exp map forward");
288
+ m.def("expm_backward", &expm_backward, "exp map backward");
289
+
290
+ m.def("logm", &logm, "log map forward");
291
+ m.def("logm_backward", &logm_backward, "log map backward");
292
+
293
+ m.def("inv", &inv, "inverse operator");
294
+ m.def("inv_backward", &inv_backward, "inverse operator backward");
295
+
296
+ m.def("mul", &mul, "group operator");
297
+ m.def("mul_backward", &mul_backward, "group operator backward");
298
+
299
+ m.def("adj", &adj, "adjoint operator");
300
+ m.def("adj_backward", &adj_backward, "adjoint operator backward");
301
+
302
+ m.def("adjT", &adjT, "transposed adjoint operator");
303
+ m.def("adjT_backward", &adjT_backward, "transposed adjoint operator backward");
304
+
305
+ m.def("act", &act, "action on point");
306
+ m.def("act_backward", &act_backward, "action on point backward");
307
+
308
+ m.def("act4", &act4, "action on homogeneous point");
309
+ m.def("act4_backward", &act4_backward, "action on homogeneous point backward");
310
+
311
+ // functions with no gradient
312
+ m.def("as_matrix", &as_matrix, "convert to matrix");
313
+ m.def("projector", &projector, "orthogonal projection matrix");
314
+ m.def("Jinv", &Jinv, "left inverse jacobian operator");
315
+
316
+ };
317
+
mini_dpvo/lietorch/src/lietorch_cpu.cpp ADDED
@@ -0,0 +1,657 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include "lietorch_cpu.h"
3
+ #include <Eigen/Dense>
4
+
5
+ #include <iostream>
6
+ #include "common.h"
7
+ #include "dispatch.h"
8
+
9
+ #include "so3.h"
10
+ #include "rxso3.h"
11
+ #include "se3.h"
12
+ #include "sim3.h"
13
+
14
+
15
+ template <typename Group, typename scalar_t>
16
+ void exp_forward_kernel(const scalar_t* a_ptr, scalar_t* X_ptr, int batch_size) {
17
+ // exponential map forward kernel
18
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
19
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
20
+
21
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
22
+ for (int64_t i=start; i<end; i++) {
23
+ Tangent a(a_ptr + i*Group::K);
24
+ Eigen::Map<Data>(X_ptr + i*Group::N) = Group::Exp(a).data();
25
+ }
26
+ });
27
+ }
28
+
29
+ template <typename Group, typename scalar_t>
30
+ void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t* da, int batch_size) {
31
+ // exponential map backward kernel
32
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
33
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
34
+
35
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
36
+ for (int64_t i=start; i<end; i++) {
37
+ Tangent a(a_ptr + i*Group::K);
38
+ Grad dX(grad + i*Group::N);
39
+ Eigen::Map<Grad>(da + i*Group::K) = dX * Group::left_jacobian(a);
40
+ }
41
+ });
42
+ }
43
+
44
+ template <typename Group, typename scalar_t>
45
+ void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int batch_size) {
46
+ // logarithm map forward kernel
47
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
48
+
49
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
50
+ for (int64_t i=start; i<end; i++) {
51
+ Tangent a = Group(X_ptr + i*Group::N).Log();
52
+ Eigen::Map<Tangent>(a_ptr + i*Group::K) = a;
53
+ }
54
+ });
55
+ }
56
+
57
+ template <typename Group, typename scalar_t>
58
+ void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
59
+ // logarithm map backward kernel
60
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
61
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
62
+
63
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
64
+ for (int64_t i=start; i<end; i++) {
65
+ Tangent a = Group(X_ptr + i*Group::N).Log();
66
+ Grad da(grad + i*Group::K);
67
+ Eigen::Map<Grad>(dX + i*Group::N) = da * Group::left_jacobian_inverse(a);
68
+ }
69
+ });
70
+ }
71
+
72
+ template <typename Group, typename scalar_t>
73
+ void inv_forward_kernel(const scalar_t* X_ptr, scalar_t* Y_ptr, int batch_size) {
74
+ // group inverse forward kernel
75
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
76
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
77
+
78
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
79
+ for (int64_t i=start; i<end; i++) {
80
+ Group X(X_ptr + i*Group::N);
81
+ Eigen::Map<Data>(Y_ptr + i*Group::N) = X.inv().data();
82
+ }
83
+ });
84
+ }
85
+
86
+ template <typename Group, typename scalar_t>
87
+ void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t *dX, int batch_size) {
88
+ // group inverse backward kernel
89
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
90
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
91
+
92
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
93
+ for (int64_t i=start; i<end; i++) {
94
+ Group Y = Group(X_ptr + i*Group::N).inv();
95
+ Grad dY(grad + i*Group::N);
96
+ Eigen::Map<Grad>(dX + i*Group::N) = -dY * Y.Adj();
97
+ }
98
+ });
99
+ }
100
+
101
+ template <typename Group, typename scalar_t>
102
+ void mul_forward_kernel(const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* Z_ptr, int batch_size) {
103
+ // group multiplication forward kernel
104
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
105
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
106
+
107
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
108
+ for (int64_t i=start; i<end; i++) {
109
+ Group Z = Group(X_ptr + i*Group::N) * Group(Y_ptr + i*Group::N);
110
+ Eigen::Map<Data>(Z_ptr + i*Group::N) = Z.data();
111
+ }
112
+ });
113
+ }
114
+
115
+ template <class Group, typename scalar_t>
116
+ void mul_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* dX, scalar_t* dY, int batch_size) {
117
+ // group multiplication backward kernel
118
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
119
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
120
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
121
+
122
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
123
+ for (int64_t i=start; i<end; i++) {
124
+ Grad dZ(grad + i*Group::N);
125
+ Group X(X_ptr + i*Group::N);
126
+ Eigen::Map<Grad>(dX + i*Group::N) = dZ;
127
+ Eigen::Map<Grad>(dY + i*Group::N) = dZ * X.Adj();
128
+ }
129
+ });
130
+ }
131
+
132
+ template <typename Group, typename scalar_t>
133
+ void adj_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
134
+ // adjoint forward kernel
135
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
136
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
137
+
138
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
139
+ for (int64_t i=start; i<end; i++) {
140
+ Group X(X_ptr + i*Group::N);
141
+ Tangent a(a_ptr + i*Group::K);
142
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.Adj(a);
143
+ }
144
+ });
145
+ }
146
+
147
+ template <typename Group, typename scalar_t>
148
+ void adj_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int batch_size) {
149
+ // adjoint backward kernel
150
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
151
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
152
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
153
+
154
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
155
+ for (int64_t i=start; i<end; i++) {
156
+ Group X(X_ptr + i*Group::N);
157
+ Grad db(grad + i*Group::K);
158
+
159
+ Tangent a(a_ptr + i*Group::K);
160
+ Tangent b = X.Adj() * a;
161
+
162
+ Eigen::Map<Grad>(da + i*Group::K) = db * X.Adj();
163
+ Eigen::Map<Grad>(dX + i*Group::N) = -db * Group::adj(b);
164
+ }
165
+ });
166
+ }
167
+
168
+ template <typename Group, typename scalar_t>
169
+ void adjT_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
170
+ // adjoint forward kernel
171
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
172
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
173
+
174
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
175
+ for (int64_t i=start; i<end; i++) {
176
+ Group X(X_ptr + i*Group::N);
177
+ Tangent a(a_ptr + i*Group::K);
178
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.AdjT(a);
179
+ }
180
+ });
181
+ }
182
+
183
+ template <typename Group, typename scalar_t>
184
+ void adjT_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int batch_size) {
185
+ // adjoint backward kernel
186
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
187
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
188
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
189
+
190
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
191
+ for (int64_t i=start; i<end; i++) {
192
+ Group X(X_ptr + i*Group::N);
193
+ Tangent db(grad + i*Group::K);
194
+ Grad a(a_ptr + i*Group::K);
195
+
196
+ Eigen::Map<Tangent>(da + i*Group::K) = X.Adj(db);
197
+ Eigen::Map<Grad>(dX + i*Group::N) = -a * Group::adj(X.Adj(db));
198
+ }
199
+ });
200
+ }
201
+
202
+
203
+ template <typename Group, typename scalar_t>
204
+ void act_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) {
205
+ // action on point forward kernel
206
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
207
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
208
+ using Point = Eigen::Matrix<scalar_t,3,1>;
209
+
210
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
211
+ for (int64_t i=start; i<end; i++) {
212
+ Group X(X_ptr + i*Group::N);
213
+ Point p(p_ptr + i*3);
214
+ Eigen::Map<Point>(q_ptr + i*3) = X * p;
215
+ }
216
+ });
217
+ }
218
+
219
+ template <typename Group, typename scalar_t>
220
+ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int batch_size) {
221
+ // adjoint backward kernel
222
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
223
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
224
+ using Point = Eigen::Matrix<scalar_t,3,1>;
225
+ using PointGrad = Eigen::Matrix<scalar_t,1,3>;
226
+ using Transformation = Eigen::Matrix<scalar_t,4,4>;
227
+
228
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
229
+ for (int64_t i=start; i<end; i++) {
230
+ Group X(X_ptr + i*Group::N);
231
+ Point p(p_ptr + i*3);
232
+ PointGrad dq(grad + i*3);
233
+
234
+ Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix().template block<3,3>(0,0);
235
+ Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p);
236
+ }
237
+ });
238
+ }
239
+
240
+
241
+ // template <typename Group, typename scalar_t>
242
+ // void tovec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
243
+ // // group inverse forward kernel
244
+ // using Data = Eigen::Matrix<scalar_t,Group::N,1>;
245
+ // using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
246
+
247
+ // at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
248
+ // for (int64_t i=start; i<end; i++) {
249
+ // Group X(X_ptr + i*Group::N);
250
+ // Grad g(grad + i*Group::N);
251
+ // Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
252
+ // }
253
+ // });
254
+ // }
255
+
256
+ // template <typename Group, typename scalar_t>
257
+ // void fromvec_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int batch_size) {
258
+ // // group inverse forward kernel
259
+ // using Data = Eigen::Matrix<scalar_t,Group::N,1>;
260
+ // using Grad = Eigen::Matrix<scalar_t,1,Group::N>;
261
+
262
+ // at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
263
+ // for (int64_t i=start; i<end; i++) {
264
+ // Group X(X_ptr + i*Group::N);
265
+ // Grad g(grad + i*Group::N);
266
+ // Eigen::Map<Grad>(dX + i*Group::N) = g * X.vec_jacobian();
267
+ // }
268
+ // });
269
+ // }
270
+
271
+
272
+ template <typename Group, typename scalar_t>
273
+ void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int batch_size) {
274
+ // action on homogeneous point forward kernel
275
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
276
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
277
+ using Point = Eigen::Matrix<scalar_t,4,1>;
278
+
279
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
280
+ for (int64_t i=start; i<end; i++) {
281
+ Group X(X_ptr + i*Group::N);
282
+ Point p(p_ptr + i*4);
283
+ Eigen::Map<Point>(q_ptr + i*4) = X.act4(p);
284
+ }
285
+ });
286
+ }
287
+
288
+ template <typename Group, typename scalar_t>
289
+ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int batch_size) {
290
+ // action on homogeneous point backward kernel
291
+
292
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
293
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
294
+ using Point = Eigen::Matrix<scalar_t,4,1>;
295
+ using PointGrad = Eigen::Matrix<scalar_t,1,4>;
296
+ using Transformation = Eigen::Matrix<scalar_t,4,4>;
297
+
298
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
299
+ for (int64_t i=start; i<end; i++) {
300
+ Group X(X_ptr + i*Group::N);
301
+ Point p(p_ptr + i*4);
302
+ PointGrad dq(grad + i*4);
303
+
304
+ Eigen::Map<PointGrad>(dp + i*4) = dq * X.Matrix4x4();
305
+ const Point q = X.act4(p);
306
+ Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act4_jacobian(q);
307
+ }
308
+ });
309
+ }
310
+
311
+ template <typename Group, typename scalar_t>
312
+ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int batch_size) {
313
+ // group inverse forward kernel
314
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
315
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
316
+ using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>;
317
+
318
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
319
+ for (int64_t i=start; i<end; i++) {
320
+ Group X(X_ptr + i*Group::N);
321
+ Eigen::Map<Matrix4>(T_ptr + i*16) = X.Matrix4x4();
322
+ }
323
+ });
324
+ }
325
+
326
+ template <typename Group, typename scalar_t>
327
+ void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int batch_size) {
328
+ // group inverse forward kernel
329
+ using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
330
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
331
+ for (int64_t i=start; i<end; i++) {
332
+ Group X(X_ptr + i*Group::N);
333
+ Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
334
+ }
335
+ });
336
+ }
337
+
338
+ template <typename Group, typename scalar_t>
339
+ void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int batch_size) {
340
+ // left-jacobian inverse action
341
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
342
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
343
+
344
+ at::parallel_for(0, batch_size, 1, [&](int64_t start, int64_t end) {
345
+ for (int64_t i=start; i<end; i++) {
346
+ Group X(X_ptr + i*Group::N);
347
+ Tangent a(a_ptr + i*Group::K);
348
+ Tangent b = Group::left_jacobian_inverse(X.Log()) * a;
349
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = b;
350
+ }
351
+ });
352
+ }
353
+
354
+ // unary operations
355
+
356
+ torch::Tensor exp_forward_cpu(int group_id, torch::Tensor a) {
357
+ int batch_size = a.size(0);
358
+ torch::Tensor X;
359
+
360
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
361
+ X = torch::zeros({batch_size, group_t::N}, a.options());
362
+ exp_forward_kernel<group_t, scalar_t>(
363
+ a.data_ptr<scalar_t>(),
364
+ X.data_ptr<scalar_t>(),
365
+ batch_size);
366
+ }));
367
+
368
+ return X;
369
+ }
370
+
371
+ std::vector<torch::Tensor> exp_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor a) {
372
+ int batch_size = a.size(0);
373
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
374
+
375
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
376
+ exp_backward_kernel<group_t, scalar_t>(
377
+ grad.data_ptr<scalar_t>(),
378
+ a.data_ptr<scalar_t>(),
379
+ da.data_ptr<scalar_t>(),
380
+ batch_size);
381
+ }));
382
+
383
+ return {da};
384
+ }
385
+
386
+ torch::Tensor log_forward_cpu(int group_id, torch::Tensor X) {
387
+ int batch_size = X.size(0);
388
+ torch::Tensor a;
389
+
390
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
391
+ a = torch::zeros({batch_size, group_t::K}, X.options());
392
+ log_forward_kernel<group_t, scalar_t>(
393
+ X.data_ptr<scalar_t>(),
394
+ a.data_ptr<scalar_t>(),
395
+ batch_size);
396
+ }));
397
+
398
+ return a;
399
+ }
400
+
401
+ std::vector<torch::Tensor> log_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X) {
402
+ int batch_size = X.size(0);
403
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
404
+
405
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
406
+ log_backward_kernel<group_t, scalar_t>(
407
+ grad.data_ptr<scalar_t>(),
408
+ X.data_ptr<scalar_t>(),
409
+ dX.data_ptr<scalar_t>(),
410
+ batch_size);
411
+ }));
412
+
413
+ return {dX};
414
+ }
415
+
416
+ torch::Tensor inv_forward_cpu(int group_id, torch::Tensor X) {
417
+ int batch_size = X.size(0);
418
+ torch::Tensor Y = torch::zeros_like(X);
419
+
420
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
421
+ inv_forward_kernel<group_t, scalar_t>(
422
+ X.data_ptr<scalar_t>(),
423
+ Y.data_ptr<scalar_t>(),
424
+ batch_size);
425
+ }));
426
+
427
+ return Y;
428
+ }
429
+
430
+ std::vector<torch::Tensor> inv_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X) {
431
+ int batch_size = X.size(0);
432
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
433
+
434
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
435
+ inv_backward_kernel<group_t, scalar_t>(
436
+ grad.data_ptr<scalar_t>(),
437
+ X.data_ptr<scalar_t>(),
438
+ dX.data_ptr<scalar_t>(),
439
+ batch_size);
440
+ }));
441
+
442
+ return {dX};
443
+ }
444
+
445
+ // binary operations
446
+ torch::Tensor mul_forward_cpu(int group_id, torch::Tensor X, torch::Tensor Y) {
447
+ int batch_size = X.size(0);
448
+ torch::Tensor Z = torch::zeros_like(X);
449
+
450
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
451
+ mul_forward_kernel<group_t, scalar_t>(
452
+ X.data_ptr<scalar_t>(),
453
+ Y.data_ptr<scalar_t>(),
454
+ Z.data_ptr<scalar_t>(),
455
+ batch_size);
456
+ }));
457
+
458
+ return Z;
459
+ }
460
+
461
+ std::vector<torch::Tensor> mul_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
462
+ int batch_size = X.size(0);
463
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
464
+ torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
465
+
466
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
467
+ mul_backward_kernel<group_t, scalar_t>(
468
+ grad.data_ptr<scalar_t>(),
469
+ X.data_ptr<scalar_t>(),
470
+ Y.data_ptr<scalar_t>(),
471
+ dX.data_ptr<scalar_t>(),
472
+ dY.data_ptr<scalar_t>(),
473
+ batch_size);
474
+ }));
475
+
476
+ return {dX, dY};
477
+ }
478
+
479
+ torch::Tensor adj_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
480
+ int batch_size = X.size(0);
481
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
482
+
483
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
484
+ adj_forward_kernel<group_t, scalar_t>(
485
+ X.data_ptr<scalar_t>(),
486
+ a.data_ptr<scalar_t>(),
487
+ b.data_ptr<scalar_t>(),
488
+ batch_size);
489
+ }));
490
+
491
+ return b;
492
+ }
493
+
494
+ std::vector<torch::Tensor> adj_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
495
+ int batch_size = X.size(0);
496
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
497
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
498
+
499
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
500
+ adj_backward_kernel<group_t, scalar_t>(
501
+ grad.data_ptr<scalar_t>(),
502
+ X.data_ptr<scalar_t>(),
503
+ a.data_ptr<scalar_t>(),
504
+ dX.data_ptr<scalar_t>(),
505
+ da.data_ptr<scalar_t>(),
506
+ batch_size);
507
+ }));
508
+
509
+ return {dX, da};
510
+ }
511
+
512
+
513
+ torch::Tensor adjT_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
514
+ int batch_size = X.size(0);
515
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
516
+
517
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
518
+ adjT_forward_kernel<group_t, scalar_t>(
519
+ X.data_ptr<scalar_t>(),
520
+ a.data_ptr<scalar_t>(),
521
+ b.data_ptr<scalar_t>(),
522
+ batch_size);
523
+ }));
524
+
525
+ return b;
526
+ }
527
+
528
+ std::vector<torch::Tensor> adjT_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
529
+ int batch_size = X.size(0);
530
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
531
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
532
+
533
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
534
+ adjT_backward_kernel<group_t, scalar_t>(
535
+ grad.data_ptr<scalar_t>(),
536
+ X.data_ptr<scalar_t>(),
537
+ a.data_ptr<scalar_t>(),
538
+ dX.data_ptr<scalar_t>(),
539
+ da.data_ptr<scalar_t>(),
540
+ batch_size);
541
+ }));
542
+
543
+ return {dX, da};
544
+ }
545
+
546
+
547
+ torch::Tensor act_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
548
+ int batch_size = X.size(0);
549
+ torch::Tensor q = torch::zeros(p.sizes(), p.options());
550
+
551
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
552
+ act_forward_kernel<group_t, scalar_t>(
553
+ X.data_ptr<scalar_t>(),
554
+ p.data_ptr<scalar_t>(),
555
+ q.data_ptr<scalar_t>(),
556
+ batch_size);
557
+ }));
558
+
559
+ return q;
560
+ }
561
+
562
+ std::vector<torch::Tensor> act_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
563
+ int batch_size = X.size(0);
564
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
565
+ torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
566
+
567
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
568
+ act_backward_kernel<group_t, scalar_t>(
569
+ grad.data_ptr<scalar_t>(),
570
+ X.data_ptr<scalar_t>(),
571
+ p.data_ptr<scalar_t>(),
572
+ dX.data_ptr<scalar_t>(),
573
+ dp.data_ptr<scalar_t>(),
574
+ batch_size);
575
+ }));
576
+
577
+ return {dX, dp};
578
+ }
579
+
580
+
581
+ torch::Tensor act4_forward_cpu(int group_id, torch::Tensor X, torch::Tensor p) {
582
+ int batch_size = X.size(0);
583
+ torch::Tensor q = torch::zeros(p.sizes(), p.options());
584
+
585
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
586
+ act4_forward_kernel<group_t, scalar_t>(
587
+ X.data_ptr<scalar_t>(),
588
+ p.data_ptr<scalar_t>(),
589
+ q.data_ptr<scalar_t>(),
590
+ batch_size);
591
+ }));
592
+
593
+ return q;
594
+ }
595
+
596
+ std::vector<torch::Tensor> act4_backward_cpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
597
+ int batch_size = X.size(0);
598
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
599
+ torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
600
+
601
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
602
+ act4_backward_kernel<group_t, scalar_t>(
603
+ grad.data_ptr<scalar_t>(),
604
+ X.data_ptr<scalar_t>(),
605
+ p.data_ptr<scalar_t>(),
606
+ dX.data_ptr<scalar_t>(),
607
+ dp.data_ptr<scalar_t>(),
608
+ batch_size);
609
+ }));
610
+
611
+ return {dX, dp};
612
+ }
613
+
614
+
615
+ torch::Tensor as_matrix_forward_cpu(int group_id, torch::Tensor X) {
616
+ int batch_size = X.size(0);
617
+ torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
618
+
619
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
620
+ as_matrix_forward_kernel<group_t, scalar_t>(
621
+ X.data_ptr<scalar_t>(),
622
+ T4x4.data_ptr<scalar_t>(),
623
+ batch_size);
624
+ }));
625
+
626
+ return T4x4;
627
+ }
628
+
629
+
630
+ torch::Tensor orthogonal_projector_cpu(int group_id, torch::Tensor X) {
631
+ int batch_size = X.size(0);
632
+ torch::Tensor P;
633
+
634
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
635
+ P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
636
+ orthogonal_projector_kernel<group_t, scalar_t>(X.data_ptr<scalar_t>(), P.data_ptr<scalar_t>(), batch_size);
637
+ }));
638
+
639
+ return P;
640
+ }
641
+
642
+
643
+
644
+ torch::Tensor jleft_forward_cpu(int group_id, torch::Tensor X, torch::Tensor a) {
645
+ int batch_size = X.size(0);
646
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
647
+
648
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
649
+ jleft_forward_kernel<group_t, scalar_t>(
650
+ X.data_ptr<scalar_t>(),
651
+ a.data_ptr<scalar_t>(),
652
+ b.data_ptr<scalar_t>(),
653
+ batch_size);
654
+ }));
655
+
656
+ return b;
657
+ }
mini_dpvo/lietorch/src/lietorch_gpu.cu ADDED
@@ -0,0 +1,601 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ #include "lietorch_gpu.h"
3
+ #include <Eigen/Dense>
4
+
5
+ #include "common.h"
6
+ #include "dispatch.h"
7
+
8
+ #include "so3.h"
9
+ #include "rxso3.h"
10
+ #include "se3.h"
11
+ #include "sim3.h"
12
+
13
+ #define GPU_1D_KERNEL_LOOP(i, n) \
14
+ for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i<n; i += blockDim.x * gridDim.x)
15
+
16
+ #define NUM_THREADS 256
17
+ #define NUM_BLOCKS(batch_size) ((batch_size + NUM_THREADS - 1) / NUM_THREADS)
18
+
19
+
20
+ template <typename Group, typename scalar_t>
21
+ __global__ void exp_forward_kernel(const scalar_t* a_ptr, scalar_t* X_ptr, int num_threads) {
22
+ // exponential map forward kernel
23
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
24
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
25
+
26
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
27
+ Tangent a(a_ptr + i*Group::K);
28
+ Eigen::Map<Data>(X_ptr + i*Group::N) = Group::Exp(a).data();
29
+ }
30
+ }
31
+
32
+ template <typename Group, typename scalar_t>
33
+ __global__ void exp_backward_kernel(const scalar_t* grad, const scalar_t* a_ptr, scalar_t* da, int num_threads) {
34
+ // exponential map backward kernel
35
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
36
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
37
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
38
+
39
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
40
+ Tangent a(a_ptr + i*Group::K);
41
+ Grad dX(grad + i*Group::N);
42
+ Eigen::Map<Grad>(da + i*Group::K) = dX * Group::left_jacobian(a);
43
+ }
44
+ }
45
+
46
+ template <typename Group, typename scalar_t>
47
+ __global__ void log_forward_kernel(const scalar_t* X_ptr, scalar_t* a_ptr, int num_threads) {
48
+ // logarithm map forward kernel
49
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
50
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
51
+
52
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
53
+ Tangent a = Group(X_ptr + i*Group::N).Log();
54
+ Eigen::Map<Tangent>(a_ptr + i*Group::K) = a;
55
+ }
56
+ }
57
+
58
+ template <typename Group, typename scalar_t>
59
+ __global__ void log_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t* dX, int num_threads) {
60
+ // logarithm map backward kernel
61
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
62
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
63
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
64
+
65
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
66
+ Tangent a = Group(X_ptr + i*Group::N).Log();
67
+ Grad da(grad + i*Group::K);
68
+ Eigen::Map<Grad>(dX + i*Group::N) = da * Group::left_jacobian_inverse(a);
69
+ }
70
+ }
71
+
72
+ template <typename Group, typename scalar_t>
73
+ __global__ void inv_forward_kernel(const scalar_t* X_ptr, scalar_t* Y_ptr, int num_threads) {
74
+ // group inverse forward kernel
75
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
76
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
77
+
78
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
79
+ Group X(X_ptr + i*Group::N);
80
+ Eigen::Map<Data>(Y_ptr + i*Group::N) = X.inv().data();
81
+ }
82
+ }
83
+
84
+
85
+ template <typename Group, typename scalar_t>
86
+ __global__ void inv_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, scalar_t *dX, int num_threads) {
87
+ // group inverse backward kernel
88
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
89
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
90
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
91
+
92
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
93
+ Group Y = Group(X_ptr + i*Group::N).inv();
94
+ Grad dY(grad + i*Group::N);
95
+ Eigen::Map<Grad>(dX + i*Group::N) = -dY * Y.Adj();
96
+ }
97
+ }
98
+
99
+
100
+ template <typename Group, typename scalar_t>
101
+ __global__ void mul_forward_kernel(const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* Z_ptr, int num_threads) {
102
+ // group multiplication forward kernel
103
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
104
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
105
+
106
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
107
+ Group Z = Group(X_ptr + i*Group::N) * Group(Y_ptr + i*Group::N);
108
+ Eigen::Map<Data>(Z_ptr + i*Group::N) = Z.data();
109
+ }
110
+ }
111
+
112
+ template <class Group, typename scalar_t>
113
+ __global__ void mul_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* Y_ptr, scalar_t* dX, scalar_t* dY, int num_threads) {
114
+ // group multiplication backward kernel
115
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
116
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
117
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
118
+
119
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
120
+ Grad dZ(grad + i*Group::N);
121
+ Group X(X_ptr + i*Group::N);
122
+ Eigen::Map<Grad>(dX + i*Group::N) = dZ;
123
+ Eigen::Map<Grad>(dY + i*Group::N) = dZ * X.Adj();
124
+ }
125
+ }
126
+
127
+ template <typename Group, typename scalar_t>
128
+ __global__ void adj_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
129
+ // adjoint forward kernel
130
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
131
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
132
+
133
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
134
+ Group X(X_ptr + i*Group::N);
135
+ Tangent a(a_ptr + i*Group::K);
136
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.Adj(a);
137
+ }
138
+ }
139
+
140
+ template <typename Group, typename scalar_t>
141
+ __global__ void adj_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int num_threads) {
142
+ // adjoint backward kernel
143
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
144
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
145
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
146
+
147
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
148
+ Group X(X_ptr + i*Group::N);
149
+ Grad db(grad + i*Group::K);
150
+
151
+ Tangent a(a_ptr + i*Group::K);
152
+ Tangent b = X.Adj() * a;
153
+
154
+ Eigen::Map<Grad>(da + i*Group::K) = db * X.Adj();
155
+ Eigen::Map<Grad>(dX + i*Group::N) = -db * Group::adj(b);
156
+ }
157
+ }
158
+
159
+
160
+ template <typename Group, typename scalar_t>
161
+ __global__ void adjT_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
162
+ // adjoint forward kernel
163
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
164
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
165
+
166
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
167
+ Group X(X_ptr + i*Group::N);
168
+ Tangent a(a_ptr + i*Group::K);
169
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = X.AdjT(a);
170
+ }
171
+ }
172
+
173
+ template <typename Group, typename scalar_t>
174
+ __global__ void adjT_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* dX, scalar_t* da, int num_threads) {
175
+ // adjoint backward kernel
176
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
177
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
178
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
179
+
180
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
181
+ Group X(X_ptr + i*Group::N);
182
+ Tangent db(grad + i*Group::K);
183
+ Grad a(a_ptr + i*Group::K);
184
+
185
+ Eigen::Map<Tangent>(da + i*Group::K) = X.Adj(db);
186
+ Eigen::Map<Grad>(dX + i*Group::N) = -a * Group::adj(X.Adj(db));
187
+ }
188
+ }
189
+
190
+ template <typename Group, typename scalar_t>
191
+ __global__ void act_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int num_threads) {
192
+ // action on point forward kernel
193
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
194
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
195
+ using Point = Eigen::Matrix<scalar_t,3,1>;
196
+
197
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
198
+ Group X(X_ptr + i*Group::N);
199
+ Point p(p_ptr + i*3);
200
+ Eigen::Map<Point>(q_ptr + i*3) = X * p;
201
+ }
202
+ }
203
+
204
+ template <typename Group, typename scalar_t>
205
+ __global__ void act_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) {
206
+ // adjoint backward kernel
207
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
208
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
209
+ using Point = Eigen::Matrix<scalar_t,3,1>;
210
+ using PointGrad = Eigen::Matrix<scalar_t,1,3>;
211
+ using Transformation = Eigen::Matrix<scalar_t,4,4>;
212
+
213
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
214
+ Group X(X_ptr + i*Group::N);
215
+ Point p(p_ptr + i*3);
216
+ PointGrad dq(grad + i*3);
217
+
218
+ Eigen::Map<PointGrad>(dp + i*3) = dq * X.Matrix4x4().block<3,3>(0,0);
219
+ Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act_jacobian(X*p);
220
+ }
221
+ }
222
+
223
+
224
+ template <typename Group, typename scalar_t>
225
+ __global__ void act4_forward_kernel(const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* q_ptr, int num_threads) {
226
+ // action on point forward kernel
227
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
228
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
229
+ using Point = Eigen::Matrix<scalar_t,4,1>;
230
+
231
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
232
+ Group X(X_ptr + i*Group::N);
233
+ Point p(p_ptr + i*4);
234
+ Eigen::Map<Point>(q_ptr + i*4) = X.act4(p);
235
+ }
236
+ }
237
+
238
+ template <typename Group, typename scalar_t>
239
+ __global__ void act4_backward_kernel(const scalar_t* grad, const scalar_t* X_ptr, const scalar_t* p_ptr, scalar_t* dX, scalar_t* dp, int num_threads) {
240
+ // adjoint backward kernel
241
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
242
+ using Grad = Eigen::Matrix<scalar_t,1,Group::K>;
243
+ using Point = Eigen::Matrix<scalar_t,4,1>;
244
+ using PointGrad = Eigen::Matrix<scalar_t,1,4>;
245
+ using Transformation = Eigen::Matrix<scalar_t,4,4>;
246
+
247
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
248
+ Group X(X_ptr + i*Group::N);
249
+ Point p(p_ptr + i*4);
250
+ PointGrad dq(grad + i*4);
251
+
252
+ Eigen::Map<PointGrad>(dp + i*4) = dq * X.Matrix4x4();
253
+ const Point q = X.act4(p);
254
+ Eigen::Map<Grad>(dX + i*Group::N) = dq * Group::act4_jacobian(q);
255
+ }
256
+ }
257
+
258
+ template <typename Group, typename scalar_t>
259
+ __global__ void as_matrix_forward_kernel(const scalar_t* X_ptr, scalar_t* T_ptr, int num_threads) {
260
+ // convert to 4x4 matrix representation
261
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
262
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
263
+ using Matrix4 = Eigen::Matrix<scalar_t,4,4,Eigen::RowMajor>;
264
+
265
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
266
+ Group X(X_ptr + i*Group::N);
267
+ Eigen::Map<Matrix4>(T_ptr + i*16) = X.Matrix4x4();
268
+ }
269
+ }
270
+
271
+ template <typename Group, typename scalar_t>
272
+ __global__ void orthogonal_projector_kernel(const scalar_t* X_ptr, scalar_t* P_ptr, int num_threads) {
273
+ // orthogonal projection matrix
274
+ using Proj = Eigen::Matrix<scalar_t,Group::N,Group::N,Eigen::RowMajor>;
275
+
276
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
277
+ Group X(X_ptr + i*Group::N);
278
+ Eigen::Map<Proj>(P_ptr + i*Group::N*Group::N) = X.orthogonal_projector();
279
+ }
280
+ }
281
+
282
+ template <typename Group, typename scalar_t>
283
+ __global__ void jleft_forward_kernel(const scalar_t* X_ptr, const scalar_t* a_ptr, scalar_t* b_ptr, int num_threads) {
284
+ // left jacobian inverse action
285
+ using Tangent = Eigen::Matrix<scalar_t,Group::K,1>;
286
+ using Data = Eigen::Matrix<scalar_t,Group::N,1>;
287
+
288
+ GPU_1D_KERNEL_LOOP(i, num_threads) {
289
+ Group X(X_ptr + i*Group::N);
290
+ Tangent a(a_ptr + i*Group::K);
291
+ Tangent b = Group::left_jacobian_inverse(X.Log()) * a;
292
+ Eigen::Map<Tangent>(b_ptr + i*Group::K) = b;
293
+ }
294
+ }
295
+
296
+ // unary operations
297
+
298
+ torch::Tensor exp_forward_gpu(int group_id, torch::Tensor a) {
299
+ int batch_size = a.size(0);
300
+ torch::Tensor X;
301
+
302
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_forward_kernel", ([&] {
303
+ X = torch::zeros({batch_size, group_t::N}, a.options());
304
+ exp_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
305
+ a.data_ptr<scalar_t>(),
306
+ X.data_ptr<scalar_t>(),
307
+ batch_size);
308
+ }));
309
+
310
+ return X;
311
+ }
312
+
313
+ std::vector<torch::Tensor> exp_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor a) {
314
+ int batch_size = a.size(0);
315
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
316
+
317
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, a.type(), "exp_backward_kernel", ([&] {
318
+ exp_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
319
+ grad.data_ptr<scalar_t>(),
320
+ a.data_ptr<scalar_t>(),
321
+ da.data_ptr<scalar_t>(),
322
+ batch_size);
323
+ }));
324
+
325
+ return {da};
326
+ }
327
+
328
+ torch::Tensor log_forward_gpu(int group_id, torch::Tensor X) {
329
+ int batch_size = X.size(0);
330
+ torch::Tensor a;
331
+
332
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_forward_kernel", ([&] {
333
+ a = torch::zeros({batch_size, group_t::K}, X.options());
334
+ log_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
335
+ X.data_ptr<scalar_t>(),
336
+ a.data_ptr<scalar_t>(),
337
+ batch_size);
338
+ }));
339
+
340
+ return a;
341
+ }
342
+
343
+ std::vector<torch::Tensor> log_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X) {
344
+ int batch_size = X.size(0);
345
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
346
+
347
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "log_backward_kernel", ([&] {
348
+ log_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
349
+ grad.data_ptr<scalar_t>(),
350
+ X.data_ptr<scalar_t>(),
351
+ dX.data_ptr<scalar_t>(),
352
+ batch_size);
353
+ }));
354
+
355
+ return {dX};
356
+ }
357
+
358
+ torch::Tensor inv_forward_gpu(int group_id, torch::Tensor X) {
359
+ int batch_size = X.size(0);
360
+ torch::Tensor Y = torch::zeros_like(X);
361
+
362
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_forward_kernel", ([&] {
363
+ inv_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
364
+ X.data_ptr<scalar_t>(),
365
+ Y.data_ptr<scalar_t>(),
366
+ batch_size);
367
+ }));
368
+
369
+ return Y;
370
+ }
371
+
372
+ std::vector<torch::Tensor> inv_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X) {
373
+ int batch_size = X.size(0);
374
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
375
+
376
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "inv_backward_kernel", ([&] {
377
+ inv_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
378
+ grad.data_ptr<scalar_t>(),
379
+ X.data_ptr<scalar_t>(),
380
+ dX.data_ptr<scalar_t>(),
381
+ batch_size);
382
+ }));
383
+
384
+ return {dX};
385
+ }
386
+
387
+ // binary operations
388
+ torch::Tensor mul_forward_gpu(int group_id, torch::Tensor X, torch::Tensor Y) {
389
+ int batch_size = X.size(0);
390
+ torch::Tensor Z = torch::zeros_like(X);
391
+
392
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_forward_kernel", ([&] {
393
+ mul_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
394
+ X.data_ptr<scalar_t>(),
395
+ Y.data_ptr<scalar_t>(),
396
+ Z.data_ptr<scalar_t>(),
397
+ batch_size);
398
+ }));
399
+
400
+ return Z;
401
+ }
402
+
403
+ std::vector<torch::Tensor> mul_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor Y) {
404
+ int batch_size = X.size(0);
405
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
406
+ torch::Tensor dY = torch::zeros(Y.sizes(), grad.options());
407
+
408
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "mul_backward_kernel", ([&] {
409
+ mul_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
410
+ grad.data_ptr<scalar_t>(),
411
+ X.data_ptr<scalar_t>(),
412
+ Y.data_ptr<scalar_t>(),
413
+ dX.data_ptr<scalar_t>(),
414
+ dY.data_ptr<scalar_t>(),
415
+ batch_size);
416
+ }));
417
+
418
+ return {dX, dY};
419
+ }
420
+
421
+ torch::Tensor adj_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
422
+ int batch_size = X.size(0);
423
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
424
+
425
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_forward_kernel", ([&] {
426
+ adj_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
427
+ X.data_ptr<scalar_t>(),
428
+ a.data_ptr<scalar_t>(),
429
+ b.data_ptr<scalar_t>(),
430
+ batch_size);
431
+ }));
432
+
433
+ return b;
434
+ }
435
+
436
+ std::vector<torch::Tensor> adj_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
437
+ int batch_size = X.size(0);
438
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
439
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
440
+
441
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adj_backward_kernel", ([&] {
442
+ adj_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
443
+ grad.data_ptr<scalar_t>(),
444
+ X.data_ptr<scalar_t>(),
445
+ a.data_ptr<scalar_t>(),
446
+ dX.data_ptr<scalar_t>(),
447
+ da.data_ptr<scalar_t>(),
448
+ batch_size);
449
+ }));
450
+
451
+ return {dX, da};
452
+ }
453
+
454
+
455
+ torch::Tensor adjT_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
456
+ int batch_size = X.size(0);
457
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
458
+
459
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_forward_kernel", ([&] {
460
+ adjT_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
461
+ X.data_ptr<scalar_t>(),
462
+ a.data_ptr<scalar_t>(),
463
+ b.data_ptr<scalar_t>(),
464
+ batch_size);
465
+ }));
466
+
467
+ return b;
468
+ }
469
+
470
+ std::vector<torch::Tensor> adjT_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor a) {
471
+ int batch_size = X.size(0);
472
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
473
+ torch::Tensor da = torch::zeros(a.sizes(), grad.options());
474
+
475
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "adjT_backward_kernel", ([&] {
476
+ adjT_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
477
+ grad.data_ptr<scalar_t>(),
478
+ X.data_ptr<scalar_t>(),
479
+ a.data_ptr<scalar_t>(),
480
+ dX.data_ptr<scalar_t>(),
481
+ da.data_ptr<scalar_t>(),
482
+ batch_size);
483
+ }));
484
+
485
+ return {dX, da};
486
+ }
487
+
488
+
489
+
490
+ torch::Tensor act_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
491
+ int batch_size = X.size(0);
492
+ torch::Tensor q = torch::zeros(p.sizes(), p.options());
493
+
494
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_forward_kernel", ([&] {
495
+ act_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
496
+ X.data_ptr<scalar_t>(),
497
+ p.data_ptr<scalar_t>(),
498
+ q.data_ptr<scalar_t>(),
499
+ batch_size);
500
+ }));
501
+
502
+ return q;
503
+ }
504
+
505
+ std::vector<torch::Tensor> act_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
506
+ int batch_size = X.size(0);
507
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
508
+ torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
509
+
510
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act_backward_kernel", ([&] {
511
+ act_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
512
+ grad.data_ptr<scalar_t>(),
513
+ X.data_ptr<scalar_t>(),
514
+ p.data_ptr<scalar_t>(),
515
+ dX.data_ptr<scalar_t>(),
516
+ dp.data_ptr<scalar_t>(),
517
+ batch_size);
518
+ }));
519
+
520
+ return {dX, dp};
521
+ }
522
+
523
+ torch::Tensor act4_forward_gpu(int group_id, torch::Tensor X, torch::Tensor p) {
524
+ int batch_size = X.size(0);
525
+ torch::Tensor q = torch::zeros(p.sizes(), p.options());
526
+
527
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_forward_kernel", ([&] {
528
+ act4_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
529
+ X.data_ptr<scalar_t>(),
530
+ p.data_ptr<scalar_t>(),
531
+ q.data_ptr<scalar_t>(),
532
+ batch_size);
533
+ }));
534
+
535
+ return q;
536
+ }
537
+
538
+ std::vector<torch::Tensor> act4_backward_gpu(int group_id, torch::Tensor grad, torch::Tensor X, torch::Tensor p) {
539
+ int batch_size = X.size(0);
540
+ torch::Tensor dX = torch::zeros(X.sizes(), grad.options());
541
+ torch::Tensor dp = torch::zeros(p.sizes(), grad.options());
542
+
543
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "act4_backward_kernel", ([&] {
544
+ act4_backward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
545
+ grad.data_ptr<scalar_t>(),
546
+ X.data_ptr<scalar_t>(),
547
+ p.data_ptr<scalar_t>(),
548
+ dX.data_ptr<scalar_t>(),
549
+ dp.data_ptr<scalar_t>(),
550
+ batch_size);
551
+ }));
552
+
553
+ return {dX, dp};
554
+ }
555
+
556
+
557
+ torch::Tensor as_matrix_forward_gpu(int group_id, torch::Tensor X) {
558
+ int batch_size = X.size(0);
559
+ torch::Tensor T4x4 = torch::zeros({X.size(0), 4, 4}, X.options());
560
+
561
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "as_matrix_forward_kernel", ([&] {
562
+ as_matrix_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
563
+ X.data_ptr<scalar_t>(),
564
+ T4x4.data_ptr<scalar_t>(),
565
+ batch_size);
566
+ }));
567
+
568
+ return T4x4;
569
+ }
570
+
571
+
572
+ torch::Tensor orthogonal_projector_gpu(int group_id, torch::Tensor X) {
573
+ int batch_size = X.size(0);
574
+ torch::Tensor P;
575
+
576
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "orthogonal_projector_kernel", ([&] {
577
+ P = torch::zeros({X.size(0), group_t::N, group_t::N}, X.options());
578
+ orthogonal_projector_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
579
+ X.data_ptr<scalar_t>(),
580
+ P.data_ptr<scalar_t>(),
581
+ batch_size);
582
+ }));
583
+
584
+ return P;
585
+ }
586
+
587
+
588
+ torch::Tensor jleft_forward_gpu(int group_id, torch::Tensor X, torch::Tensor a) {
589
+ int batch_size = X.size(0);
590
+ torch::Tensor b = torch::zeros(a.sizes(), a.options());
591
+
592
+ DISPATCH_GROUP_AND_FLOATING_TYPES(group_id, X.type(), "jleft_forward_kernel", ([&] {
593
+ jleft_forward_kernel<group_t, scalar_t><<<NUM_BLOCKS(batch_size), NUM_THREADS>>>(
594
+ X.data_ptr<scalar_t>(),
595
+ a.data_ptr<scalar_t>(),
596
+ b.data_ptr<scalar_t>(),
597
+ batch_size);
598
+ }));
599
+
600
+ return b;
601
+ }
mini_dpvo/logger.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch.utils.tensorboard import SummaryWriter
4
+
5
+
6
+ SUM_FREQ = 100
7
+
8
+ class Logger:
9
+ def __init__(self, name, scheduler):
10
+ self.total_steps = 0
11
+ self.running_loss = {}
12
+ self.writer = None
13
+ self.name = name
14
+ self.scheduler = scheduler
15
+
16
+ def _print_training_status(self):
17
+ if self.writer is None:
18
+ self.writer = SummaryWriter("runs/{}".format(self.name))
19
+ print([k for k in self.running_loss])
20
+
21
+ lr = self.scheduler.get_lr().pop()
22
+ metrics_data = [self.running_loss[k]/SUM_FREQ for k in self.running_loss.keys()]
23
+ training_str = "[{:6d}, {:10.7f}] ".format(self.total_steps+1, lr)
24
+ metrics_str = ("{:10.4f}, "*len(metrics_data)).format(*metrics_data)
25
+
26
+ # print the training status
27
+ print(training_str + metrics_str)
28
+
29
+ for key in self.running_loss:
30
+ val = self.running_loss[key] / SUM_FREQ
31
+ self.writer.add_scalar(key, val, self.total_steps)
32
+ self.running_loss[key] = 0.0
33
+
34
+ def push(self, metrics):
35
+
36
+ for key in metrics:
37
+ if key not in self.running_loss:
38
+ self.running_loss[key] = 0.0
39
+
40
+ self.running_loss[key] += metrics[key]
41
+
42
+ if self.total_steps % SUM_FREQ == SUM_FREQ-1:
43
+ self._print_training_status()
44
+ self.running_loss = {}
45
+
46
+ self.total_steps += 1
47
+
48
+ def write_dict(self, results):
49
+ if self.writer is None:
50
+ self.writer = SummaryWriter("runs/{}".format(self.name))
51
+ print([k for k in self.running_loss])
52
+
53
+ for key in results:
54
+ self.writer.add_scalar(key, results[key], self.total_steps)
55
+
56
+ def close(self):
57
+ self.writer.close()
58
+
mini_dpvo/net.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from collections import OrderedDict
6
+
7
+ import torch_scatter
8
+ from torch_scatter import scatter_sum
9
+
10
+ from . import fastba
11
+ from . import altcorr
12
+ from . import lietorch
13
+ from .lietorch import SE3
14
+
15
+ from .extractor import BasicEncoder, BasicEncoder4
16
+ from .blocks import GradientClip, GatedResidual, SoftAgg
17
+
18
+ from .utils import *
19
+ from .ba import BA
20
+ from . import projective_ops as pops
21
+
22
+ autocast = torch.cuda.amp.autocast
23
+ import matplotlib.pyplot as plt
24
+
25
+ DIM = 384
26
+
27
+ class Update(nn.Module):
28
+ def __init__(self, p):
29
+ super(Update, self).__init__()
30
+
31
+ self.c1 = nn.Sequential(
32
+ nn.Linear(DIM, DIM),
33
+ nn.ReLU(inplace=True),
34
+ nn.Linear(DIM, DIM))
35
+
36
+ self.c2 = nn.Sequential(
37
+ nn.Linear(DIM, DIM),
38
+ nn.ReLU(inplace=True),
39
+ nn.Linear(DIM, DIM))
40
+
41
+ self.norm = nn.LayerNorm(DIM, eps=1e-3)
42
+
43
+ self.agg_kk = SoftAgg(DIM)
44
+ self.agg_ij = SoftAgg(DIM)
45
+
46
+ self.gru = nn.Sequential(
47
+ nn.LayerNorm(DIM, eps=1e-3),
48
+ GatedResidual(DIM),
49
+ nn.LayerNorm(DIM, eps=1e-3),
50
+ GatedResidual(DIM),
51
+ )
52
+
53
+ self.corr = nn.Sequential(
54
+ nn.Linear(2*49*p*p, DIM),
55
+ nn.ReLU(inplace=True),
56
+ nn.Linear(DIM, DIM),
57
+ nn.LayerNorm(DIM, eps=1e-3),
58
+ nn.ReLU(inplace=True),
59
+ nn.Linear(DIM, DIM),
60
+ )
61
+
62
+ self.d = nn.Sequential(
63
+ nn.ReLU(inplace=False),
64
+ nn.Linear(DIM, 2),
65
+ GradientClip())
66
+
67
+ self.w = nn.Sequential(
68
+ nn.ReLU(inplace=False),
69
+ nn.Linear(DIM, 2),
70
+ GradientClip(),
71
+ nn.Sigmoid())
72
+
73
+
74
+ def forward(self, net, inp, corr, flow, ii, jj, kk):
75
+ """ update operator """
76
+
77
+ net = net + inp + self.corr(corr)
78
+ net = self.norm(net)
79
+
80
+ ix, jx = fastba.neighbors(kk, jj)
81
+ mask_ix = (ix >= 0).float().reshape(1, -1, 1)
82
+ mask_jx = (jx >= 0).float().reshape(1, -1, 1)
83
+
84
+ net = net + self.c1(mask_ix * net[:,ix])
85
+ net = net + self.c2(mask_jx * net[:,jx])
86
+
87
+ net = net + self.agg_kk(net, kk)
88
+ net = net + self.agg_ij(net, ii*12345 + jj)
89
+
90
+ net = self.gru(net)
91
+
92
+ return net, (self.d(net), self.w(net), None)
93
+
94
+
95
+ class Patchifier(nn.Module):
96
+ def __init__(self, patch_size=3):
97
+ super(Patchifier, self).__init__()
98
+ self.patch_size = patch_size
99
+ self.fnet = BasicEncoder4(output_dim=128, norm_fn='instance')
100
+ self.inet = BasicEncoder4(output_dim=DIM, norm_fn='none')
101
+
102
+ def __image_gradient(self, images):
103
+ gray = ((images + 0.5) * (255.0 / 2)).sum(dim=2)
104
+ dx = gray[...,:-1,1:] - gray[...,:-1,:-1]
105
+ dy = gray[...,1:,:-1] - gray[...,:-1,:-1]
106
+ g = torch.sqrt(dx**2 + dy**2)
107
+ g = F.avg_pool2d(g, 4, 4)
108
+ return g
109
+
110
+ def forward(self, images, patches_per_image=80, disps=None, gradient_bias=False, return_color=False):
111
+ """ extract patches from input images """
112
+ fmap = self.fnet(images) / 4.0
113
+ imap = self.inet(images) / 4.0
114
+
115
+ b, n, c, h, w = fmap.shape
116
+ P = self.patch_size
117
+
118
+ # bias patch selection towards regions with high gradient
119
+ if gradient_bias:
120
+ g = self.__image_gradient(images)
121
+ x = torch.randint(1, w-1, size=[n, 3*patches_per_image], device="cuda")
122
+ y = torch.randint(1, h-1, size=[n, 3*patches_per_image], device="cuda")
123
+
124
+ coords = torch.stack([x, y], dim=-1).float()
125
+ g = altcorr.patchify(g[0,:,None], coords, 0).view(n, 3 * patches_per_image)
126
+
127
+ ix = torch.argsort(g, dim=1)
128
+ x = torch.gather(x, 1, ix[:, -patches_per_image:])
129
+ y = torch.gather(y, 1, ix[:, -patches_per_image:])
130
+
131
+ else:
132
+ x = torch.randint(1, w-1, size=[n, patches_per_image], device="cuda")
133
+ y = torch.randint(1, h-1, size=[n, patches_per_image], device="cuda")
134
+
135
+ coords = torch.stack([x, y], dim=-1).float()
136
+ imap = altcorr.patchify(imap[0], coords, 0).view(b, -1, DIM, 1, 1)
137
+ gmap = altcorr.patchify(fmap[0], coords, P//2).view(b, -1, 128, P, P)
138
+
139
+ if return_color:
140
+ clr = altcorr.patchify(images[0], 4*(coords + 0.5), 0).view(b, -1, 3)
141
+
142
+ if disps is None:
143
+ disps = torch.ones(b, n, h, w, device="cuda")
144
+
145
+ grid, _ = coords_grid_with_index(disps, device=fmap.device)
146
+ patches = altcorr.patchify(grid[0], coords, P//2).view(b, -1, 3, P, P)
147
+
148
+ index = torch.arange(n, device="cuda").view(n, 1)
149
+ index = index.repeat(1, patches_per_image).reshape(-1)
150
+
151
+ if return_color:
152
+ return fmap, gmap, imap, patches, index, clr
153
+
154
+ return fmap, gmap, imap, patches, index
155
+
156
+
157
+ class CorrBlock:
158
+ def __init__(self, fmap, gmap, radius=3, dropout=0.2, levels=[1,4]):
159
+ self.dropout = dropout
160
+ self.radius = radius
161
+ self.levels = levels
162
+
163
+ self.gmap = gmap
164
+ self.pyramid = pyramidify(fmap, lvls=levels)
165
+
166
+ def __call__(self, ii, jj, coords):
167
+ corrs = []
168
+ for i in range(len(self.levels)):
169
+ corrs += [ altcorr.corr(self.gmap, self.pyramid[i], coords / self.levels[i], ii, jj, self.radius, self.dropout) ]
170
+ return torch.stack(corrs, -1).view(1, len(ii), -1)
171
+
172
+
173
+ class VONet(nn.Module):
174
+ def __init__(self, use_viewer=False):
175
+ super(VONet, self).__init__()
176
+ self.P = 3
177
+ self.patchify = Patchifier(self.P)
178
+ self.update = Update(self.P)
179
+
180
+ self.DIM = DIM
181
+ self.RES = 4
182
+
183
+
184
+ @autocast(enabled=False)
185
+ def forward(self, images, poses, disps, intrinsics, M=1024, STEPS=12, P=1, structure_only=False, rescale=False):
186
+ """ Estimates SE3 or Sim3 between pair of frames """
187
+
188
+ images = 2 * (images / 255.0) - 0.5
189
+ intrinsics = intrinsics / 4.0
190
+ disps = disps[:, :, 1::4, 1::4].float()
191
+
192
+ fmap, gmap, imap, patches, ix = self.patchify(images, disps=disps)
193
+
194
+ corr_fn = CorrBlock(fmap, gmap)
195
+
196
+ b, N, c, h, w = fmap.shape
197
+ p = self.P
198
+
199
+ patches_gt = patches.clone()
200
+ Ps = poses
201
+
202
+ d = patches[..., 2, p//2, p//2]
203
+ patches = set_depth(patches, torch.rand_like(d))
204
+
205
+ kk, jj = flatmeshgrid(torch.where(ix < 8)[0], torch.arange(0,8, device="cuda"))
206
+ ii = ix[kk]
207
+
208
+ imap = imap.view(b, -1, DIM)
209
+ net = torch.zeros(b, len(kk), DIM, device="cuda", dtype=torch.float)
210
+
211
+ Gs = SE3.IdentityLike(poses)
212
+
213
+ if structure_only:
214
+ Gs.data[:] = poses.data[:]
215
+
216
+ traj = []
217
+ bounds = [-64, -64, w + 64, h + 64]
218
+
219
+ while len(traj) < STEPS:
220
+ Gs = Gs.detach()
221
+ patches = patches.detach()
222
+
223
+ n = ii.max() + 1
224
+ if len(traj) >= 8 and n < images.shape[1]:
225
+ if not structure_only: Gs.data[:,n] = Gs.data[:,n-1]
226
+ kk1, jj1 = flatmeshgrid(torch.where(ix < n)[0], torch.arange(n, n+1, device="cuda"))
227
+ kk2, jj2 = flatmeshgrid(torch.where(ix == n)[0], torch.arange(0, n+1, device="cuda"))
228
+
229
+ ii = torch.cat([ix[kk1], ix[kk2], ii])
230
+ jj = torch.cat([jj1, jj2, jj])
231
+ kk = torch.cat([kk1, kk2, kk])
232
+
233
+ net1 = torch.zeros(b, len(kk1) + len(kk2), DIM, device="cuda")
234
+ net = torch.cat([net1, net], dim=1)
235
+
236
+ if np.random.rand() < 0.1:
237
+ k = (ii != (n - 4)) & (jj != (n - 4))
238
+ ii = ii[k]
239
+ jj = jj[k]
240
+ kk = kk[k]
241
+ net = net[:,k]
242
+
243
+ patches[:,ix==n,2] = torch.median(patches[:,(ix == n-1) | (ix == n-2),2])
244
+ n = ii.max() + 1
245
+
246
+ coords = pops.transform(Gs, patches, intrinsics, ii, jj, kk)
247
+ coords1 = coords.permute(0, 1, 4, 2, 3).contiguous()
248
+
249
+ corr = corr_fn(kk, jj, coords1)
250
+ net, (delta, weight, _) = self.update(net, imap[:,kk], corr, None, ii, jj, kk)
251
+
252
+ lmbda = 1e-4
253
+ target = coords[...,p//2,p//2,:] + delta
254
+
255
+ ep = 10
256
+ for itr in range(2):
257
+ Gs, patches = BA(Gs, patches, intrinsics, target, weight, lmbda, ii, jj, kk,
258
+ bounds, ep=ep, fixedp=1, structure_only=structure_only)
259
+
260
+ kl = torch.as_tensor(0)
261
+ dij = (ii - jj).abs()
262
+ k = (dij > 0) & (dij <= 2)
263
+
264
+ coords = pops.transform(Gs, patches, intrinsics, ii[k], jj[k], kk[k])
265
+ coords_gt, valid, _ = pops.transform(Ps, patches_gt, intrinsics, ii[k], jj[k], kk[k], jacobian=True)
266
+
267
+ traj.append((valid, coords, coords_gt, Gs[:,:n], Ps[:,:n], kl))
268
+
269
+ return traj
270
+
mini_dpvo/plot_utils.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ from evo.core import sync
6
+ from evo.core.trajectory import PoseTrajectory3D
7
+ # from evo.tools import plot
8
+ from pathlib import Path
9
+
10
+
11
+ def make_traj(args) -> PoseTrajectory3D:
12
+ if isinstance(args, tuple):
13
+ traj, tstamps = args
14
+ return PoseTrajectory3D(positions_xyz=traj[:,:3], orientations_quat_wxyz=traj[:,3:], timestamps=tstamps)
15
+ assert isinstance(args, PoseTrajectory3D), type(args)
16
+ return deepcopy(args)
17
+
18
+ def best_plotmode(traj):
19
+ _, i1, i2 = np.argsort(np.var(traj.positions_xyz, axis=0))
20
+ plot_axes = "xyz"[i2] + "xyz"[i1]
21
+ return getattr(plot.PlotMode, plot_axes)
22
+
23
+ def plot_trajectory(pred_traj, gt_traj=None, title="", filename="", align=True, correct_scale=True):
24
+ pred_traj = make_traj(pred_traj)
25
+
26
+ if gt_traj is not None:
27
+ gt_traj = make_traj(gt_traj)
28
+ gt_traj, pred_traj = sync.associate_trajectories(gt_traj, pred_traj)
29
+
30
+ if align:
31
+ pred_traj.align(gt_traj, correct_scale=correct_scale)
32
+
33
+ plot_collection = plot.PlotCollection("PlotCol")
34
+ fig = plt.figure(figsize=(8, 8))
35
+ plot_mode = best_plotmode(gt_traj if (gt_traj is not None) else pred_traj)
36
+ ax = plot.prepare_axis(fig, plot_mode)
37
+ ax.set_title(title)
38
+ if gt_traj is not None:
39
+ plot.traj(ax, plot_mode, gt_traj, '--', 'gray', "Ground Truth")
40
+ plot.traj(ax, plot_mode, pred_traj, '-', 'blue', "Predicted")
41
+ plot_collection.add_figure("traj (error)", fig)
42
+ plot_collection.export(filename, confirm_overwrite=False)
43
+ plt.close(fig=fig)
44
+ print(f"Saved {filename}")
45
+
46
+ def save_trajectory_tum_format(traj, filename):
47
+ traj = make_traj(traj)
48
+ tostr = lambda a: ' '.join(map(str, a))
49
+ with Path(filename).open('w') as f:
50
+ for i in range(traj.num_poses):
51
+ f.write(f"{traj.timestamps[i]} {tostr(traj.positions_xyz[i])} {tostr(traj.orientations_quat_wxyz[i][[1,2,3,0]])}\n")
52
+ print(f"Saved {filename}")
mini_dpvo/projective_ops.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+ from .lietorch import SE3, Sim3
5
+
6
+ MIN_DEPTH = 0.2
7
+
8
+ def extract_intrinsics(intrinsics):
9
+ return intrinsics[...,None,None,:].unbind(dim=-1)
10
+
11
+ def coords_grid(ht, wd, **kwargs):
12
+ y, x = torch.meshgrid(
13
+ torch.arange(ht).to(**kwargs).float(),
14
+ torch.arange(wd).to(**kwargs).float())
15
+
16
+ return torch.stack([x, y], dim=-1)
17
+
18
+
19
+ def iproj(patches, intrinsics):
20
+ """ inverse projection """
21
+ x, y, d = patches.unbind(dim=2)
22
+ fx, fy, cx, cy = intrinsics[...,None,None].unbind(dim=2)
23
+
24
+ i = torch.ones_like(d)
25
+ xn = (x - cx) / fx
26
+ yn = (y - cy) / fy
27
+
28
+ X = torch.stack([xn, yn, i, d], dim=-1)
29
+ return X
30
+
31
+
32
+ def proj(X, intrinsics, depth=False):
33
+ """ projection """
34
+
35
+ X, Y, Z, W = X.unbind(dim=-1)
36
+ fx, fy, cx, cy = intrinsics[...,None,None].unbind(dim=2)
37
+
38
+ # d = 0.01 * torch.ones_like(Z)
39
+ # d[Z > 0.01] = 1.0 / Z[Z > 0.01]
40
+ # d = torch.ones_like(Z)
41
+ # d[Z.abs() > 0.1] = 1.0 / Z[Z.abs() > 0.1]
42
+
43
+ d = 1.0 / Z.clamp(min=0.1)
44
+ x = fx * (d * X) + cx
45
+ y = fy * (d * Y) + cy
46
+
47
+ if depth:
48
+ return torch.stack([x, y, d], dim=-1)
49
+
50
+ return torch.stack([x, y], dim=-1)
51
+
52
+
53
+ def transform(poses, patches, intrinsics, ii, jj, kk, depth=False, valid=False, jacobian=False, tonly=False):
54
+ """ projective transform """
55
+
56
+ # backproject
57
+ X0 = iproj(patches[:,kk], intrinsics[:,ii])
58
+
59
+ # transform
60
+ Gij = poses[:, jj] * poses[:, ii].inv()
61
+
62
+ if tonly:
63
+ Gij[...,3:] = torch.as_tensor([0,0,0,1], device=Gij.device)
64
+
65
+ X1 = Gij[:,:,None,None] * X0
66
+
67
+ # project
68
+ x1 = proj(X1, intrinsics[:,jj], depth)
69
+
70
+
71
+ if jacobian:
72
+ p = X1.shape[2]
73
+ X, Y, Z, H = X1[...,p//2,p//2,:].unbind(dim=-1)
74
+ o = torch.zeros_like(H)
75
+ i = torch.zeros_like(H)
76
+
77
+ fx, fy, cx, cy = intrinsics[:,jj].unbind(dim=-1)
78
+
79
+ d = torch.zeros_like(Z)
80
+ d[Z.abs() > 0.2] = 1.0 / Z[Z.abs() > 0.2]
81
+
82
+ Ja = torch.stack([
83
+ H, o, o, o, Z, -Y,
84
+ o, H, o, -Z, o, X,
85
+ o, o, H, Y, -X, o,
86
+ o, o, o, o, o, o,
87
+ ], dim=-1).view(1, len(ii), 4, 6)
88
+
89
+ Jp = torch.stack([
90
+ fx*d, o, -fx*X*d*d, o,
91
+ o, fy*d, -fy*Y*d*d, o,
92
+ ], dim=-1).view(1, len(ii), 2, 4)
93
+
94
+ Jj = torch.matmul(Jp, Ja)
95
+ Ji = -Gij[:,:,None].adjT(Jj)
96
+
97
+ Jz = torch.matmul(Jp, Gij.matrix()[...,:,3:])
98
+
99
+ return x1, (Z > 0.2).float(), (Ji, Jj, Jz)
100
+
101
+ if valid:
102
+ return x1, (X1[...,2] > 0.2).float()
103
+
104
+ return x1
105
+
106
+ def point_cloud(poses, patches, intrinsics, ix):
107
+ """ generate point cloud from patches """
108
+ return poses[:,ix,None,None].inv() * iproj(patches, intrinsics[:,ix])
109
+
110
+
111
+ def flow_mag(poses, patches, intrinsics, ii, jj, kk, beta=0.3):
112
+ """ projective transform """
113
+
114
+ coords0 = transform(poses, patches, intrinsics, ii, ii, kk)
115
+ coords1 = transform(poses, patches, intrinsics, ii, jj, kk, tonly=False)
116
+ coords2 = transform(poses, patches, intrinsics, ii, jj, kk, tonly=True)
117
+
118
+ flow1 = (coords1 - coords0).norm(dim=-1)
119
+ flow2 = (coords2 - coords0).norm(dim=-1)
120
+
121
+ return beta * flow1 + (1-beta) * flow2
mini_dpvo/stream.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from pathlib import Path
4
+ from itertools import chain
5
+ from multiprocessing import Queue
6
+
7
+
8
+ def image_stream(
9
+ queue: Queue, imagedir: str, calib: str, stride: int, skip: int = 0
10
+ ) -> None:
11
+ """image generator"""
12
+
13
+ calib = np.loadtxt(calib, delimiter=" ")
14
+ fx, fy, cx, cy = calib[:4]
15
+
16
+ K = np.eye(3)
17
+ K[0, 0] = fx
18
+ K[0, 2] = cx
19
+ K[1, 1] = fy
20
+ K[1, 2] = cy
21
+
22
+ img_exts = ["*.png", "*.jpeg", "*.jpg"]
23
+ image_list = sorted(chain.from_iterable(Path(imagedir).glob(e) for e in img_exts))[
24
+ skip::stride
25
+ ]
26
+
27
+ for t, imfile in enumerate(image_list):
28
+ image = cv2.imread(str(imfile))
29
+ if len(calib) > 4:
30
+ image = cv2.undistort(image, K, calib[4:])
31
+
32
+ if 0:
33
+ image = cv2.resize(image, None, fx=0.5, fy=0.5)
34
+ intrinsics = np.array([fx / 2, fy / 2, cx / 2, cy / 2])
35
+
36
+ else:
37
+ intrinsics = np.array([fx, fy, cx, cy])
38
+
39
+ h, w, _ = image.shape
40
+ image = image[: h - h % 16, : w - w % 16]
41
+
42
+ queue.put((t, image, intrinsics))
43
+
44
+ queue.put((-1, image, intrinsics))
45
+
46
+
47
+ def video_stream(
48
+ queue: Queue, imagedir: str, calib: str, stride: int, skip: int = 0
49
+ ) -> None:
50
+ """video generator"""
51
+
52
+ calib = np.loadtxt(calib, delimiter=" ")
53
+ fx, fy, cx, cy = calib[:4]
54
+
55
+ K = np.eye(3)
56
+ K[0, 0] = fx
57
+ K[0, 2] = cx
58
+ K[1, 1] = fy
59
+ K[1, 2] = cy
60
+
61
+ cap = cv2.VideoCapture(imagedir)
62
+
63
+ t = 0
64
+
65
+ for _ in range(skip):
66
+ ret, image = cap.read()
67
+
68
+ while True:
69
+ # Capture frame-by-frame
70
+ for _ in range(stride):
71
+ ret, image = cap.read()
72
+ # if frame is read correctly ret is True
73
+ if not ret:
74
+ break
75
+
76
+ if not ret:
77
+ break
78
+
79
+ if len(calib) > 4:
80
+ image = cv2.undistort(image, K, calib[4:])
81
+
82
+ image = cv2.resize(image, None, fx=0.5, fy=0.5, interpolation=cv2.INTER_AREA)
83
+ h, w, _ = image.shape
84
+ image = image[: h - h % 16, : w - w % 16]
85
+
86
+ intrinsics = np.array([fx * 0.5, fy * 0.5, cx * 0.5, cy * 0.5])
87
+ queue.put((t, image, intrinsics))
88
+
89
+ t += 1
90
+
91
+ queue.put((-1, image, intrinsics))
92
+ cap.release()
mini_dpvo/utils.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+
4
+
5
+ all_times = []
6
+
7
+
8
+ class Timer:
9
+ def __init__(self, name: str, enabled: bool = True):
10
+ self.name = name
11
+ self.enabled = enabled
12
+
13
+ if self.enabled:
14
+ self.start = torch.cuda.Event(enable_timing=True)
15
+ self.end = torch.cuda.Event(enable_timing=True)
16
+
17
+ def __enter__(self):
18
+ if self.enabled:
19
+ self.start.record()
20
+
21
+ def __exit__(self, type, value, traceback):
22
+ global all_times
23
+ if self.enabled:
24
+ self.end.record()
25
+ torch.cuda.synchronize()
26
+
27
+ elapsed = self.start.elapsed_time(self.end)
28
+ all_times.append(elapsed)
29
+ print(f"{self.name}: {elapsed:.2f}ms")
30
+
31
+
32
+ def coords_grid(b, n, h, w, **kwargs):
33
+ """coordinate grid"""
34
+ x = torch.arange(0, w, dtype=torch.float, **kwargs)
35
+ y = torch.arange(0, h, dtype=torch.float, **kwargs)
36
+ coords = torch.stack(torch.meshgrid(y, x, indexing="ij"))
37
+ return coords[[1, 0]].view(1, 1, 2, h, w).repeat(b, n, 1, 1, 1)
38
+
39
+
40
+ def coords_grid_with_index(d, **kwargs):
41
+ """coordinate grid with frame index"""
42
+ b, n, h, w = d.shape
43
+ i = torch.ones_like(d)
44
+ x = torch.arange(0, w, dtype=torch.float, **kwargs)
45
+ y = torch.arange(0, h, dtype=torch.float, **kwargs)
46
+
47
+ y, x = torch.stack(torch.meshgrid(y, x, indexing="ij"))
48
+ y = y.view(1, 1, h, w).repeat(b, n, 1, 1)
49
+ x = x.view(1, 1, h, w).repeat(b, n, 1, 1)
50
+
51
+ coords = torch.stack([x, y, d], dim=2)
52
+ index = torch.arange(0, n, dtype=torch.float, **kwargs)
53
+ index = index.view(1, n, 1, 1, 1).repeat(b, 1, 1, h, w)
54
+
55
+ return coords, index
56
+
57
+
58
+ def patchify(x, patch_size=3):
59
+ """extract patches from video"""
60
+ b, n, c, h, w = x.shape
61
+ x = x.view(b * n, c, h, w)
62
+ y = F.unfold(x, patch_size)
63
+ y = y.transpose(1, 2)
64
+ return y.reshape(b, -1, c, patch_size, patch_size)
65
+
66
+
67
+ def pyramidify(fmap, lvls=[1]):
68
+ """turn fmap into a pyramid"""
69
+ b, n, c, h, w = fmap.shape
70
+
71
+ pyramid = []
72
+ for lvl in lvls:
73
+ gmap = F.avg_pool2d(fmap.view(b * n, c, h, w), lvl, stride=lvl)
74
+ pyramid += [gmap.view(b, n, c, h // lvl, w // lvl)]
75
+
76
+ return pyramid
77
+
78
+
79
+ def all_pairs_exclusive(n, **kwargs):
80
+ ii, jj = torch.meshgrid(torch.arange(n, **kwargs), torch.arange(n, **kwargs))
81
+ k = ii != jj
82
+ return ii[k].reshape(-1), jj[k].reshape(-1)
83
+
84
+
85
+ def set_depth(patches, depth):
86
+ patches[..., 2, :, :] = depth[..., None, None]
87
+ return patches
88
+
89
+
90
+ def flatmeshgrid(*args, **kwargs):
91
+ grid = torch.meshgrid(*args, **kwargs)
92
+ return (x.reshape(-1) for x in grid)