xiaoyuxi commited on
Commit
c8d9d42
·
0 Parent(s):

Cleaned history, reset to current state

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +69 -0
  3. README.md +14 -0
  4. _viz/viz_template.html +1778 -0
  5. app.py +1118 -0
  6. app_3rd/README.md +12 -0
  7. app_3rd/sam_utils/hf_sam_predictor.py +129 -0
  8. app_3rd/sam_utils/inference.py +123 -0
  9. app_3rd/spatrack_utils/infer_track.py +194 -0
  10. config/__init__.py +0 -0
  11. config/magic_infer_moge.yaml +48 -0
  12. examples/backpack.mp4 +3 -0
  13. examples/ball.mp4 +3 -0
  14. examples/basketball.mp4 +3 -0
  15. examples/biker.mp4 +3 -0
  16. examples/cinema_0.mp4 +3 -0
  17. examples/cinema_1.mp4 +3 -0
  18. examples/drifting.mp4 +3 -0
  19. examples/ego_kc1.mp4 +3 -0
  20. examples/ego_teaser.mp4 +3 -0
  21. examples/handwave.mp4 +3 -0
  22. examples/hockey.mp4 +3 -0
  23. examples/ken_block_0.mp4 +3 -0
  24. examples/kiss.mp4 +3 -0
  25. examples/kitchen.mp4 +3 -0
  26. examples/kitchen_egocentric.mp4 +3 -0
  27. examples/pillow.mp4 +3 -0
  28. examples/protein.mp4 +3 -0
  29. examples/pusht.mp4 +3 -0
  30. examples/robot1.mp4 +3 -0
  31. examples/robot2.mp4 +3 -0
  32. examples/robot_3.mp4 +3 -0
  33. examples/robot_unitree.mp4 +3 -0
  34. examples/running.mp4 +3 -0
  35. examples/teleop2.mp4 +3 -0
  36. examples/vertical_place.mp4 +3 -0
  37. models/SpaTrackV2/models/SpaTrack.py +759 -0
  38. models/SpaTrackV2/models/__init__.py +0 -0
  39. models/SpaTrackV2/models/blocks.py +519 -0
  40. models/SpaTrackV2/models/camera_transform.py +248 -0
  41. models/SpaTrackV2/models/depth_refiner/backbone.py +472 -0
  42. models/SpaTrackV2/models/depth_refiner/decode_head.py +619 -0
  43. models/SpaTrackV2/models/depth_refiner/depth_refiner.py +115 -0
  44. models/SpaTrackV2/models/depth_refiner/network.py +429 -0
  45. models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +1187 -0
  46. models/SpaTrackV2/models/depth_refiner/stablizer.py +342 -0
  47. models/SpaTrackV2/models/predictor.py +153 -0
  48. models/SpaTrackV2/models/tracker3D/TrackRefiner.py +1478 -0
  49. models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py +418 -0
  50. models/SpaTrackV2/models/tracker3D/co_tracker/utils.py +929 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore the multi media
2
+ checkpoints
3
+ **/checkpoints/
4
+ **/temp/
5
+ temp
6
+ assets_dev
7
+ assets/example0/results
8
+ assets/example0/snowboard.npz
9
+ assets/example1/results
10
+ assets/davis_eval
11
+ assets/*/results
12
+ *gradio*
13
+ #
14
+ models/monoD/zoeDepth/ckpts/*
15
+ models/monoD/depth_anything/ckpts/*
16
+ vis_results
17
+ dist_encrypted
18
+ # remove the dependencies
19
+ deps
20
+
21
+ # filter the __pycache__ files
22
+ __pycache__/
23
+ /**/**/__pycache__
24
+ /**/__pycache__
25
+
26
+ outputs
27
+ scripts/lauch_exp/config
28
+ scripts/lauch_exp/submit_job.log
29
+ scripts/lauch_exp/hydra_output
30
+ scripts/lauch_wulan
31
+ scripts/custom_video
32
+ # ignore the visualizer
33
+ viser
34
+ viser_result
35
+ benchmark/results
36
+ benchmark
37
+
38
+ ossutil_output
39
+
40
+ prev_version
41
+ spat_ceres
42
+ wandb
43
+ *.log
44
+ seg_target.py
45
+
46
+ eval_davis.py
47
+ eval_multiple_gpu.py
48
+ eval_pose_scan.py
49
+ eval_single_gpu.py
50
+
51
+ infer_cam.py
52
+ infer_stream.py
53
+
54
+ *.egg-info/
55
+ **/*.egg-info
56
+
57
+ eval_kinectics.py
58
+ models/SpaTrackV2/datasets
59
+
60
+ scripts
61
+ config/fix_2d.yaml
62
+
63
+ models/SpaTrackV2/datasets
64
+ scripts/
65
+
66
+ models/**/build
67
+ models/**/dist
68
+
69
+ temp_local
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SpatialTrackerV2
3
+ emoji: ⚡️
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Official Space for SpatialTrackerV2
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
_viz/viz_template.html ADDED
@@ -0,0 +1,1778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>3D Point Cloud Visualizer</title>
7
+ <style>
8
+ :root {
9
+ --primary: #9b59b6; /* Brighter purple for dark mode */
10
+ --primary-light: #3a2e4a;
11
+ --secondary: #a86add;
12
+ --accent: #ff6e6e;
13
+ --bg: #1a1a1a;
14
+ --surface: #2c2c2c;
15
+ --text: #e0e0e0;
16
+ --text-secondary: #a0a0a0;
17
+ --border: #444444;
18
+ --shadow: rgba(0, 0, 0, 0.2);
19
+ --shadow-hover: rgba(0, 0, 0, 0.3);
20
+
21
+ --space-sm: 16px;
22
+ --space-md: 24px;
23
+ --space-lg: 32px;
24
+ }
25
+
26
+ body {
27
+ margin: 0;
28
+ overflow: hidden;
29
+ background: var(--bg);
30
+ color: var(--text);
31
+ font-family: 'Inter', sans-serif;
32
+ -webkit-font-smoothing: antialiased;
33
+ }
34
+
35
+ #canvas-container {
36
+ position: absolute;
37
+ width: 100%;
38
+ height: 100%;
39
+ }
40
+
41
+ #ui-container {
42
+ position: absolute;
43
+ top: 0;
44
+ left: 0;
45
+ width: 100%;
46
+ height: 100%;
47
+ pointer-events: none;
48
+ z-index: 10;
49
+ }
50
+
51
+ #status-bar {
52
+ position: absolute;
53
+ top: 16px;
54
+ left: 16px;
55
+ background: rgba(30, 30, 30, 0.9);
56
+ padding: 8px 16px;
57
+ border-radius: 8px;
58
+ pointer-events: auto;
59
+ box-shadow: 0 4px 6px var(--shadow);
60
+ backdrop-filter: blur(4px);
61
+ border: 1px solid var(--border);
62
+ color: var(--text);
63
+ transition: opacity 0.5s ease, transform 0.5s ease;
64
+ font-weight: 500;
65
+ }
66
+
67
+ #status-bar.hidden {
68
+ opacity: 0;
69
+ transform: translateY(-20px);
70
+ pointer-events: none;
71
+ }
72
+
73
+ #control-panel {
74
+ position: absolute;
75
+ bottom: 16px;
76
+ left: 50%;
77
+ transform: translateX(-50%);
78
+ background: rgba(44, 44, 44, 0.95);
79
+ padding: 6px 8px;
80
+ border-radius: 6px;
81
+ display: flex;
82
+ gap: 8px;
83
+ align-items: center;
84
+ justify-content: space-between;
85
+ pointer-events: auto;
86
+ box-shadow: 0 4px 10px var(--shadow);
87
+ backdrop-filter: blur(4px);
88
+ border: 1px solid var(--border);
89
+ }
90
+
91
+ #timeline {
92
+ width: 150px;
93
+ height: 4px;
94
+ background: rgba(255, 255, 255, 0.1);
95
+ border-radius: 2px;
96
+ position: relative;
97
+ cursor: pointer;
98
+ }
99
+
100
+ #progress {
101
+ position: absolute;
102
+ height: 100%;
103
+ background: var(--primary);
104
+ border-radius: 2px;
105
+ width: 0%;
106
+ }
107
+
108
+ #playback-controls {
109
+ display: flex;
110
+ gap: 4px;
111
+ align-items: center;
112
+ }
113
+
114
+ button {
115
+ background: rgba(255, 255, 255, 0.08);
116
+ border: 1px solid var(--border);
117
+ color: var(--text);
118
+ padding: 4px 6px;
119
+ border-radius: 3px;
120
+ cursor: pointer;
121
+ display: flex;
122
+ align-items: center;
123
+ justify-content: center;
124
+ transition: background 0.2s, transform 0.2s;
125
+ font-family: 'Inter', sans-serif;
126
+ font-weight: 500;
127
+ font-size: 6px;
128
+ }
129
+
130
+ button:hover {
131
+ background: rgba(255, 255, 255, 0.15);
132
+ transform: translateY(-1px);
133
+ }
134
+
135
+ button.active {
136
+ background: var(--primary);
137
+ color: white;
138
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
139
+ }
140
+
141
+ select, input {
142
+ background: rgba(255, 255, 255, 0.08);
143
+ border: 1px solid var(--border);
144
+ color: var(--text);
145
+ padding: 4px 6px;
146
+ border-radius: 3px;
147
+ cursor: pointer;
148
+ font-family: 'Inter', sans-serif;
149
+ font-size: 6px;
150
+ }
151
+
152
+ .icon {
153
+ width: 10px;
154
+ height: 10px;
155
+ fill: currentColor;
156
+ }
157
+
158
+ .tooltip {
159
+ position: absolute;
160
+ bottom: 100%;
161
+ left: 50%;
162
+ transform: translateX(-50%);
163
+ background: var(--surface);
164
+ color: var(--text);
165
+ padding: 3px 6px;
166
+ border-radius: 3px;
167
+ font-size: 7px;
168
+ white-space: nowrap;
169
+ margin-bottom: 4px;
170
+ opacity: 0;
171
+ transition: opacity 0.2s;
172
+ pointer-events: none;
173
+ box-shadow: 0 2px 4px var(--shadow);
174
+ border: 1px solid var(--border);
175
+ }
176
+
177
+ button:hover .tooltip {
178
+ opacity: 1;
179
+ }
180
+
181
+ #settings-panel {
182
+ position: absolute;
183
+ top: 16px;
184
+ right: 16px;
185
+ background: rgba(44, 44, 44, 0.98);
186
+ padding: 10px;
187
+ border-radius: 6px;
188
+ width: 195px;
189
+ max-height: calc(100vh - 40px);
190
+ overflow-y: auto;
191
+ pointer-events: auto;
192
+ box-shadow: 0 4px 15px var(--shadow);
193
+ backdrop-filter: blur(4px);
194
+ border: 1px solid var(--border);
195
+ display: block;
196
+ opacity: 1;
197
+ scrollbar-width: thin;
198
+ scrollbar-color: var(--primary-light) transparent;
199
+ transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
200
+ }
201
+
202
+ #settings-panel.is-hidden {
203
+ transform: translateX(calc(100% + 20px));
204
+ opacity: 0;
205
+ pointer-events: none;
206
+ }
207
+
208
+ #settings-panel::-webkit-scrollbar {
209
+ width: 3px;
210
+ }
211
+
212
+ #settings-panel::-webkit-scrollbar-track {
213
+ background: transparent;
214
+ }
215
+
216
+ #settings-panel::-webkit-scrollbar-thumb {
217
+ background-color: var(--primary-light);
218
+ border-radius: 3px;
219
+ }
220
+
221
+ @media (max-height: 700px) {
222
+ #settings-panel {
223
+ max-height: calc(100vh - 40px);
224
+ }
225
+ }
226
+
227
+ @media (max-width: 768px) {
228
+ #control-panel {
229
+ width: 90%;
230
+ flex-wrap: wrap;
231
+ justify-content: center;
232
+ }
233
+
234
+ #timeline {
235
+ width: 100%;
236
+ order: 3;
237
+ margin-top: 10px;
238
+ }
239
+
240
+ #settings-panel {
241
+ width: 140px;
242
+ right: 10px;
243
+ top: 10px;
244
+ max-height: calc(100vh - 20px);
245
+ }
246
+ }
247
+
248
+ .settings-group {
249
+ margin-bottom: 8px;
250
+ }
251
+
252
+ .settings-group h3 {
253
+ margin: 0 0 6px 0;
254
+ font-size: 10px;
255
+ font-weight: 500;
256
+ color: var(--text-secondary);
257
+ }
258
+
259
+ .slider-container {
260
+ display: flex;
261
+ align-items: center;
262
+ gap: 6px;
263
+ width: 100%;
264
+ }
265
+
266
+ .slider-container label {
267
+ min-width: 60px;
268
+ font-size: 10px;
269
+ flex-shrink: 0;
270
+ }
271
+
272
+ input[type="range"] {
273
+ flex: 1;
274
+ height: 2px;
275
+ -webkit-appearance: none;
276
+ background: rgba(255, 255, 255, 0.1);
277
+ border-radius: 1px;
278
+ min-width: 0;
279
+ }
280
+
281
+ input[type="range"]::-webkit-slider-thumb {
282
+ -webkit-appearance: none;
283
+ width: 8px;
284
+ height: 8px;
285
+ border-radius: 50%;
286
+ background: var(--primary);
287
+ cursor: pointer;
288
+ }
289
+
290
+ .toggle-switch {
291
+ position: relative;
292
+ display: inline-block;
293
+ width: 20px;
294
+ height: 10px;
295
+ }
296
+
297
+ .toggle-switch input {
298
+ opacity: 0;
299
+ width: 0;
300
+ height: 0;
301
+ }
302
+
303
+ .toggle-slider {
304
+ position: absolute;
305
+ cursor: pointer;
306
+ top: 0;
307
+ left: 0;
308
+ right: 0;
309
+ bottom: 0;
310
+ background: rgba(255, 255, 255, 0.1);
311
+ transition: .4s;
312
+ border-radius: 10px;
313
+ }
314
+
315
+ .toggle-slider:before {
316
+ position: absolute;
317
+ content: "";
318
+ height: 8px;
319
+ width: 8px;
320
+ left: 1px;
321
+ bottom: 1px;
322
+ background: var(--surface);
323
+ border: 1px solid var(--border);
324
+ transition: .4s;
325
+ border-radius: 50%;
326
+ }
327
+
328
+ input:checked + .toggle-slider {
329
+ background: var(--primary);
330
+ }
331
+
332
+ input:checked + .toggle-slider:before {
333
+ transform: translateX(10px);
334
+ }
335
+
336
+ .checkbox-container {
337
+ display: flex;
338
+ align-items: center;
339
+ gap: 4px;
340
+ margin-bottom: 4px;
341
+ }
342
+
343
+ .checkbox-container label {
344
+ font-size: 10px;
345
+ cursor: pointer;
346
+ }
347
+
348
+ #loading-overlay {
349
+ position: absolute;
350
+ top: 0;
351
+ left: 0;
352
+ width: 100%;
353
+ height: 100%;
354
+ background: var(--bg);
355
+ display: flex;
356
+ flex-direction: column;
357
+ align-items: center;
358
+ justify-content: center;
359
+ z-index: 100;
360
+ transition: opacity 0.5s;
361
+ }
362
+
363
+ #loading-overlay.fade-out {
364
+ opacity: 0;
365
+ pointer-events: none;
366
+ }
367
+
368
+ .spinner {
369
+ width: 50px;
370
+ height: 50px;
371
+ border: 5px solid rgba(155, 89, 182, 0.2);
372
+ border-radius: 50%;
373
+ border-top-color: var(--primary);
374
+ animation: spin 1s ease-in-out infinite;
375
+ margin-bottom: 16px;
376
+ }
377
+
378
+ @keyframes spin {
379
+ to { transform: rotate(360deg); }
380
+ }
381
+
382
+ #loading-text {
383
+ margin-top: 16px;
384
+ font-size: 18px;
385
+ color: var(--text);
386
+ font-weight: 500;
387
+ }
388
+
389
+ #frame-counter {
390
+ color: var(--text-secondary);
391
+ font-size: 7px;
392
+ font-weight: 500;
393
+ min-width: 60px;
394
+ text-align: center;
395
+ padding: 0 4px;
396
+ }
397
+
398
+ .control-btn {
399
+ background: rgba(255, 255, 255, 0.08);
400
+ border: 1px solid var(--border);
401
+ padding: 4px 6px;
402
+ border-radius: 3px;
403
+ cursor: pointer;
404
+ display: flex;
405
+ align-items: center;
406
+ justify-content: center;
407
+ transition: all 0.2s ease;
408
+ font-size: 6px;
409
+ }
410
+
411
+ .control-btn:hover {
412
+ background: rgba(255, 255, 255, 0.15);
413
+ transform: translateY(-1px);
414
+ }
415
+
416
+ .control-btn.active {
417
+ background: var(--primary);
418
+ color: white;
419
+ }
420
+
421
+ .control-btn.active:hover {
422
+ background: var(--primary);
423
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
424
+ }
425
+
426
+ #settings-toggle-btn {
427
+ position: relative;
428
+ border-radius: 6px;
429
+ z-index: 20;
430
+ }
431
+
432
+ #settings-toggle-btn.active {
433
+ background: var(--primary);
434
+ color: white;
435
+ }
436
+
437
+ #status-bar,
438
+ #control-panel,
439
+ #settings-panel,
440
+ button,
441
+ input,
442
+ select,
443
+ .toggle-switch {
444
+ pointer-events: auto;
445
+ }
446
+
447
+ h2 {
448
+ font-size: 0.9rem;
449
+ font-weight: 600;
450
+ margin-top: 0;
451
+ margin-bottom: 12px;
452
+ color: var(--primary);
453
+ cursor: move;
454
+ user-select: none;
455
+ display: flex;
456
+ align-items: center;
457
+ }
458
+
459
+ .drag-handle {
460
+ font-size: 10px;
461
+ margin-right: 4px;
462
+ opacity: 0.6;
463
+ }
464
+
465
+ h2:hover .drag-handle {
466
+ opacity: 1;
467
+ }
468
+
469
+ .loading-subtitle {
470
+ font-size: 7px;
471
+ color: var(--text-secondary);
472
+ margin-top: 4px;
473
+ }
474
+
475
+ #reset-view-btn {
476
+ background: var(--primary-light);
477
+ color: var(--primary);
478
+ border: 1px solid rgba(155, 89, 182, 0.2);
479
+ font-weight: 600;
480
+ font-size: 9px;
481
+ padding: 4px 6px;
482
+ transition: all 0.2s;
483
+ }
484
+
485
+ #reset-view-btn:hover {
486
+ background: var(--primary);
487
+ color: white;
488
+ transform: translateY(-2px);
489
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
490
+ }
491
+
492
+ #show-settings-btn {
493
+ position: absolute;
494
+ top: 16px;
495
+ right: 16px;
496
+ z-index: 15;
497
+ display: none;
498
+ }
499
+
500
+ #settings-panel.visible {
501
+ display: block;
502
+ opacity: 1;
503
+ animation: slideIn 0.3s ease forwards;
504
+ }
505
+
506
+ @keyframes slideIn {
507
+ from {
508
+ transform: translateY(20px);
509
+ opacity: 0;
510
+ }
511
+ to {
512
+ transform: translateY(0);
513
+ opacity: 1;
514
+ }
515
+ }
516
+
517
+ .dragging {
518
+ opacity: 0.9;
519
+ box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
520
+ transition: none !important;
521
+ }
522
+
523
+ /* Tooltip for draggable element */
524
+ .tooltip-drag {
525
+ position: absolute;
526
+ left: 50%;
527
+ transform: translateX(-50%);
528
+ background: var(--primary);
529
+ color: white;
530
+ font-size: 9px;
531
+ padding: 2px 4px;
532
+ border-radius: 2px;
533
+ opacity: 0;
534
+ pointer-events: none;
535
+ transition: opacity 0.3s;
536
+ white-space: nowrap;
537
+ bottom: 100%;
538
+ margin-bottom: 4px;
539
+ }
540
+
541
+ h2:hover .tooltip-drag {
542
+ opacity: 1;
543
+ }
544
+
545
+ .btn-group {
546
+ display: flex;
547
+ margin-top: 8px;
548
+ }
549
+
550
+ #reset-settings-btn {
551
+ background: var(--primary-light);
552
+ color: var(--primary);
553
+ border: 1px solid rgba(155, 89, 182, 0.2);
554
+ font-weight: 600;
555
+ font-size: 9px;
556
+ padding: 4px 6px;
557
+ transition: all 0.2s;
558
+ }
559
+
560
+ #reset-settings-btn:hover {
561
+ background: var(--primary);
562
+ color: white;
563
+ transform: translateY(-2px);
564
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
565
+ }
566
+ </style>
567
+ </head>
568
+ <body>
569
+ <link rel="preconnect" href="https://fonts.googleapis.com">
570
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
571
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
572
+
573
+ <div id="canvas-container"></div>
574
+
575
+ <div id="ui-container">
576
+ <div id="status-bar">Initializing...</div>
577
+
578
+ <div id="control-panel">
579
+ <button id="play-pause-btn" class="control-btn">
580
+ <svg class="icon" viewBox="0 0 24 24">
581
+ <path id="play-icon" d="M8 5v14l11-7z"/>
582
+ <path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
583
+ </svg>
584
+ <span class="tooltip">Play/Pause</span>
585
+ </button>
586
+
587
+ <div id="timeline">
588
+ <div id="progress"></div>
589
+ </div>
590
+
591
+ <div id="frame-counter">Frame 0 / 0</div>
592
+
593
+ <div id="playback-controls">
594
+ <button id="speed-btn" class="control-btn">1x</button>
595
+ </div>
596
+ </div>
597
+
598
+ <div id="settings-panel">
599
+ <h2>
600
+ <span class="drag-handle">☰</span>
601
+ Visualization Settings
602
+ <button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
603
+ <svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
604
+ <path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
605
+ </svg>
606
+ </button>
607
+ </h2>
608
+
609
+ <div class="settings-group">
610
+ <h3>Point Cloud</h3>
611
+ <div class="slider-container">
612
+ <label for="point-size">Size</label>
613
+ <input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
614
+ </div>
615
+ <div class="slider-container">
616
+ <label for="point-opacity">Opacity</label>
617
+ <input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
618
+ </div>
619
+ <div class="slider-container">
620
+ <label for="max-depth">Max Depth</label>
621
+ <input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
622
+ </div>
623
+ </div>
624
+
625
+ <div class="settings-group">
626
+ <h3>Trajectory</h3>
627
+ <div class="checkbox-container">
628
+ <label class="toggle-switch">
629
+ <input type="checkbox" id="show-trajectory" checked>
630
+ <span class="toggle-slider"></span>
631
+ </label>
632
+ <label for="show-trajectory">Show Trajectory</label>
633
+ </div>
634
+ <div class="checkbox-container">
635
+ <label class="toggle-switch">
636
+ <input type="checkbox" id="enable-rich-trail">
637
+ <span class="toggle-slider"></span>
638
+ </label>
639
+ <label for="enable-rich-trail">Visual-Rich Trail</label>
640
+ </div>
641
+ <div class="slider-container">
642
+ <label for="trajectory-line-width">Line Width</label>
643
+ <input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
644
+ </div>
645
+ <div class="slider-container">
646
+ <label for="trajectory-ball-size">Ball Size</label>
647
+ <input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
648
+ </div>
649
+ <div class="slider-container">
650
+ <label for="trajectory-history">History Frames</label>
651
+ <input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
652
+ </div>
653
+ <div class="slider-container" id="tail-opacity-container" style="display: none;">
654
+ <label for="trajectory-fade">Tail Opacity</label>
655
+ <input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
656
+ </div>
657
+ </div>
658
+
659
+ <div class="settings-group">
660
+ <h3>Camera</h3>
661
+ <div class="checkbox-container">
662
+ <label class="toggle-switch">
663
+ <input type="checkbox" id="show-camera-frustum" checked>
664
+ <span class="toggle-slider"></span>
665
+ </label>
666
+ <label for="show-camera-frustum">Show Camera Frustum</label>
667
+ </div>
668
+ <div class="slider-container">
669
+ <label for="frustum-size">Size</label>
670
+ <input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
671
+ </div>
672
+ </div>
673
+
674
+ <div class="settings-group">
675
+ <div class="btn-group">
676
+ <button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
677
+ <button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
678
+ </div>
679
+ </div>
680
+ </div>
681
+
682
+ <button id="show-settings-btn" class="control-btn" title="Show Settings">
683
+ <svg class="icon" viewBox="0 0 24 24">
684
+ <path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
685
+ </svg>
686
+ </button>
687
+ </div>
688
+
689
+ <div id="loading-overlay">
690
+ <!-- <div class="spinner"></div> -->
691
+ <div id="loading-text"></div>
692
+ <div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
693
+ </div>
694
+
695
+ <!-- Libraries -->
696
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
697
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/build/three.min.js"></script>
698
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/controls/OrbitControls.js"></script>
699
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/build/dat.gui.min.js"></script>
700
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegmentsGeometry.js"></script>
701
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineGeometry.js"></script>
702
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineMaterial.js"></script>
703
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegments2.js"></script>
704
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/Line2.js"></script>
705
+
706
+ <script>
707
+ class PointCloudVisualizer {
708
+ constructor() {
709
+ this.data = null;
710
+ this.config = {};
711
+ this.currentFrame = 0;
712
+ this.isPlaying = false;
713
+ this.playbackSpeed = 1;
714
+ this.lastFrameTime = 0;
715
+ this.defaultSettings = null;
716
+
717
+ this.ui = {
718
+ statusBar: document.getElementById('status-bar'),
719
+ playPauseBtn: document.getElementById('play-pause-btn'),
720
+ speedBtn: document.getElementById('speed-btn'),
721
+ timeline: document.getElementById('timeline'),
722
+ progress: document.getElementById('progress'),
723
+ settingsPanel: document.getElementById('settings-panel'),
724
+ loadingOverlay: document.getElementById('loading-overlay'),
725
+ loadingText: document.getElementById('loading-text'),
726
+ settingsToggleBtn: document.getElementById('settings-toggle-btn'),
727
+ frameCounter: document.getElementById('frame-counter'),
728
+ pointSize: document.getElementById('point-size'),
729
+ pointOpacity: document.getElementById('point-opacity'),
730
+ maxDepth: document.getElementById('max-depth'),
731
+ showTrajectory: document.getElementById('show-trajectory'),
732
+ enableRichTrail: document.getElementById('enable-rich-trail'),
733
+ trajectoryLineWidth: document.getElementById('trajectory-line-width'),
734
+ trajectoryBallSize: document.getElementById('trajectory-ball-size'),
735
+ trajectoryHistory: document.getElementById('trajectory-history'),
736
+ trajectoryFade: document.getElementById('trajectory-fade'),
737
+ tailOpacityContainer: document.getElementById('tail-opacity-container'),
738
+ resetViewBtn: document.getElementById('reset-view-btn'),
739
+ showCameraFrustum: document.getElementById('show-camera-frustum'),
740
+ frustumSize: document.getElementById('frustum-size'),
741
+ hideSettingsBtn: document.getElementById('hide-settings-btn'),
742
+ showSettingsBtn: document.getElementById('show-settings-btn')
743
+ };
744
+
745
+ this.scene = null;
746
+ this.camera = null;
747
+ this.renderer = null;
748
+ this.controls = null;
749
+ this.pointCloud = null;
750
+ this.trajectories = [];
751
+ this.cameraFrustum = null;
752
+
753
+ this.initThreeJS();
754
+ this.loadDefaultSettings().then(() => {
755
+ this.initEventListeners();
756
+ this.loadData();
757
+ });
758
+ }
759
+
760
+ async loadDefaultSettings() {
761
+ try {
762
+ const urlParams = new URLSearchParams(window.location.search);
763
+ const dataPath = urlParams.get('data') || '';
764
+
765
+ const defaultSettings = {
766
+ pointSize: 0.03,
767
+ pointOpacity: 1.0,
768
+ showTrajectory: true,
769
+ trajectoryLineWidth: 2.5,
770
+ trajectoryBallSize: 0.015,
771
+ trajectoryHistory: 0,
772
+ showCameraFrustum: true,
773
+ frustumSize: 0.2
774
+ };
775
+
776
+ if (!dataPath) {
777
+ this.defaultSettings = defaultSettings;
778
+ this.applyDefaultSettings();
779
+ return;
780
+ }
781
+
782
+ // Try to extract dataset and videoId from the data path
783
+ // Expected format: demos/datasetname/videoid.bin
784
+ const pathParts = dataPath.split('/');
785
+ if (pathParts.length < 3) {
786
+ this.defaultSettings = defaultSettings;
787
+ this.applyDefaultSettings();
788
+ return;
789
+ }
790
+
791
+ const datasetName = pathParts[pathParts.length - 2];
792
+ let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
793
+
794
+ // Load settings from data.json
795
+ const response = await fetch('./data.json');
796
+ if (!response.ok) {
797
+ this.defaultSettings = defaultSettings;
798
+ this.applyDefaultSettings();
799
+ return;
800
+ }
801
+
802
+ const settingsData = await response.json();
803
+
804
+ // Check if this dataset and video exist
805
+ if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
806
+ this.defaultSettings = settingsData[datasetName][videoId];
807
+ } else {
808
+ this.defaultSettings = defaultSettings;
809
+ }
810
+
811
+ this.applyDefaultSettings();
812
+ } catch (error) {
813
+ console.error("Error loading default settings:", error);
814
+
815
+ this.defaultSettings = {
816
+ pointSize: 0.03,
817
+ pointOpacity: 1.0,
818
+ showTrajectory: true,
819
+ trajectoryLineWidth: 2.5,
820
+ trajectoryBallSize: 0.015,
821
+ trajectoryHistory: 0,
822
+ showCameraFrustum: true,
823
+ frustumSize: 0.2
824
+ };
825
+
826
+ this.applyDefaultSettings();
827
+ }
828
+ }
829
+
830
+ applyDefaultSettings() {
831
+ if (!this.defaultSettings) return;
832
+
833
+ if (this.ui.pointSize) {
834
+ this.ui.pointSize.value = this.defaultSettings.pointSize;
835
+ }
836
+
837
+ if (this.ui.pointOpacity) {
838
+ this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
839
+ }
840
+
841
+ if (this.ui.maxDepth) {
842
+ this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
843
+ }
844
+
845
+ if (this.ui.showTrajectory) {
846
+ this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
847
+ }
848
+
849
+ if (this.ui.trajectoryLineWidth) {
850
+ this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
851
+ }
852
+
853
+ if (this.ui.trajectoryBallSize) {
854
+ this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
855
+ }
856
+
857
+ if (this.ui.trajectoryHistory) {
858
+ this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
859
+ }
860
+
861
+ if (this.ui.showCameraFrustum) {
862
+ this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
863
+ }
864
+
865
+ if (this.ui.frustumSize) {
866
+ this.ui.frustumSize.value = this.defaultSettings.frustumSize;
867
+ }
868
+ }
869
+
870
+ initThreeJS() {
871
+ this.scene = new THREE.Scene();
872
+ this.scene.background = new THREE.Color(0x1a1a1a);
873
+
874
+ this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
875
+ this.camera.position.set(0, 0, 0);
876
+
877
+ this.renderer = new THREE.WebGLRenderer({ antialias: true });
878
+ this.renderer.setPixelRatio(window.devicePixelRatio);
879
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
880
+ document.getElementById('canvas-container').appendChild(this.renderer.domElement);
881
+
882
+ this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
883
+ this.controls.enableDamping = true;
884
+ this.controls.dampingFactor = 0.05;
885
+ this.controls.target.set(0, 0, 0);
886
+ this.controls.minDistance = 0.1;
887
+ this.controls.maxDistance = 1000;
888
+ this.controls.update();
889
+
890
+ const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
891
+ this.scene.add(ambientLight);
892
+
893
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
894
+ directionalLight.position.set(1, 1, 1);
895
+ this.scene.add(directionalLight);
896
+ }
897
+
898
+ initEventListeners() {
899
+ window.addEventListener('resize', () => this.onWindowResize());
900
+
901
+ this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
902
+
903
+ this.ui.timeline.addEventListener('click', (e) => {
904
+ const rect = this.ui.timeline.getBoundingClientRect();
905
+ const pos = (e.clientX - rect.left) / rect.width;
906
+ this.seekTo(pos);
907
+ });
908
+
909
+ this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
910
+
911
+ this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
912
+ this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
913
+ this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
914
+ this.ui.showTrajectory.addEventListener('change', () => {
915
+ this.trajectories.forEach(trajectory => {
916
+ trajectory.visible = this.ui.showTrajectory.checked;
917
+ });
918
+ });
919
+
920
+ this.ui.enableRichTrail.addEventListener('change', () => {
921
+ this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
922
+ this.updateTrajectories(this.currentFrame);
923
+ });
924
+
925
+ this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
926
+ this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
927
+ this.ui.trajectoryHistory.addEventListener('input', () => {
928
+ this.updateTrajectories(this.currentFrame);
929
+ });
930
+ this.ui.trajectoryFade.addEventListener('input', () => {
931
+ this.updateTrajectories(this.currentFrame);
932
+ });
933
+
934
+ this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
935
+
936
+ const resetSettingsBtn = document.getElementById('reset-settings-btn');
937
+ if (resetSettingsBtn) {
938
+ resetSettingsBtn.addEventListener('click', () => this.resetSettings());
939
+ }
940
+
941
+ document.addEventListener('keydown', (e) => {
942
+ if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
943
+ this.ui.settingsPanel.classList.remove('visible');
944
+ this.ui.settingsToggleBtn.classList.remove('active');
945
+ }
946
+ });
947
+
948
+ if (this.ui.settingsToggleBtn) {
949
+ this.ui.settingsToggleBtn.addEventListener('click', () => {
950
+ const isVisible = this.ui.settingsPanel.classList.toggle('visible');
951
+ this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
952
+
953
+ if (isVisible) {
954
+ const panelRect = this.ui.settingsPanel.getBoundingClientRect();
955
+ const viewportHeight = window.innerHeight;
956
+
957
+ if (panelRect.bottom > viewportHeight) {
958
+ this.ui.settingsPanel.style.bottom = 'auto';
959
+ this.ui.settingsPanel.style.top = '80px';
960
+ }
961
+ }
962
+ });
963
+ }
964
+
965
+ if (this.ui.frustumSize) {
966
+ this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
967
+ }
968
+
969
+ if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
970
+ this.ui.hideSettingsBtn.addEventListener('click', () => {
971
+ this.ui.settingsPanel.classList.add('is-hidden');
972
+ this.ui.showSettingsBtn.style.display = 'flex';
973
+ });
974
+
975
+ this.ui.showSettingsBtn.addEventListener('click', () => {
976
+ this.ui.settingsPanel.classList.remove('is-hidden');
977
+ this.ui.showSettingsBtn.style.display = 'none';
978
+ });
979
+ }
980
+ }
981
+
982
+ makeElementDraggable(element) {
983
+ let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
984
+
985
+ const dragHandle = element.querySelector('h2');
986
+
987
+ if (dragHandle) {
988
+ dragHandle.onmousedown = dragMouseDown;
989
+ dragHandle.title = "Drag to move panel";
990
+ } else {
991
+ element.onmousedown = dragMouseDown;
992
+ }
993
+
994
+ function dragMouseDown(e) {
995
+ e = e || window.event;
996
+ e.preventDefault();
997
+ pos3 = e.clientX;
998
+ pos4 = e.clientY;
999
+ document.onmouseup = closeDragElement;
1000
+ document.onmousemove = elementDrag;
1001
+
1002
+ element.classList.add('dragging');
1003
+ }
1004
+
1005
+ function elementDrag(e) {
1006
+ e = e || window.event;
1007
+ e.preventDefault();
1008
+ pos1 = pos3 - e.clientX;
1009
+ pos2 = pos4 - e.clientY;
1010
+ pos3 = e.clientX;
1011
+ pos4 = e.clientY;
1012
+
1013
+ const newTop = element.offsetTop - pos2;
1014
+ const newLeft = element.offsetLeft - pos1;
1015
+
1016
+ const viewportWidth = window.innerWidth;
1017
+ const viewportHeight = window.innerHeight;
1018
+
1019
+ const panelRect = element.getBoundingClientRect();
1020
+
1021
+ const maxTop = viewportHeight - 50;
1022
+ const maxLeft = viewportWidth - 50;
1023
+
1024
+ element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
1025
+ element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
1026
+
1027
+ // Remove bottom/right settings when dragging
1028
+ element.style.bottom = 'auto';
1029
+ element.style.right = 'auto';
1030
+ }
1031
+
1032
+ function closeDragElement() {
1033
+ document.onmouseup = null;
1034
+ document.onmousemove = null;
1035
+
1036
+ element.classList.remove('dragging');
1037
+ }
1038
+ }
1039
+
1040
+ async loadData() {
1041
+ try {
1042
+ // this.ui.loadingText.textContent = "Loading binary data...";
1043
+
1044
+ let arrayBuffer;
1045
+
1046
+ if (window.embeddedBase64) {
1047
+ // Base64 embedded path
1048
+ const binaryString = atob(window.embeddedBase64);
1049
+ const len = binaryString.length;
1050
+ const bytes = new Uint8Array(len);
1051
+ for (let i = 0; i < len; i++) {
1052
+ bytes[i] = binaryString.charCodeAt(i);
1053
+ }
1054
+ arrayBuffer = bytes.buffer;
1055
+ } else {
1056
+ // Default fetch path (fallback)
1057
+ const urlParams = new URLSearchParams(window.location.search);
1058
+ const dataPath = urlParams.get('data') || 'data.bin';
1059
+
1060
+ const response = await fetch(dataPath);
1061
+ if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
1062
+ arrayBuffer = await response.arrayBuffer();
1063
+ }
1064
+
1065
+ const dataView = new DataView(arrayBuffer);
1066
+ const headerLen = dataView.getUint32(0, true);
1067
+
1068
+ const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
1069
+ const header = JSON.parse(headerText);
1070
+
1071
+ const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
1072
+ const decompressed = pako.inflate(compressedBlob).buffer;
1073
+
1074
+ const arrays = {};
1075
+ for (const key in header) {
1076
+ if (key === "meta") continue;
1077
+
1078
+ const meta = header[key];
1079
+ const { dtype, shape, offset, length } = meta;
1080
+ const slice = decompressed.slice(offset, offset + length);
1081
+
1082
+ let typedArray;
1083
+ switch (dtype) {
1084
+ case "uint8": typedArray = new Uint8Array(slice); break;
1085
+ case "uint16": typedArray = new Uint16Array(slice); break;
1086
+ case "float32": typedArray = new Float32Array(slice); break;
1087
+ case "float64": typedArray = new Float64Array(slice); break;
1088
+ default: throw new Error(`Unknown dtype: ${dtype}`);
1089
+ }
1090
+
1091
+ arrays[key] = { data: typedArray, shape: shape };
1092
+ }
1093
+
1094
+ this.data = arrays;
1095
+ this.config = header.meta;
1096
+
1097
+ this.initCameraWithCorrectFOV();
1098
+ this.ui.loadingText.textContent = "Creating point cloud...";
1099
+
1100
+ this.initPointCloud();
1101
+ this.initTrajectories();
1102
+
1103
+ setTimeout(() => {
1104
+ this.ui.loadingOverlay.classList.add('fade-out');
1105
+ this.ui.statusBar.classList.add('hidden');
1106
+ this.startAnimation();
1107
+ }, 500);
1108
+ } catch (error) {
1109
+ console.error("Error loading data:", error);
1110
+ this.ui.statusBar.textContent = `Error: ${error.message}`;
1111
+ // this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
1112
+ }
1113
+ }
1114
+
1115
+ initPointCloud() {
1116
+ const numPoints = this.config.resolution[0] * this.config.resolution[1];
1117
+ const positions = new Float32Array(numPoints * 3);
1118
+ const colors = new Float32Array(numPoints * 3);
1119
+
1120
+ const geometry = new THREE.BufferGeometry();
1121
+ geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
1122
+ geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
1123
+
1124
+ const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
1125
+ const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
1126
+
1127
+ const material = new THREE.PointsMaterial({
1128
+ size: pointSize,
1129
+ vertexColors: true,
1130
+ transparent: true,
1131
+ opacity: pointOpacity,
1132
+ sizeAttenuation: true
1133
+ });
1134
+
1135
+ this.pointCloud = new THREE.Points(geometry, material);
1136
+ this.scene.add(this.pointCloud);
1137
+ }
1138
+
1139
+ initTrajectories() {
1140
+ if (!this.data.trajectories) return;
1141
+
1142
+ this.trajectories.forEach(trajectory => {
1143
+ if (trajectory.userData.lineSegments) {
1144
+ trajectory.userData.lineSegments.forEach(segment => {
1145
+ segment.geometry.dispose();
1146
+ segment.material.dispose();
1147
+ });
1148
+ }
1149
+ this.scene.remove(trajectory);
1150
+ });
1151
+ this.trajectories = [];
1152
+
1153
+ const shape = this.data.trajectories.shape;
1154
+ if (!shape || shape.length < 2) return;
1155
+
1156
+ const [totalFrames, numTrajectories] = shape;
1157
+ const palette = this.createColorPalette(numTrajectories);
1158
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1159
+ const maxHistory = 500; // Max value of the history slider, for the object pool
1160
+
1161
+ for (let i = 0; i < numTrajectories; i++) {
1162
+ const trajectoryGroup = new THREE.Group();
1163
+
1164
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1165
+ const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
1166
+ const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
1167
+ const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
1168
+ trajectoryGroup.add(positionMarker);
1169
+
1170
+ // High-Performance Line (default)
1171
+ const simpleLineGeometry = new THREE.BufferGeometry();
1172
+ const simpleLinePositions = new Float32Array(maxHistory * 3);
1173
+ simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
1174
+ const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
1175
+ simpleLine.frustumCulled = false;
1176
+ trajectoryGroup.add(simpleLine);
1177
+
1178
+ // High-Quality Line Segments (for rich trail)
1179
+ const lineSegments = [];
1180
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1181
+
1182
+ // Create a pool of line segment objects
1183
+ for (let j = 0; j < maxHistory - 1; j++) {
1184
+ const lineGeometry = new THREE.LineGeometry();
1185
+ lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
1186
+ const lineMaterial = new THREE.LineMaterial({
1187
+ color: palette[i],
1188
+ linewidth: lineWidth,
1189
+ resolution: resolution,
1190
+ transparent: true,
1191
+ depthWrite: false, // Correctly handle transparency
1192
+ opacity: 0
1193
+ });
1194
+ const segment = new THREE.Line2(lineGeometry, lineMaterial);
1195
+ segment.frustumCulled = false;
1196
+ segment.visible = false; // Start with all segments hidden
1197
+ trajectoryGroup.add(segment);
1198
+ lineSegments.push(segment);
1199
+ }
1200
+
1201
+ trajectoryGroup.userData = {
1202
+ marker: positionMarker,
1203
+ simpleLine: simpleLine,
1204
+ lineSegments: lineSegments,
1205
+ color: palette[i]
1206
+ };
1207
+
1208
+ this.scene.add(trajectoryGroup);
1209
+ this.trajectories.push(trajectoryGroup);
1210
+ }
1211
+
1212
+ const showTrajectory = this.ui.showTrajectory.checked;
1213
+ this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
1214
+ }
1215
+
1216
+ createColorPalette(count) {
1217
+ const colors = [];
1218
+ const hueStep = 360 / count;
1219
+
1220
+ for (let i = 0; i < count; i++) {
1221
+ const hue = (i * hueStep) % 360;
1222
+ const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
1223
+ colors.push(color);
1224
+ }
1225
+
1226
+ return colors;
1227
+ }
1228
+
1229
+ updatePointCloud(frameIndex) {
1230
+ if (!this.data || !this.pointCloud) return;
1231
+
1232
+ const positions = this.pointCloud.geometry.attributes.position.array;
1233
+ const colors = this.pointCloud.geometry.attributes.color.array;
1234
+
1235
+ const rgbVideo = this.data.rgb_video;
1236
+ const depthsRgb = this.data.depths_rgb;
1237
+ const intrinsics = this.data.intrinsics;
1238
+ const invExtrinsics = this.data.inv_extrinsics;
1239
+
1240
+ const width = this.config.resolution[0];
1241
+ const height = this.config.resolution[1];
1242
+ const numPoints = width * height;
1243
+
1244
+ const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
1245
+ const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
1246
+
1247
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1248
+ const transform = this.getTransformElements(invExtrMat);
1249
+
1250
+ const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
1251
+ const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
1252
+
1253
+ const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
1254
+
1255
+ let validPointCount = 0;
1256
+
1257
+ for (let i = 0; i < numPoints; i++) {
1258
+ const xPix = i % width;
1259
+ const yPix = Math.floor(i / width);
1260
+
1261
+ const d0 = depthFrame[i * 3];
1262
+ const d1 = depthFrame[i * 3 + 1];
1263
+ const depthEncoded = d0 | (d1 << 8);
1264
+ const depthValue = (depthEncoded / ((1 << 16) - 1)) *
1265
+ (this.config.depthRange[1] - this.config.depthRange[0]) +
1266
+ this.config.depthRange[0];
1267
+
1268
+ if (depthValue === 0 || depthValue > maxDepth) {
1269
+ continue;
1270
+ }
1271
+
1272
+ const X = ((xPix - cx) * depthValue) / fx;
1273
+ const Y = ((yPix - cy) * depthValue) / fy;
1274
+ const Z = depthValue;
1275
+
1276
+ const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
1277
+ const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
1278
+ const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
1279
+
1280
+ const index = validPointCount * 3;
1281
+ positions[index] = tx;
1282
+ positions[index + 1] = -ty;
1283
+ positions[index + 2] = -tz;
1284
+
1285
+ colors[index] = rgbFrame[i * 3] / 255;
1286
+ colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
1287
+ colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
1288
+
1289
+ validPointCount++;
1290
+ }
1291
+
1292
+ this.pointCloud.geometry.setDrawRange(0, validPointCount);
1293
+ this.pointCloud.geometry.attributes.position.needsUpdate = true;
1294
+ this.pointCloud.geometry.attributes.color.needsUpdate = true;
1295
+ this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
1296
+
1297
+ this.updateTrajectories(frameIndex);
1298
+
1299
+ const progress = (frameIndex + 1) / this.config.totalFrames;
1300
+ this.ui.progress.style.width = `${progress * 100}%`;
1301
+
1302
+ if (this.ui.frameCounter && this.config.totalFrames) {
1303
+ this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
1304
+ }
1305
+
1306
+ this.updateCameraFrustum(frameIndex);
1307
+ }
1308
+
1309
+ updateTrajectories(frameIndex) {
1310
+ if (!this.data.trajectories || this.trajectories.length === 0) return;
1311
+
1312
+ const trajectoryData = this.data.trajectories.data;
1313
+ const [totalFrames, numTrajectories] = this.data.trajectories.shape;
1314
+ const historyFrames = parseInt(this.ui.trajectoryHistory.value);
1315
+ const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
1316
+
1317
+ const isRichMode = this.ui.enableRichTrail.checked;
1318
+
1319
+ for (let i = 0; i < numTrajectories; i++) {
1320
+ const trajectoryGroup = this.trajectories[i];
1321
+ const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
1322
+
1323
+ const currentPos = new THREE.Vector3();
1324
+ const currentOffset = (frameIndex * numTrajectories + i) * 3;
1325
+
1326
+ currentPos.x = trajectoryData[currentOffset];
1327
+ currentPos.y = -trajectoryData[currentOffset + 1];
1328
+ currentPos.z = -trajectoryData[currentOffset + 2];
1329
+
1330
+ marker.position.copy(currentPos);
1331
+ marker.material.opacity = 1.0;
1332
+
1333
+ const historyToShow = Math.min(historyFrames, frameIndex + 1);
1334
+
1335
+ if (isRichMode) {
1336
+ // --- High-Quality Mode ---
1337
+ simpleLine.visible = false;
1338
+
1339
+ for (let j = 0; j < lineSegments.length; j++) {
1340
+ const segment = lineSegments[j];
1341
+ if (j < historyToShow - 1) {
1342
+ const headFrame = frameIndex - j;
1343
+ const tailFrame = frameIndex - j - 1;
1344
+ const headOffset = (headFrame * numTrajectories + i) * 3;
1345
+ const tailOffset = (tailFrame * numTrajectories + i) * 3;
1346
+ const positions = [
1347
+ trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
1348
+ trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
1349
+ ];
1350
+ segment.geometry.setPositions(positions);
1351
+ const headOpacity = 1.0;
1352
+ const normalizedAge = j / Math.max(1, historyToShow - 2);
1353
+ const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
1354
+ segment.material.opacity = Math.max(0, alpha);
1355
+ segment.visible = true;
1356
+ } else {
1357
+ segment.visible = false;
1358
+ }
1359
+ }
1360
+ } else {
1361
+ // --- Performance Mode ---
1362
+ lineSegments.forEach(s => s.visible = false);
1363
+ simpleLine.visible = true;
1364
+
1365
+ const positions = simpleLine.geometry.attributes.position.array;
1366
+ for (let j = 0; j < historyToShow; j++) {
1367
+ const historyFrame = Math.max(0, frameIndex - j);
1368
+ const offset = (historyFrame * numTrajectories + i) * 3;
1369
+ positions[j * 3] = trajectoryData[offset];
1370
+ positions[j * 3 + 1] = -trajectoryData[offset + 1];
1371
+ positions[j * 3 + 2] = -trajectoryData[offset + 2];
1372
+ }
1373
+ simpleLine.geometry.setDrawRange(0, historyToShow);
1374
+ simpleLine.geometry.attributes.position.needsUpdate = true;
1375
+ }
1376
+ }
1377
+ }
1378
+
1379
+ updateTrajectorySettings() {
1380
+ if (!this.trajectories || this.trajectories.length === 0) return;
1381
+
1382
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1383
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1384
+
1385
+ this.trajectories.forEach(trajectoryGroup => {
1386
+ const { marker, lineSegments } = trajectoryGroup.userData;
1387
+
1388
+ marker.geometry.dispose();
1389
+ marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
1390
+
1391
+ // Line width only affects rich mode
1392
+ lineSegments.forEach(segment => {
1393
+ if (segment.material) {
1394
+ segment.material.linewidth = lineWidth;
1395
+ }
1396
+ });
1397
+ });
1398
+
1399
+ this.updateTrajectories(this.currentFrame);
1400
+ }
1401
+
1402
+ getDepthColor(normalizedDepth) {
1403
+ const hue = (1 - normalizedDepth) * 240 / 360;
1404
+ const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
1405
+ return color;
1406
+ }
1407
+
1408
+ getFrame(typedArray, shape, frameIndex) {
1409
+ const [T, H, W, C] = shape;
1410
+ const frameSize = H * W * C;
1411
+ const offset = frameIndex * frameSize;
1412
+ return typedArray.subarray(offset, offset + frameSize);
1413
+ }
1414
+
1415
+ get3x3Matrix(typedArray, shape, frameIndex) {
1416
+ const frameSize = 9;
1417
+ const offset = frameIndex * frameSize;
1418
+ const K = [];
1419
+ for (let i = 0; i < 3; i++) {
1420
+ const row = [];
1421
+ for (let j = 0; j < 3; j++) {
1422
+ row.push(typedArray[offset + i * 3 + j]);
1423
+ }
1424
+ K.push(row);
1425
+ }
1426
+ return K;
1427
+ }
1428
+
1429
+ get4x4Matrix(typedArray, shape, frameIndex) {
1430
+ const frameSize = 16;
1431
+ const offset = frameIndex * frameSize;
1432
+ const M = [];
1433
+ for (let i = 0; i < 4; i++) {
1434
+ const row = [];
1435
+ for (let j = 0; j < 4; j++) {
1436
+ row.push(typedArray[offset + i * 4 + j]);
1437
+ }
1438
+ M.push(row);
1439
+ }
1440
+ return M;
1441
+ }
1442
+
1443
+ getTransformElements(matrix) {
1444
+ return {
1445
+ m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
1446
+ m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
1447
+ m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
1448
+ };
1449
+ }
1450
+
1451
+ togglePlayback() {
1452
+ this.isPlaying = !this.isPlaying;
1453
+
1454
+ const playIcon = document.getElementById('play-icon');
1455
+ const pauseIcon = document.getElementById('pause-icon');
1456
+
1457
+ if (this.isPlaying) {
1458
+ playIcon.style.display = 'none';
1459
+ pauseIcon.style.display = 'block';
1460
+ this.lastFrameTime = performance.now();
1461
+ } else {
1462
+ playIcon.style.display = 'block';
1463
+ pauseIcon.style.display = 'none';
1464
+ }
1465
+ }
1466
+
1467
+ cyclePlaybackSpeed() {
1468
+ const speeds = [0.5, 1, 2, 4, 8];
1469
+ const speedRates = speeds.map(s => s * this.config.baseFrameRate);
1470
+
1471
+ let currentIndex = 0;
1472
+ const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
1473
+
1474
+ for (let i = 0; i < speeds.length; i++) {
1475
+ if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
1476
+ currentIndex = i;
1477
+ }
1478
+ }
1479
+
1480
+ const nextIndex = (currentIndex + 1) % speeds.length;
1481
+ this.playbackSpeed = speedRates[nextIndex];
1482
+ this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
1483
+
1484
+ if (speeds[nextIndex] === 1) {
1485
+ this.ui.speedBtn.classList.remove('active');
1486
+ } else {
1487
+ this.ui.speedBtn.classList.add('active');
1488
+ }
1489
+ }
1490
+
1491
+ seekTo(position) {
1492
+ const frameIndex = Math.floor(position * this.config.totalFrames);
1493
+ this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
1494
+ this.updatePointCloud(this.currentFrame);
1495
+ }
1496
+
1497
+ updatePointCloudSettings() {
1498
+ if (!this.pointCloud) return;
1499
+
1500
+ const size = parseFloat(this.ui.pointSize.value);
1501
+ const opacity = parseFloat(this.ui.pointOpacity.value);
1502
+
1503
+ this.pointCloud.material.size = size;
1504
+ this.pointCloud.material.opacity = opacity;
1505
+ this.pointCloud.material.needsUpdate = true;
1506
+
1507
+ this.updatePointCloud(this.currentFrame);
1508
+ }
1509
+
1510
+ updateControls() {
1511
+ if (!this.controls) return;
1512
+ this.controls.update();
1513
+ }
1514
+
1515
+ resetView() {
1516
+ if (!this.camera || !this.controls) return;
1517
+
1518
+ // Reset camera position
1519
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1520
+
1521
+ // Reset controls
1522
+ this.controls.reset();
1523
+
1524
+ // Set target slightly in front of camera
1525
+ this.controls.target.set(0, 0, -1);
1526
+ this.controls.update();
1527
+
1528
+ // Show status message
1529
+ this.ui.statusBar.textContent = "View reset";
1530
+ this.ui.statusBar.classList.remove('hidden');
1531
+
1532
+ // Hide status message after a few seconds
1533
+ setTimeout(() => {
1534
+ this.ui.statusBar.classList.add('hidden');
1535
+ }, 3000);
1536
+ }
1537
+
1538
+ onWindowResize() {
1539
+ if (!this.camera || !this.renderer) return;
1540
+
1541
+ const windowAspect = window.innerWidth / window.innerHeight;
1542
+ this.camera.aspect = windowAspect;
1543
+ this.camera.updateProjectionMatrix();
1544
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
1545
+
1546
+ if (this.trajectories && this.trajectories.length > 0) {
1547
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1548
+ this.trajectories.forEach(trajectory => {
1549
+ const { lineSegments } = trajectory.userData;
1550
+ if (lineSegments && lineSegments.length > 0) {
1551
+ lineSegments.forEach(segment => {
1552
+ if (segment.material && segment.material.resolution) {
1553
+ segment.material.resolution.copy(resolution);
1554
+ }
1555
+ });
1556
+ }
1557
+ });
1558
+ }
1559
+
1560
+ if (this.cameraFrustum) {
1561
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1562
+ this.cameraFrustum.children.forEach(line => {
1563
+ if (line.material && line.material.resolution) {
1564
+ line.material.resolution.copy(resolution);
1565
+ }
1566
+ });
1567
+ }
1568
+ }
1569
+
1570
+ startAnimation() {
1571
+ this.isPlaying = true;
1572
+ this.lastFrameTime = performance.now();
1573
+
1574
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1575
+ this.controls.target.set(0, 0, -1);
1576
+ this.controls.update();
1577
+
1578
+ this.playbackSpeed = this.config.baseFrameRate;
1579
+
1580
+ document.getElementById('play-icon').style.display = 'none';
1581
+ document.getElementById('pause-icon').style.display = 'block';
1582
+
1583
+ this.animate();
1584
+ }
1585
+
1586
+ animate() {
1587
+ requestAnimationFrame(() => this.animate());
1588
+
1589
+ if (this.controls) {
1590
+ this.controls.update();
1591
+ }
1592
+
1593
+ if (this.isPlaying && this.data) {
1594
+ const now = performance.now();
1595
+ const delta = (now - this.lastFrameTime) / 1000;
1596
+
1597
+ const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
1598
+ if (framesToAdvance > 0) {
1599
+ this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
1600
+ this.lastFrameTime = now;
1601
+ this.updatePointCloud(this.currentFrame);
1602
+ }
1603
+ }
1604
+
1605
+ if (this.renderer && this.scene && this.camera) {
1606
+ this.renderer.render(this.scene, this.camera);
1607
+ }
1608
+ }
1609
+
1610
+ initCameraWithCorrectFOV() {
1611
+ const fov = this.config.fov || 60;
1612
+
1613
+ const windowAspect = window.innerWidth / window.innerHeight;
1614
+
1615
+ this.camera = new THREE.PerspectiveCamera(
1616
+ fov,
1617
+ windowAspect,
1618
+ 0.1,
1619
+ 10000
1620
+ );
1621
+
1622
+ this.controls.object = this.camera;
1623
+ this.controls.update();
1624
+
1625
+ this.initCameraFrustum();
1626
+ }
1627
+
1628
+ initCameraFrustum() {
1629
+ this.cameraFrustum = new THREE.Group();
1630
+
1631
+ this.scene.add(this.cameraFrustum);
1632
+
1633
+ this.initCameraFrustumGeometry();
1634
+
1635
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
1636
+
1637
+ this.cameraFrustum.visible = showCameraFrustum;
1638
+ }
1639
+
1640
+ initCameraFrustumGeometry() {
1641
+ const fov = this.config.fov || 60;
1642
+ const originalAspect = this.config.original_aspect_ratio || 1.33;
1643
+
1644
+ const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
1645
+
1646
+ const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
1647
+ const halfWidth = halfHeight * originalAspect;
1648
+
1649
+ const vertices = [
1650
+ new THREE.Vector3(0, 0, 0),
1651
+ new THREE.Vector3(-halfWidth, -halfHeight, size),
1652
+ new THREE.Vector3(halfWidth, -halfHeight, size),
1653
+ new THREE.Vector3(halfWidth, halfHeight, size),
1654
+ new THREE.Vector3(-halfWidth, halfHeight, size)
1655
+ ];
1656
+
1657
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1658
+
1659
+ const linePairs = [
1660
+ [1, 2], [2, 3], [3, 4], [4, 1],
1661
+ [0, 1], [0, 2], [0, 3], [0, 4]
1662
+ ];
1663
+
1664
+ const colors = {
1665
+ edge: new THREE.Color(0x3366ff),
1666
+ ray: new THREE.Color(0x33cc66)
1667
+ };
1668
+
1669
+ linePairs.forEach((pair, index) => {
1670
+ const positions = [
1671
+ vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
1672
+ vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
1673
+ ];
1674
+
1675
+ const lineGeometry = new THREE.LineGeometry();
1676
+ lineGeometry.setPositions(positions);
1677
+
1678
+ let color = index < 4 ? colors.edge : colors.ray;
1679
+
1680
+ const lineMaterial = new THREE.LineMaterial({
1681
+ color: color,
1682
+ linewidth: 2,
1683
+ resolution: resolution,
1684
+ dashed: false
1685
+ });
1686
+
1687
+ const line = new THREE.Line2(lineGeometry, lineMaterial);
1688
+ this.cameraFrustum.add(line);
1689
+ });
1690
+ }
1691
+
1692
+ updateCameraFrustum(frameIndex) {
1693
+ if (!this.cameraFrustum || !this.data) return;
1694
+
1695
+ const invExtrinsics = this.data.inv_extrinsics;
1696
+ if (!invExtrinsics) return;
1697
+
1698
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1699
+
1700
+ const matrix = new THREE.Matrix4();
1701
+ matrix.set(
1702
+ invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
1703
+ invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
1704
+ invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
1705
+ invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
1706
+ );
1707
+
1708
+ const position = new THREE.Vector3();
1709
+ position.setFromMatrixPosition(matrix);
1710
+
1711
+ const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
1712
+
1713
+ const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
1714
+
1715
+ const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
1716
+
1717
+ const quaternion = new THREE.Quaternion();
1718
+ quaternion.setFromRotationMatrix(finalRotation);
1719
+
1720
+ position.y = -position.y;
1721
+ position.z = -position.z;
1722
+
1723
+ this.cameraFrustum.position.copy(position);
1724
+ this.cameraFrustum.quaternion.copy(quaternion);
1725
+
1726
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
1727
+
1728
+ if (this.cameraFrustum.visible !== showCameraFrustum) {
1729
+ this.cameraFrustum.visible = showCameraFrustum;
1730
+ }
1731
+
1732
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1733
+ this.cameraFrustum.children.forEach(line => {
1734
+ if (line.material && line.material.resolution) {
1735
+ line.material.resolution.copy(resolution);
1736
+ }
1737
+ });
1738
+ }
1739
+
1740
+ updateFrustumDimensions() {
1741
+ if (!this.cameraFrustum) return;
1742
+
1743
+ while(this.cameraFrustum.children.length > 0) {
1744
+ const child = this.cameraFrustum.children[0];
1745
+ if (child.geometry) child.geometry.dispose();
1746
+ if (child.material) child.material.dispose();
1747
+ this.cameraFrustum.remove(child);
1748
+ }
1749
+
1750
+ this.initCameraFrustumGeometry();
1751
+
1752
+ this.updateCameraFrustum(this.currentFrame);
1753
+ }
1754
+
1755
+ resetSettings() {
1756
+ if (!this.defaultSettings) return;
1757
+
1758
+ this.applyDefaultSettings();
1759
+
1760
+ this.updatePointCloudSettings();
1761
+ this.updateTrajectorySettings();
1762
+ this.updateFrustumDimensions();
1763
+
1764
+ this.ui.statusBar.textContent = "Settings reset to defaults";
1765
+ this.ui.statusBar.classList.remove('hidden');
1766
+
1767
+ setTimeout(() => {
1768
+ this.ui.statusBar.classList.add('hidden');
1769
+ }, 3000);
1770
+ }
1771
+ }
1772
+
1773
+ window.addEventListener('DOMContentLoaded', () => {
1774
+ new PointCloudVisualizer();
1775
+ });
1776
+ </script>
1777
+ </body>
1778
+ </html>
app.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import cv2
6
+ import base64
7
+ import time
8
+ import tempfile
9
+ import shutil
10
+ import glob
11
+ import threading
12
+ import subprocess
13
+ import struct
14
+ import zlib
15
+ from pathlib import Path
16
+ from einops import rearrange
17
+ from typing import List, Tuple, Union
18
+ try:
19
+ import spaces
20
+ except ImportError:
21
+ # Fallback for local development
22
+ def spaces(func):
23
+ return func
24
+ import torch
25
+ import logging
26
+ from concurrent.futures import ThreadPoolExecutor
27
+ import atexit
28
+ import uuid
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Import custom modules with error handling
35
+ try:
36
+ from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
37
+ from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
38
+ except ImportError as e:
39
+ logger.error(f"Failed to import custom modules: {e}")
40
+ raise
41
+
42
+ # Constants
43
+ MAX_FRAMES = 80
44
+ COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
45
+ MARKERS = [1, 5] # Cross for negative, Star for positive
46
+ MARKER_SIZE = 8
47
+
48
+ # Thread pool for delayed deletion
49
+ thread_pool_executor = ThreadPoolExecutor(max_workers=2)
50
+
51
+ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
52
+ """Delete file or directory after specified delay (default 10 minutes)"""
53
+ def _delete():
54
+ try:
55
+ if os.path.isfile(path):
56
+ os.remove(path)
57
+ elif os.path.isdir(path):
58
+ shutil.rmtree(path)
59
+ except Exception as e:
60
+ logger.warning(f"Failed to delete {path}: {e}")
61
+
62
+ def _wait_and_delete():
63
+ time.sleep(delay)
64
+ _delete()
65
+
66
+ thread_pool_executor.submit(_wait_and_delete)
67
+ atexit.register(_delete)
68
+
69
+ def create_user_temp_dir():
70
+ """Create a unique temporary directory for each user session"""
71
+ session_id = str(uuid.uuid4())[:8] # Short unique ID
72
+ temp_dir = os.path.join("temp_local", f"session_{session_id}")
73
+ os.makedirs(temp_dir, exist_ok=True)
74
+
75
+ # Schedule deletion after 10 minutes
76
+ delete_later(temp_dir, delay=600)
77
+
78
+ return temp_dir
79
+
80
+ from huggingface_hub import hf_hub_download
81
+ # init the model
82
+ os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
83
+
84
+ if os.environ.get("VGGT_DIR", None) is not None:
85
+ from models.vggt.vggt.models.vggt_moe import VGGT_MoE
86
+ from models.vggt.vggt.utils.load_fn import preprocess_image
87
+ vggt_model = VGGT_MoE()
88
+ vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
89
+ vggt_model.eval()
90
+ vggt_model = vggt_model.to("cuda")
91
+
92
+ # Global model initialization
93
+ print("🚀 Initializing local models...")
94
+ tracker_model, _ = get_tracker_predictor(".", vo_points=756)
95
+ predictor = get_sam_predictor()
96
+ print("✅ Models loaded successfully!")
97
+
98
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
99
+
100
+ @spaces.GPU
101
+ def gpu_run_inference(predictor_arg, image, points, boxes):
102
+ """GPU-accelerated SAM inference"""
103
+ if predictor_arg is None:
104
+ print("Initializing SAM predictor inside GPU function...")
105
+ predictor_arg = get_sam_predictor(predictor=predictor)
106
+
107
+ # Ensure predictor is on GPU
108
+ try:
109
+ if hasattr(predictor_arg, 'model'):
110
+ predictor_arg.model = predictor_arg.model.cuda()
111
+ elif hasattr(predictor_arg, 'sam'):
112
+ predictor_arg.sam = predictor_arg.sam.cuda()
113
+ elif hasattr(predictor_arg, 'to'):
114
+ predictor_arg = predictor_arg.to('cuda')
115
+
116
+ if hasattr(image, 'cuda'):
117
+ image = image.cuda()
118
+
119
+ except Exception as e:
120
+ print(f"Warning: Could not move predictor to GPU: {e}")
121
+
122
+ return run_inference(predictor_arg, image, points, boxes)
123
+
124
+ @spaces.GPU
125
+ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
126
+ """GPU-accelerated tracking"""
127
+ import torchvision.transforms as T
128
+ import decord
129
+
130
+ if tracker_model_arg is None or tracker_viser_arg is None:
131
+ print("Initializing tracker models inside GPU function...")
132
+ out_dir = os.path.join(temp_dir, "results")
133
+ os.makedirs(out_dir, exist_ok=True)
134
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
135
+
136
+ # Setup paths
137
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
138
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
139
+ out_dir = os.path.join(temp_dir, "results")
140
+ os.makedirs(out_dir, exist_ok=True)
141
+
142
+ # Load video using decord
143
+ video_reader = decord.VideoReader(video_path)
144
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
145
+
146
+ # Resize to ensure minimum side is 336
147
+ h, w = video_tensor.shape[2:]
148
+ scale = max(224 / h, 224 / w)
149
+ if scale < 1:
150
+ new_h, new_w = int(h * scale), int(w * scale)
151
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
152
+ video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
153
+
154
+ # Move to GPU
155
+ video_tensor = video_tensor.cuda()
156
+ print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
157
+
158
+ depth_tensor = None
159
+ intrs = None
160
+ extrs = None
161
+ data_npz_load = {}
162
+
163
+ # run vggt
164
+ if os.environ.get("VGGT_DIR", None) is not None:
165
+ # process the image tensor
166
+ video_tensor = preprocess_image(video_tensor)[None]
167
+ with torch.no_grad():
168
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
169
+ # Predict attributes including cameras, depth maps, and point maps.
170
+ predictions = vggt_model(video_tensor.cuda()/255)
171
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
172
+ depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
173
+
174
+ depth_tensor = depth_map.squeeze().cpu().numpy()
175
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
176
+ extrs = extrinsic.squeeze().cpu().numpy()
177
+ intrs = intrinsic.squeeze().cpu().numpy()
178
+ video_tensor = video_tensor.squeeze()
179
+ #NOTE: 20% of the depth is not reliable
180
+ # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
181
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
182
+
183
+ # Load and process mask
184
+ if os.path.exists(mask_path):
185
+ mask = cv2.imread(mask_path)
186
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
187
+ mask = mask.sum(axis=-1)>0
188
+ else:
189
+ mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
190
+ grid_size = 10
191
+
192
+ # Get frame dimensions and create grid points
193
+ frame_H, frame_W = video_tensor.shape[2:]
194
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
195
+
196
+ # Sample mask values at grid points and filter
197
+ if os.path.exists(mask_path):
198
+ grid_pts_int = grid_pts[0].long()
199
+ mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
200
+ grid_pts = grid_pts[:, mask_values]
201
+
202
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
203
+ print(f"Query points shape: {query_xyt.shape}")
204
+
205
+ # Run model inference
206
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
207
+ (
208
+ c2w_traj, intrs, point_map, conf_depth,
209
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
210
+ ) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
211
+ intrs=intrs, extrs=extrs,
212
+ queries=query_xyt,
213
+ fps=1, full_point=False, iters_track=4,
214
+ query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
215
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
216
+
217
+ # Resize results to avoid large I/O
218
+ max_size = 224
219
+ h, w = video.shape[2:]
220
+ scale = min(max_size / h, max_size / w)
221
+ if scale < 1:
222
+ new_h, new_w = int(h * scale), int(w * scale)
223
+ video = T.Resize((new_h, new_w))(video)
224
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
225
+ point_map = T.Resize((new_h, new_w))(point_map)
226
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
227
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
228
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
229
+
230
+ # Visualize tracks
231
+ tracker_viser_arg.visualize(video=video[None],
232
+ tracks=track2d_pred[None][...,:2],
233
+ visibility=vis_pred[None],filename="test")
234
+
235
+ # Save in tapip3d format
236
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
237
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
238
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
239
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
240
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
241
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
242
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
243
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
244
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
245
+
246
+ return None
247
+
248
+ def compress_and_write(filename, header, blob):
249
+ header_bytes = json.dumps(header).encode("utf-8")
250
+ header_len = struct.pack("<I", len(header_bytes))
251
+ with open(filename, "wb") as f:
252
+ f.write(header_len)
253
+ f.write(header_bytes)
254
+ f.write(blob)
255
+
256
+ def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
257
+ fixed_size = (width, height)
258
+
259
+ data = np.load(npz_file)
260
+ extrinsics = data["extrinsics"]
261
+ intrinsics = data["intrinsics"]
262
+ trajs = data["coords"]
263
+ T, C, H, W = data["video"].shape
264
+
265
+ fx = intrinsics[0, 0, 0]
266
+ fy = intrinsics[0, 1, 1]
267
+ fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
268
+ fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
269
+ original_aspect_ratio = (W / fx) / (H / fy)
270
+
271
+ rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
272
+ rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
273
+ for frame in rgb_video])
274
+
275
+ depth_video = data["depths"].astype(np.float32)
276
+ if "confs_depth" in data.keys():
277
+ confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
278
+ depth_video = depth_video * confs
279
+ depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
280
+ for frame in depth_video])
281
+
282
+ scale_x = fixed_size[0] / W
283
+ scale_y = fixed_size[1] / H
284
+ intrinsics = intrinsics.copy()
285
+ intrinsics[:, 0, :] *= scale_x
286
+ intrinsics[:, 1, :] *= scale_y
287
+
288
+ min_depth = float(depth_video.min()) * 0.8
289
+ max_depth = float(depth_video.max()) * 1.5
290
+
291
+ depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
292
+ depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
293
+
294
+ depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
295
+ depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
296
+ depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
297
+
298
+ first_frame_inv = np.linalg.inv(extrinsics[0])
299
+ normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
300
+
301
+ normalized_trajs = np.zeros_like(trajs)
302
+ for t in range(T):
303
+ homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
304
+ transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
305
+ normalized_trajs[t] = transformed_trajs[:, :3]
306
+
307
+ arrays = {
308
+ "rgb_video": rgb_video,
309
+ "depths_rgb": depths_rgb,
310
+ "intrinsics": intrinsics,
311
+ "extrinsics": normalized_extrinsics,
312
+ "inv_extrinsics": np.linalg.inv(normalized_extrinsics),
313
+ "trajectories": normalized_trajs.astype(np.float32),
314
+ "cameraZ": 0.0
315
+ }
316
+
317
+ header = {}
318
+ blob_parts = []
319
+ offset = 0
320
+ for key, arr in arrays.items():
321
+ arr = np.ascontiguousarray(arr)
322
+ arr_bytes = arr.tobytes()
323
+ header[key] = {
324
+ "dtype": str(arr.dtype),
325
+ "shape": arr.shape,
326
+ "offset": offset,
327
+ "length": len(arr_bytes)
328
+ }
329
+ blob_parts.append(arr_bytes)
330
+ offset += len(arr_bytes)
331
+
332
+ raw_blob = b"".join(blob_parts)
333
+ compressed_blob = zlib.compress(raw_blob, level=9)
334
+
335
+ header["meta"] = {
336
+ "depthRange": [min_depth, max_depth],
337
+ "totalFrames": int(T),
338
+ "resolution": fixed_size,
339
+ "baseFrameRate": fps,
340
+ "numTrajectoryPoints": normalized_trajs.shape[1],
341
+ "fov": float(fov_y),
342
+ "fov_x": float(fov_x),
343
+ "original_aspect_ratio": float(original_aspect_ratio),
344
+ "fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
345
+ }
346
+
347
+ compress_and_write('./_viz/data.bin', header, compressed_blob)
348
+ with open('./_viz/data.bin', "rb") as f:
349
+ encoded_blob = base64.b64encode(f.read()).decode("ascii")
350
+ os.unlink('./_viz/data.bin')
351
+
352
+ random_path = f'./_viz/_{time.time()}.html'
353
+ with open('./_viz/viz_template.html') as f:
354
+ html_template = f.read()
355
+ html_out = html_template.replace(
356
+ "<head>",
357
+ f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
358
+ )
359
+ with open(random_path,'w') as f:
360
+ f.write(html_out)
361
+
362
+ return random_path
363
+
364
+ def numpy_to_base64(arr):
365
+ """Convert numpy array to base64 string"""
366
+ return base64.b64encode(arr.tobytes()).decode('utf-8')
367
+
368
+ def base64_to_numpy(b64_str, shape, dtype):
369
+ """Convert base64 string back to numpy array"""
370
+ return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
371
+
372
+ def get_video_name(video_path):
373
+ """Extract video name without extension"""
374
+ return os.path.splitext(os.path.basename(video_path))[0]
375
+
376
+ def extract_first_frame(video_path):
377
+ """Extract first frame from video file"""
378
+ try:
379
+ cap = cv2.VideoCapture(video_path)
380
+ ret, frame = cap.read()
381
+ cap.release()
382
+
383
+ if ret:
384
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
385
+ return frame_rgb
386
+ else:
387
+ return None
388
+ except Exception as e:
389
+ print(f"Error extracting first frame: {e}")
390
+ return None
391
+
392
+ def handle_video_upload(video):
393
+ """Handle video upload and extract first frame"""
394
+ if video is None:
395
+ return (None, None, [],
396
+ gr.update(value=50),
397
+ gr.update(value=756),
398
+ gr.update(value=3))
399
+
400
+ # Create user-specific temporary directory
401
+ user_temp_dir = create_user_temp_dir()
402
+
403
+ # Get original video name and copy to temp directory
404
+ if isinstance(video, str):
405
+ video_name = get_video_name(video)
406
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
407
+ shutil.copy(video, video_path)
408
+ else:
409
+ video_name = get_video_name(video.name)
410
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
411
+ with open(video_path, 'wb') as f:
412
+ f.write(video.read())
413
+
414
+ print(f"📁 Video saved to: {video_path}")
415
+
416
+ # Extract first frame
417
+ frame = extract_first_frame(video_path)
418
+ if frame is None:
419
+ return (None, None, [],
420
+ gr.update(value=50),
421
+ gr.update(value=756),
422
+ gr.update(value=3))
423
+
424
+ # Resize frame to have minimum side length of 336
425
+ h, w = frame.shape[:2]
426
+ scale = 336 / min(h, w)
427
+ new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
428
+ frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
429
+
430
+ # Store frame data with temp directory info
431
+ frame_data = {
432
+ 'data': numpy_to_base64(frame),
433
+ 'shape': frame.shape,
434
+ 'dtype': str(frame.dtype),
435
+ 'temp_dir': user_temp_dir,
436
+ 'video_name': video_name,
437
+ 'video_path': video_path
438
+ }
439
+
440
+ # Get video-specific settings
441
+ print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
442
+ grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
443
+ print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
444
+
445
+ return (json.dumps(frame_data), frame, [],
446
+ gr.update(value=grid_size_val),
447
+ gr.update(value=vo_points_val),
448
+ gr.update(value=fps_val))
449
+
450
+ def save_masks(o_masks, video_name, temp_dir):
451
+ """Save binary masks to files in user-specific temp directory"""
452
+ o_files = []
453
+ for mask, _ in o_masks:
454
+ o_mask = np.uint8(mask.squeeze() * 255)
455
+ o_file = os.path.join(temp_dir, f"{video_name}.png")
456
+ cv2.imwrite(o_file, o_mask)
457
+ o_files.append(o_file)
458
+ return o_files
459
+
460
+ def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
461
+ """Handle point selection for SAM"""
462
+ if original_img is None:
463
+ return None, []
464
+
465
+ try:
466
+ # Convert stored image data back to numpy array
467
+ frame_data = json.loads(original_img)
468
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
469
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
470
+ video_name = frame_data.get('video_name', 'video')
471
+
472
+ # Create a display image for visualization
473
+ display_img = original_img_array.copy()
474
+ new_sel_pix = sel_pix.copy() if sel_pix else []
475
+ new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
476
+
477
+ print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
478
+ # Run SAM inference
479
+ o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
480
+
481
+ # Draw points on display image
482
+ for point, label in new_sel_pix:
483
+ cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
484
+
485
+ # Draw mask overlay on display image
486
+ if o_masks:
487
+ mask = o_masks[0][0]
488
+ overlay = display_img.copy()
489
+ overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
490
+ display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
491
+
492
+ # Save mask for tracking
493
+ save_masks(o_masks, video_name, temp_dir)
494
+ print(f"✅ Mask saved for video: {video_name}")
495
+
496
+ return display_img, new_sel_pix
497
+
498
+ except Exception as e:
499
+ print(f"❌ Error in select_point: {e}")
500
+ return None, []
501
+
502
+ def reset_points(original_img: str, sel_pix):
503
+ """Reset all points and clear the mask"""
504
+ if original_img is None:
505
+ return None, []
506
+
507
+ try:
508
+ # Convert stored image data back to numpy array
509
+ frame_data = json.loads(original_img)
510
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
511
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
512
+
513
+ # Create a display image (just the original image)
514
+ display_img = original_img_array.copy()
515
+
516
+ # Clear all points
517
+ new_sel_pix = []
518
+
519
+ # Clear any existing masks
520
+ for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
521
+ try:
522
+ os.remove(mask_file)
523
+ except Exception as e:
524
+ logger.warning(f"Failed to remove mask file {mask_file}: {e}")
525
+
526
+ print("🔄 Points and masks reset")
527
+ return display_img, new_sel_pix
528
+
529
+ except Exception as e:
530
+ print(f"❌ Error in reset_points: {e}")
531
+ return None, []
532
+
533
+ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
534
+ """Launch visualization with user-specific temp directory"""
535
+ if original_image_state is None:
536
+ return None, None, None
537
+
538
+ try:
539
+ # Get user's temp directory from stored frame data
540
+ frame_data = json.loads(original_image_state)
541
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
542
+ video_name = frame_data.get('video_name', 'video')
543
+
544
+ print(f"🚀 Starting tracking for video: {video_name}")
545
+ print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
546
+
547
+ # Check for mask files
548
+ mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
549
+ video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
550
+
551
+ if not video_files:
552
+ print("❌ No video file found")
553
+ return "❌ Error: No video file found", None, None
554
+
555
+ video_path = video_files[0]
556
+ mask_path = mask_files[0] if mask_files else None
557
+
558
+ # Run tracker
559
+ print("🎯 Running tracker...")
560
+ out_dir = os.path.join(temp_dir, "results")
561
+ os.makedirs(out_dir, exist_ok=True)
562
+
563
+ gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
564
+
565
+ # Process results
566
+ npz_path = os.path.join(out_dir, "result.npz")
567
+ track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
568
+
569
+ if os.path.exists(npz_path):
570
+ print("📊 Processing 3D visualization...")
571
+ html_path = process_point_cloud_data(npz_path)
572
+
573
+ # Schedule deletion of generated files
574
+ delete_later(html_path, delay=600)
575
+ if os.path.exists(track2d_video):
576
+ delete_later(track2d_video, delay=600)
577
+ delete_later(npz_path, delay=600)
578
+
579
+ # Create iframe HTML
580
+ iframe_html = f"""
581
+ <div style='border: 3px solid #667eea; border-radius: 10px;
582
+ background: #f8f9ff; height: 650px; width: 100%;
583
+ box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
584
+ margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
585
+ <iframe id="viz_iframe" src="/gradio_api/file={html_path}"
586
+ width="100%" height="650" frameborder="0"
587
+ style="border: none; display: block; width: 100%; height: 650px;
588
+ margin: 0; padding: 0; border-radius: 7px;">
589
+ </iframe>
590
+ </div>
591
+ """
592
+
593
+ print("✅ Tracking completed successfully!")
594
+ return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
595
+ else:
596
+ print("❌ Tracking failed - no results generated")
597
+ return "❌ Error: Tracking failed to generate results", None, None
598
+
599
+ except Exception as e:
600
+ print(f"❌ Error in launch_viz: {e}")
601
+ return f"❌ Error: {str(e)}", None, None
602
+
603
+ def clear_all():
604
+ """Clear all buffers and temporary files"""
605
+ return (None, None, [],
606
+ gr.update(value=50),
607
+ gr.update(value=756),
608
+ gr.update(value=3))
609
+
610
+ def clear_all_with_download():
611
+ """Clear all buffers including both download components"""
612
+ return (None, None, [],
613
+ gr.update(value=50),
614
+ gr.update(value=756),
615
+ gr.update(value=3),
616
+ None, # tracking_video_download
617
+ None) # HTML download component
618
+
619
+ def get_video_settings(video_name):
620
+ """Get video-specific settings based on video name"""
621
+ video_settings = {
622
+ "running": (50, 512, 2),
623
+ "backpack": (40, 600, 2),
624
+ "kitchen": (60, 800, 3),
625
+ "pillow": (35, 500, 2),
626
+ "handwave": (35, 500, 8),
627
+ "hockey": (45, 700, 2),
628
+ "drifting": (35, 1000, 6),
629
+ "basketball": (45, 1500, 5),
630
+ "ken_block_0": (45, 700, 2),
631
+ "ego_kc1": (45, 500, 4),
632
+ "vertical_place": (45, 500, 3),
633
+ "ego_teaser": (45, 1200, 10),
634
+ "robot_unitree": (45, 500, 4),
635
+ "robot_3": (35, 400, 5),
636
+ "teleop2": (45, 256, 7),
637
+ "pusht": (45, 256, 10),
638
+ "cinema_0": (45, 356, 5),
639
+ "cinema_1": (45, 756, 3),
640
+ "robot1": (45, 600, 2),
641
+ "robot2": (45, 600, 2),
642
+ "protein": (45, 600, 2),
643
+ "kitchen_egocentric": (45, 600, 2),
644
+ }
645
+
646
+ return video_settings.get(video_name, (50, 756, 3))
647
+
648
+ # Create the Gradio interface
649
+ print("🎨 Creating Gradio interface...")
650
+
651
+ with gr.Blocks(
652
+ theme=gr.themes.Soft(),
653
+ title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
654
+ css="""
655
+ .gradio-container {
656
+ max-width: 1200px !important;
657
+ margin: auto !important;
658
+ }
659
+ .gr-button {
660
+ margin: 5px;
661
+ }
662
+ .gr-form {
663
+ background: white;
664
+ border-radius: 10px;
665
+ padding: 20px;
666
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
667
+ }
668
+ /* 移除 gr.Group 的默认灰色背景 */
669
+ .gr-form {
670
+ background: transparent !important;
671
+ border: none !important;
672
+ box-shadow: none !important;
673
+ padding: 0 !important;
674
+ }
675
+ /* 固定3D可视化器尺寸 */
676
+ #viz_container {
677
+ height: 650px !important;
678
+ min-height: 650px !important;
679
+ max-height: 650px !important;
680
+ width: 100% !important;
681
+ margin: 0 !important;
682
+ padding: 0 !important;
683
+ overflow: hidden !important;
684
+ }
685
+ #viz_container > div {
686
+ height: 650px !important;
687
+ min-height: 650px !important;
688
+ max-height: 650px !important;
689
+ width: 100% !important;
690
+ margin: 0 !important;
691
+ padding: 0 !important;
692
+ box-sizing: border-box !important;
693
+ }
694
+ #viz_container iframe {
695
+ height: 650px !important;
696
+ min-height: 650px !important;
697
+ max-height: 650px !important;
698
+ width: 100% !important;
699
+ border: none !important;
700
+ display: block !important;
701
+ margin: 0 !important;
702
+ padding: 0 !important;
703
+ box-sizing: border-box !important;
704
+ }
705
+ /* 固定视频上传组件高度 */
706
+ .gr-video {
707
+ height: 300px !important;
708
+ min-height: 300px !important;
709
+ max-height: 300px !important;
710
+ }
711
+ .gr-video video {
712
+ height: 260px !important;
713
+ max-height: 260px !important;
714
+ object-fit: contain !important;
715
+ background: #f8f9fa;
716
+ }
717
+ .gr-video .gr-video-player {
718
+ height: 260px !important;
719
+ max-height: 260px !important;
720
+ }
721
+ /* 强力移除examples的灰色背景 - 使用更通用的选择器 */
722
+ .horizontal-examples,
723
+ .horizontal-examples > *,
724
+ .horizontal-examples * {
725
+ background: transparent !important;
726
+ background-color: transparent !important;
727
+ border: none !important;
728
+ }
729
+
730
+ /* Examples组件水平滚动样式 */
731
+ .horizontal-examples [data-testid="examples"] {
732
+ background: transparent !important;
733
+ background-color: transparent !important;
734
+ }
735
+
736
+ .horizontal-examples [data-testid="examples"] > div {
737
+ background: transparent !important;
738
+ background-color: transparent !important;
739
+ overflow-x: auto !important;
740
+ overflow-y: hidden !important;
741
+ scrollbar-width: thin;
742
+ scrollbar-color: #667eea transparent;
743
+ padding: 0 !important;
744
+ margin-top: 10px;
745
+ border: none !important;
746
+ }
747
+
748
+ .horizontal-examples [data-testid="examples"] table {
749
+ display: flex !important;
750
+ flex-wrap: nowrap !important;
751
+ min-width: max-content !important;
752
+ gap: 15px !important;
753
+ padding: 10px 0;
754
+ background: transparent !important;
755
+ border: none !important;
756
+ }
757
+
758
+ .horizontal-examples [data-testid="examples"] tbody {
759
+ display: flex !important;
760
+ flex-direction: row !important;
761
+ flex-wrap: nowrap !important;
762
+ gap: 15px !important;
763
+ background: transparent !important;
764
+ }
765
+
766
+ .horizontal-examples [data-testid="examples"] tr {
767
+ display: flex !important;
768
+ flex-direction: column !important;
769
+ min-width: 160px !important;
770
+ max-width: 160px !important;
771
+ margin: 0 !important;
772
+ background: white !important;
773
+ border-radius: 12px;
774
+ box-shadow: 0 3px 12px rgba(0,0,0,0.12);
775
+ transition: all 0.3s ease;
776
+ cursor: pointer;
777
+ overflow: hidden;
778
+ border: none !important;
779
+ }
780
+
781
+ .horizontal-examples [data-testid="examples"] tr:hover {
782
+ transform: translateY(-4px);
783
+ box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
784
+ }
785
+
786
+ .horizontal-examples [data-testid="examples"] td {
787
+ text-align: center !important;
788
+ padding: 0 !important;
789
+ border: none !important;
790
+ background: transparent !important;
791
+ }
792
+
793
+ .horizontal-examples [data-testid="examples"] td:first-child {
794
+ padding: 0 !important;
795
+ background: transparent !important;
796
+ }
797
+
798
+ .horizontal-examples [data-testid="examples"] video {
799
+ border-radius: 8px 8px 0 0 !important;
800
+ width: 100% !important;
801
+ height: 90px !important;
802
+ object-fit: cover !important;
803
+ background: #f8f9fa !important;
804
+ }
805
+
806
+ .horizontal-examples [data-testid="examples"] td:last-child {
807
+ font-size: 11px !important;
808
+ font-weight: 600 !important;
809
+ color: #333 !important;
810
+ padding: 8px 12px !important;
811
+ background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
812
+ border-radius: 0 0 8px 8px;
813
+ }
814
+
815
+ /* 滚动条样式 */
816
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
817
+ height: 8px;
818
+ }
819
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
820
+ background: transparent;
821
+ border-radius: 4px;
822
+ }
823
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
824
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
825
+ border-radius: 4px;
826
+ }
827
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
828
+ background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
829
+ }
830
+ """
831
+ ) as demo:
832
+
833
+ # Add prominent main title
834
+
835
+ gr.Markdown("""
836
+ # ✨ SpatialTrackerV2
837
+
838
+ Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
839
+
840
+ **⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
841
+
842
+ **🔬 Advanced Usage with SAM:**
843
+ 1. Upload a video file or select from examples below
844
+ 2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
845
+ 3. Adjust tracking parameters for optimal performance
846
+ 4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
847
+
848
+ """)
849
+
850
+ # Status indicator
851
+ gr.Markdown("**Status:** 🟢 Local Processing Mode")
852
+
853
+ # Main content area - video upload left, 3D visualization right
854
+ with gr.Row():
855
+ with gr.Column(scale=1):
856
+ # Video upload section
857
+ gr.Markdown("### 📂 Select Video")
858
+
859
+ # Define video_input here so it can be referenced in examples
860
+ video_input = gr.Video(
861
+ label="Upload Video or Select Example",
862
+ format="mp4",
863
+ height=250 # Matched height with 3D viz
864
+ )
865
+
866
+
867
+ # Traditional examples but with horizontal scroll styling
868
+ gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
869
+ with gr.Row(elem_classes=["horizontal-examples"]):
870
+ # Horizontal video examples with slider
871
+ # gr.HTML("<div style='margin-top: 5px;'></div>")
872
+ gr.Examples(
873
+ examples=[
874
+ ["./examples/robot1.mp4"],
875
+ ["./examples/robot2.mp4"],
876
+ ["./examples/protein.mp4"],
877
+ ["./examples/kitchen_egocentric.mp4"],
878
+ ["./examples/hockey.mp4"],
879
+ ["./examples/running.mp4"],
880
+ ["./examples/robot_3.mp4"],
881
+ ["./examples/backpack.mp4"],
882
+ ["./examples/kitchen.mp4"],
883
+ ["./examples/pillow.mp4"],
884
+ ["./examples/handwave.mp4"],
885
+ ["./examples/drifting.mp4"],
886
+ ["./examples/basketball.mp4"],
887
+ ["./examples/ken_block_0.mp4"],
888
+ ["./examples/ego_kc1.mp4"],
889
+ ["./examples/vertical_place.mp4"],
890
+ ["./examples/ego_teaser.mp4"],
891
+ ["./examples/robot_unitree.mp4"],
892
+ ["./examples/teleop2.mp4"],
893
+ ["./examples/pusht.mp4"],
894
+ ["./examples/cinema_0.mp4"],
895
+ ["./examples/cinema_1.mp4"],
896
+ ],
897
+ inputs=[video_input],
898
+ outputs=[video_input],
899
+ fn=None,
900
+ cache_examples=False,
901
+ label="",
902
+ examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
903
+ )
904
+
905
+ with gr.Column(scale=2):
906
+ # 3D Visualization - wider and taller to match left side
907
+ with gr.Group():
908
+ gr.Markdown("### 🌐 3D Trajectory Visualization")
909
+ viz_html = gr.HTML(
910
+ label="3D Trajectory Visualization",
911
+ value="""
912
+ <div style='border: 3px solid #667eea; border-radius: 10px;
913
+ background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
914
+ text-align: center; height: 650px; display: flex;
915
+ flex-direction: column; justify-content: center; align-items: center;
916
+ box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
917
+ margin: 0; padding: 20px; box-sizing: border-box;'>
918
+ <div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
919
+ <h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
920
+ 3D Trajectory Visualization
921
+ </h3>
922
+ <p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
923
+ Track any pixels in 3D space with camera motion
924
+ </p>
925
+ <div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
926
+ padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
927
+ <span style='color: #667eea; font-weight: 600; font-size: 16px;'>
928
+ ⚡ Powered by SpatialTracker V2
929
+ </span>
930
+ </div>
931
+ </div>
932
+ """,
933
+ elem_id="viz_container"
934
+ )
935
+
936
+ # Start button section - below video area
937
+ with gr.Row():
938
+ with gr.Column(scale=3):
939
+ launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
940
+ with gr.Column(scale=1):
941
+ clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
942
+
943
+ # Tracking parameters section
944
+ with gr.Row():
945
+ gr.Markdown("### ⚙️ Tracking Parameters")
946
+ with gr.Row():
947
+ grid_size = gr.Slider(
948
+ minimum=10, maximum=100, step=10, value=50,
949
+ label="Grid Size", info="Tracking detail level"
950
+ )
951
+ vo_points = gr.Slider(
952
+ minimum=100, maximum=2000, step=50, value=756,
953
+ label="VO Points", info="Motion accuracy"
954
+ )
955
+ fps = gr.Slider(
956
+ minimum=1, maximum=20, step=1, value=3,
957
+ label="FPS", info="Processing speed"
958
+ )
959
+
960
+ # Advanced Point Selection with SAM - Collapsed by default
961
+ with gr.Row():
962
+ gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
963
+ with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
964
+ gr.HTML("""
965
+ <div style='margin-bottom: 15px;'>
966
+ <ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
967
+ <li>Click on target objects in the image for SAM-guided segmentation</li>
968
+ <li>Positive points: include these areas | Negative points: exclude these areas</li>
969
+ <li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
970
+ </ul>
971
+ </div>
972
+ """)
973
+
974
+ with gr.Row():
975
+ with gr.Column():
976
+ interactive_frame = gr.Image(
977
+ label="Click to select tracking points with SAM guidance",
978
+ type="numpy",
979
+ interactive=True,
980
+ height=300
981
+ )
982
+
983
+ with gr.Row():
984
+ point_type = gr.Radio(
985
+ choices=["positive_point", "negative_point"],
986
+ value="positive_point",
987
+ label="Point Type",
988
+ info="Positive: track these areas | Negative: avoid these areas"
989
+ )
990
+
991
+ with gr.Row():
992
+ reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
993
+
994
+ # Downloads section - hidden but still functional for local processing
995
+ with gr.Row(visible=False):
996
+ with gr.Column(scale=1):
997
+ tracking_video_download = gr.File(
998
+ label="📹 Download 2D Tracking Video",
999
+ interactive=False,
1000
+ visible=False
1001
+ )
1002
+ with gr.Column(scale=1):
1003
+ html_download = gr.File(
1004
+ label="📄 Download 3D Visualization HTML",
1005
+ interactive=False,
1006
+ visible=False
1007
+ )
1008
+
1009
+ # GitHub Star Section
1010
+ gr.HTML("""
1011
+ <div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
1012
+ border-radius: 8px; padding: 20px; margin: 15px 0;
1013
+ box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
1014
+ border: 1px solid rgba(102, 126, 234, 0.15);'>
1015
+ <div style='text-align: center;'>
1016
+ <h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1017
+ ⭐ Love SpatialTracker? Give us a Star! ⭐
1018
+ </h3>
1019
+ <p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1020
+ Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
1021
+ </p>
1022
+ <a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
1023
+ style='display: inline-flex; align-items: center; gap: 8px;
1024
+ background: rgba(102, 126, 234, 0.1); color: #4a5568;
1025
+ padding: 10px 20px; border-radius: 25px; text-decoration: none;
1026
+ font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
1027
+ transition: all 0.3s ease;'
1028
+ onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
1029
+ onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
1030
+ <span style='font-size: 16px;'>⭐</span>
1031
+ Star SpatialTracker V2 on GitHub
1032
+ </a>
1033
+ </div>
1034
+ </div>
1035
+ """)
1036
+
1037
+ # Acknowledgments Section
1038
+ gr.HTML("""
1039
+ <div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
1040
+ border-radius: 8px; padding: 20px; margin: 15px 0;
1041
+ box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
1042
+ border: 1px solid rgba(255, 193, 7, 0.2);'>
1043
+ <div style='text-align: center;'>
1044
+ <h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1045
+ 📚 Acknowledgments
1046
+ </h3>
1047
+ <p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1048
+ Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
1049
+ </p>
1050
+ <a href="https://github.com/zbw001/TAPIP3D" target="_blank"
1051
+ style='display: inline-flex; align-items: center; gap: 8px;
1052
+ background: rgba(255, 193, 7, 0.15); color: #5d4037;
1053
+ padding: 10px 20px; border-radius: 25px; text-decoration: none;
1054
+ font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
1055
+ transition: all 0.3s ease;'
1056
+ onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
1057
+ onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
1058
+ 📚 Visit TAPIP3D Repository
1059
+ </a>
1060
+ </div>
1061
+ </div>
1062
+ """)
1063
+
1064
+ # Footer
1065
+ gr.HTML("""
1066
+ <div style='text-align: center; margin: 20px 0 10px 0;'>
1067
+ <span style='font-size: 12px; color: #888; font-style: italic;'>
1068
+ Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
1069
+ </span>
1070
+ </div>
1071
+ """)
1072
+
1073
+ # Hidden state variables
1074
+ original_image_state = gr.State(None)
1075
+ selected_points = gr.State([])
1076
+
1077
+ # Event handlers
1078
+ video_input.change(
1079
+ fn=handle_video_upload,
1080
+ inputs=[video_input],
1081
+ outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
1082
+ )
1083
+
1084
+ interactive_frame.select(
1085
+ fn=select_point,
1086
+ inputs=[original_image_state, selected_points, point_type],
1087
+ outputs=[interactive_frame, selected_points]
1088
+ )
1089
+
1090
+ reset_points_btn.click(
1091
+ fn=reset_points,
1092
+ inputs=[original_image_state, selected_points],
1093
+ outputs=[interactive_frame, selected_points]
1094
+ )
1095
+
1096
+ clear_all_btn.click(
1097
+ fn=clear_all_with_download,
1098
+ outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
1099
+ )
1100
+
1101
+ launch_btn.click(
1102
+ fn=launch_viz,
1103
+ inputs=[grid_size, vo_points, fps, original_image_state],
1104
+ outputs=[viz_html, tracking_video_download, html_download]
1105
+ )
1106
+
1107
+ # Launch the interface
1108
+ if __name__ == "__main__":
1109
+ print("🌟 Launching SpatialTracker V2 Local Version...")
1110
+ print("🔗 Running in Local Processing Mode")
1111
+
1112
+ demo.launch(
1113
+ server_name="0.0.0.0",
1114
+ server_port=7860,
1115
+ share=True,
1116
+ debug=True,
1117
+ show_error=True
1118
+ )
app_3rd/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🌟 SpatialTrackerV2 Integrated with SAM 🌟
2
+ SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
3
+
4
+ ## Installation
5
+ ```
6
+
7
+ python -m pip install git+https://github.com/facebookresearch/segment-anything.git
8
+ cd app_3rd/sam_utils
9
+ mkdir checkpoints
10
+ cd checkpoints
11
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
12
+ ```
app_3rd/sam_utils/hf_sam_predictor.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Tuple, List, Union
5
+ import warnings
6
+ import cv2
7
+ try:
8
+ from transformers import SamModel, SamProcessor
9
+ from huggingface_hub import hf_hub_download
10
+ HF_AVAILABLE = True
11
+ except ImportError:
12
+ HF_AVAILABLE = False
13
+ warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
14
+
15
+ # Hugging Face model mapping
16
+ HF_MODELS = {
17
+ 'vit_b': 'facebook/sam-vit-base',
18
+ 'vit_l': 'facebook/sam-vit-large',
19
+ 'vit_h': 'facebook/sam-vit-huge'
20
+ }
21
+
22
+ class HFSamPredictor:
23
+ """
24
+ Hugging Face version of SamPredictor that wraps the transformers SAM models.
25
+ This class provides the same interface as the original SamPredictor for seamless integration.
26
+ """
27
+
28
+ def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
29
+ """
30
+ Initialize the HF SAM predictor.
31
+
32
+ Args:
33
+ model: The SAM model from transformers
34
+ processor: The SAM processor from transformers
35
+ device: Device to run the model on ('cuda', 'cpu', etc.)
36
+ """
37
+ self.model = model
38
+ self.processor = processor
39
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
40
+ self.model.to(self.device)
41
+ self.model.eval()
42
+
43
+ # Store the current image and its features
44
+ self.original_size = None
45
+ self.input_size = None
46
+ self.features = None
47
+ self.image = None
48
+
49
+ @classmethod
50
+ def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
51
+ """
52
+ Load a SAM model from Hugging Face Hub.
53
+
54
+ Args:
55
+ model_name: Model name from HF_MODELS or direct HF model path
56
+ device: Device to load the model on
57
+
58
+ Returns:
59
+ HFSamPredictor instance
60
+ """
61
+ if not HF_AVAILABLE:
62
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
63
+
64
+ # Map model type to HF model name if needed
65
+ if model_name in HF_MODELS:
66
+ model_name = HF_MODELS[model_name]
67
+
68
+ print(f"Loading SAM model from Hugging Face: {model_name}")
69
+
70
+ # Load model and processor
71
+ model = SamModel.from_pretrained(model_name)
72
+ processor = SamProcessor.from_pretrained(model_name)
73
+ return cls(model, processor, device)
74
+
75
+ def preprocess(self, image: np.ndarray,
76
+ input_points: List[List[float]], input_labels: List[int]) -> None:
77
+ """
78
+ Set the image for prediction. This preprocesses the image and extracts features.
79
+
80
+ Args:
81
+ image: Input image as numpy array (H, W, C) in RGB format
82
+ """
83
+ if image.dtype != np.uint8:
84
+ image = (image * 255).astype(np.uint8)
85
+
86
+ self.image = image
87
+ self.original_size = image.shape[:2]
88
+
89
+ # Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
90
+ inputs = self.processor(
91
+ images=image,
92
+ input_points=input_points,
93
+ input_labels=input_labels,
94
+ return_tensors="pt"
95
+ )
96
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
97
+
98
+ self.input_size = inputs['pixel_values'].shape[-2:]
99
+ self.features = inputs
100
+ return inputs
101
+
102
+
103
+ def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
104
+ image: Optional[np.ndarray] = None) -> HFSamPredictor:
105
+ """
106
+ Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
107
+
108
+ Args:
109
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
110
+ device: Device to run the model on
111
+ image: Optional image to set immediately
112
+
113
+ Returns:
114
+ HFSamPredictor instance
115
+ """
116
+ if not HF_AVAILABLE:
117
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
118
+
119
+ if device is None:
120
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
+
122
+ # Load the predictor
123
+ predictor = HFSamPredictor.from_pretrained(model_type, device)
124
+
125
+ # Set image if provided
126
+ if image is not None:
127
+ predictor.set_image(image)
128
+
129
+ return predictor
app_3rd/sam_utils/inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import torch
5
+ from segment_anything import SamPredictor, sam_model_registry
6
+
7
+ # Try to import HF SAM support
8
+ try:
9
+ from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
10
+ HF_AVAILABLE = True
11
+ except ImportError:
12
+ HF_AVAILABLE = False
13
+
14
+ models = {
15
+ 'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
16
+ 'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
17
+ 'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
18
+ }
19
+
20
+
21
+ def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
22
+ """
23
+ Get SAM predictor with option to use HuggingFace version
24
+
25
+ Args:
26
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
27
+ device: Device to run on
28
+ image: Optional image to set immediately
29
+ use_hf: Whether to use HuggingFace SAM instead of original SAM
30
+ """
31
+ if predictor is not None:
32
+ return predictor
33
+ if use_hf:
34
+ if not HF_AVAILABLE:
35
+ raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
36
+ return get_hf_sam_predictor(model_type, device, image)
37
+
38
+ # Original SAM logic
39
+ if device is None and torch.cuda.is_available():
40
+ device = 'cuda'
41
+ elif device is None:
42
+ device = 'cpu'
43
+ # sam model
44
+ sam = sam_model_registry[model_type](checkpoint=models[model_type])
45
+ sam = sam.to(device)
46
+
47
+ predictor = SamPredictor(sam)
48
+ if image is not None:
49
+ predictor.set_image(image)
50
+ return predictor
51
+
52
+
53
+ def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
54
+ """
55
+ Run inference with either original SAM or HF SAM predictor
56
+
57
+ Args:
58
+ predictor: SamPredictor or HFSamPredictor instance
59
+ input_x: Input image
60
+ selected_points: List of (point, label) tuples
61
+ multi_object: Whether to handle multiple objects
62
+ """
63
+ if len(selected_points) == 0:
64
+ return []
65
+
66
+ # Check if using HF SAM
67
+ if isinstance(predictor, HFSamPredictor):
68
+ return _run_hf_inference(predictor, input_x, selected_points, multi_object)
69
+ else:
70
+ return _run_original_inference(predictor, input_x, selected_points, multi_object)
71
+
72
+
73
+ def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
74
+ """Run inference with original SAM"""
75
+ points = torch.Tensor(
76
+ [p for p, _ in selected_points]
77
+ ).to(predictor.device).unsqueeze(1)
78
+
79
+ labels = torch.Tensor(
80
+ [int(l) for _, l in selected_points]
81
+ ).to(predictor.device).unsqueeze(1)
82
+
83
+ transformed_points = predictor.transform.apply_coords_torch(
84
+ points, input_x.shape[:2])
85
+
86
+ masks, scores, logits = predictor.predict_torch(
87
+ point_coords=transformed_points[:,0][None],
88
+ point_labels=labels[:,0][None],
89
+ multimask_output=False,
90
+ )
91
+ masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
92
+
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
95
+
96
+ return [(masks, 'final_mask')]
97
+
98
+
99
+ def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
100
+ """Run inference with HF SAM"""
101
+ # Prepare points and labels for HF SAM
102
+ select_pts = [[list(p) for p, _ in selected_points]]
103
+ select_lbls = [[int(l) for _, l in selected_points]]
104
+
105
+ # Preprocess inputs
106
+ inputs = predictor.preprocess(input_x, select_pts, select_lbls)
107
+
108
+ # Run inference
109
+ with torch.no_grad():
110
+ outputs = predictor.model(**inputs)
111
+
112
+ # Post-process masks
113
+ masks = predictor.processor.image_processor.post_process_masks(
114
+ outputs.pred_masks.cpu(),
115
+ inputs["original_sizes"].cpu(),
116
+ inputs["reshaped_input_sizes"].cpu(),
117
+ )
118
+ masks = masks[0][:,:1,...].cpu().numpy()
119
+
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+
123
+ return [(masks, 'final_mask')]
app_3rd/spatrack_utils/infer_track.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.SpaTrackV2.models.predictor import Predictor
2
+ import yaml
3
+ import easydict
4
+ import os
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ import torchvision.transforms as T
9
+ from PIL import Image
10
+ import io
11
+ import moviepy.editor as mp
12
+ from models.SpaTrackV2.utils.visualizer import Visualizer
13
+ import tqdm
14
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
15
+ import glob
16
+ from rich import print
17
+ import argparse
18
+ import decord
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ config = {
22
+ "ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
23
+ "cfg_dir": "config/magic_infer_moge.yaml",
24
+ }
25
+
26
+ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
27
+ """
28
+ Initialize and return the tracker predictor and visualizer
29
+ Args:
30
+ output_dir: Directory to save visualization results
31
+ vo_points: Number of points for visual odometry
32
+ Returns:
33
+ Tuple of (tracker_predictor, visualizer)
34
+ """
35
+ viz = True
36
+ os.makedirs(output_dir, exist_ok=True)
37
+
38
+ with open(config["cfg_dir"], "r") as f:
39
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
40
+ cfg = easydict.EasyDict(cfg)
41
+ cfg.out_dir = output_dir
42
+ cfg.model.track_num = vo_points
43
+
44
+ # Check if it's a local path or HuggingFace repo
45
+ if tracker_model is not None:
46
+ model = tracker_model
47
+ model.spatrack.track_num = vo_points
48
+ else:
49
+ if os.path.exists(config["ckpt_dir"]):
50
+ # Local file
51
+ model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
52
+ else:
53
+ # HuggingFace repo - download the model
54
+ print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
55
+ checkpoint_path = hf_hub_download(
56
+ repo_id=config["ckpt_dir"],
57
+ repo_type="model",
58
+ filename="SpaTrack3_offline.pth"
59
+ )
60
+ model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
61
+ model.eval()
62
+ model.to("cuda")
63
+
64
+ viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
65
+ fps=10, pad_value=0, tracks_leave_trace=5)
66
+
67
+ return model, viser
68
+
69
+ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
70
+ """
71
+ Run tracking on a video sequence
72
+ Args:
73
+ model: Tracker predictor instance
74
+ viser: Visualizer instance
75
+ temp_dir: Directory containing temporary files
76
+ video_name: Name of the video file (without extension)
77
+ grid_size: Size of the tracking grid
78
+ vo_points: Number of points for visual odometry
79
+ fps: Frames per second for visualization
80
+ """
81
+ # Setup paths
82
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
83
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
84
+ out_dir = os.path.join(temp_dir, "results")
85
+ os.makedirs(out_dir, exist_ok=True)
86
+
87
+ # Load video using decord
88
+ video_reader = decord.VideoReader(video_path)
89
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
90
+
91
+ # resize make sure the shortest side is 336
92
+ h, w = video_tensor.shape[2:]
93
+ scale = max(336 / h, 336 / w)
94
+ if scale < 1:
95
+ new_h, new_w = int(h * scale), int(w * scale)
96
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
97
+ video_tensor = video_tensor[::fps].float()
98
+ depth_tensor = None
99
+ intrs = None
100
+ extrs = None
101
+ data_npz_load = {}
102
+
103
+ # Load and process mask
104
+ if os.path.exists(mask_path):
105
+ mask = cv2.imread(mask_path)
106
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
107
+ mask = mask.sum(axis=-1)>0
108
+ else:
109
+ mask = np.ones_like(video_tensor[0,0].numpy())>0
110
+
111
+ # Get frame dimensions and create grid points
112
+ frame_H, frame_W = video_tensor.shape[2:]
113
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
114
+
115
+ # Sample mask values at grid points and filter out points where mask=0
116
+ if os.path.exists(mask_path):
117
+ grid_pts_int = grid_pts[0].long()
118
+ mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
119
+ grid_pts = grid_pts[:, mask_values]
120
+
121
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
122
+
123
+ # run vggt
124
+ if os.environ.get("VGGT_DIR", None) is not None:
125
+ vggt_model = VGGT()
126
+ vggt_model.load_state_dict(torch.load(VGGT_DIR))
127
+ vggt_model.eval()
128
+ vggt_model = vggt_model.to("cuda")
129
+ # process the image tensor
130
+ video_tensor = preprocess_image(video_tensor)[None]
131
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
132
+ # Predict attributes including cameras, depth maps, and point maps.
133
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
134
+ pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
135
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
136
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
137
+ # Predict Depth Maps
138
+ depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
139
+ # clear the cache
140
+ del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
141
+ torch.cuda.empty_cache()
142
+ depth_tensor = depth_map.squeeze().cpu().numpy()
143
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
144
+ extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
145
+ intrs = intrinsic.squeeze().cpu().numpy()
146
+ video_tensor = video_tensor.squeeze()
147
+ #NOTE: 20% of the depth is not reliable
148
+ # threshold = depth_conf.squeeze().view(-1).quantile(0.5)
149
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
150
+
151
+ # Run model inference
152
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
153
+ (
154
+ c2w_traj, intrs, point_map, conf_depth,
155
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
156
+ ) = model.forward(video_tensor, depth=depth_tensor,
157
+ intrs=intrs, extrs=extrs,
158
+ queries=query_xyt,
159
+ fps=1, full_point=False, iters_track=4,
160
+ query_no_BA=True, fixed_cam=False, stage=1,
161
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
162
+
163
+ # Resize results to avoid too large I/O Burden
164
+ max_size = 336
165
+ h, w = video.shape[2:]
166
+ scale = min(max_size / h, max_size / w)
167
+ if scale < 1:
168
+ new_h, new_w = int(h * scale), int(w * scale)
169
+ video = T.Resize((new_h, new_w))(video)
170
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
171
+ point_map = T.Resize((new_h, new_w))(point_map)
172
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
173
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
174
+ if depth_tensor is not None:
175
+ depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
176
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
177
+
178
+ # Visualize tracks
179
+ viser.visualize(video=video[None],
180
+ tracks=track2d_pred[None][...,:2],
181
+ visibility=vis_pred[None],filename="test")
182
+
183
+ # Save in tapip3d format
184
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
185
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
186
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
187
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
188
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
189
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
190
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
191
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
192
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
193
+
194
+ print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
config/__init__.py ADDED
File without changes
config/magic_infer_moge.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ # config the hydra logger, only in hydra `$` can be decoded as cite
3
+ data: ./assets/room
4
+ vis_track: false
5
+ hydra:
6
+ run:
7
+ dir: .
8
+ output_subdir: null
9
+ job_logging: {}
10
+ hydra_logging: {}
11
+ mixed_precision: bf16
12
+ visdom:
13
+ viz_ip: "localhost"
14
+ port: 6666
15
+ relax_load: false
16
+ res_all: 336
17
+ # config the ckpt path
18
+ # ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
19
+ ckpts: "Yuxihenry/SpatialTracker_Files"
20
+ batch_size: 1
21
+ input:
22
+ type: image
23
+ fps: 1
24
+ model_wind_size: 32
25
+ model:
26
+ backbone_cfg:
27
+ ckpt_dir: "checkpoints/model.pt"
28
+ chunk_size: 24 # downsample factor for patchified features
29
+ ckpt_fwd: true
30
+ ft_cfg:
31
+ mode: "fix"
32
+ paras_name: []
33
+ resolution: 336
34
+ max_len: 512
35
+ Track_cfg:
36
+ base_ckpt: "checkpoints/scaled_offline.pth"
37
+ base:
38
+ stride: 4
39
+ corr_radius: 3
40
+ window_len: 60
41
+ stablizer: True
42
+ mode: "online"
43
+ s_wind: 200
44
+ overlap: 4
45
+ track_num: 0
46
+
47
+ dist_train:
48
+ num_nodes: 1
examples/backpack.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
3
+ size 1208738
examples/ball.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31f6e3bf875a85284b376c05170b4c08b546b7d5e95106848b1e3818a9d0db91
3
+ size 3030268
examples/basketball.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0df3b429d5fd64c298f2d79b2d818a4044e7341a71d70b957f60b24e313c3760
3
+ size 2487837
examples/biker.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fba880c24bdb8fa3b84b1b491d52f2c1f426fb09e34c3013603e5a549cf3b22b
3
+ size 249196
examples/cinema_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a68a5643c14f61c05d48e25a98ddf5cf0344d3ffcda08ad4a0adc989d49d7a9c
3
+ size 1774022
examples/cinema_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99624e2d0fb2e9f994e46aefb904e884de37a6d78e7f6b6670e286eaa397e515
3
+ size 2370749
examples/drifting.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f3937871117d3cc5d7da3ef31d1edf5626fc8372126b73590f75f05713fe97c
3
+ size 4695804
examples/ego_kc1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22fe64e458e329e8b3c3e20b3725ffd85c3a2e725fd03909cf883d3fd02c80b3
3
+ size 1365980
examples/ego_teaser.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8780b291b48046b1c7dea90712c1c3f59d60c03216df1c489f6f03e8d61fae5c
3
+ size 7365665
examples/handwave.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6dde7cf4ffa7c66b6861bb5abdedc49dfc4b5b4dd9dd46ee8415dd4953935b6
3
+ size 2099369
examples/hockey.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c3be095777b442dc401e7d1f489b749611ffade3563a01e4e3d1e511311bd86
3
+ size 1795810
examples/ken_block_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b788faeb4d3206fa604d622a05268f1321ad6a229178fe12319d20c9438deb1
3
+ size 196343
examples/kiss.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f78fffc5108d95d4e5837d7607226f3dd9796615ea3481f2629c69ccd2ccb12f
3
+ size 1073570
examples/kitchen.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3120e942a9b3d7b300928e43113b000fb5ccc209012a2c560ec26b8a04c2d5f9
3
+ size 543970
examples/kitchen_egocentric.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5468ab10d8d39b68b51fa616adc3d099dab7543e38dd221a0a7a20a2401824a2
3
+ size 2176685
examples/pillow.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
3
+ size 1407147
examples/protein.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2dc9cfceb0984b61ebc62fda4c826135ebe916c8c966a8123dcc3315d43b73f
3
+ size 2002300
examples/pusht.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:996d1923e36811a1069e4d6b5e8c0338d9068c0870ea09c4c04e13e9fbcd207a
3
+ size 5256495
examples/robot1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a3b9e4449572129fdd96a751938e211241cdd86bcc56ffd33bfd23fc4d6e9c0
3
+ size 1178671
examples/robot2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:188b2d8824ce345c86a603bff210639a6158d72cf6119cc1d3f79d409ac68bb3
3
+ size 867261
examples/robot_3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:784a0f9c36a316d0da5745075dbc8cefd9ce60c25b067d3d80a1d52830df8a37
3
+ size 1153015
examples/robot_unitree.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99bc274f7613a665c6135085fe01691ebfaa9033101319071f37c550ab21d1ea
3
+ size 1964268
examples/running.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ceb96b287fefb1c090dcd2f5db7634f808d2079413500beeb7b33023dfae51b
3
+ size 7307897
examples/teleop2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59ea006a18227da8cf5db1fa50cd48e71ec7eb66fef48ea2158c325088bd9fee
3
+ size 1077267
examples/vertical_place.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c8061ae449f986113c2ecb17aefc2c13f737aecbcd41d6c057c88e6d41ac3ee
3
+ size 719810
models/SpaTrackV2/models/SpaTrack.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python
2
+ """
3
+ SpaTrackerV2, which is an unified model to estimate 'intrinsic',
4
+ 'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
5
+
6
+ Contact: DM [email protected]
7
+ """
8
+
9
+ import os
10
+ import numpy as np
11
+ from typing import Literal, Union, List, Tuple, Dict
12
+ import cv2
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ # from depth anything v2
17
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
18
+ from einops import rearrange
19
+ from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
20
+ from models.moge.model.v1 import MoGeModel
21
+ import copy
22
+ from functools import partial
23
+ from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
24
+ import kornia
25
+ from models.SpaTrackV2.utils.model_utils import sample_features5d
26
+ import utils3d
27
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
28
+ from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
29
+ import random
30
+
31
+ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
32
+ def __init__(
33
+ self,
34
+ loggers: list, # include [ viz, logger_tf, logger]
35
+ backbone_cfg,
36
+ Track_cfg=None,
37
+ chunk_size=24,
38
+ ckpt_fwd: bool = False,
39
+ ft_cfg=None,
40
+ resolution=518,
41
+ max_len=600, # the maximum video length we can preprocess,
42
+ track_num=768,
43
+ ):
44
+
45
+ self.chunk_size = chunk_size
46
+ self.max_len = max_len
47
+ self.resolution = resolution
48
+ # config the T-Lora Dinov2
49
+ #NOTE: initial the base model
50
+ base_cfg = copy.deepcopy(backbone_cfg)
51
+ backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
52
+
53
+ super(SpaTrack2, self).__init__()
54
+ if os.path.exists(backbone_ckpt_dir)==False:
55
+ base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
56
+ else:
57
+ checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
58
+ base_model = MoGeModel(**checkpoint["model_config"])
59
+ base_model.load_state_dict(checkpoint['model'])
60
+ # avoid the base_model is a member of SpaTrack2
61
+ object.__setattr__(self, 'base_model', base_model)
62
+
63
+ # Tracker model
64
+ self.Track3D = TrackRefiner3D(Track_cfg)
65
+ track_base_ckpt_dir = Track_cfg.base_ckpt
66
+ if os.path.exists(track_base_ckpt_dir):
67
+ track_pretrain = torch.load(track_base_ckpt_dir)
68
+ self.Track3D.load_state_dict(track_pretrain, strict=False)
69
+
70
+ # wrap the function of make lora trainable
71
+ self.make_paras_trainable = partial(self.make_paras_trainable,
72
+ mode=ft_cfg.mode,
73
+ paras_name=ft_cfg.paras_name)
74
+ self.track_num = track_num
75
+
76
+ def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
77
+ # gradient required for the lora_experts and gate
78
+ for name, param in self.named_parameters():
79
+ if any(x in name for x in paras_name):
80
+ if mode == 'fix':
81
+ param.requires_grad = False
82
+ else:
83
+ param.requires_grad = True
84
+ else:
85
+ if mode == 'fix':
86
+ param.requires_grad = True
87
+ else:
88
+ param.requires_grad = False
89
+ total_params = sum(p.numel() for p in self.parameters())
90
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
91
+ print(f"Total parameters: {total_params}")
92
+ print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
93
+
94
+ def ProcVid(self,
95
+ x: torch.Tensor):
96
+ """
97
+ split the video into several overlapped windows.
98
+
99
+ args:
100
+ x: the input video frames. [B, T, C, H, W]
101
+ outputs:
102
+ patch_size: the patch size of the video features
103
+ raises:
104
+ ValueError: if the input video is longer than `max_len`.
105
+
106
+ """
107
+ # normalize the input images
108
+ num_types = x.dtype
109
+ x = normalize_rgb(x, input_size=self.resolution)
110
+ x = x.to(num_types)
111
+ # get the video features
112
+ B, T, C, H, W = x.size()
113
+ if T > self.max_len:
114
+ raise ValueError(f"the video length should no more than {self.max_len}.")
115
+ # get the video features
116
+ patch_h, patch_w = H // 14, W // 14
117
+ patch_size = (patch_h, patch_w)
118
+ # resize and get the video features
119
+ x = x.view(B * T, C, H, W)
120
+ # operate the temporal encoding
121
+ return patch_size, x
122
+
123
+ def forward_stream(
124
+ self,
125
+ video: torch.Tensor,
126
+ queries: torch.Tensor = None,
127
+ T_org: int = None,
128
+ depth: torch.Tensor|np.ndarray|str=None,
129
+ unc_metric_in: torch.Tensor|np.ndarray|str=None,
130
+ intrs: torch.Tensor|np.ndarray|str=None,
131
+ extrs: torch.Tensor|np.ndarray|str=None,
132
+ queries_3d: torch.Tensor = None,
133
+ window_len: int = 16,
134
+ overlap_len: int = 4,
135
+ full_point: bool = False,
136
+ track2d_gt: torch.Tensor = None,
137
+ fixed_cam: bool = False,
138
+ query_no_BA: bool = False,
139
+ stage: int = 0,
140
+ support_frame: int = 0,
141
+ replace_ratio: float = 0.6,
142
+ annots_train: Dict = None,
143
+ iters_track=4,
144
+ **kwargs,
145
+ ):
146
+ # step 1 allocate the query points on the grid
147
+ T, C, H, W = video.shape
148
+
149
+ if annots_train is not None:
150
+ vis_gt = annots_train["vis"]
151
+ _, _, N = vis_gt.shape
152
+ number_visible = vis_gt.sum(dim=1)
153
+ ratio_rand = torch.rand(1, N, device=vis_gt.device)
154
+ first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
155
+ assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
156
+
157
+ first_positive_inds = first_positive_inds.long()
158
+ gather = torch.gather(
159
+ annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
160
+ )
161
+ xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
162
+ queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
163
+
164
+
165
+ # Unfold video into segments of window_len with overlap_len
166
+ step_slide = window_len - overlap_len
167
+ if T < window_len:
168
+ video_unf = video.unsqueeze(0)
169
+ if depth is not None:
170
+ depth_unf = depth.unsqueeze(0)
171
+ else:
172
+ depth_unf = None
173
+ if unc_metric_in is not None:
174
+ unc_metric_unf = unc_metric_in.unsqueeze(0)
175
+ else:
176
+ unc_metric_unf = None
177
+ if intrs is not None:
178
+ intrs_unf = intrs.unsqueeze(0)
179
+ else:
180
+ intrs_unf = None
181
+ if extrs is not None:
182
+ extrs_unf = extrs.unsqueeze(0)
183
+ else:
184
+ extrs_unf = None
185
+ else:
186
+ video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
187
+ if depth is not None:
188
+ depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
189
+ intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
190
+ else:
191
+ depth_unf = None
192
+ intrs_unf = None
193
+ if extrs is not None:
194
+ extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
195
+ else:
196
+ extrs_unf = None
197
+ if unc_metric_in is not None:
198
+ unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
199
+ else:
200
+ unc_metric_unf = None
201
+
202
+ # parallel
203
+ # Get number of segments
204
+ B = video_unf.shape[0]
205
+ #TODO: Process each segment in parallel using torch.nn.DataParallel
206
+ c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
207
+ intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
208
+ point_map = torch.zeros(T, 3, H, W).cuda()
209
+ unc_metric = torch.zeros(T, H, W).cuda()
210
+ # set the queries
211
+ N, _ = queries.shape
212
+ track3d_pred = torch.zeros(T, N, 6).cuda()
213
+ track2d_pred = torch.zeros(T, N, 3).cuda()
214
+ vis_pred = torch.zeros(T, N, 1).cuda()
215
+ conf_pred = torch.zeros(T, N, 1).cuda()
216
+ dyn_preds = torch.zeros(T, N, 1).cuda()
217
+ # sort the queries by time
218
+ sorted_indices = np.argsort(queries[...,0])
219
+ sorted_inv_indices = np.argsort(sorted_indices)
220
+ sort_query = queries[sorted_indices]
221
+ sort_query = torch.from_numpy(sort_query).cuda()
222
+ if queries_3d is not None:
223
+ sort_query_3d = queries_3d[sorted_indices]
224
+ sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
225
+
226
+ queries_len = 0
227
+ overlap_d = None
228
+ cache = None
229
+ loss = 0.0
230
+
231
+ for i in range(B):
232
+ segment = video_unf[i:i+1].cuda()
233
+ # Forward pass through model
234
+ # detect the key points for each frames
235
+
236
+ queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
237
+ if queries_3d is not None:
238
+ queries_new_3d = sort_query_3d[queries_new_mask]
239
+ queries_new_3d = queries_new_3d.float()
240
+ else:
241
+ queries_new_3d = None
242
+ queries_new = sort_query[queries_new_mask.bool()]
243
+ queries_new = queries_new.float()
244
+ if i > 0:
245
+ overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
246
+ overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
247
+ overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
248
+ overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
249
+ overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
250
+ overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
251
+ overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
252
+ queries_new[...,0] -= i*step_slide
253
+ queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
254
+
255
+ if annots_train is None:
256
+ annots = {}
257
+ else:
258
+ annots = copy.deepcopy(annots_train)
259
+ annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
260
+ annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
261
+ annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
262
+ annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
263
+ annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
264
+ annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
265
+
266
+ if depth is not None:
267
+ annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
268
+ if unc_metric_in is not None:
269
+ annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
270
+ if intrs is not None:
271
+ intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
272
+ focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
273
+ pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
274
+ pose_fake[:, -1] = focal
275
+ pose_fake[:,3]=1
276
+ annots["intrs_gt"] = intr_seg
277
+ if extrs is not None:
278
+ extrs_unf_norm = extrs_unf[i:i+1][0].clone()
279
+ extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
280
+ rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
281
+ annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
282
+ annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
283
+ annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
284
+ annots["use_extr"] = True
285
+
286
+ kwargs.update({"stage": stage})
287
+
288
+ #TODO: DEBUG
289
+ out = self.forward(segment, pts_q=queries_new,
290
+ pts_q_3d=queries_new_3d, overlap_d=overlap_d,
291
+ full_point=full_point,
292
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA,
293
+ support_frame=segment.shape[1]-1,
294
+ cache=cache, replace_ratio=replace_ratio,
295
+ iters_track=iters_track,
296
+ **kwargs, annots=annots)
297
+ if self.training:
298
+ loss += out["loss"].squeeze()
299
+ # from models.SpaTrackV2.utils.visualizer import Visualizer
300
+ # vis_track = Visualizer(grayscale=False,
301
+ # fps=10, pad_value=50, tracks_leave_trace=0)
302
+ # vis_track.visualize(video=segment,
303
+ # tracks=out["traj_est"][...,:2],
304
+ # visibility=out["vis_est"],
305
+ # save_video=True)
306
+ # # visualize 4d
307
+ # import os, json
308
+ # import os.path as osp
309
+ # viser4d_dir = os.path.join("viser_4d_results")
310
+ # os.makedirs(viser4d_dir, exist_ok=True)
311
+ # depth_est = annots["depth_gt"][0]
312
+ # unc_metric = out["unc_metric"]
313
+ # mask = (unc_metric > 0.5).squeeze(1)
314
+ # # pose_est = out["poses_pred"].squeeze(0)
315
+ # pose_est = annots["traj_mat"][0]
316
+ # rgb_tracks = out["rgb_tracks"].squeeze(0)
317
+ # intrinsics = out["intrs"].squeeze(0)
318
+ # for i_k in range(out["depth"].shape[0]):
319
+ # img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
320
+ # img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
321
+ # cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
322
+ # if stage == 1:
323
+ # depth = depth_est[i_k].squeeze().cpu().numpy()
324
+ # np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
325
+ # else:
326
+ # point_map_vis = out["points_map"][i_k].cpu().numpy()
327
+ # np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
328
+ # np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
329
+ # np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
330
+ # np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
331
+ # np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
332
+
333
+ queries_len = len(queries_new)
334
+ # update the track3d and track2d
335
+ left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
336
+ track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
337
+ track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
338
+ vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
339
+ conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
340
+ dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
341
+
342
+ # process the output for each segment
343
+ seg_c2w = out["poses_pred"][0]
344
+ seg_intrs = out["intrs"][0]
345
+ seg_point_map = out["points_map"]
346
+ seg_conf_depth = out["unc_metric"]
347
+
348
+ # cache management
349
+ cache = out["cache"]
350
+ for k in cache.keys():
351
+ if "_pyramid" in k:
352
+ for j in range(len(cache[k])):
353
+ if len(cache[k][j].shape) == 5:
354
+ cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
355
+ elif len(cache[k][j].shape) == 4:
356
+ cache[k][j] = cache[k][j][:,:1,:queries_len,:]
357
+ elif "_pred_cache" in k:
358
+ cache[k] = cache[k][-overlap_len:,:queries_len,:]
359
+ else:
360
+ cache[k] = cache[k][-overlap_len:]
361
+
362
+ # update the results
363
+ idx_glob = i * step_slide
364
+ # refine part
365
+ # mask_update = sort_query[..., 0] < i * step_slide + window_len
366
+ # sort_query_pick = sort_query[mask_update]
367
+ intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
368
+ point_map[idx_glob:idx_glob+window_len] = seg_point_map
369
+ unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
370
+ # update the camera poses
371
+
372
+ # if using the ground truth pose
373
+ # if extrs_unf is not None:
374
+ # c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
375
+ # else:
376
+ prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
377
+ c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
378
+
379
+ track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
380
+ track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
381
+ vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
382
+ conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
383
+ dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
384
+ unc_metric = unc_metric[:T_org,:]
385
+ point_map = point_map[:T_org,:]
386
+ intrs_out = intrs_out[:T_org,:]
387
+ c2w_traj = c2w_traj[:T_org,:]
388
+ if self.training:
389
+ ret = {
390
+ "loss": loss,
391
+ "depth_loss": 0.0,
392
+ "ab_loss": 0.0,
393
+ "vis_loss": out["vis_loss"],
394
+ "track_loss": out["track_loss"],
395
+ "conf_loss": out["conf_loss"],
396
+ "dyn_loss": out["dyn_loss"],
397
+ "sync_loss": out["sync_loss"],
398
+ "poses_pred": c2w_traj[None],
399
+ "intrs": intrs_out[None],
400
+ "points_map": point_map,
401
+ "track3d_pred": track3d_pred[None],
402
+ "rgb_tracks": track3d_pred[None],
403
+ "track2d_pred": track2d_pred[None],
404
+ "traj_est": track2d_pred[None],
405
+ "vis_est": vis_pred[None], "conf_pred": conf_pred[None],
406
+ "dyn_preds": dyn_preds[None],
407
+ "imgs_raw": video[None],
408
+ "unc_metric": unc_metric,
409
+ }
410
+
411
+ return ret
412
+ else:
413
+ return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
414
+ def forward(self,
415
+ x: torch.Tensor,
416
+ annots: Dict = {},
417
+ pts_q: torch.Tensor = None,
418
+ pts_q_3d: torch.Tensor = None,
419
+ overlap_d: torch.Tensor = None,
420
+ full_point = False,
421
+ fixed_cam = False,
422
+ support_frame = 0,
423
+ query_no_BA = False,
424
+ cache = None,
425
+ replace_ratio = 0.6,
426
+ iters_track=4,
427
+ **kwargs):
428
+ """
429
+ forward the video camera model, which predict (
430
+ `intr` `camera poses` `video depth`
431
+ )
432
+
433
+ args:
434
+ x: the input video frames. [B, T, C, H, W]
435
+ annots: the annotations for video frames.
436
+ {
437
+ "poses_gt": the pose encoding for the video frames. [B, T, 7]
438
+ "depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
439
+ "metric": bool, whether to calculate the metric for the video frames.
440
+ }
441
+ """
442
+ self.support_frame = support_frame
443
+
444
+ #TODO: to adjust a little bit
445
+ track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
446
+ B, T, _, H, W = x.shape
447
+ imgs_raw = x.clone()
448
+ # get the video split and features for each segment
449
+ patch_size, x_resize = self.ProcVid(x)
450
+ x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
451
+ H_resize, W_resize = x_resize.shape[-2:]
452
+
453
+ prec_fx = W / W_resize
454
+ prec_fy = H / H_resize
455
+ # get patch size
456
+ P_H, P_W = patch_size
457
+
458
+ # get the depth, pointmap and mask
459
+ #TODO: Release DepthAnything Version
460
+ points_map_gt = None
461
+ with torch.no_grad():
462
+ if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
463
+ if if_gt_depth==False:
464
+ if cache is not None:
465
+ T_cache = cache["points_map"].shape[0]
466
+ T_new = T - T_cache
467
+ x_resize_new = x_resize[:, T_cache:]
468
+ else:
469
+ T_new = T
470
+ x_resize_new = x_resize
471
+ # infer with chunk
472
+ chunk_size = self.chunk_size
473
+ metric_depth = []
474
+ intrs = []
475
+ unc_metric = []
476
+ mask = []
477
+ points_map = []
478
+ normals = []
479
+ normals_mask = []
480
+ for i in range(0, B*T_new, chunk_size):
481
+ output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
482
+ metric_depth.append(output['depth'])
483
+ intrs.append(output['intrinsics'])
484
+ unc_metric.append(output['mask_prob'])
485
+ mask.append(output['mask'])
486
+ points_map.append(output['points'])
487
+ normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
488
+ normals.append(normals_i)
489
+ normals_mask.append(normals_mask_i)
490
+
491
+ metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
492
+ intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
493
+ intrs[:,:,0,:] *= W_resize
494
+ intrs[:,:,1,:] *= H_resize
495
+ # points_map = torch.cat(points_map, dim=0)
496
+ mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
497
+ # cat the normals
498
+ normals = torch.cat(normals, dim=0)
499
+ normals_mask = torch.cat(normals_mask, dim=0)
500
+
501
+ metric_depth = metric_depth.clone()
502
+ metric_depth[metric_depth == torch.inf] = 0
503
+ _depths = metric_depth[metric_depth > 0].reshape(-1)
504
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
505
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
506
+ iqr = q75 - q25
507
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
508
+ _depth_roi = torch.tensor(
509
+ [1e-1, upper_bound.item()],
510
+ dtype=metric_depth.dtype,
511
+ device=metric_depth.device
512
+ )
513
+ mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
514
+ mask = mask * mask_roi
515
+ mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
516
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
517
+ unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
518
+ unc_metric *= mask
519
+ if full_point:
520
+ unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
521
+ if cache is not None:
522
+ assert B==1, "only support batch size 1 right now."
523
+ unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
524
+ intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
525
+ points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
526
+ metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
527
+
528
+ if "poses_gt" in annots.keys():
529
+ intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
530
+ H_resize, W_resize, self.resolution)
531
+ else:
532
+ c2w_traj_gt = None
533
+
534
+ if "intrs_gt" in annots.keys():
535
+ intrs = annots["intrs_gt"].view(B, T, 3, 3)
536
+ fx_factor = W_resize / W
537
+ fy_factor = H_resize / H
538
+ intrs[:,:,0,:] *= fx_factor
539
+ intrs[:,:,1,:] *= fy_factor
540
+
541
+ if "depth_gt" in annots.keys():
542
+
543
+ metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
544
+ metric_depth_gt = F.interpolate(metric_depth_gt,
545
+ size=(H_resize, W_resize), mode='nearest')
546
+
547
+ _depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
548
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
549
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
550
+ iqr = q75 - q25
551
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
552
+ _depth_roi = torch.tensor(
553
+ [1e-1, upper_bound.item()],
554
+ dtype=metric_depth_gt.dtype,
555
+ device=metric_depth_gt.device
556
+ )
557
+ mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
558
+ # if (upper_bound > 200).any():
559
+ # import pdb; pdb.set_trace()
560
+ if (kwargs.get('stage', 0) == 2):
561
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
562
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
563
+ else:
564
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
565
+ unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
566
+ # filter the sky
567
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
568
+ if "unc_metric" in annots.keys():
569
+ unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
570
+ size=(H_resize, W_resize), mode='nearest')
571
+ unc_metric = unc_metric * unc_metric_
572
+ if if_gt_depth:
573
+ points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
574
+ metric_depth = metric_depth_gt
575
+ points_map_gt = points_map
576
+ else:
577
+ points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
578
+
579
+ # track the 3d points
580
+ ret_track = None
581
+ regular_track = True
582
+ dyn_preds, final_tracks = None, None
583
+
584
+ if "use_extr" in annots.keys():
585
+ init_pose = True
586
+ else:
587
+ init_pose = False
588
+ # set the custom vid and valid only
589
+ custom_vid = annots.get("custom_vid", False)
590
+ valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
591
+ if self.training:
592
+ if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
593
+ traj3d = annots['traj_3d']
594
+ if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
595
+ support_pts_q = get_track_points(H_resize, W_resize,
596
+ T, x.device, query_size=self.track_num // 2,
597
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
598
+ else:
599
+ support_pts_q = get_track_points(H_resize, W_resize,
600
+ T, x.device, query_size=random.randint(32, 256),
601
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
602
+ if pts_q is not None:
603
+ pts_q = pts_q[None,None]
604
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
605
+ metric_depth,
606
+ unc_metric.detach(), points_map, pts_q,
607
+ intrs=intrs.clone(), cache=cache,
608
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
609
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
610
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
611
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
612
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
613
+ else:
614
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
615
+ metric_depth,
616
+ unc_metric.detach(), points_map, traj3d[..., :2],
617
+ intrs=intrs.clone(), cache=cache,
618
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
619
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
620
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
621
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
622
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
623
+ regular_track = False
624
+
625
+
626
+ if regular_track:
627
+ if pts_q is None:
628
+ pts_q = get_track_points(H_resize, W_resize,
629
+ T, x.device, query_size=self.track_num,
630
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
631
+ support_pts_q = None
632
+ else:
633
+ pts_q = pts_q[None,None]
634
+ # resize the query points
635
+ pts_q[...,1] *= W_resize / W
636
+ pts_q[...,2] *= H_resize / H
637
+
638
+ if pts_q_3d is not None:
639
+ pts_q_3d = pts_q_3d[None,None]
640
+ # resize the query points
641
+ pts_q_3d[...,1] *= W_resize / W
642
+ pts_q_3d[...,2] *= H_resize / H
643
+ else:
644
+ # adjust the query with uncertainty
645
+ if (full_point==False) and (overlap_d is None):
646
+ pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
647
+ pts_q = pts_q[:,:,pts_q_unc>0.5,:]
648
+ if (pts_q_unc<0.5).sum() > 0:
649
+ # pad the query points
650
+ pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
651
+ # pick the random indices
652
+ indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
653
+ pad_pts = indices
654
+ pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
655
+
656
+ support_pts_q = get_track_points(H_resize, W_resize,
657
+ T, x.device, query_size=self.track_num,
658
+ support_frame=self.support_frame,
659
+ unc_metric=unc_metric, mode="mixed")[None]
660
+
661
+ points_map[points_map>1e3] = 0
662
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
663
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
664
+ metric_depth,
665
+ unc_metric.detach(), points_map, pts_q,
666
+ pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
667
+ overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
668
+ prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
669
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
670
+ stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
671
+ intrs = intrs_org
672
+ points_map = point_map_org_refined
673
+ c2w_traj = ret_track["cam_pred"]
674
+
675
+ if ret_track is not None:
676
+ if ret_track["loss"] is not None:
677
+ track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
678
+
679
+ # update the cache
680
+ cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
681
+ # output
682
+ depth = F.interpolate(metric_depth,
683
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
684
+ points_map = F.interpolate(points_map,
685
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
686
+ unc_metric = F.interpolate(unc_metric,
687
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
688
+
689
+ if self.training:
690
+
691
+ loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
692
+ ret = {"loss": loss,
693
+ "depth_loss": point_map_loss,
694
+ "ab_loss": (scale_loss + shift_loss)*50,
695
+ "vis_loss": vis_loss, "track_loss": track_loss,
696
+ "poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
697
+ "imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
698
+ "depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
699
+ "sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
700
+ }
701
+
702
+ else:
703
+
704
+ if ret_track is not None:
705
+ traj_est = ret_track['preds']
706
+ traj_est[..., 0] *= W / W_resize
707
+ traj_est[..., 1] *= H / H_resize
708
+ vis_est = ret_track['vis_pred']
709
+ else:
710
+ traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
711
+ vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
712
+
713
+ if intrs is not None:
714
+ intrs[..., 0, :] *= W / W_resize
715
+ intrs[..., 1, :] *= H / H_resize
716
+ ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
717
+ "depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
718
+ "rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
719
+ "conf_pred": ret_track['conf_pred'], "cache": cache,
720
+ }
721
+
722
+ return ret
723
+
724
+
725
+
726
+
727
+ # three stages of training
728
+
729
+ # stage 1:
730
+ # gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
731
+ # Tracking and Pose as well -> based on gt depth and intrinsics
732
+ # (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
733
+
734
+ # stage 2: fixed 3D tracking
735
+ # Joint depth refiner
736
+ # input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
737
+ # estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
738
+ # ongoing two days
739
+
740
+ # stage 3: train multi windows by propagation
741
+ # 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
742
+
743
+ # types of scenarioes:
744
+ # 1. auto driving (waymo open dataset)
745
+ # 2. robot
746
+ # 3. internet ego video
747
+
748
+
749
+
750
+ # Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
751
+ # Update Variables:
752
+ # 1. 3D tracks B T N 3 xyz.
753
+ # 2. 2D tracks B T N 2 x y.
754
+ # 3. Dynamic Mask B T H W.
755
+ # 4. Camera Pose B T 4 4.
756
+ # 5. Video Depth.
757
+
758
+ # (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
759
+ # Campatiablity by product.
models/SpaTrackV2/models/__init__.py ADDED
File without changes
models/SpaTrackV2/models/blocks.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import autocast
11
+ from einops import rearrange
12
+ import collections
13
+ from functools import partial
14
+ from itertools import repeat
15
+ import torchvision.models as tvm
16
+ from torch.utils.checkpoint import checkpoint
17
+ from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
18
+ from typing import Union, Tuple
19
+ from torch import Tensor
20
+
21
+ # From PyTorch internals
22
+ def _ntuple(n):
23
+ def parse(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ return tuple(repeat(x, n))
27
+
28
+ return parse
29
+
30
+
31
+ def exists(val):
32
+ return val is not None
33
+
34
+
35
+ def default(val, d):
36
+ return val if exists(val) else d
37
+
38
+
39
+ to_2tuple = _ntuple(2)
40
+
41
+ class LayerScale(nn.Module):
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ init_values: Union[float, Tensor] = 1e-5,
46
+ inplace: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.inplace = inplace
50
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
51
+
52
+ def forward(self, x: Tensor) -> Tensor:
53
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
54
+
55
+ class Mlp(nn.Module):
56
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
57
+
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ act_layer=nn.GELU,
64
+ norm_layer=None,
65
+ bias=True,
66
+ drop=0.0,
67
+ use_conv=False,
68
+ ):
69
+ super().__init__()
70
+ out_features = out_features or in_features
71
+ hidden_features = hidden_features or in_features
72
+ bias = to_2tuple(bias)
73
+ drop_probs = to_2tuple(drop)
74
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
75
+
76
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
77
+ self.act = act_layer()
78
+ self.drop1 = nn.Dropout(drop_probs[0])
79
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
80
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
81
+ self.drop2 = nn.Dropout(drop_probs[1])
82
+
83
+ def forward(self, x):
84
+ x = self.fc1(x)
85
+ x = self.act(x)
86
+ x = self.drop1(x)
87
+ x = self.fc2(x)
88
+ x = self.drop2(x)
89
+ return x
90
+
91
+ class Attention(nn.Module):
92
+ def __init__(self, query_dim, context_dim=None,
93
+ num_heads=8, dim_head=48, qkv_bias=False, flash=False):
94
+ super().__init__()
95
+ inner_dim = self.inner_dim = dim_head * num_heads
96
+ context_dim = default(context_dim, query_dim)
97
+ self.scale = dim_head**-0.5
98
+ self.heads = num_heads
99
+ self.flash = flash
100
+
101
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
103
+ self.to_out = nn.Linear(inner_dim, query_dim)
104
+
105
+ def forward(self, x, context=None, attn_bias=None):
106
+ B, N1, _ = x.shape
107
+ C = self.inner_dim
108
+ h = self.heads
109
+ q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
110
+ context = default(context, x)
111
+ k, v = self.to_kv(context).chunk(2, dim=-1)
112
+
113
+ N2 = context.shape[1]
114
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
115
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
116
+
117
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
118
+ if self.flash==False:
119
+ sim = (q @ k.transpose(-2, -1)) * self.scale
120
+ if attn_bias is not None:
121
+ sim = sim + attn_bias
122
+ if sim.abs().max()>1e2:
123
+ import pdb; pdb.set_trace()
124
+ attn = sim.softmax(dim=-1)
125
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
126
+ else:
127
+ input_args = [x.contiguous() for x in [q, k, v]]
128
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
129
+
130
+ if self.to_out.bias.dtype != x.dtype:
131
+ x = x.to(self.to_out.bias.dtype)
132
+
133
+ return self.to_out(x)
134
+
135
+
136
+ class VGG19(nn.Module):
137
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
138
+ super().__init__()
139
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
140
+ self.amp = amp
141
+ self.amp_dtype = amp_dtype
142
+
143
+ def forward(self, x, **kwargs):
144
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
145
+ feats = {}
146
+ scale = 1
147
+ for layer in self.layers:
148
+ if isinstance(layer, nn.MaxPool2d):
149
+ feats[scale] = x
150
+ scale = scale*2
151
+ x = layer(x)
152
+ return feats
153
+
154
+ class CNNandDinov2(nn.Module):
155
+ def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
156
+ super().__init__()
157
+ # in case the Internet connection is not stable, please load the DINOv2 locally
158
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
159
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
160
+
161
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
162
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
163
+
164
+
165
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
166
+ self.cnn = VGG19(**cnn_kwargs)
167
+ self.amp = amp
168
+ self.amp_dtype = amp_dtype
169
+ if self.amp:
170
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
171
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
172
+
173
+
174
+ def train(self, mode: bool = True):
175
+ return self.cnn.train(mode)
176
+
177
+ def forward(self, x, upsample = False):
178
+ B,C,H,W = x.shape
179
+ feature_pyramid = self.cnn(x)
180
+
181
+ if not upsample:
182
+ with torch.no_grad():
183
+ if self.dinov2_vitl14[0].device != x.device:
184
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
185
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
186
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
187
+ del dinov2_features_16
188
+ feature_pyramid[16] = features_16
189
+ return feature_pyramid
190
+
191
+ class Dinov2(nn.Module):
192
+ def __init__(self, amp = True, amp_dtype = torch.float16):
193
+ super().__init__()
194
+ # in case the Internet connection is not stable, please load the DINOv2 locally
195
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
196
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
197
+
198
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
199
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
200
+
201
+ self.amp = amp
202
+ self.amp_dtype = amp_dtype
203
+ if self.amp:
204
+ self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
205
+
206
+ def forward(self, x, upsample = False):
207
+ B,C,H,W = x.shape
208
+ mean_ = torch.tensor([0.485, 0.456, 0.406],
209
+ device=x.device).view(1, 3, 1, 1)
210
+ std_ = torch.tensor([0.229, 0.224, 0.225],
211
+ device=x.device).view(1, 3, 1, 1)
212
+ x = (x+1)/2
213
+ x = (x - mean_)/std_
214
+ h_re, w_re = 560, 560
215
+ x_resize = F.interpolate(x, size=(h_re, w_re),
216
+ mode='bilinear', align_corners=True)
217
+ if not upsample:
218
+ with torch.no_grad():
219
+ dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
220
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
221
+ del dinov2_features_16
222
+ features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
223
+ return features_16
224
+
225
+ class AttnBlock(nn.Module):
226
+ """
227
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
228
+ """
229
+
230
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
231
+ flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
232
+ super().__init__()
233
+ self.debug=debug
234
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
235
+ self.flash=flash
236
+
237
+ self.attn = Attention(
238
+ hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
239
+ **block_kwargs
240
+ )
241
+ self.ls = LayerScale(hidden_size, init_values=0.005)
242
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
243
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
244
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
245
+ self.mlp = Mlp(
246
+ in_features=hidden_size,
247
+ hidden_features=mlp_hidden_dim,
248
+ act_layer=approx_gelu,
249
+ )
250
+ self.ckpt_fwd = ckpt_fwd
251
+ def forward(self, x):
252
+ if self.debug:
253
+ print(x.max(), x.min(), x.mean())
254
+ if self.ckpt_fwd:
255
+ x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
256
+ else:
257
+ x = x + self.attn(self.norm1(x))
258
+
259
+ x = x + self.ls(self.mlp(self.norm2(x)))
260
+ return x
261
+
262
+ class CrossAttnBlock(nn.Module):
263
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
264
+ flash=False, ckpt_fwd=False, **block_kwargs):
265
+ super().__init__()
266
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
267
+ self.norm_context = nn.LayerNorm(hidden_size)
268
+
269
+ self.cross_attn = Attention(
270
+ hidden_size, context_dim=context_dim, dim_head=head_dim,
271
+ num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
272
+ )
273
+
274
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
276
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
277
+ self.mlp = Mlp(
278
+ in_features=hidden_size,
279
+ hidden_features=mlp_hidden_dim,
280
+ act_layer=approx_gelu,
281
+ drop=0,
282
+ )
283
+ self.ckpt_fwd = ckpt_fwd
284
+ def forward(self, x, context):
285
+ if self.ckpt_fwd:
286
+ with autocast():
287
+ x = x + checkpoint(self.cross_attn,
288
+ self.norm1(x), self.norm_context(context), use_reentrant=False)
289
+ else:
290
+ with autocast():
291
+ x = x + self.cross_attn(
292
+ self.norm1(x), self.norm_context(context)
293
+ )
294
+ x = x + self.mlp(self.norm2(x))
295
+ return x
296
+
297
+
298
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
299
+ """Wrapper for grid_sample, uses pixel coordinates"""
300
+ H, W = img.shape[-2:]
301
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
302
+ # go to 0,1 then 0,2 then -1,1
303
+ xgrid = 2 * xgrid / (W - 1) - 1
304
+ ygrid = 2 * ygrid / (H - 1) - 1
305
+
306
+ grid = torch.cat([xgrid, ygrid], dim=-1)
307
+ img = F.grid_sample(img, grid, align_corners=True, mode=mode)
308
+
309
+ if mask:
310
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
311
+ return img, mask.float()
312
+
313
+ return img
314
+
315
+
316
+ class CorrBlock:
317
+ def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
318
+ B, S, C, H_prev, W_prev = fmaps.shape
319
+ self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
320
+
321
+ self.num_levels = num_levels
322
+ self.radius = radius
323
+ self.fmaps_pyramid = []
324
+ self.depth_pyramid = []
325
+ self.fmaps_pyramid.append(fmaps)
326
+ if depths_dnG is not None:
327
+ self.depth_pyramid.append(depths_dnG)
328
+ for i in range(self.num_levels - 1):
329
+ if depths_dnG is not None:
330
+ depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
331
+ depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
332
+ _, _, H, W = depths_dnG_.shape
333
+ depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
334
+ self.depth_pyramid.append(depths_dnG)
335
+ fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
336
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
337
+ _, _, H, W = fmaps_.shape
338
+ fmaps = fmaps_.reshape(B, S, C, H, W)
339
+ H_prev = H
340
+ W_prev = W
341
+ self.fmaps_pyramid.append(fmaps)
342
+
343
+ def sample(self, coords):
344
+ r = self.radius
345
+ B, S, N, D = coords.shape
346
+ assert D == 2
347
+
348
+ H, W = self.H, self.W
349
+ out_pyramid = []
350
+ for i in range(self.num_levels):
351
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
352
+ _, _, _, H, W = corrs.shape
353
+
354
+ dx = torch.linspace(-r, r, 2 * r + 1)
355
+ dy = torch.linspace(-r, r, 2 * r + 1)
356
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
357
+ coords.device
358
+ )
359
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
360
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
361
+ coords_lvl = centroid_lvl + delta_lvl
362
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
363
+ corrs = corrs.view(B, S, N, -1)
364
+ out_pyramid.append(corrs)
365
+
366
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
367
+ return out.contiguous().float()
368
+
369
+ def corr(self, targets):
370
+ B, S, N, C = targets.shape
371
+ assert C == self.C
372
+ assert S == self.S
373
+
374
+ fmap1 = targets
375
+
376
+ self.corrs_pyramid = []
377
+ for fmaps in self.fmaps_pyramid:
378
+ _, _, _, H, W = fmaps.shape
379
+ fmap2s = fmaps.view(B, S, C, H * W)
380
+ corrs = torch.matmul(fmap1, fmap2s)
381
+ corrs = corrs.view(B, S, N, H, W)
382
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
383
+ self.corrs_pyramid.append(corrs)
384
+
385
+ def corr_sample(self, targets, coords, coords_dp=None):
386
+ B, S, N, C = targets.shape
387
+ r = self.radius
388
+ Dim_c = (2*r+1)**2
389
+ assert C == self.C
390
+ assert S == self.S
391
+
392
+ out_pyramid = []
393
+ out_pyramid_dp = []
394
+ for i in range(self.num_levels):
395
+ dx = torch.linspace(-r, r, 2 * r + 1)
396
+ dy = torch.linspace(-r, r, 2 * r + 1)
397
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
398
+ coords.device
399
+ )
400
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
401
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
402
+ coords_lvl = centroid_lvl + delta_lvl
403
+ fmaps = self.fmaps_pyramid[i]
404
+ _, _, _, H, W = fmaps.shape
405
+ fmap2s = fmaps.view(B*S, C, H, W)
406
+ if len(self.depth_pyramid)>0:
407
+ depths_dnG_i = self.depth_pyramid[i]
408
+ depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
409
+ dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
410
+ dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
411
+ out_pyramid_dp.append(dp_corrs)
412
+ fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
413
+ fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
414
+ corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
415
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
416
+ corrs = corrs.view(B, S, N, -1)
417
+ out_pyramid.append(corrs)
418
+
419
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
420
+ if len(self.depth_pyramid)>0:
421
+ out_dp = torch.cat(out_pyramid_dp, dim=-1)
422
+ self.fcorrD = out_dp.contiguous().float()
423
+ else:
424
+ self.fcorrD = torch.zeros_like(out).contiguous().float()
425
+ return out.contiguous().float()
426
+
427
+
428
+ class EUpdateFormer(nn.Module):
429
+ """
430
+ Transformer model that updates track estimates.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ space_depth=12,
436
+ time_depth=12,
437
+ input_dim=320,
438
+ hidden_size=384,
439
+ num_heads=8,
440
+ output_dim=130,
441
+ mlp_ratio=4.0,
442
+ vq_depth=3,
443
+ add_space_attn=True,
444
+ add_time_attn=True,
445
+ flash=True
446
+ ):
447
+ super().__init__()
448
+ self.out_channels = 2
449
+ self.num_heads = num_heads
450
+ self.hidden_size = hidden_size
451
+ self.add_space_attn = add_space_attn
452
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
453
+ self.flash = flash
454
+ self.flow_head = nn.Sequential(
455
+ nn.Linear(hidden_size, output_dim, bias=True),
456
+ nn.ReLU(inplace=True),
457
+ nn.Linear(output_dim, output_dim, bias=True),
458
+ nn.ReLU(inplace=True),
459
+ nn.Linear(output_dim, output_dim, bias=True)
460
+ )
461
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
462
+ cfg = xLSTMBlockStackConfig(
463
+ mlstm_block=mLSTMBlockConfig(
464
+ mlstm=mLSTMLayerConfig(
465
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
466
+ )
467
+ ),
468
+ slstm_block=sLSTMBlockConfig(
469
+ slstm=sLSTMLayerConfig(
470
+ backend="cuda",
471
+ num_heads=4,
472
+ conv1d_kernel_size=4,
473
+ bias_init="powerlaw_blockdependent",
474
+ ),
475
+ feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
476
+ ),
477
+ context_length=50,
478
+ num_blocks=7,
479
+ embedding_dim=384,
480
+ slstm_at=[1],
481
+
482
+ )
483
+ self.xlstm_fwd = xLSTMBlockStack(cfg)
484
+ self.xlstm_bwd = xLSTMBlockStack(cfg)
485
+
486
+ self.initialize_weights()
487
+
488
+ def initialize_weights(self):
489
+ def _basic_init(module):
490
+ if isinstance(module, nn.Linear):
491
+ torch.nn.init.xavier_uniform_(module.weight)
492
+ if module.bias is not None:
493
+ nn.init.constant_(module.bias, 0)
494
+
495
+ self.apply(_basic_init)
496
+
497
+ def forward(self,
498
+ input_tensor,
499
+ track_mask=None):
500
+ """ Updating with Transformer
501
+
502
+ Args:
503
+ input_tensor: B, N, T, C
504
+ arap_embed: B, N, T, C
505
+ """
506
+ B, N, T, C = input_tensor.shape
507
+ x = self.input_transform(input_tensor)
508
+
509
+ track_mask = track_mask.permute(0,2,1,3).float()
510
+ fwd_x = x*track_mask
511
+ bwd_x = x.flip(2)*track_mask.flip(2)
512
+ feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
513
+ feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
514
+ feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
515
+
516
+ flow = self.flow_head(feat)
517
+
518
+ return flow[..., :2], flow[..., 2:]
519
+
models/SpaTrackV2/models/camera_transform.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Adapted from https://github.com/amyxlase/relpose-plus-plus
9
+
10
+ import torch
11
+ import numpy as np
12
+ import math
13
+
14
+
15
+
16
+
17
+ def bbox_xyxy_to_xywh(xyxy):
18
+ wh = xyxy[2:] - xyxy[:2]
19
+ xywh = np.concatenate([xyxy[:2], wh])
20
+ return xywh
21
+
22
+
23
+ def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
24
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
25
+
26
+ principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
27
+
28
+ focal_length, principal_point_cropped = _convert_pixels_to_ndc(
29
+ focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
30
+ )
31
+
32
+ return focal_length, principal_point_cropped
33
+
34
+
35
+ def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
36
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
37
+
38
+ # now scale and convert from pixels to NDC
39
+ image_size_wh_output = new_size_wh.float()
40
+ scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
41
+ focal_length_px_scaled = focal_length_px * scale
42
+ principal_point_px_scaled = principal_point_px * scale
43
+
44
+ focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
45
+ focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
46
+ )
47
+ return focal_length_scaled, principal_point_scaled
48
+
49
+
50
+ def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
51
+ half_image_size = image_size_wh / 2
52
+ rescale = half_image_size.min()
53
+ principal_point_px = half_image_size - principal_point * rescale
54
+ focal_length_px = focal_length * rescale
55
+ return focal_length_px, principal_point_px
56
+
57
+
58
+ def _convert_pixels_to_ndc(
59
+ focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
60
+ ):
61
+ half_image_size = image_size_wh / 2
62
+ rescale = half_image_size.min()
63
+ principal_point = (half_image_size - principal_point_px) / rescale
64
+ focal_length = focal_length_px / rescale
65
+ return focal_length, principal_point
66
+
67
+
68
+ def normalize_cameras(
69
+ cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
70
+ pose_mode="C2W"
71
+ ):
72
+ """
73
+ Normalizes cameras such that
74
+ (1) the optical axes point to the origin and the average distance to the origin is 1
75
+ (2) the first camera is the origin
76
+ (3) the translation vector is normalized
77
+
78
+ TODO: some transforms overlap with others. no need to do so many transforms
79
+ Args:
80
+ cameras (List[camera]).
81
+ """
82
+ # Let distance from first camera to origin be unit
83
+ new_cameras = cameras.clone()
84
+ scale = 1.0
85
+
86
+ if compute_optical:
87
+ new_cameras, points = compute_optical_transform(new_cameras, points=points)
88
+ if first_camera:
89
+ new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
90
+ if normalize_trans:
91
+ new_cameras, points, scale = normalize_translation(new_cameras,
92
+ points=points, max_norm=max_norm)
93
+ return new_cameras, points, scale
94
+
95
+
96
+ def compute_optical_transform(new_cameras, points=None):
97
+ """
98
+ adapted from https://github.com/amyxlase/relpose-plus-plus
99
+ """
100
+
101
+ new_transform = new_cameras.get_world_to_view_transform()
102
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
103
+ t = Translate(p_intersect)
104
+ scale = dist.squeeze()[0]
105
+
106
+ if points is not None:
107
+ points = t.inverse().transform_points(points)
108
+ points = points / scale
109
+
110
+ # Degenerate case
111
+ if scale == 0:
112
+ scale = torch.norm(new_cameras.T, dim=(0, 1))
113
+ scale = torch.sqrt(scale)
114
+ new_cameras.T = new_cameras.T / scale
115
+ else:
116
+ new_matrix = t.compose(new_transform).get_matrix()
117
+ new_cameras.R = new_matrix[:, :3, :3]
118
+ new_cameras.T = new_matrix[:, 3, :3] / scale
119
+
120
+ return new_cameras, points
121
+
122
+
123
+ def compute_optical_axis_intersection(cameras):
124
+ centers = cameras.get_camera_center()
125
+ principal_points = cameras.principal_point
126
+
127
+ one_vec = torch.ones((len(cameras), 1))
128
+ optical_axis = torch.cat((principal_points, one_vec), -1)
129
+
130
+ pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
131
+
132
+ pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
133
+
134
+ directions = pp2 - centers
135
+ centers = centers.unsqueeze(0).unsqueeze(0)
136
+ directions = directions.unsqueeze(0).unsqueeze(0)
137
+
138
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
139
+
140
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
141
+ dist = (p_intersect - centers).norm(dim=-1)
142
+
143
+ return p_intersect, dist, p_line_intersect, pp2, r
144
+
145
+
146
+ def intersect_skew_line_groups(p, r, mask):
147
+ # p, r both of shape (B, N, n_intersected_lines, 3)
148
+ # mask of shape (B, N, n_intersected_lines)
149
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
150
+ _, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
151
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
152
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
153
+
154
+
155
+ def intersect_skew_lines_high_dim(p, r, mask=None):
156
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
157
+ dim = p.shape[-1]
158
+ # make sure the heading vectors are l2-normed
159
+ if mask is None:
160
+ mask = torch.ones_like(p[..., 0])
161
+ r = torch.nn.functional.normalize(r, dim=-1)
162
+
163
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
164
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
165
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
166
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
167
+
168
+ if torch.any(torch.isnan(p_intersect)):
169
+ print(p_intersect)
170
+ raise ValueError(f"p_intersect is NaN")
171
+
172
+ return p_intersect, r
173
+
174
+
175
+ def _point_line_distance(p1, r1, p2):
176
+ df = p2 - p1
177
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
178
+ line_pt_nearest = p2 - proj_vector
179
+ d = (proj_vector).norm(dim=-1)
180
+ return d, line_pt_nearest
181
+
182
+
183
+ def first_camera_transform(cameras, rotation_only=False,
184
+ points=None, pose_mode="C2W"):
185
+ """
186
+ Transform so that the first camera is the origin
187
+ """
188
+
189
+ new_cameras = cameras.clone()
190
+ # new_transform = new_cameras.get_world_to_view_transform()
191
+
192
+ R = cameras.R
193
+ T = cameras.T
194
+ Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
195
+ Tran_M = torch.cat([Tran_M,
196
+ torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
197
+ if pose_mode == "C2W":
198
+ Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
199
+ elif pose_mode == "W2C":
200
+ Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
201
+
202
+ if False:
203
+ tR = Rotate(new_cameras.R[0].unsqueeze(0))
204
+ if rotation_only:
205
+ t = tR.inverse()
206
+ else:
207
+ tT = Translate(new_cameras.T[0].unsqueeze(0))
208
+ t = tR.compose(tT).inverse()
209
+
210
+ if points is not None:
211
+ points = t.inverse().transform_points(points)
212
+
213
+ if pose_mode == "C2W":
214
+ new_matrix = new_transform.compose(t).get_matrix()
215
+ else:
216
+ import ipdb; ipdb.set_trace()
217
+ new_matrix = t.compose(new_transform).get_matrix()
218
+
219
+ new_cameras.R = Tran_M_new[:, :3, :3]
220
+ new_cameras.T = Tran_M_new[:, :3, 3]
221
+
222
+ return new_cameras, points
223
+
224
+
225
+ def normalize_translation(new_cameras, points=None, max_norm=False):
226
+ t_gt = new_cameras.T.clone()
227
+ t_gt = t_gt[1:, :]
228
+
229
+ if max_norm:
230
+ t_gt_norm = torch.norm(t_gt, dim=(-1))
231
+ t_gt_scale = t_gt_norm.max()
232
+ if t_gt_norm.max() < 0.001:
233
+ t_gt_scale = torch.ones_like(t_gt_scale)
234
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
235
+ else:
236
+ t_gt_norm = torch.norm(t_gt, dim=(0, 1))
237
+ t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
238
+ t_gt_scale = t_gt_scale / 2
239
+ if t_gt_norm.max() < 0.001:
240
+ t_gt_scale = torch.ones_like(t_gt_scale)
241
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
242
+
243
+ new_cameras.T = new_cameras.T / t_gt_scale
244
+
245
+ if points is not None:
246
+ points = points / t_gt_scale
247
+
248
+ return new_cameras, points, t_gt_scale
models/SpaTrackV2/models/depth_refiner/backbone.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the NVIDIA Source Code License
5
+ # ---------------------------------------------------------------
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from functools import partial
10
+
11
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
12
+ from timm.models import register_model
13
+ from timm.models.vision_transformer import _cfg
14
+ import math
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.dwconv = DWConv(hidden_features)
24
+ self.act = act_layer()
25
+ self.fc2 = nn.Linear(hidden_features, out_features)
26
+ self.drop = nn.Dropout(drop)
27
+
28
+ self.apply(self._init_weights)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+ elif isinstance(m, nn.LayerNorm):
36
+ nn.init.constant_(m.bias, 0)
37
+ nn.init.constant_(m.weight, 1.0)
38
+ elif isinstance(m, nn.Conv2d):
39
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
40
+ fan_out //= m.groups
41
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
42
+ if m.bias is not None:
43
+ m.bias.data.zero_()
44
+
45
+ def forward(self, x, H, W):
46
+ x = self.fc1(x)
47
+ x = self.dwconv(x, H, W)
48
+ x = self.act(x)
49
+ x = self.drop(x)
50
+ x = self.fc2(x)
51
+ x = self.drop(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
57
+ super().__init__()
58
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
59
+
60
+ self.dim = dim
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ self.scale = qk_scale or head_dim ** -0.5
64
+
65
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
66
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ self.sr_ratio = sr_ratio
72
+ if sr_ratio > 1:
73
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
74
+ self.norm = nn.LayerNorm(dim)
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, nn.Linear):
80
+ trunc_normal_(m.weight, std=.02)
81
+ if isinstance(m, nn.Linear) and m.bias is not None:
82
+ nn.init.constant_(m.bias, 0)
83
+ elif isinstance(m, nn.LayerNorm):
84
+ nn.init.constant_(m.bias, 0)
85
+ nn.init.constant_(m.weight, 1.0)
86
+ elif isinstance(m, nn.Conv2d):
87
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88
+ fan_out //= m.groups
89
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
90
+ if m.bias is not None:
91
+ m.bias.data.zero_()
92
+
93
+ def forward(self, x, H, W):
94
+ B, N, C = x.shape
95
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
96
+
97
+ if self.sr_ratio > 1:
98
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
99
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
100
+ x_ = self.norm(x_)
101
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
102
+ else:
103
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
104
+ k, v = kv[0], kv[1]
105
+
106
+ attn = (q @ k.transpose(-2, -1)) * self.scale
107
+ attn = attn.softmax(dim=-1)
108
+ attn = self.attn_drop(attn)
109
+
110
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
111
+ x = self.proj(x)
112
+ x = self.proj_drop(x)
113
+
114
+ return x
115
+
116
+
117
+ class Block(nn.Module):
118
+
119
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
120
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
121
+ super().__init__()
122
+ self.norm1 = norm_layer(dim)
123
+ self.attn = Attention(
124
+ dim,
125
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
126
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
127
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
128
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
129
+ self.norm2 = norm_layer(dim)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
132
+
133
+ self.apply(self._init_weights)
134
+
135
+ def _init_weights(self, m):
136
+ if isinstance(m, nn.Linear):
137
+ trunc_normal_(m.weight, std=.02)
138
+ if isinstance(m, nn.Linear) and m.bias is not None:
139
+ nn.init.constant_(m.bias, 0)
140
+ elif isinstance(m, nn.LayerNorm):
141
+ nn.init.constant_(m.bias, 0)
142
+ nn.init.constant_(m.weight, 1.0)
143
+ elif isinstance(m, nn.Conv2d):
144
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
145
+ fan_out //= m.groups
146
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
147
+ if m.bias is not None:
148
+ m.bias.data.zero_()
149
+
150
+ def forward(self, x, H, W):
151
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
152
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
153
+
154
+ return x
155
+
156
+
157
+ class OverlapPatchEmbed(nn.Module):
158
+ """ Image to Patch Embedding
159
+ """
160
+
161
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
162
+ super().__init__()
163
+ img_size = to_2tuple(img_size)
164
+ patch_size = to_2tuple(patch_size)
165
+
166
+ self.img_size = img_size
167
+ self.patch_size = patch_size
168
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
169
+ self.num_patches = self.H * self.W
170
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
171
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
172
+ self.norm = nn.LayerNorm(embed_dim)
173
+
174
+ self.apply(self._init_weights)
175
+
176
+ def _init_weights(self, m):
177
+ if isinstance(m, nn.Linear):
178
+ trunc_normal_(m.weight, std=.02)
179
+ if isinstance(m, nn.Linear) and m.bias is not None:
180
+ nn.init.constant_(m.bias, 0)
181
+ elif isinstance(m, nn.LayerNorm):
182
+ nn.init.constant_(m.bias, 0)
183
+ nn.init.constant_(m.weight, 1.0)
184
+ elif isinstance(m, nn.Conv2d):
185
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
186
+ fan_out //= m.groups
187
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
188
+ if m.bias is not None:
189
+ m.bias.data.zero_()
190
+
191
+ def forward(self, x):
192
+ x = self.proj(x)
193
+ _, _, H, W = x.shape
194
+ x = x.flatten(2).transpose(1, 2)
195
+ x = self.norm(x)
196
+
197
+ return x, H, W
198
+
199
+
200
+
201
+
202
+ class OverlapPatchEmbed43(nn.Module):
203
+ """ Image to Patch Embedding
204
+ """
205
+
206
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
207
+ super().__init__()
208
+ img_size = to_2tuple(img_size)
209
+ patch_size = to_2tuple(patch_size)
210
+
211
+ self.img_size = img_size
212
+ self.patch_size = patch_size
213
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
214
+ self.num_patches = self.H * self.W
215
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
216
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
217
+ self.norm = nn.LayerNorm(embed_dim)
218
+
219
+ self.apply(self._init_weights)
220
+
221
+ def _init_weights(self, m):
222
+ if isinstance(m, nn.Linear):
223
+ trunc_normal_(m.weight, std=.02)
224
+ if isinstance(m, nn.Linear) and m.bias is not None:
225
+ nn.init.constant_(m.bias, 0)
226
+ elif isinstance(m, nn.LayerNorm):
227
+ nn.init.constant_(m.bias, 0)
228
+ nn.init.constant_(m.weight, 1.0)
229
+ elif isinstance(m, nn.Conv2d):
230
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
231
+ fan_out //= m.groups
232
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
233
+ if m.bias is not None:
234
+ m.bias.data.zero_()
235
+
236
+ def forward(self, x):
237
+ if x.shape[1]==4:
238
+ x = self.proj_4c(x)
239
+ else:
240
+ x = self.proj(x)
241
+ _, _, H, W = x.shape
242
+ x = x.flatten(2).transpose(1, 2)
243
+ x = self.norm(x)
244
+
245
+ return x, H, W
246
+
247
+ class MixVisionTransformer(nn.Module):
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
249
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
250
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
251
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
252
+ super().__init__()
253
+ self.num_classes = num_classes
254
+ self.depths = depths
255
+
256
+ # patch_embed 43
257
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
258
+ embed_dim=embed_dims[0])
259
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
260
+ embed_dim=embed_dims[1])
261
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
262
+ embed_dim=embed_dims[2])
263
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
264
+ embed_dim=embed_dims[3])
265
+
266
+ # transformer encoder
267
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
268
+ cur = 0
269
+ self.block1 = nn.ModuleList([Block(
270
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
271
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
272
+ sr_ratio=sr_ratios[0])
273
+ for i in range(depths[0])])
274
+ self.norm1 = norm_layer(embed_dims[0])
275
+
276
+ cur += depths[0]
277
+ self.block2 = nn.ModuleList([Block(
278
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
279
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
280
+ sr_ratio=sr_ratios[1])
281
+ for i in range(depths[1])])
282
+ self.norm2 = norm_layer(embed_dims[1])
283
+
284
+ cur += depths[1]
285
+ self.block3 = nn.ModuleList([Block(
286
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
287
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
288
+ sr_ratio=sr_ratios[2])
289
+ for i in range(depths[2])])
290
+ self.norm3 = norm_layer(embed_dims[2])
291
+
292
+ cur += depths[2]
293
+ self.block4 = nn.ModuleList([Block(
294
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
295
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
296
+ sr_ratio=sr_ratios[3])
297
+ for i in range(depths[3])])
298
+ self.norm4 = norm_layer(embed_dims[3])
299
+
300
+ # classification head
301
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
302
+
303
+ self.apply(self._init_weights)
304
+
305
+ def _init_weights(self, m):
306
+ if isinstance(m, nn.Linear):
307
+ trunc_normal_(m.weight, std=.02)
308
+ if isinstance(m, nn.Linear) and m.bias is not None:
309
+ nn.init.constant_(m.bias, 0)
310
+ elif isinstance(m, nn.LayerNorm):
311
+ nn.init.constant_(m.bias, 0)
312
+ nn.init.constant_(m.weight, 1.0)
313
+ elif isinstance(m, nn.Conv2d):
314
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
315
+ fan_out //= m.groups
316
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
317
+ if m.bias is not None:
318
+ m.bias.data.zero_()
319
+
320
+ def init_weights(self, pretrained=None):
321
+ if isinstance(pretrained, str):
322
+ logger = get_root_logger()
323
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
324
+
325
+ def reset_drop_path(self, drop_path_rate):
326
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
327
+ cur = 0
328
+ for i in range(self.depths[0]):
329
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
330
+
331
+ cur += self.depths[0]
332
+ for i in range(self.depths[1]):
333
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
334
+
335
+ cur += self.depths[1]
336
+ for i in range(self.depths[2]):
337
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
338
+
339
+ cur += self.depths[2]
340
+ for i in range(self.depths[3]):
341
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
342
+
343
+ def freeze_patch_emb(self):
344
+ self.patch_embed1.requires_grad = False
345
+
346
+ @torch.jit.ignore
347
+ def no_weight_decay(self):
348
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
349
+
350
+ def get_classifier(self):
351
+ return self.head
352
+
353
+ def reset_classifier(self, num_classes, global_pool=''):
354
+ self.num_classes = num_classes
355
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
356
+
357
+ def forward_features(self, x):
358
+ B = x.shape[0]
359
+ outs = []
360
+
361
+ # stage 1
362
+ x, H, W = self.patch_embed1(x)
363
+ for i, blk in enumerate(self.block1):
364
+ x = blk(x, H, W)
365
+ x = self.norm1(x)
366
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
367
+ outs.append(x)
368
+
369
+ # stage 2
370
+ x, H, W = self.patch_embed2(x)
371
+ for i, blk in enumerate(self.block2):
372
+ x = blk(x, H, W)
373
+ x = self.norm2(x)
374
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
375
+ outs.append(x)
376
+
377
+ # stage 3
378
+ x, H, W = self.patch_embed3(x)
379
+ for i, blk in enumerate(self.block3):
380
+ x = blk(x, H, W)
381
+ x = self.norm3(x)
382
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
383
+ outs.append(x)
384
+
385
+ # stage 4
386
+ x, H, W = self.patch_embed4(x)
387
+ for i, blk in enumerate(self.block4):
388
+ x = blk(x, H, W)
389
+ x = self.norm4(x)
390
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
391
+ outs.append(x)
392
+
393
+ return outs
394
+
395
+ def forward(self, x):
396
+ if x.dim() == 5:
397
+ x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
398
+ x = self.forward_features(x)
399
+ # x = self.head(x)
400
+
401
+ return x
402
+
403
+
404
+ class DWConv(nn.Module):
405
+ def __init__(self, dim=768):
406
+ super(DWConv, self).__init__()
407
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
408
+
409
+ def forward(self, x, H, W):
410
+ B, N, C = x.shape
411
+ x = x.transpose(1, 2).view(B, C, H, W)
412
+ x = self.dwconv(x)
413
+ x = x.flatten(2).transpose(1, 2)
414
+
415
+ return x
416
+
417
+
418
+
419
+ #@BACKBONES.register_module()
420
+ class mit_b0(MixVisionTransformer):
421
+ def __init__(self, **kwargs):
422
+ super(mit_b0, self).__init__(
423
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
424
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
425
+ drop_rate=0.0, drop_path_rate=0.1)
426
+
427
+
428
+ #@BACKBONES.register_module()
429
+ class mit_b1(MixVisionTransformer):
430
+ def __init__(self, **kwargs):
431
+ super(mit_b1, self).__init__(
432
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
433
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
434
+ drop_rate=0.0, drop_path_rate=0.1)
435
+
436
+
437
+ #@BACKBONES.register_module()
438
+ class mit_b2(MixVisionTransformer):
439
+ def __init__(self, **kwargs):
440
+ super(mit_b2, self).__init__(
441
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
442
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
443
+ drop_rate=0.0, drop_path_rate=0.1)
444
+
445
+
446
+ #@BACKBONES.register_module()
447
+ class mit_b3(MixVisionTransformer):
448
+ def __init__(self, **kwargs):
449
+ super(mit_b3, self).__init__(
450
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
451
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
452
+ drop_rate=0.0, drop_path_rate=0.1)
453
+
454
+
455
+ #@BACKBONES.register_module()
456
+ class mit_b4(MixVisionTransformer):
457
+ def __init__(self, **kwargs):
458
+ super(mit_b4, self).__init__(
459
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
460
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
461
+ drop_rate=0.0, drop_path_rate=0.1)
462
+
463
+
464
+ #@BACKBONES.register_module()
465
+ class mit_b5(MixVisionTransformer):
466
+ def __init__(self, **kwargs):
467
+ super(mit_b5, self).__init__(
468
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
469
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
470
+ drop_rate=0.0, drop_path_rate=0.1)
471
+
472
+
models/SpaTrackV2/models/depth_refiner/decode_head.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # from mmcv.cnn import normal_init
7
+ # from mmcv.runner import auto_fp16, force_fp32
8
+
9
+ # from mmseg.core import build_pixel_sampler
10
+ # from mmseg.ops import resize
11
+
12
+
13
+ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
14
+ """Base class for BaseDecodeHead.
15
+
16
+ Args:
17
+ in_channels (int|Sequence[int]): Input channels.
18
+ channels (int): Channels after modules, before conv_seg.
19
+ num_classes (int): Number of classes.
20
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
21
+ conv_cfg (dict|None): Config of conv layers. Default: None.
22
+ norm_cfg (dict|None): Config of norm layers. Default: None.
23
+ act_cfg (dict): Config of activation layers.
24
+ Default: dict(type='ReLU')
25
+ in_index (int|Sequence[int]): Input feature index. Default: -1
26
+ input_transform (str|None): Transformation type of input features.
27
+ Options: 'resize_concat', 'multiple_select', None.
28
+ 'resize_concat': Multiple feature maps will be resize to the
29
+ same size as first one and than concat together.
30
+ Usually used in FCN head of HRNet.
31
+ 'multiple_select': Multiple feature maps will be bundle into
32
+ a list and passed into decode head.
33
+ None: Only one select feature map is allowed.
34
+ Default: None.
35
+ loss_decode (dict): Config of decode loss.
36
+ Default: dict(type='CrossEntropyLoss').
37
+ ignore_index (int | None): The label index to be ignored. When using
38
+ masked BCE loss, ignore_index should be set to None. Default: 255
39
+ sampler (dict|None): The config of segmentation map sampler.
40
+ Default: None.
41
+ align_corners (bool): align_corners argument of F.interpolate.
42
+ Default: False.
43
+ """
44
+
45
+ def __init__(self,
46
+ in_channels,
47
+ channels,
48
+ *,
49
+ num_classes,
50
+ dropout_ratio=0.1,
51
+ conv_cfg=None,
52
+ norm_cfg=None,
53
+ act_cfg=dict(type='ReLU'),
54
+ in_index=-1,
55
+ input_transform=None,
56
+ loss_decode=dict(
57
+ type='CrossEntropyLoss',
58
+ use_sigmoid=False,
59
+ loss_weight=1.0),
60
+ decoder_params=None,
61
+ ignore_index=255,
62
+ sampler=None,
63
+ align_corners=False):
64
+ super(BaseDecodeHead, self).__init__()
65
+ self._init_inputs(in_channels, in_index, input_transform)
66
+ self.channels = channels
67
+ self.num_classes = num_classes
68
+ self.dropout_ratio = dropout_ratio
69
+ self.conv_cfg = conv_cfg
70
+ self.norm_cfg = norm_cfg
71
+ self.act_cfg = act_cfg
72
+ self.in_index = in_index
73
+ self.ignore_index = ignore_index
74
+ self.align_corners = align_corners
75
+
76
+ if sampler is not None:
77
+ self.sampler = build_pixel_sampler(sampler, context=self)
78
+ else:
79
+ self.sampler = None
80
+
81
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
82
+ if dropout_ratio > 0:
83
+ self.dropout = nn.Dropout2d(dropout_ratio)
84
+ else:
85
+ self.dropout = None
86
+ self.fp16_enabled = False
87
+
88
+ def extra_repr(self):
89
+ """Extra repr."""
90
+ s = f'input_transform={self.input_transform}, ' \
91
+ f'ignore_index={self.ignore_index}, ' \
92
+ f'align_corners={self.align_corners}'
93
+ return s
94
+
95
+ def _init_inputs(self, in_channels, in_index, input_transform):
96
+ """Check and initialize input transforms.
97
+
98
+ The in_channels, in_index and input_transform must match.
99
+ Specifically, when input_transform is None, only single feature map
100
+ will be selected. So in_channels and in_index must be of type int.
101
+ When input_transform
102
+
103
+ Args:
104
+ in_channels (int|Sequence[int]): Input channels.
105
+ in_index (int|Sequence[int]): Input feature index.
106
+ input_transform (str|None): Transformation type of input features.
107
+ Options: 'resize_concat', 'multiple_select', None.
108
+ 'resize_concat': Multiple feature maps will be resize to the
109
+ same size as first one and than concat together.
110
+ Usually used in FCN head of HRNet.
111
+ 'multiple_select': Multiple feature maps will be bundle into
112
+ a list and passed into decode head.
113
+ None: Only one select feature map is allowed.
114
+ """
115
+
116
+ if input_transform is not None:
117
+ assert input_transform in ['resize_concat', 'multiple_select']
118
+ self.input_transform = input_transform
119
+ self.in_index = in_index
120
+ if input_transform is not None:
121
+ assert isinstance(in_channels, (list, tuple))
122
+ assert isinstance(in_index, (list, tuple))
123
+ assert len(in_channels) == len(in_index)
124
+ if input_transform == 'resize_concat':
125
+ self.in_channels = sum(in_channels)
126
+ else:
127
+ self.in_channels = in_channels
128
+ else:
129
+ assert isinstance(in_channels, int)
130
+ assert isinstance(in_index, int)
131
+ self.in_channels = in_channels
132
+
133
+ def init_weights(self):
134
+ """Initialize weights of classification layer."""
135
+ normal_init(self.conv_seg, mean=0, std=0.01)
136
+
137
+ def _transform_inputs(self, inputs):
138
+ """Transform inputs for decoder.
139
+
140
+ Args:
141
+ inputs (list[Tensor]): List of multi-level img features.
142
+
143
+ Returns:
144
+ Tensor: The transformed inputs
145
+ """
146
+
147
+ if self.input_transform == 'resize_concat':
148
+ inputs = [inputs[i] for i in self.in_index]
149
+ upsampled_inputs = [
150
+ resize(
151
+ input=x,
152
+ size=inputs[0].shape[2:],
153
+ mode='bilinear',
154
+ align_corners=self.align_corners) for x in inputs
155
+ ]
156
+ inputs = torch.cat(upsampled_inputs, dim=1)
157
+ elif self.input_transform == 'multiple_select':
158
+ inputs = [inputs[i] for i in self.in_index]
159
+ else:
160
+ inputs = inputs[self.in_index]
161
+
162
+ return inputs
163
+
164
+ # @auto_fp16()
165
+ @abstractmethod
166
+ def forward(self, inputs):
167
+ """Placeholder of forward function."""
168
+ pass
169
+
170
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
171
+ """Forward function for training.
172
+ Args:
173
+ inputs (list[Tensor]): List of multi-level img features.
174
+ img_metas (list[dict]): List of image info dict where each dict
175
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
176
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
177
+ For details on the values of these keys see
178
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
179
+ gt_semantic_seg (Tensor): Semantic segmentation masks
180
+ used if the architecture supports semantic segmentation task.
181
+ train_cfg (dict): The training config.
182
+
183
+ Returns:
184
+ dict[str, Tensor]: a dictionary of loss components
185
+ """
186
+ seg_logits = self.forward(inputs)
187
+ losses = self.losses(seg_logits, gt_semantic_seg)
188
+ return losses
189
+
190
+ def forward_test(self, inputs, img_metas, test_cfg):
191
+ """Forward function for testing.
192
+
193
+ Args:
194
+ inputs (list[Tensor]): List of multi-level img features.
195
+ img_metas (list[dict]): List of image info dict where each dict
196
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
197
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
198
+ For details on the values of these keys see
199
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
200
+ test_cfg (dict): The testing config.
201
+
202
+ Returns:
203
+ Tensor: Output segmentation map.
204
+ """
205
+ return self.forward(inputs)
206
+
207
+ def cls_seg(self, feat):
208
+ """Classify each pixel."""
209
+ if self.dropout is not None:
210
+ feat = self.dropout(feat)
211
+ output = self.conv_seg(feat)
212
+ return output
213
+
214
+
215
+ class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
216
+ """Base class for BaseDecodeHead_clips.
217
+
218
+ Args:
219
+ in_channels (int|Sequence[int]): Input channels.
220
+ channels (int): Channels after modules, before conv_seg.
221
+ num_classes (int): Number of classes.
222
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
223
+ conv_cfg (dict|None): Config of conv layers. Default: None.
224
+ norm_cfg (dict|None): Config of norm layers. Default: None.
225
+ act_cfg (dict): Config of activation layers.
226
+ Default: dict(type='ReLU')
227
+ in_index (int|Sequence[int]): Input feature index. Default: -1
228
+ input_transform (str|None): Transformation type of input features.
229
+ Options: 'resize_concat', 'multiple_select', None.
230
+ 'resize_concat': Multiple feature maps will be resize to the
231
+ same size as first one and than concat together.
232
+ Usually used in FCN head of HRNet.
233
+ 'multiple_select': Multiple feature maps will be bundle into
234
+ a list and passed into decode head.
235
+ None: Only one select feature map is allowed.
236
+ Default: None.
237
+ loss_decode (dict): Config of decode loss.
238
+ Default: dict(type='CrossEntropyLoss').
239
+ ignore_index (int | None): The label index to be ignored. When using
240
+ masked BCE loss, ignore_index should be set to None. Default: 255
241
+ sampler (dict|None): The config of segmentation map sampler.
242
+ Default: None.
243
+ align_corners (bool): align_corners argument of F.interpolate.
244
+ Default: False.
245
+ """
246
+
247
+ def __init__(self,
248
+ in_channels,
249
+ channels,
250
+ *,
251
+ num_classes,
252
+ dropout_ratio=0.1,
253
+ conv_cfg=None,
254
+ norm_cfg=None,
255
+ act_cfg=dict(type='ReLU'),
256
+ in_index=-1,
257
+ input_transform=None,
258
+ loss_decode=dict(
259
+ type='CrossEntropyLoss',
260
+ use_sigmoid=False,
261
+ loss_weight=1.0),
262
+ decoder_params=None,
263
+ ignore_index=255,
264
+ sampler=None,
265
+ align_corners=False,
266
+ num_clips=5):
267
+ super(BaseDecodeHead_clips, self).__init__()
268
+ self._init_inputs(in_channels, in_index, input_transform)
269
+ self.channels = channels
270
+ self.num_classes = num_classes
271
+ self.dropout_ratio = dropout_ratio
272
+ self.conv_cfg = conv_cfg
273
+ self.norm_cfg = norm_cfg
274
+ self.act_cfg = act_cfg
275
+ self.in_index = in_index
276
+ self.ignore_index = ignore_index
277
+ self.align_corners = align_corners
278
+ self.num_clips=num_clips
279
+
280
+ if sampler is not None:
281
+ self.sampler = build_pixel_sampler(sampler, context=self)
282
+ else:
283
+ self.sampler = None
284
+
285
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
286
+ if dropout_ratio > 0:
287
+ self.dropout = nn.Dropout2d(dropout_ratio)
288
+ else:
289
+ self.dropout = None
290
+ self.fp16_enabled = False
291
+
292
+ def extra_repr(self):
293
+ """Extra repr."""
294
+ s = f'input_transform={self.input_transform}, ' \
295
+ f'ignore_index={self.ignore_index}, ' \
296
+ f'align_corners={self.align_corners}'
297
+ return s
298
+
299
+ def _init_inputs(self, in_channels, in_index, input_transform):
300
+ """Check and initialize input transforms.
301
+
302
+ The in_channels, in_index and input_transform must match.
303
+ Specifically, when input_transform is None, only single feature map
304
+ will be selected. So in_channels and in_index must be of type int.
305
+ When input_transform
306
+
307
+ Args:
308
+ in_channels (int|Sequence[int]): Input channels.
309
+ in_index (int|Sequence[int]): Input feature index.
310
+ input_transform (str|None): Transformation type of input features.
311
+ Options: 'resize_concat', 'multiple_select', None.
312
+ 'resize_concat': Multiple feature maps will be resize to the
313
+ same size as first one and than concat together.
314
+ Usually used in FCN head of HRNet.
315
+ 'multiple_select': Multiple feature maps will be bundle into
316
+ a list and passed into decode head.
317
+ None: Only one select feature map is allowed.
318
+ """
319
+
320
+ if input_transform is not None:
321
+ assert input_transform in ['resize_concat', 'multiple_select']
322
+ self.input_transform = input_transform
323
+ self.in_index = in_index
324
+ if input_transform is not None:
325
+ assert isinstance(in_channels, (list, tuple))
326
+ assert isinstance(in_index, (list, tuple))
327
+ assert len(in_channels) == len(in_index)
328
+ if input_transform == 'resize_concat':
329
+ self.in_channels = sum(in_channels)
330
+ else:
331
+ self.in_channels = in_channels
332
+ else:
333
+ assert isinstance(in_channels, int)
334
+ assert isinstance(in_index, int)
335
+ self.in_channels = in_channels
336
+
337
+ def init_weights(self):
338
+ """Initialize weights of classification layer."""
339
+ normal_init(self.conv_seg, mean=0, std=0.01)
340
+
341
+ def _transform_inputs(self, inputs):
342
+ """Transform inputs for decoder.
343
+
344
+ Args:
345
+ inputs (list[Tensor]): List of multi-level img features.
346
+
347
+ Returns:
348
+ Tensor: The transformed inputs
349
+ """
350
+
351
+ if self.input_transform == 'resize_concat':
352
+ inputs = [inputs[i] for i in self.in_index]
353
+ upsampled_inputs = [
354
+ resize(
355
+ input=x,
356
+ size=inputs[0].shape[2:],
357
+ mode='bilinear',
358
+ align_corners=self.align_corners) for x in inputs
359
+ ]
360
+ inputs = torch.cat(upsampled_inputs, dim=1)
361
+ elif self.input_transform == 'multiple_select':
362
+ inputs = [inputs[i] for i in self.in_index]
363
+ else:
364
+ inputs = inputs[self.in_index]
365
+
366
+ return inputs
367
+
368
+ # @auto_fp16()
369
+ @abstractmethod
370
+ def forward(self, inputs):
371
+ """Placeholder of forward function."""
372
+ pass
373
+
374
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
375
+ """Forward function for training.
376
+ Args:
377
+ inputs (list[Tensor]): List of multi-level img features.
378
+ img_metas (list[dict]): List of image info dict where each dict
379
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
380
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
381
+ For details on the values of these keys see
382
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
383
+ gt_semantic_seg (Tensor): Semantic segmentation masks
384
+ used if the architecture supports semantic segmentation task.
385
+ train_cfg (dict): The training config.
386
+
387
+ Returns:
388
+ dict[str, Tensor]: a dictionary of loss components
389
+ """
390
+ seg_logits = self.forward(inputs,batch_size, num_clips)
391
+ losses = self.losses(seg_logits, gt_semantic_seg)
392
+ return losses
393
+
394
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
395
+ """Forward function for testing.
396
+
397
+ Args:
398
+ inputs (list[Tensor]): List of multi-level img features.
399
+ img_metas (list[dict]): List of image info dict where each dict
400
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
401
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
402
+ For details on the values of these keys see
403
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
404
+ test_cfg (dict): The testing config.
405
+
406
+ Returns:
407
+ Tensor: Output segmentation map.
408
+ """
409
+ return self.forward(inputs, batch_size, num_clips)
410
+
411
+ def cls_seg(self, feat):
412
+ """Classify each pixel."""
413
+ if self.dropout is not None:
414
+ feat = self.dropout(feat)
415
+ output = self.conv_seg(feat)
416
+ return output
417
+
418
+ class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
419
+ """Base class for BaseDecodeHead_clips_flow.
420
+
421
+ Args:
422
+ in_channels (int|Sequence[int]): Input channels.
423
+ channels (int): Channels after modules, before conv_seg.
424
+ num_classes (int): Number of classes.
425
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
426
+ conv_cfg (dict|None): Config of conv layers. Default: None.
427
+ norm_cfg (dict|None): Config of norm layers. Default: None.
428
+ act_cfg (dict): Config of activation layers.
429
+ Default: dict(type='ReLU')
430
+ in_index (int|Sequence[int]): Input feature index. Default: -1
431
+ input_transform (str|None): Transformation type of input features.
432
+ Options: 'resize_concat', 'multiple_select', None.
433
+ 'resize_concat': Multiple feature maps will be resize to the
434
+ same size as first one and than concat together.
435
+ Usually used in FCN head of HRNet.
436
+ 'multiple_select': Multiple feature maps will be bundle into
437
+ a list and passed into decode head.
438
+ None: Only one select feature map is allowed.
439
+ Default: None.
440
+ loss_decode (dict): Config of decode loss.
441
+ Default: dict(type='CrossEntropyLoss').
442
+ ignore_index (int | None): The label index to be ignored. When using
443
+ masked BCE loss, ignore_index should be set to None. Default: 255
444
+ sampler (dict|None): The config of segmentation map sampler.
445
+ Default: None.
446
+ align_corners (bool): align_corners argument of F.interpolate.
447
+ Default: False.
448
+ """
449
+
450
+ def __init__(self,
451
+ in_channels,
452
+ channels,
453
+ *,
454
+ num_classes,
455
+ dropout_ratio=0.1,
456
+ conv_cfg=None,
457
+ norm_cfg=None,
458
+ act_cfg=dict(type='ReLU'),
459
+ in_index=-1,
460
+ input_transform=None,
461
+ loss_decode=dict(
462
+ type='CrossEntropyLoss',
463
+ use_sigmoid=False,
464
+ loss_weight=1.0),
465
+ decoder_params=None,
466
+ ignore_index=255,
467
+ sampler=None,
468
+ align_corners=False,
469
+ num_clips=5):
470
+ super(BaseDecodeHead_clips_flow, self).__init__()
471
+ self._init_inputs(in_channels, in_index, input_transform)
472
+ self.channels = channels
473
+ self.num_classes = num_classes
474
+ self.dropout_ratio = dropout_ratio
475
+ self.conv_cfg = conv_cfg
476
+ self.norm_cfg = norm_cfg
477
+ self.act_cfg = act_cfg
478
+ self.in_index = in_index
479
+ self.ignore_index = ignore_index
480
+ self.align_corners = align_corners
481
+ self.num_clips=num_clips
482
+
483
+ if sampler is not None:
484
+ self.sampler = build_pixel_sampler(sampler, context=self)
485
+ else:
486
+ self.sampler = None
487
+
488
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
489
+ if dropout_ratio > 0:
490
+ self.dropout = nn.Dropout2d(dropout_ratio)
491
+ else:
492
+ self.dropout = None
493
+ self.fp16_enabled = False
494
+
495
+ def extra_repr(self):
496
+ """Extra repr."""
497
+ s = f'input_transform={self.input_transform}, ' \
498
+ f'ignore_index={self.ignore_index}, ' \
499
+ f'align_corners={self.align_corners}'
500
+ return s
501
+
502
+ def _init_inputs(self, in_channels, in_index, input_transform):
503
+ """Check and initialize input transforms.
504
+
505
+ The in_channels, in_index and input_transform must match.
506
+ Specifically, when input_transform is None, only single feature map
507
+ will be selected. So in_channels and in_index must be of type int.
508
+ When input_transform
509
+
510
+ Args:
511
+ in_channels (int|Sequence[int]): Input channels.
512
+ in_index (int|Sequence[int]): Input feature index.
513
+ input_transform (str|None): Transformation type of input features.
514
+ Options: 'resize_concat', 'multiple_select', None.
515
+ 'resize_concat': Multiple feature maps will be resize to the
516
+ same size as first one and than concat together.
517
+ Usually used in FCN head of HRNet.
518
+ 'multiple_select': Multiple feature maps will be bundle into
519
+ a list and passed into decode head.
520
+ None: Only one select feature map is allowed.
521
+ """
522
+
523
+ if input_transform is not None:
524
+ assert input_transform in ['resize_concat', 'multiple_select']
525
+ self.input_transform = input_transform
526
+ self.in_index = in_index
527
+ if input_transform is not None:
528
+ assert isinstance(in_channels, (list, tuple))
529
+ assert isinstance(in_index, (list, tuple))
530
+ assert len(in_channels) == len(in_index)
531
+ if input_transform == 'resize_concat':
532
+ self.in_channels = sum(in_channels)
533
+ else:
534
+ self.in_channels = in_channels
535
+ else:
536
+ assert isinstance(in_channels, int)
537
+ assert isinstance(in_index, int)
538
+ self.in_channels = in_channels
539
+
540
+ def init_weights(self):
541
+ """Initialize weights of classification layer."""
542
+ normal_init(self.conv_seg, mean=0, std=0.01)
543
+
544
+ def _transform_inputs(self, inputs):
545
+ """Transform inputs for decoder.
546
+
547
+ Args:
548
+ inputs (list[Tensor]): List of multi-level img features.
549
+
550
+ Returns:
551
+ Tensor: The transformed inputs
552
+ """
553
+
554
+ if self.input_transform == 'resize_concat':
555
+ inputs = [inputs[i] for i in self.in_index]
556
+ upsampled_inputs = [
557
+ resize(
558
+ input=x,
559
+ size=inputs[0].shape[2:],
560
+ mode='bilinear',
561
+ align_corners=self.align_corners) for x in inputs
562
+ ]
563
+ inputs = torch.cat(upsampled_inputs, dim=1)
564
+ elif self.input_transform == 'multiple_select':
565
+ inputs = [inputs[i] for i in self.in_index]
566
+ else:
567
+ inputs = inputs[self.in_index]
568
+
569
+ return inputs
570
+
571
+ # @auto_fp16()
572
+ @abstractmethod
573
+ def forward(self, inputs):
574
+ """Placeholder of forward function."""
575
+ pass
576
+
577
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
578
+ """Forward function for training.
579
+ Args:
580
+ inputs (list[Tensor]): List of multi-level img features.
581
+ img_metas (list[dict]): List of image info dict where each dict
582
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
583
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
584
+ For details on the values of these keys see
585
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
586
+ gt_semantic_seg (Tensor): Semantic segmentation masks
587
+ used if the architecture supports semantic segmentation task.
588
+ train_cfg (dict): The training config.
589
+
590
+ Returns:
591
+ dict[str, Tensor]: a dictionary of loss components
592
+ """
593
+ seg_logits = self.forward(inputs,batch_size, num_clips,img)
594
+ losses = self.losses(seg_logits, gt_semantic_seg)
595
+ return losses
596
+
597
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
598
+ """Forward function for testing.
599
+
600
+ Args:
601
+ inputs (list[Tensor]): List of multi-level img features.
602
+ img_metas (list[dict]): List of image info dict where each dict
603
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
604
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
605
+ For details on the values of these keys see
606
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
607
+ test_cfg (dict): The testing config.
608
+
609
+ Returns:
610
+ Tensor: Output segmentation map.
611
+ """
612
+ return self.forward(inputs, batch_size, num_clips,img)
613
+
614
+ def cls_seg(self, feat):
615
+ """Classify each pixel."""
616
+ if self.dropout is not None:
617
+ feat = self.dropout(feat)
618
+ output = self.conv_seg(feat)
619
+ return output
models/SpaTrackV2/models/depth_refiner/depth_refiner.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
5
+ from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
6
+ from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
7
+ from einops import rearrange
8
+ class TrackStablizer(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ self.backbone = mit_b3()
13
+
14
+ old_conv = self.backbone.patch_embed1.proj
15
+ new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
16
+
17
+ new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
18
+ self.backbone.patch_embed1.proj = new_conv
19
+
20
+ self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
21
+ in_index=[0, 1, 2, 3],
22
+ feature_strides=[4, 8, 16, 32],
23
+ channels=128,
24
+ dropout_ratio=0.1,
25
+ num_classes=1,
26
+ align_corners=False,
27
+ decoder_params=dict(embed_dim=256, depths=4),
28
+ num_clips=16,
29
+ norm_cfg = dict(type='SyncBN', requires_grad=True))
30
+
31
+ self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
32
+ nn.ReLU(inplace=True))
33
+ self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
34
+ nn.ReLU(inplace=True))
35
+ self.success = False
36
+ self.x = None
37
+
38
+ def buffer_forward(self, inputs, num_clips=16):
39
+ """
40
+ buffer forward for getting the pointmap and image features
41
+ """
42
+ B, T, C, H, W = inputs.shape
43
+ self.x = self.backbone(inputs)
44
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
45
+ self.success = True
46
+ return scale, shift
47
+
48
+ def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
49
+
50
+ """
51
+ Args:
52
+ inputs: [B, T, C, H, W], RGB + PointMap + Mask
53
+ tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
54
+ num_clips: int, number of clips to use
55
+ """
56
+ B, T, C, H, W = inputs.shape
57
+ edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
58
+ edge_feat1 = self.edge_conv1(edge_feat)
59
+
60
+ if not self.success:
61
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
62
+ self.success = True
63
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
64
+ else:
65
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
66
+
67
+ return update
68
+
69
+ def reset_success(self):
70
+ self.success = False
71
+ self.x = None
72
+ self.Track_Stabilizer.reset_success()
73
+
74
+
75
+ if __name__ == "__main__":
76
+ # Create test input tensors
77
+ batch_size = 1
78
+ seq_len = 16
79
+ channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
80
+ height = 384
81
+ width = 512
82
+
83
+ # Create random input tensor with shape [B, T, C, H, W]
84
+ inputs = torch.randn(batch_size, seq_len, channels, height, width)
85
+
86
+ # Create random tracks
87
+ tracks = torch.randn(batch_size, seq_len, 1024, 4)
88
+
89
+ # Create random test images
90
+ test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
91
+
92
+ # Initialize model and move to GPU
93
+ model = TrackStablizer().cuda()
94
+
95
+ # Move inputs to GPU and run forward pass
96
+ inputs = inputs.cuda()
97
+ tracks = tracks.cuda()
98
+ outputs = model.buffer_forward(inputs, num_clips=seq_len)
99
+ import time
100
+ start_time = time.time()
101
+ outputs = model(inputs, tracks, num_clips=seq_len)
102
+ end_time = time.time()
103
+ print(f"Time taken: {end_time - start_time} seconds")
104
+ import pdb; pdb.set_trace()
105
+ # # Print shapes for verification
106
+ # print(f"Input shape: {inputs.shape}")
107
+ # print(f"Output shape: {outputs.shape}")
108
+
109
+ # # Basic tests
110
+ # assert outputs.shape[0] == batch_size, "Batch size mismatch"
111
+ # assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
112
+ # assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
113
+
114
+ # print("All tests passed!")
115
+
models/SpaTrackV2/models/depth_refiner/network.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ Author: Ke Xian
5
6
+ Date: 2020/07/20
7
+ '''
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.init as init
12
+
13
+ # ==============================================================================================================
14
+
15
+ class FTB(nn.Module):
16
+ def __init__(self, inchannels, midchannels=512):
17
+ super(FTB, self).__init__()
18
+ self.in1 = inchannels
19
+ self.mid = midchannels
20
+
21
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
22
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
23
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
24
+ #nn.BatchNorm2d(num_features=self.mid),\
25
+ nn.ReLU(inplace=True),\
26
+ nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
27
+ self.relu = nn.ReLU(inplace=True)
28
+
29
+ self.init_params()
30
+
31
+ def forward(self, x):
32
+ x = self.conv1(x)
33
+ x = x + self.conv_branch(x)
34
+ x = self.relu(x)
35
+
36
+ return x
37
+
38
+ def init_params(self):
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ #init.kaiming_normal_(m.weight, mode='fan_out')
42
+ init.normal_(m.weight, std=0.01)
43
+ # init.xavier_normal_(m.weight)
44
+ if m.bias is not None:
45
+ init.constant_(m.bias, 0)
46
+ elif isinstance(m, nn.ConvTranspose2d):
47
+ #init.kaiming_normal_(m.weight, mode='fan_out')
48
+ init.normal_(m.weight, std=0.01)
49
+ # init.xavier_normal_(m.weight)
50
+ if m.bias is not None:
51
+ init.constant_(m.bias, 0)
52
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
53
+ init.constant_(m.weight, 1)
54
+ init.constant_(m.bias, 0)
55
+ elif isinstance(m, nn.Linear):
56
+ init.normal_(m.weight, std=0.01)
57
+ if m.bias is not None:
58
+ init.constant_(m.bias, 0)
59
+
60
+ class ATA(nn.Module):
61
+ def __init__(self, inchannels, reduction = 8):
62
+ super(ATA, self).__init__()
63
+ self.inchannels = inchannels
64
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
65
+ self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
66
+ nn.ReLU(inplace=True),
67
+ nn.Linear(self.inchannels // reduction, self.inchannels),
68
+ nn.Sigmoid())
69
+ self.init_params()
70
+
71
+ def forward(self, low_x, high_x):
72
+ n, c, _, _ = low_x.size()
73
+ x = torch.cat([low_x, high_x], 1)
74
+ x = self.avg_pool(x)
75
+ x = x.view(n, -1)
76
+ x = self.fc(x).view(n,c,1,1)
77
+ x = low_x * x + high_x
78
+
79
+ return x
80
+
81
+ def init_params(self):
82
+ for m in self.modules():
83
+ if isinstance(m, nn.Conv2d):
84
+ #init.kaiming_normal_(m.weight, mode='fan_out')
85
+ #init.normal(m.weight, std=0.01)
86
+ init.xavier_normal_(m.weight)
87
+ if m.bias is not None:
88
+ init.constant_(m.bias, 0)
89
+ elif isinstance(m, nn.ConvTranspose2d):
90
+ #init.kaiming_normal_(m.weight, mode='fan_out')
91
+ #init.normal_(m.weight, std=0.01)
92
+ init.xavier_normal_(m.weight)
93
+ if m.bias is not None:
94
+ init.constant_(m.bias, 0)
95
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
96
+ init.constant_(m.weight, 1)
97
+ init.constant_(m.bias, 0)
98
+ elif isinstance(m, nn.Linear):
99
+ init.normal_(m.weight, std=0.01)
100
+ if m.bias is not None:
101
+ init.constant_(m.bias, 0)
102
+
103
+
104
+ class FFM(nn.Module):
105
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
106
+ super(FFM, self).__init__()
107
+ self.inchannels = inchannels
108
+ self.midchannels = midchannels
109
+ self.outchannels = outchannels
110
+ self.upfactor = upfactor
111
+
112
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
113
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
114
+
115
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
116
+
117
+ self.init_params()
118
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
119
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
120
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
121
+
122
+ def forward(self, low_x, high_x):
123
+ x = self.ftb1(low_x)
124
+
125
+ '''
126
+ x = torch.cat((x,high_x),1)
127
+ if x.shape[2] == 12:
128
+ x = self.p1(x)
129
+ elif x.shape[2] == 24:
130
+ x = self.p2(x)
131
+ elif x.shape[2] == 48:
132
+ x = self.p3(x)
133
+ '''
134
+ x = x + high_x ###high_x
135
+ x = self.ftb2(x)
136
+ x = self.upsample(x)
137
+
138
+ return x
139
+
140
+ def init_params(self):
141
+ for m in self.modules():
142
+ if isinstance(m, nn.Conv2d):
143
+ #init.kaiming_normal_(m.weight, mode='fan_out')
144
+ init.normal_(m.weight, std=0.01)
145
+ #init.xavier_normal_(m.weight)
146
+ if m.bias is not None:
147
+ init.constant_(m.bias, 0)
148
+ elif isinstance(m, nn.ConvTranspose2d):
149
+ #init.kaiming_normal_(m.weight, mode='fan_out')
150
+ init.normal_(m.weight, std=0.01)
151
+ #init.xavier_normal_(m.weight)
152
+ if m.bias is not None:
153
+ init.constant_(m.bias, 0)
154
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
155
+ init.constant_(m.weight, 1)
156
+ init.constant_(m.bias, 0)
157
+ elif isinstance(m, nn.Linear):
158
+ init.normal_(m.weight, std=0.01)
159
+ if m.bias is not None:
160
+ init.constant_(m.bias, 0)
161
+
162
+
163
+
164
+ class noFFM(nn.Module):
165
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
166
+ super(noFFM, self).__init__()
167
+ self.inchannels = inchannels
168
+ self.midchannels = midchannels
169
+ self.outchannels = outchannels
170
+ self.upfactor = upfactor
171
+
172
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
173
+
174
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
175
+
176
+ self.init_params()
177
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
178
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
179
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
180
+
181
+ def forward(self, low_x, high_x):
182
+
183
+ #x = self.ftb1(low_x)
184
+ x = high_x ###high_x
185
+ x = self.ftb2(x)
186
+ x = self.upsample(x)
187
+
188
+ return x
189
+
190
+ def init_params(self):
191
+ for m in self.modules():
192
+ if isinstance(m, nn.Conv2d):
193
+ #init.kaiming_normal_(m.weight, mode='fan_out')
194
+ init.normal_(m.weight, std=0.01)
195
+ #init.xavier_normal_(m.weight)
196
+ if m.bias is not None:
197
+ init.constant_(m.bias, 0)
198
+ elif isinstance(m, nn.ConvTranspose2d):
199
+ #init.kaiming_normal_(m.weight, mode='fan_out')
200
+ init.normal_(m.weight, std=0.01)
201
+ #init.xavier_normal_(m.weight)
202
+ if m.bias is not None:
203
+ init.constant_(m.bias, 0)
204
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
205
+ init.constant_(m.weight, 1)
206
+ init.constant_(m.bias, 0)
207
+ elif isinstance(m, nn.Linear):
208
+ init.normal_(m.weight, std=0.01)
209
+ if m.bias is not None:
210
+ init.constant_(m.bias, 0)
211
+
212
+
213
+
214
+
215
+ class AO(nn.Module):
216
+ # Adaptive output module
217
+ def __init__(self, inchannels, outchannels, upfactor=2):
218
+ super(AO, self).__init__()
219
+ self.inchannels = inchannels
220
+ self.outchannels = outchannels
221
+ self.upfactor = upfactor
222
+
223
+ """
224
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
225
+ nn.BatchNorm2d(num_features=self.inchannels//2),\
226
+ nn.ReLU(inplace=True),\
227
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
228
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
229
+ #nn.ReLU(inplace=True)) ## get positive values
230
+ """
231
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
232
+ #nn.BatchNorm2d(num_features=self.inchannels//2),\
233
+ nn.ReLU(inplace=True),\
234
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
235
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
236
+
237
+ #nn.ReLU(inplace=True)) ## get positive values
238
+
239
+ self.init_params()
240
+
241
+ def forward(self, x):
242
+ x = self.adapt_conv(x)
243
+ return x
244
+
245
+ def init_params(self):
246
+ for m in self.modules():
247
+ if isinstance(m, nn.Conv2d):
248
+ #init.kaiming_normal_(m.weight, mode='fan_out')
249
+ init.normal_(m.weight, std=0.01)
250
+ #init.xavier_normal_(m.weight)
251
+ if m.bias is not None:
252
+ init.constant_(m.bias, 0)
253
+ elif isinstance(m, nn.ConvTranspose2d):
254
+ #init.kaiming_normal_(m.weight, mode='fan_out')
255
+ init.normal_(m.weight, std=0.01)
256
+ #init.xavier_normal_(m.weight)
257
+ if m.bias is not None:
258
+ init.constant_(m.bias, 0)
259
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
260
+ init.constant_(m.weight, 1)
261
+ init.constant_(m.bias, 0)
262
+ elif isinstance(m, nn.Linear):
263
+ init.normal_(m.weight, std=0.01)
264
+ if m.bias is not None:
265
+ init.constant_(m.bias, 0)
266
+
267
+ class ASPP(nn.Module):
268
+ def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
269
+ super(ASPP, self).__init__()
270
+ self.inchannels = inchannels
271
+ self.planes = planes
272
+ self.rates = rates
273
+ self.kernel_sizes = []
274
+ self.paddings = []
275
+ for rate in self.rates:
276
+ if rate == 1:
277
+ self.kernel_sizes.append(1)
278
+ self.paddings.append(0)
279
+ else:
280
+ self.kernel_sizes.append(3)
281
+ self.paddings.append(rate)
282
+ self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
283
+ stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
284
+ nn.ReLU(inplace=True),
285
+ nn.BatchNorm2d(num_features=self.planes)
286
+ )
287
+ self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
288
+ stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
289
+ nn.ReLU(inplace=True),
290
+ nn.BatchNorm2d(num_features=self.planes),
291
+ )
292
+ self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
293
+ stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
294
+ nn.ReLU(inplace=True),
295
+ nn.BatchNorm2d(num_features=self.planes),
296
+ )
297
+ self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
298
+ stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
299
+ nn.ReLU(inplace=True),
300
+ nn.BatchNorm2d(num_features=self.planes),
301
+ )
302
+
303
+ #self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
304
+ def forward(self, x):
305
+ x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
306
+ #x = self.conv(x)
307
+
308
+ return x
309
+
310
+ # ==============================================================================================================
311
+
312
+
313
+ class ResidualConv(nn.Module):
314
+ def __init__(self, inchannels):
315
+ super(ResidualConv, self).__init__()
316
+ #nn.BatchNorm2d
317
+ self.conv = nn.Sequential(
318
+ #nn.BatchNorm2d(num_features=inchannels),
319
+ nn.ReLU(inplace=False),
320
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
321
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
322
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
323
+ nn.BatchNorm2d(num_features=inchannels//2),
324
+ nn.ReLU(inplace=False),
325
+ nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
326
+ )
327
+ self.init_params()
328
+
329
+ def forward(self, x):
330
+ x = self.conv(x)+x
331
+ return x
332
+
333
+ def init_params(self):
334
+ for m in self.modules():
335
+ if isinstance(m, nn.Conv2d):
336
+ #init.kaiming_normal_(m.weight, mode='fan_out')
337
+ init.normal_(m.weight, std=0.01)
338
+ #init.xavier_normal_(m.weight)
339
+ if m.bias is not None:
340
+ init.constant_(m.bias, 0)
341
+ elif isinstance(m, nn.ConvTranspose2d):
342
+ #init.kaiming_normal_(m.weight, mode='fan_out')
343
+ init.normal_(m.weight, std=0.01)
344
+ #init.xavier_normal_(m.weight)
345
+ if m.bias is not None:
346
+ init.constant_(m.bias, 0)
347
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
348
+ init.constant_(m.weight, 1)
349
+ init.constant_(m.bias, 0)
350
+ elif isinstance(m, nn.Linear):
351
+ init.normal_(m.weight, std=0.01)
352
+ if m.bias is not None:
353
+ init.constant_(m.bias, 0)
354
+
355
+
356
+ class FeatureFusion(nn.Module):
357
+ def __init__(self, inchannels, outchannels):
358
+ super(FeatureFusion, self).__init__()
359
+ self.conv = ResidualConv(inchannels=inchannels)
360
+ #nn.BatchNorm2d
361
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
362
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
363
+ nn.BatchNorm2d(num_features=outchannels),
364
+ nn.ReLU(inplace=True))
365
+
366
+ def forward(self, lowfeat, highfeat):
367
+ return self.up(highfeat + self.conv(lowfeat))
368
+
369
+ def init_params(self):
370
+ for m in self.modules():
371
+ if isinstance(m, nn.Conv2d):
372
+ #init.kaiming_normal_(m.weight, mode='fan_out')
373
+ init.normal_(m.weight, std=0.01)
374
+ #init.xavier_normal_(m.weight)
375
+ if m.bias is not None:
376
+ init.constant_(m.bias, 0)
377
+ elif isinstance(m, nn.ConvTranspose2d):
378
+ #init.kaiming_normal_(m.weight, mode='fan_out')
379
+ init.normal_(m.weight, std=0.01)
380
+ #init.xavier_normal_(m.weight)
381
+ if m.bias is not None:
382
+ init.constant_(m.bias, 0)
383
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
384
+ init.constant_(m.weight, 1)
385
+ init.constant_(m.bias, 0)
386
+ elif isinstance(m, nn.Linear):
387
+ init.normal_(m.weight, std=0.01)
388
+ if m.bias is not None:
389
+ init.constant_(m.bias, 0)
390
+
391
+
392
+ class SenceUnderstand(nn.Module):
393
+ def __init__(self, channels):
394
+ super(SenceUnderstand, self).__init__()
395
+ self.channels = channels
396
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
397
+ nn.ReLU(inplace = True))
398
+ self.pool = nn.AdaptiveAvgPool2d(8)
399
+ self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
400
+ nn.ReLU(inplace = True))
401
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
402
+ nn.ReLU(inplace=True))
403
+ self.initial_params()
404
+
405
+ def forward(self, x):
406
+ n,c,h,w = x.size()
407
+ x = self.conv1(x)
408
+ x = self.pool(x)
409
+ x = x.view(n,-1)
410
+ x = self.fc(x)
411
+ x = x.view(n, self.channels, 1, 1)
412
+ x = self.conv2(x)
413
+ x = x.repeat(1,1,h,w)
414
+ return x
415
+
416
+ def initial_params(self, dev=0.01):
417
+ for m in self.modules():
418
+ if isinstance(m, nn.Conv2d):
419
+ #print torch.sum(m.weight)
420
+ m.weight.data.normal_(0, dev)
421
+ if m.bias is not None:
422
+ m.bias.data.fill_(0)
423
+ elif isinstance(m, nn.ConvTranspose2d):
424
+ #print torch.sum(m.weight)
425
+ m.weight.data.normal_(0, dev)
426
+ if m.bias is not None:
427
+ m.bias.data.fill_(0)
428
+ elif isinstance(m, nn.Linear):
429
+ m.weight.data.normal_(0, dev)
models/SpaTrackV2/models/depth_refiner/stablilization_attention.py ADDED
@@ -0,0 +1,1187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.utils.checkpoint as checkpoint
7
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
8
+ from einops import rearrange
9
+
10
+ class Mlp(nn.Module):
11
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
12
+ super().__init__()
13
+ out_features = out_features or in_features
14
+ hidden_features = hidden_features or in_features
15
+ self.fc1 = nn.Linear(in_features, hidden_features)
16
+ self.act = act_layer()
17
+ self.fc2 = nn.Linear(hidden_features, out_features)
18
+ self.drop = nn.Dropout(drop)
19
+
20
+ def forward(self, x):
21
+ x = self.fc1(x)
22
+ x = self.act(x)
23
+ x = self.drop(x)
24
+ x = self.fc2(x)
25
+ x = self.drop(x)
26
+ return x
27
+
28
+
29
+ def window_partition(x, window_size):
30
+ """
31
+ Args:
32
+ x: (B, H, W, C)
33
+ window_size (int): window size
34
+
35
+ Returns:
36
+ windows: (num_windows*B, window_size, window_size, C)
37
+ """
38
+ B, H, W, C = x.shape
39
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
40
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
41
+ return windows
42
+
43
+ def window_partition_noreshape(x, window_size):
44
+ """
45
+ Args:
46
+ x: (B, H, W, C)
47
+ window_size (int): window size
48
+
49
+ Returns:
50
+ windows: (B, num_windows_h, num_windows_w, window_size, window_size, C)
51
+ """
52
+ B, H, W, C = x.shape
53
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
54
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
55
+ return windows
56
+
57
+ def window_reverse(windows, window_size, H, W):
58
+ """
59
+ Args:
60
+ windows: (num_windows*B, window_size, window_size, C)
61
+ window_size (int): Window size
62
+ H (int): Height of image
63
+ W (int): Width of image
64
+
65
+ Returns:
66
+ x: (B, H, W, C)
67
+ """
68
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
69
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
70
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
71
+ return x
72
+
73
+ def get_roll_masks(H, W, window_size, shift_size):
74
+ #####################################
75
+ # move to top-left
76
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
77
+ h_slices = (slice(0, H-window_size),
78
+ slice(H-window_size, H-shift_size),
79
+ slice(H-shift_size, H))
80
+ w_slices = (slice(0, W-window_size),
81
+ slice(W-window_size, W-shift_size),
82
+ slice(W-shift_size, W))
83
+ cnt = 0
84
+ for h in h_slices:
85
+ for w in w_slices:
86
+ img_mask[:, h, w, :] = cnt
87
+ cnt += 1
88
+
89
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
90
+ mask_windows = mask_windows.view(-1, window_size * window_size)
91
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
92
+ attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
93
+
94
+ ####################################
95
+ # move to top right
96
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
97
+ h_slices = (slice(0, H-window_size),
98
+ slice(H-window_size, H-shift_size),
99
+ slice(H-shift_size, H))
100
+ w_slices = (slice(0, shift_size),
101
+ slice(shift_size, window_size),
102
+ slice(window_size, W))
103
+ cnt = 0
104
+ for h in h_slices:
105
+ for w in w_slices:
106
+ img_mask[:, h, w, :] = cnt
107
+ cnt += 1
108
+
109
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
110
+ mask_windows = mask_windows.view(-1, window_size * window_size)
111
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
112
+ attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
113
+
114
+ ####################################
115
+ # move to bottom left
116
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
117
+ h_slices = (slice(0, shift_size),
118
+ slice(shift_size, window_size),
119
+ slice(window_size, H))
120
+ w_slices = (slice(0, W-window_size),
121
+ slice(W-window_size, W-shift_size),
122
+ slice(W-shift_size, W))
123
+ cnt = 0
124
+ for h in h_slices:
125
+ for w in w_slices:
126
+ img_mask[:, h, w, :] = cnt
127
+ cnt += 1
128
+
129
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
130
+ mask_windows = mask_windows.view(-1, window_size * window_size)
131
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
132
+ attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
133
+
134
+ ####################################
135
+ # move to bottom right
136
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
137
+ h_slices = (slice(0, shift_size),
138
+ slice(shift_size, window_size),
139
+ slice(window_size, H))
140
+ w_slices = (slice(0, shift_size),
141
+ slice(shift_size, window_size),
142
+ slice(window_size, W))
143
+ cnt = 0
144
+ for h in h_slices:
145
+ for w in w_slices:
146
+ img_mask[:, h, w, :] = cnt
147
+ cnt += 1
148
+
149
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
150
+ mask_windows = mask_windows.view(-1, window_size * window_size)
151
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
152
+ attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
153
+
154
+ # append all
155
+ attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1)
156
+ return attn_mask_all
157
+
158
+ def get_relative_position_index(q_windows, k_windows):
159
+ """
160
+ Args:
161
+ q_windows: tuple (query_window_height, query_window_width)
162
+ k_windows: tuple (key_window_height, key_window_width)
163
+
164
+ Returns:
165
+ relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
166
+ """
167
+ # get pair-wise relative position index for each token inside the window
168
+ coords_h_q = torch.arange(q_windows[0])
169
+ coords_w_q = torch.arange(q_windows[1])
170
+ coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
171
+
172
+ coords_h_k = torch.arange(k_windows[0])
173
+ coords_w_k = torch.arange(k_windows[1])
174
+ coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww
175
+
176
+ coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
177
+ coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
178
+
179
+ relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
180
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
181
+ relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0
182
+ relative_coords[:, :, 1] += k_windows[1] - 1
183
+ relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1
184
+ relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
185
+ return relative_position_index
186
+
187
+ def get_relative_position_index3d(q_windows, k_windows, num_clips):
188
+ """
189
+ Args:
190
+ q_windows: tuple (query_window_height, query_window_width)
191
+ k_windows: tuple (key_window_height, key_window_width)
192
+
193
+ Returns:
194
+ relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
195
+ """
196
+ # get pair-wise relative position index for each token inside the window
197
+ coords_d_q = torch.arange(num_clips)
198
+ coords_h_q = torch.arange(q_windows[0])
199
+ coords_w_q = torch.arange(q_windows[1])
200
+ coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
201
+
202
+ coords_d_k = torch.arange(num_clips)
203
+ coords_h_k = torch.arange(k_windows[0])
204
+ coords_w_k = torch.arange(k_windows[1])
205
+ coords_k = torch.stack(torch.meshgrid([coords_d_k, coords_h_k, coords_w_k])) # 2, Wh, Ww
206
+
207
+ coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
208
+ coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
209
+
210
+ relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
211
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
212
+ relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
213
+ relative_coords[:, :, 1] += k_windows[0] - 1
214
+ relative_coords[:, :, 2] += k_windows[1] - 1
215
+ relative_coords[:, :, 0] *= (q_windows[0] + k_windows[0] - 1)*(q_windows[1] + k_windows[1] - 1)
216
+ relative_coords[:, :, 1] *= (q_windows[1] + k_windows[1] - 1)
217
+ relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
218
+ return relative_position_index
219
+
220
+
221
+ class WindowAttention3d3(nn.Module):
222
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
223
+
224
+ Args:
225
+ dim (int): Number of input channels.
226
+ expand_size (int): The expand size at focal level 1.
227
+ window_size (tuple[int]): The height and width of the window.
228
+ focal_window (int): Focal region size.
229
+ focal_level (int): Focal attention level.
230
+ num_heads (int): Number of attention heads.
231
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
232
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
233
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
234
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
235
+ pool_method (str): window pooling method. Default: none
236
+ """
237
+
238
+ def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads,
239
+ qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none", focal_l_clips=[7,1,2], focal_kernel_clips=[7,5,3]):
240
+
241
+ super().__init__()
242
+ self.dim = dim
243
+ self.expand_size = expand_size
244
+ self.window_size = window_size # Wh, Ww
245
+ self.pool_method = pool_method
246
+ self.num_heads = num_heads
247
+ head_dim = dim // num_heads
248
+ self.scale = qk_scale or head_dim ** -0.5
249
+ self.focal_level = focal_level
250
+ self.focal_window = focal_window
251
+
252
+ # define a parameter table of relative position bias for each window
253
+ self.relative_position_bias_table = nn.Parameter(
254
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
255
+
256
+ # get pair-wise relative position index for each token inside the window
257
+ coords_h = torch.arange(self.window_size[0])
258
+ coords_w = torch.arange(self.window_size[1])
259
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
260
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
261
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
262
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
263
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
264
+ relative_coords[:, :, 1] += self.window_size[1] - 1
265
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
266
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
267
+ self.register_buffer("relative_position_index", relative_position_index)
268
+
269
+ num_clips=4
270
+ # # define a parameter table of relative position bias
271
+ # self.relative_position_bias_table = nn.Parameter(
272
+ # torch.zeros((2 * num_clips - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
273
+
274
+ # # get pair-wise relative position index for each token inside the window
275
+ # coords_d = torch.arange(num_clips)
276
+ # coords_h = torch.arange(self.window_size[0])
277
+ # coords_w = torch.arange(self.window_size[1])
278
+ # coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
279
+ # coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
280
+ # relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
281
+ # relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
282
+ # relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
283
+ # relative_coords[:, :, 1] += self.window_size[0] - 1
284
+ # relative_coords[:, :, 2] += self.window_size[1] - 1
285
+
286
+ # relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
287
+ # relative_coords[:, :, 1] *= (2 * self.window_size[1] - 1)
288
+ # relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
289
+ # self.register_buffer("relative_position_index", relative_position_index)
290
+
291
+
292
+ if self.expand_size > 0 and focal_level > 0:
293
+ # define a parameter table of position bias between window and its fine-grained surroundings
294
+ self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \
295
+ (4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (self.window_size[0] - self.expand_size))
296
+ self.relative_position_bias_table_to_neighbors = nn.Parameter(
297
+ torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], self.window_size_of_key)) # Wh*Ww, nH, nSurrounding
298
+ trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02)
299
+
300
+ # get mask for rolled k and rolled v
301
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1]); mask_tl[:-self.expand_size, :-self.expand_size] = 0
302
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1]); mask_tr[:-self.expand_size, self.expand_size:] = 0
303
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1]); mask_bl[self.expand_size:, :-self.expand_size] = 0
304
+ mask_br = torch.ones(self.window_size[0], self.window_size[1]); mask_br[self.expand_size:, self.expand_size:] = 0
305
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
306
+ self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1))
307
+
308
+ if pool_method != "none" and focal_level > 1:
309
+ #self.relative_position_bias_table_to_windows = nn.ParameterList()
310
+ #self.relative_position_bias_table_to_windows_clips = nn.ParameterList()
311
+ #self.register_parameter('relative_position_bias_table_to_windows',[])
312
+ #self.register_parameter('relative_position_bias_table_to_windows_clips',[])
313
+ self.unfolds = nn.ModuleList()
314
+ self.unfolds_clips=nn.ModuleList()
315
+
316
+ # build relative position bias between local patch and pooled windows
317
+ for k in range(focal_level-1):
318
+ stride = 2**k
319
+ kernel_size = 2*(self.focal_window // 2) + 2**k + (2**k-1)
320
+ # define unfolding operations
321
+ self.unfolds += [nn.Unfold(
322
+ kernel_size=(kernel_size, kernel_size),
323
+ stride=stride, padding=kernel_size // 2)
324
+ ]
325
+
326
+ # define relative position bias table
327
+ relative_position_bias_table_to_windows = nn.Parameter(
328
+ torch.zeros(
329
+ self.num_heads,
330
+ (self.window_size[0] + self.focal_window + 2**k - 2) * (self.window_size[1] + self.focal_window + 2**k - 2),
331
+ )
332
+ )
333
+ trunc_normal_(relative_position_bias_table_to_windows, std=.02)
334
+ #self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows)
335
+ self.register_parameter('relative_position_bias_table_to_windows_{}'.format(k),relative_position_bias_table_to_windows)
336
+
337
+ # define relative position bias index
338
+ relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(self.focal_window + 2**k - 1))
339
+ # relative_position_index_k = get_relative_position_index3d(self.window_size, to_2tuple(self.focal_window + 2**k - 1), num_clips)
340
+ self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k)
341
+
342
+ # define unfolding index for focal_level > 0
343
+ if k > 0:
344
+ mask = torch.zeros(kernel_size, kernel_size); mask[(2**k)-1:, (2**k)-1:] = 1
345
+ self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1))
346
+
347
+ for k in range(len(focal_l_clips)):
348
+ # kernel_size=focal_kernel_clips[k]
349
+ focal_l_big_flag=False
350
+ if focal_l_clips[k]>self.window_size[0]:
351
+ stride=1
352
+ padding=0
353
+ kernel_size=focal_kernel_clips[k]
354
+ kernel_size_true=kernel_size
355
+ focal_l_big_flag=True
356
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
357
+ # padding=(kernel_size-stride)/2
358
+ else:
359
+ stride = focal_l_clips[k]
360
+ # kernel_size
361
+ # kernel_size = 2*(focal_kernel_clips[k]// 2) + 2**focal_l_clips[k] + (2**focal_l_clips[k]-1)
362
+ kernel_size = focal_kernel_clips[k] ## kernel_size must be jishu
363
+ assert kernel_size%2==1
364
+ padding=kernel_size // 2
365
+ # kernel_size_true=focal_kernel_clips[k]+2**focal_l_clips[k]-1
366
+ kernel_size_true=kernel_size
367
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
368
+
369
+ self.unfolds_clips += [nn.Unfold(
370
+ kernel_size=(kernel_size, kernel_size),
371
+ stride=stride,
372
+ padding=padding)
373
+ ]
374
+ relative_position_bias_table_to_windows = nn.Parameter(
375
+ torch.zeros(
376
+ self.num_heads,
377
+ (self.window_size[0] + kernel_size_true - 1) * (self.window_size[0] + kernel_size_true - 1),
378
+ )
379
+ )
380
+ trunc_normal_(relative_position_bias_table_to_windows, std=.02)
381
+ #self.relative_position_bias_table_to_windows_clips.append(relative_position_bias_table_to_windows)
382
+ self.register_parameter('relative_position_bias_table_to_windows_clips_{}'.format(k),relative_position_bias_table_to_windows)
383
+ relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(kernel_size_true))
384
+ self.register_buffer("relative_position_index_clips_{}".format(k), relative_position_index_k)
385
+ # if (not focal_l_big_flag) and focal_l_clips[k]>0:
386
+ # mask = torch.zeros(kernel_size, kernel_size); mask[(2**focal_l_clips[k])-1:, (2**focal_l_clips[k])-1:] = 1
387
+ # self.register_buffer("valid_ind_unfold_clips_{}".format(k), mask.flatten(0).nonzero().view(-1))
388
+
389
+
390
+
391
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
392
+ self.attn_drop = nn.Dropout(attn_drop)
393
+ self.proj = nn.Linear(dim, dim)
394
+ self.proj_drop = nn.Dropout(proj_drop)
395
+
396
+ self.softmax = nn.Softmax(dim=-1)
397
+ self.focal_l_clips=focal_l_clips
398
+ self.focal_kernel_clips=focal_kernel_clips
399
+
400
+ def forward(self, x_all, mask_all=None, batch_size=None, num_clips=None):
401
+ """
402
+ Args:
403
+ x_all (list[Tensors]): input features at different granularity
404
+ mask_all (list[Tensors/None]): masks for input features at different granularity
405
+ """
406
+ x = x_all[0][0] #
407
+
408
+ B0, nH, nW, C = x.shape
409
+ # assert B==batch_size*num_clips
410
+ assert B0==batch_size
411
+ qkv = self.qkv(x).reshape(B0, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous()
412
+ q, k, v = qkv[0], qkv[1], qkv[2] # B0, nH, nW, C
413
+
414
+ # partition q map
415
+ # print("x.shape: ", x.shape)
416
+ # print("q.shape: ", q.shape) # [4, 126, 126, 256]
417
+ (q_windows, k_windows, v_windows) = map(
418
+ lambda t: window_partition(t, self.window_size[0]).view(
419
+ -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
420
+ ).transpose(1, 2),
421
+ (q, k, v)
422
+ )
423
+
424
+ # q_dim0, q_dim1, q_dim2, q_dim3=q_windows.shape
425
+ # q_windows=q_windows.view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), q_dim1, q_dim2, q_dim3)
426
+ # q_windows=q_windows[:,-1].contiguous().view(-1, q_dim1, q_dim2, q_dim3) # query for the last frame (target frame)
427
+
428
+ # k_windows.shape [1296, 8, 49, 32]
429
+
430
+ if self.expand_size > 0 and self.focal_level > 0:
431
+ (k_tl, v_tl) = map(
432
+ lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
433
+ )
434
+ (k_tr, v_tr) = map(
435
+ lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
436
+ )
437
+ (k_bl, v_bl) = map(
438
+ lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
439
+ )
440
+ (k_br, v_br) = map(
441
+ lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
442
+ )
443
+
444
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
445
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
446
+ (k_tl, k_tr, k_bl, k_br)
447
+ )
448
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
449
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
450
+ (v_tl, v_tr, v_bl, v_br)
451
+ )
452
+ k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
453
+ v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
454
+
455
+ # mask out tokens in current window
456
+ # print("self.valid_ind_rolled.shape: ", self.valid_ind_rolled.shape) # [132]
457
+ # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 196, 32]
458
+ k_rolled = k_rolled[:, :, self.valid_ind_rolled]
459
+ v_rolled = v_rolled[:, :, self.valid_ind_rolled]
460
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
461
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
462
+ else:
463
+ k_rolled = k_windows; v_rolled = v_windows;
464
+
465
+ # print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 181, 32]
466
+
467
+ if self.pool_method != "none" and self.focal_level > 1:
468
+ k_pooled = []
469
+ v_pooled = []
470
+ for k in range(self.focal_level-1):
471
+ stride = 2**k
472
+ x_window_pooled = x_all[0][k+1] # B0, nWh, nWw, C
473
+ nWh, nWw = x_window_pooled.shape[1:3]
474
+
475
+ # generate mask for pooled windows
476
+ # print("x_window_pooled.shape: ", x_window_pooled.shape)
477
+ mask = x_window_pooled.new(nWh, nWw).fill_(1)
478
+ # print("here: ",x_window_pooled.shape, self.unfolds[k].kernel_size, self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).shape)
479
+ # print(mask.unique())
480
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
481
+ 1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
482
+ view(nWh*nWw // stride // stride, -1, 1)
483
+
484
+ if k > 0:
485
+ valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
486
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
487
+
488
+ # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
489
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
490
+ # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
491
+ x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
492
+ # print(x_window_masks.shape)
493
+ mask_all[0][k+1] = x_window_masks
494
+
495
+ # generate k and v for pooled windows
496
+ qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
497
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
498
+
499
+
500
+ (k_pooled_k, v_pooled_k) = map(
501
+ lambda t: self.unfolds[k](t).view(
502
+ B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
503
+ view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
504
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
505
+ )
506
+
507
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
508
+ # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
509
+
510
+ if k > 0:
511
+ (k_pooled_k, v_pooled_k) = map(
512
+ lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
513
+ )
514
+
515
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
516
+
517
+ k_pooled += [k_pooled_k]
518
+ v_pooled += [v_pooled_k]
519
+
520
+ for k in range(len(self.focal_l_clips)):
521
+ focal_l_big_flag=False
522
+ if self.focal_l_clips[k]>self.window_size[0]:
523
+ stride=1
524
+ focal_l_big_flag=True
525
+ else:
526
+ stride = self.focal_l_clips[k]
527
+ # if self.window_size>=focal_l_clips[k]:
528
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
529
+ # # padding=(kernel_size-stride)/2
530
+ # else:
531
+ # stride=1
532
+ # padding=0
533
+ x_window_pooled = x_all[k+1]
534
+ nWh, nWw = x_window_pooled.shape[1:3]
535
+ mask = x_window_pooled.new(nWh, nWw).fill_(1)
536
+
537
+ # import pdb; pdb.set_trace()
538
+ # print(x_window_pooled.shape, self.unfolds_clips[k].kernel_size, self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).shape)
539
+
540
+ unfolded_mask = self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).view(
541
+ 1, 1, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
542
+ view(nWh*nWw // stride // stride, -1, 1)
543
+
544
+ # if (not focal_l_big_flag) and self.focal_l_clips[k]>0:
545
+ # valid_ind_unfold_k = getattr(self, "valid_ind_unfold_clips_{}".format(k))
546
+ # unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
547
+
548
+ # print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
549
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
550
+ # print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
551
+ x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
552
+ # print(x_window_masks.shape)
553
+ mask_all[k+1] = x_window_masks
554
+
555
+ # generate k and v for pooled windows
556
+ qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
557
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
558
+
559
+ if (not focal_l_big_flag):
560
+ (k_pooled_k, v_pooled_k) = map(
561
+ lambda t: self.unfolds_clips[k](t).view(
562
+ B0, C, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
563
+ view(-1, self.unfolds_clips[k].kernel_size[0]*self.unfolds_clips[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
564
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
565
+ )
566
+ else:
567
+
568
+ (k_pooled_k, v_pooled_k) = map(
569
+ lambda t: self.unfolds_clips[k](t),
570
+ (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
571
+ )
572
+ LLL=k_pooled_k.size(2)
573
+ LLL_h=int(LLL**0.5)
574
+ assert LLL_h**2==LLL
575
+ k_pooled_k=k_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
576
+ v_pooled_k=v_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
577
+
578
+
579
+
580
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
581
+ # print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
582
+ # if (not focal_l_big_flag) and self.focal_l_clips[k]:
583
+ # (k_pooled_k, v_pooled_k) = map(
584
+ # lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
585
+ # )
586
+
587
+ # print("k_pooled_k.shape: ", k_pooled_k.shape)
588
+
589
+ k_pooled += [k_pooled_k]
590
+ v_pooled += [v_pooled_k]
591
+
592
+ # qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
593
+ # k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
594
+ # (k_pooled_k, v_pooled_k) = map(
595
+ # lambda t: self.unfolds[k](t).view(
596
+ # B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
597
+ # view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
598
+ # (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
599
+ # )
600
+ # k_pooled += [k_pooled_k]
601
+ # v_pooled += [v_pooled_k]
602
+
603
+
604
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
605
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
606
+ else:
607
+ k_all = k_rolled
608
+ v_all = v_rolled
609
+
610
+ N = k_all.shape[-2]
611
+ q_windows = q_windows * self.scale
612
+ # print(q_windows.shape, k_all.shape, v_all.shape)
613
+ # exit()
614
+ # k_all_dim0, k_all_dim1, k_all_dim2, k_all_dim3=k_all.shape
615
+ # k_all=k_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
616
+ # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
617
+ # v_all=v_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
618
+ # k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
619
+
620
+ # print(q_windows.shape, k_all.shape, v_all.shape, k_rolled.shape)
621
+ # exit()
622
+ attn = (q_windows @ k_all.transpose(-2, -1)) # B0*nW, nHead, window_size*window_size, focal_window_size*focal_window_size
623
+
624
+ window_area = self.window_size[0] * self.window_size[1]
625
+ # window_area_clips= num_clips*self.window_size[0] * self.window_size[1]
626
+ window_area_rolled = k_rolled.shape[2]
627
+
628
+ # add relative position bias for tokens inside window
629
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
630
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
631
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
632
+ # print(relative_position_bias.shape, attn.shape)
633
+ attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, :window_area] + relative_position_bias.unsqueeze(0)
634
+
635
+ # relative_position_bias = self.relative_position_bias_table[self.relative_position_index[-window_area:, :window_area_clips].reshape(-1)].view(
636
+ # window_area, window_area_clips, -1) # Wh*Ww,Wd*Wh*Ww,nH
637
+ # relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_heads,window_area,num_clips,window_area
638
+ # ).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,window_area_clips).contiguous() # nH, Wh*Ww, Wh*Ww*Wd
639
+ # # attn_dim0, attn_dim1, attn_dim2, attn_dim3=attn.shape
640
+ # # attn=attn.view(attn_dim0,attn_dim1,attn_dim2,num_clips,-1)
641
+ # # print(attn.shape, relative_position_bias.shape)
642
+ # attn[:,:,:window_area, :window_area_clips]=attn[:,:,:window_area, :window_area_clips] + relative_position_bias.unsqueeze(0)
643
+ # attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
644
+
645
+ # add relative position bias for patches inside a window
646
+ if self.expand_size > 0 and self.focal_level > 0:
647
+ attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors
648
+
649
+ if self.pool_method != "none" and self.focal_level > 1:
650
+ # add relative position bias for different windows in an image
651
+ offset = window_area_rolled
652
+ # print(offset)
653
+ for k in range(self.focal_level-1):
654
+ # add relative position bias
655
+ relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
656
+ relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
657
+ -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
658
+ ) # nH, NWh*NWw,focal_region*focal_region
659
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
660
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
661
+ # add attentional mask
662
+ if mask_all[0][k+1] is not None:
663
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
664
+ attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
665
+ mask_all[0][k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[0][k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[0][k+1].shape[-1])
666
+
667
+ offset += (self.focal_window+2**k-1)**2
668
+ # print(offset)
669
+ for k in range(len(self.focal_l_clips)):
670
+ focal_l_big_flag=False
671
+ if self.focal_l_clips[k]>self.window_size[0]:
672
+ stride=1
673
+ padding=0
674
+ kernel_size=self.focal_kernel_clips[k]
675
+ kernel_size_true=kernel_size
676
+ focal_l_big_flag=True
677
+ # stride=math.ceil(self.window_size/focal_l_clips[k])
678
+ # padding=(kernel_size-stride)/2
679
+ else:
680
+ stride = self.focal_l_clips[k]
681
+ # kernel_size
682
+ # kernel_size = 2*(self.focal_kernel_clips[k]// 2) + 2**self.focal_l_clips[k] + (2**self.focal_l_clips[k]-1)
683
+ kernel_size = self.focal_kernel_clips[k]
684
+ padding=kernel_size // 2
685
+ # kernel_size_true=self.focal_kernel_clips[k]+2**self.focal_l_clips[k]-1
686
+ kernel_size_true=kernel_size
687
+ relative_position_index_k = getattr(self, 'relative_position_index_clips_{}'.format(k))
688
+ relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_clips_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
689
+ -1, self.window_size[0] * self.window_size[1], (kernel_size_true)**2,
690
+ )
691
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
692
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + relative_position_bias_to_windows.unsqueeze(0)
693
+ if mask_all[k+1] is not None:
694
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
695
+ attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + \
696
+ mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
697
+ offset += (kernel_size_true)**2
698
+ # print(offset)
699
+ # relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
700
+ # # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k.view(-1)].view(
701
+ # # -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
702
+ # # ) # nH, NWh*NWw,focal_region*focal_region
703
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
704
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
705
+ # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k[-window_area:, :].view(-1)].view(
706
+ # -1, self.window_size[0] * self.window_size[1], num_clips*(self.focal_window+2**k-1)**2,
707
+ # ).contiguous() # nH, NWh*NWw, num_clips*focal_region*focal_region
708
+ # relative_position_bias_to_windows = relative_position_bias_to_windows.view(self.num_heads,
709
+ # window_area,num_clips,-1).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,-1)
710
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
711
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
712
+ # # add attentional mask
713
+ # if mask_all[k+1] is not None:
714
+ # # print("inside the mask, be careful 1")
715
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
716
+ # # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
717
+ # # mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
718
+ # # print("here: ", mask_all[k+1].shape, mask_all[k+1][:, :, None, None, :].shape)
719
+
720
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
721
+ # attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + \
722
+ # mask_all[k+1][:, :, None, None, :,None].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1, num_clips).view(-1, 1, 1, mask_all[k+1].shape[-1]*num_clips)
723
+ # # print()
724
+
725
+ # offset += (self.focal_window+2**k-1)**2
726
+
727
+ # print("mask_all[0]: ", mask_all[0])
728
+ # exit()
729
+ if mask_all[0][0] is not None:
730
+ print("inside the mask, be careful 0")
731
+ nW = mask_all[0].shape[0]
732
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N)
733
+ attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :]
734
+ attn = attn.view(-1, self.num_heads, window_area, N)
735
+ attn = self.softmax(attn)
736
+ else:
737
+ attn = self.softmax(attn)
738
+
739
+ attn = self.attn_drop(attn)
740
+
741
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C)
742
+ x = self.proj(x)
743
+ x = self.proj_drop(x)
744
+ # print(x.shape)
745
+ # x = x.view(B/num_clips, nH, nW, C )
746
+ # exit()
747
+ return x
748
+
749
+ def extra_repr(self) -> str:
750
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
751
+
752
+ def flops(self, N, window_size, unfold_size):
753
+ # calculate flops for 1 window with token length of N
754
+ flops = 0
755
+ # qkv = self.qkv(x)
756
+ flops += N * self.dim * 3 * self.dim
757
+ # attn = (q @ k.transpose(-2, -1))
758
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
759
+ if self.pool_method != "none" and self.focal_level > 1:
760
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
761
+ if self.expand_size > 0 and self.focal_level > 0:
762
+ flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
763
+
764
+ # x = (attn @ v)
765
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
766
+ if self.pool_method != "none" and self.focal_level > 1:
767
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
768
+ if self.expand_size > 0 and self.focal_level > 0:
769
+ flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
770
+
771
+ # x = self.proj(x)
772
+ flops += N * self.dim * self.dim
773
+ return flops
774
+
775
+
776
+ class CffmTransformerBlock3d3(nn.Module):
777
+ r""" Focal Transformer Block.
778
+
779
+ Args:
780
+ dim (int): Number of input channels.
781
+ input_resolution (tuple[int]): Input resulotion.
782
+ num_heads (int): Number of attention heads.
783
+ window_size (int): Window size.
784
+ expand_size (int): expand size at first focal level (finest level).
785
+ shift_size (int): Shift size for SW-MSA.
786
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
787
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
788
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
789
+ drop (float, optional): Dropout rate. Default: 0.0
790
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
791
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
792
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
793
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
794
+ pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
795
+ focal_level (int): number of focal levels. Default: 1.
796
+ focal_window (int): region size of focal attention. Default: 1
797
+ use_layerscale (bool): whether use layer scale for training stability. Default: False
798
+ layerscale_value (float): scaling value for layer scale. Default: 1e-4
799
+ """
800
+
801
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0,
802
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
803
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
804
+ focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[7,2,4], focal_kernel_clips=[7,5,3]):
805
+ super().__init__()
806
+ self.dim = dim
807
+ self.input_resolution = input_resolution
808
+ self.num_heads = num_heads
809
+ self.window_size = window_size
810
+ self.shift_size = shift_size
811
+ self.expand_size = expand_size
812
+ self.mlp_ratio = mlp_ratio
813
+ self.pool_method = pool_method
814
+ self.focal_level = focal_level
815
+ self.focal_window = focal_window
816
+ self.use_layerscale = use_layerscale
817
+ self.focal_l_clips=focal_l_clips
818
+ self.focal_kernel_clips=focal_kernel_clips
819
+
820
+ if min(self.input_resolution) <= self.window_size:
821
+ # if window size is larger than input resolution, we don't partition windows
822
+ self.expand_size = 0
823
+ self.shift_size = 0
824
+ self.window_size = min(self.input_resolution)
825
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
826
+
827
+ self.window_size_glo = self.window_size
828
+
829
+ self.pool_layers = nn.ModuleList()
830
+ self.pool_layers_clips = nn.ModuleList()
831
+ if self.pool_method != "none":
832
+ for k in range(self.focal_level-1):
833
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
834
+ if self.pool_method == "fc":
835
+ self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1))
836
+ self.pool_layers[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
837
+ self.pool_layers[-1].bias.data.fill_(0)
838
+ elif self.pool_method == "conv":
839
+ self.pool_layers.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
840
+ for k in range(len(focal_l_clips)):
841
+ # window_size_glo = math.floor(self.window_size_glo / (2 ** k))
842
+ if focal_l_clips[k]>self.window_size:
843
+ window_size_glo = focal_l_clips[k]
844
+ else:
845
+ window_size_glo = math.floor(self.window_size_glo / (focal_l_clips[k]))
846
+ # window_size_glo = focal_l_clips[k]
847
+ if self.pool_method == "fc":
848
+ self.pool_layers_clips.append(nn.Linear(window_size_glo * window_size_glo, 1))
849
+ self.pool_layers_clips[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
850
+ self.pool_layers_clips[-1].bias.data.fill_(0)
851
+ elif self.pool_method == "conv":
852
+ self.pool_layers_clips.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
853
+
854
+ self.norm1 = norm_layer(dim)
855
+
856
+ self.attn = WindowAttention3d3(
857
+ dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size),
858
+ focal_window=focal_window, focal_level=focal_level, num_heads=num_heads,
859
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method, focal_l_clips=focal_l_clips, focal_kernel_clips=focal_kernel_clips)
860
+
861
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
862
+ self.norm2 = norm_layer(dim)
863
+ mlp_hidden_dim = int(dim * mlp_ratio)
864
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
865
+
866
+ # print("******self.shift_size: ", self.shift_size)
867
+
868
+ if self.shift_size > 0:
869
+ # calculate attention mask for SW-MSA
870
+ H, W = self.input_resolution
871
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
872
+ h_slices = (slice(0, -self.window_size),
873
+ slice(-self.window_size, -self.shift_size),
874
+ slice(-self.shift_size, None))
875
+ w_slices = (slice(0, -self.window_size),
876
+ slice(-self.window_size, -self.shift_size),
877
+ slice(-self.shift_size, None))
878
+ cnt = 0
879
+ for h in h_slices:
880
+ for w in w_slices:
881
+ img_mask[:, h, w, :] = cnt
882
+ cnt += 1
883
+
884
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
885
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
886
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
887
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
888
+ else:
889
+ # print("here mask none")
890
+ attn_mask = None
891
+ self.register_buffer("attn_mask", attn_mask)
892
+
893
+ if self.use_layerscale:
894
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
895
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
896
+
897
+ def forward(self, x):
898
+ H0, W0 = self.input_resolution
899
+ # B, L, C = x.shape
900
+ B0, D0, H0, W0, C = x.shape
901
+ shortcut = x
902
+ # assert L == H * W, "input feature has wrong size"
903
+ x=x.reshape(B0*D0,H0,W0,C).reshape(B0*D0,H0*W0,C)
904
+
905
+
906
+ x = self.norm1(x)
907
+ x = x.reshape(B0*D0, H0, W0, C)
908
+ # print("here")
909
+ # exit()
910
+
911
+ # pad feature maps to multiples of window size
912
+ pad_l = pad_t = 0
913
+ pad_r = (self.window_size - W0 % self.window_size) % self.window_size
914
+ pad_b = (self.window_size - H0 % self.window_size) % self.window_size
915
+ if pad_r > 0 or pad_b > 0:
916
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
917
+
918
+ B, H, W, C = x.shape ## B=B0*D0
919
+
920
+ if self.shift_size > 0:
921
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
922
+ else:
923
+ shifted_x = x
924
+
925
+ # print("shifted_x.shape: ", shifted_x.shape)
926
+ shifted_x=shifted_x.view(B0,D0,H,W,C)
927
+ x_windows_all = [shifted_x[:,-1]]
928
+ x_windows_all_clips=[]
929
+ x_window_masks_all = [self.attn_mask]
930
+ x_window_masks_all_clips=[]
931
+
932
+ if self.focal_level > 1 and self.pool_method != "none":
933
+ # if we add coarser granularity and the pool method is not none
934
+ # pooling_index=0
935
+ for k in range(self.focal_level-1):
936
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
937
+ pooled_h = math.ceil(H / self.window_size) * (2 ** k)
938
+ pooled_w = math.ceil(W / self.window_size) * (2 ** k)
939
+ H_pool = pooled_h * window_size_glo
940
+ W_pool = pooled_w * window_size_glo
941
+
942
+ x_level_k = shifted_x[:,-1]
943
+ # trim or pad shifted_x depending on the required size
944
+ if H > H_pool:
945
+ trim_t = (H - H_pool) // 2
946
+ trim_b = H - H_pool - trim_t
947
+ x_level_k = x_level_k[:, trim_t:-trim_b]
948
+ elif H < H_pool:
949
+ pad_t = (H_pool - H) // 2
950
+ pad_b = H_pool - H - pad_t
951
+ x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
952
+
953
+ if W > W_pool:
954
+ trim_l = (W - W_pool) // 2
955
+ trim_r = W - W_pool - trim_l
956
+ x_level_k = x_level_k[:, :, trim_l:-trim_r]
957
+ elif W < W_pool:
958
+ pad_l = (W_pool - W) // 2
959
+ pad_r = W_pool - W - pad_l
960
+ x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
961
+
962
+ x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
963
+ nWh, nWw = x_windows_noreshape.shape[1:3]
964
+ if self.pool_method == "mean":
965
+ x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
966
+ elif self.pool_method == "max":
967
+ x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
968
+ elif self.pool_method == "fc":
969
+ x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
970
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
971
+ elif self.pool_method == "conv":
972
+ x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
973
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
974
+
975
+ x_windows_all += [x_windows_pooled]
976
+ # print(x_windows_pooled.shape)
977
+ x_window_masks_all += [None]
978
+ # pooling_index=pooling_index+1
979
+
980
+ x_windows_all_clips += [x_windows_all]
981
+ x_window_masks_all_clips += [x_window_masks_all]
982
+ for k in range(len(self.focal_l_clips)):
983
+ if self.focal_l_clips[k]>self.window_size:
984
+ window_size_glo = self.focal_l_clips[k]
985
+ else:
986
+ window_size_glo = math.floor(self.window_size_glo / (self.focal_l_clips[k]))
987
+
988
+ pooled_h = math.ceil(H / self.window_size) * (self.focal_l_clips[k])
989
+ pooled_w = math.ceil(W / self.window_size) * (self.focal_l_clips[k])
990
+
991
+ H_pool = pooled_h * window_size_glo
992
+ W_pool = pooled_w * window_size_glo
993
+
994
+ x_level_k = shifted_x[:,k]
995
+ if H!=H_pool or W!=W_pool:
996
+ x_level_k=F.interpolate(x_level_k.permute(0,3,1,2), size=(H_pool, W_pool), mode='bilinear').permute(0,2,3,1)
997
+
998
+ # print(x_level_k.shape)
999
+ x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
1000
+ nWh, nWw = x_windows_noreshape.shape[1:3]
1001
+ if self.pool_method == "mean":
1002
+ x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
1003
+ elif self.pool_method == "max":
1004
+ x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
1005
+ elif self.pool_method == "fc":
1006
+ x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
1007
+ x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
1008
+ elif self.pool_method == "conv":
1009
+ x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
1010
+ x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
1011
+
1012
+ x_windows_all_clips += [x_windows_pooled]
1013
+ # print(x_windows_pooled.shape)
1014
+ x_window_masks_all_clips += [None]
1015
+ # pooling_index=pooling_index+1
1016
+ # exit()
1017
+
1018
+ attn_windows = self.attn(x_windows_all_clips, mask_all=x_window_masks_all_clips, batch_size=B0, num_clips=D0) # nW*B0, window_size*window_size, C
1019
+
1020
+ attn_windows = attn_windows[:, :self.window_size ** 2]
1021
+
1022
+ # merge windows
1023
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
1024
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H(padded) W(padded) C
1025
+
1026
+ # reverse cyclic shift
1027
+ if self.shift_size > 0:
1028
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
1029
+ else:
1030
+ x = shifted_x
1031
+ # x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C)
1032
+ x = x[:, :H0, :W0].contiguous().view(B0, -1, C)
1033
+
1034
+ # FFN
1035
+ # x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
1036
+ # x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
1037
+
1038
+ # print(x.shape, shortcut[:,-1].view(B0, -1, C).shape)
1039
+ x = shortcut[:,-1].view(B0, -1, C) + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
1040
+ x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
1041
+
1042
+ # x=torch.cat([shortcut[:,:-1],x.view(B0,self.input_resolution[0],self.input_resolution[1],C).unsqueeze(1)],1)
1043
+ x=torch.cat([shortcut[:,:-1],x.view(B0,H0,W0,C).unsqueeze(1)],1)
1044
+
1045
+ assert x.shape==shortcut.shape
1046
+
1047
+ # exit()
1048
+
1049
+ return x
1050
+
1051
+ def extra_repr(self) -> str:
1052
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
1053
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
1054
+
1055
+ def flops(self):
1056
+ flops = 0
1057
+ H, W = self.input_resolution
1058
+ # norm1
1059
+ flops += self.dim * H * W
1060
+
1061
+ # W-MSA/SW-MSA
1062
+ nW = H * W / self.window_size / self.window_size
1063
+ flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window)
1064
+
1065
+ if self.pool_method != "none" and self.focal_level > 1:
1066
+ for k in range(self.focal_level-1):
1067
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
1068
+ nW_glo = nW * (2**k)
1069
+ # (sub)-window pooling
1070
+ flops += nW_glo * self.dim * window_size_glo * window_size_glo
1071
+ # qkv for global levels
1072
+ # NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer,
1073
+ # but theoritically, we only need to compute k and v.
1074
+ flops += nW_glo * self.dim * 3 * self.dim
1075
+
1076
+ # mlp
1077
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
1078
+ # norm2
1079
+ flops += self.dim * H * W
1080
+ return flops
1081
+
1082
+
1083
+ class BasicLayer3d3(nn.Module):
1084
+ """ A basic Focal Transformer layer for one stage.
1085
+
1086
+ Args:
1087
+ dim (int): Number of input channels.
1088
+ input_resolution (tuple[int]): Input resolution.
1089
+ depth (int): Number of blocks.
1090
+ num_heads (int): Number of attention heads.
1091
+ window_size (int): Local window size.
1092
+ expand_size (int): expand size for focal level 1.
1093
+ expand_layer (str): expand layer. Default: all
1094
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
1095
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
1096
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
1097
+ drop (float, optional): Dropout rate. Default: 0.0
1098
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
1099
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
1100
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
1101
+ pool_method (str): Window pooling method. Default: none.
1102
+ focal_level (int): Number of focal levels. Default: 1.
1103
+ focal_window (int): region size at each focal level. Default: 1.
1104
+ use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False
1105
+ use_shift (bool): Whether use window shift as in Swin Transformer. Default: False
1106
+ use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False
1107
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
1108
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
1109
+ use_layerscale (bool): Whether use layer scale for stability. Default: False.
1110
+ layerscale_value (float): Layerscale value. Default: 1e-4.
1111
+ """
1112
+
1113
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all",
1114
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
1115
+ drop_path=0., norm_layer=nn.LayerNorm, pool_method="none",
1116
+ focal_level=1, focal_window=1, use_conv_embed=False, use_shift=False, use_pre_norm=False,
1117
+ downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[16,8,2], focal_kernel_clips=[7,5,3]):
1118
+
1119
+ super().__init__()
1120
+ self.dim = dim
1121
+ self.input_resolution = input_resolution
1122
+ self.depth = depth
1123
+ self.use_checkpoint = use_checkpoint
1124
+
1125
+ if expand_layer == "even":
1126
+ expand_factor = 0
1127
+ elif expand_layer == "odd":
1128
+ expand_factor = 1
1129
+ elif expand_layer == "all":
1130
+ expand_factor = -1
1131
+
1132
+ # build blocks
1133
+ self.blocks = nn.ModuleList([
1134
+ CffmTransformerBlock3d3(dim=dim, input_resolution=input_resolution,
1135
+ num_heads=num_heads, window_size=window_size,
1136
+ shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0,
1137
+ expand_size=0 if (i % 2 == expand_factor) else expand_size,
1138
+ mlp_ratio=mlp_ratio,
1139
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
1140
+ drop=drop,
1141
+ attn_drop=attn_drop,
1142
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
1143
+ norm_layer=norm_layer,
1144
+ pool_method=pool_method,
1145
+ focal_level=focal_level,
1146
+ focal_window=focal_window,
1147
+ use_layerscale=use_layerscale,
1148
+ layerscale_value=layerscale_value,
1149
+ focal_l_clips=focal_l_clips,
1150
+ focal_kernel_clips=focal_kernel_clips)
1151
+ for i in range(depth)])
1152
+
1153
+ # patch merging layer
1154
+ if downsample is not None:
1155
+ self.downsample = downsample(
1156
+ img_size=input_resolution, patch_size=2, in_chans=dim, embed_dim=2*dim,
1157
+ use_conv_embed=use_conv_embed, norm_layer=norm_layer, use_pre_norm=use_pre_norm,
1158
+ is_stem=False
1159
+ )
1160
+ else:
1161
+ self.downsample = None
1162
+
1163
+ def forward(self, x, batch_size=None, num_clips=None, reg_tokens=None):
1164
+ B, D, C, H, W = x.shape
1165
+ x = rearrange(x, 'b d c h w -> b d h w c')
1166
+ for blk in self.blocks:
1167
+ if self.use_checkpoint:
1168
+ x = checkpoint.checkpoint(blk, x)
1169
+ else:
1170
+ x = blk(x)
1171
+
1172
+ if self.downsample is not None:
1173
+ x = x.view(x.shape[0], self.input_resolution[0], self.input_resolution[1], -1).permute(0, 3, 1, 2).contiguous()
1174
+ x = self.downsample(x)
1175
+ x = rearrange(x, 'b d h w c -> b d c h w')
1176
+ return x
1177
+
1178
+ def extra_repr(self) -> str:
1179
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
1180
+
1181
+ def flops(self):
1182
+ flops = 0
1183
+ for blk in self.blocks:
1184
+ flops += blk.flops()
1185
+ if self.downsample is not None:
1186
+ flops += self.downsample.flops()
1187
+ return flops
models/SpaTrackV2/models/depth_refiner/stablizer.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch.nn as nn
3
+ import torch
4
+ # from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
5
+ from collections import OrderedDict
6
+ # from mmseg.ops import resize
7
+ from torch.nn.functional import interpolate as resize
8
+ # from builder import HEADS
9
+ from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
10
+ # from mmseg.models.utils import *
11
+ import attr
12
+ from IPython import embed
13
+ from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
14
+ import cv2
15
+ from models.SpaTrackV2.models.depth_refiner.network import *
16
+ import warnings
17
+ # from mmcv.utils import Registry, build_from_cfg
18
+ from torch import nn
19
+ from einops import rearrange
20
+ import torch.nn.functional as F
21
+ from models.SpaTrackV2.models.blocks import (
22
+ AttnBlock, CrossAttnBlock, Mlp
23
+ )
24
+
25
+ class MLP(nn.Module):
26
+ """
27
+ Linear Embedding
28
+ """
29
+ def __init__(self, input_dim=2048, embed_dim=768):
30
+ super().__init__()
31
+ self.proj = nn.Linear(input_dim, embed_dim)
32
+
33
+ def forward(self, x):
34
+ x = x.flatten(2).transpose(1, 2)
35
+ x = self.proj(x)
36
+ return x
37
+
38
+
39
+ def scatter_multiscale_fast(
40
+ track2d: torch.Tensor,
41
+ trackfeature: torch.Tensor,
42
+ H: int,
43
+ W: int,
44
+ kernel_sizes = [1]
45
+ ) -> torch.Tensor:
46
+ """
47
+ Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
48
+
49
+ This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
50
+ while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
51
+ avoiding dilution by empty pixels.
52
+
53
+ Args:
54
+ track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
55
+ for each track point across batches, frames, and points.
56
+ trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
57
+ for each track point.
58
+ H (int): Height of the target output image.
59
+ W (int): Width of the target output image.
60
+ kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
61
+
62
+ Returns:
63
+ torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
64
+ """
65
+ B, T, N, C = trackfeature.shape
66
+ device = trackfeature.device
67
+
68
+ # 1. Flatten coordinates and filter valid points within image bounds
69
+ coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
70
+ x = coords_flat[:, 0] # x coordinates
71
+ y = coords_flat[:, 1] # y coordinates
72
+ feat_flat = trackfeature.reshape(-1, C) # Flatten features
73
+
74
+ valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
75
+ x = x[valid_mask]
76
+ y = y[valid_mask]
77
+ feat_flat = feat_flat[valid_mask]
78
+ valid_count = x.shape[0]
79
+
80
+ if valid_count == 0:
81
+ return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
82
+
83
+ # 2. Calculate linear indices and batch-frame indices for scattering
84
+ lin_idx = y * W + x # Linear index within a single frame (H*W range)
85
+
86
+ # Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
87
+ bt_idx_raw = (
88
+ torch.arange(B * T, device=device)
89
+ .view(B, T, 1)
90
+ .expand(B, T, N)
91
+ .reshape(-1)
92
+ )
93
+ bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
94
+
95
+ # 3. Create accumulation buffers for features and weights
96
+ total_space = B * T * H * W
97
+ img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
98
+ weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
99
+
100
+ # 4. Scatter features and weights into accumulation buffers
101
+ idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
102
+
103
+ # Add features to corresponding indices (index_add_ is efficient for sparse updates)
104
+ img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
105
+ weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
106
+
107
+ # 5. Normalize features by valid weights, keep zeros for invalid regions
108
+ valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
109
+ img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
110
+ img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
111
+
112
+ # 6. Reshape to (B, T, C, H, W) for further processing
113
+ img = (
114
+ img_accum_flat.view(B, T, H, W, C)
115
+ .permute(0, 1, 4, 2, 3)
116
+ .contiguous()
117
+ ) # Shape: (B, T, C, H, W)
118
+
119
+ # 7. Multi-scale pooling with weight masking to exclude zero holes
120
+ blurred_outputs = []
121
+ for k in kernel_sizes:
122
+ pad = k // 2
123
+ img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
124
+
125
+ # Create weight mask for valid regions (1 where features exist, 0 otherwise)
126
+ weight_mask = (
127
+ weight_accum_flat.view(B, T, 1, H, W) > 0
128
+ ).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
129
+
130
+ # Calculate number of valid neighbors in each pooling window
131
+ weight_sum = F.conv2d(
132
+ weight_mask,
133
+ torch.ones((1, 1, k, k), device=device),
134
+ stride=1,
135
+ padding=pad
136
+ ) # Shape: (B*T, 1, H, W)
137
+
138
+ # Sum features only in valid regions
139
+ feat_sum = F.conv2d(
140
+ img_bt * weight_mask, # Mask out invalid regions before summing
141
+ torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
142
+ stride=1,
143
+ padding=pad,
144
+ groups=C
145
+ ) # Shape: (B*T, C, H, W)
146
+
147
+ # Compute average only over valid neighbors
148
+ feat_avg = feat_sum / (weight_sum + 1e-6)
149
+ blurred_outputs.append(feat_avg)
150
+
151
+ # 8. Fuse multi-scale results by averaging across kernel sizes
152
+ fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
153
+ return fused.view(B, T, C, H, W) # Restore original shape
154
+
155
+ #@HEADS.register_module()
156
+ class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
157
+
158
+ def __init__(self, feature_strides, **kwargs):
159
+ super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
160
+ self.training = False
161
+ assert len(feature_strides) == len(self.in_channels)
162
+ assert min(feature_strides) == feature_strides[0]
163
+ self.feature_strides = feature_strides
164
+
165
+ c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
166
+
167
+ decoder_params = kwargs['decoder_params']
168
+ embedding_dim = decoder_params['embed_dim']
169
+
170
+ self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
171
+ self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
172
+ self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
173
+ self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
174
+
175
+ self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
176
+ nn.ReLU(inplace=True))
177
+
178
+ self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
179
+
180
+ depths = decoder_params['depths']
181
+
182
+ self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
183
+ self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
184
+
185
+ self.att_temporal = nn.ModuleList(
186
+ [
187
+ AttnBlock(embedding_dim, 8,
188
+ mlp_ratio=4, flash=True, ckpt_fwd=True)
189
+ for _ in range(8)
190
+ ]
191
+ )
192
+ self.att_spatial = nn.ModuleList(
193
+ [
194
+ AttnBlock(embedding_dim, 8,
195
+ mlp_ratio=4, flash=True, ckpt_fwd=True)
196
+ for _ in range(8)
197
+ ]
198
+ )
199
+ self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
200
+
201
+
202
+ # Initialize reg tokens
203
+ nn.init.trunc_normal_(self.reg_tokens, std=0.02)
204
+
205
+ self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
206
+ input_resolution=(96,
207
+ 96),
208
+ depth=depths,
209
+ num_heads=8,
210
+ window_size=7,
211
+ mlp_ratio=4.,
212
+ qkv_bias=True,
213
+ qk_scale=None,
214
+ drop=0.,
215
+ attn_drop=0.,
216
+ drop_path=0.,
217
+ norm_layer=nn.LayerNorm,
218
+ pool_method='fc',
219
+ downsample=None,
220
+ focal_level=2,
221
+ focal_window=5,
222
+ expand_size=3,
223
+ expand_layer="all",
224
+ use_conv_embed=False,
225
+ use_shift=False,
226
+ use_pre_norm=False,
227
+ use_checkpoint=False,
228
+ use_layerscale=False,
229
+ layerscale_value=1e-4,
230
+ focal_l_clips=[7,4,2],
231
+ focal_kernel_clips=[7,5,3])
232
+
233
+ self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
234
+ self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
235
+ self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
236
+ self.AO = AO(32, outchannels=3, upfactor=1)
237
+ self._c2 = None
238
+ self._c_further = None
239
+
240
+ def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
241
+
242
+ # input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
243
+ if self.training:
244
+ assert self.num_clips==num_clips
245
+
246
+ x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
247
+ c1, c2, c3, c4 = x
248
+
249
+ ############## MLP decoder on C1-C4 ###########
250
+ n, _, h, w = c4.shape
251
+ batch_size = n // num_clips
252
+
253
+ _c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
254
+ _c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
255
+
256
+ _c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
257
+ _c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
258
+
259
+ _c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
260
+ _c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
261
+
262
+ _c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
263
+ _c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
264
+
265
+ _, _, h, w=_c.shape
266
+ _c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
267
+
268
+ # Expand reg_tokens to match batch size
269
+ reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
270
+
271
+ _c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
272
+
273
+ assert _c_further.shape==_c2.shape
274
+ self._c2 = _c2
275
+ self._c_further = _c_further
276
+
277
+ # compute the scale and shift of the global patch
278
+ global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
279
+ global_patch = torch.cat([global_patch, reg_tokens], dim=1)
280
+ for i in range(8):
281
+ global_patch = self.att_temporal[i](global_patch)
282
+ global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
283
+ global_patch = self.att_spatial[i](global_patch)
284
+ global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
285
+
286
+ reg_tokens = global_patch[:, -2:, :]
287
+ s_ = self.scale_shift_head(reg_tokens)
288
+ scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
289
+ shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
290
+ shift[:,:,:2,...] = 0
291
+ return scale, shift
292
+
293
+ def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
294
+
295
+ if self._c2 is None:
296
+ scale, shift = self.buffer_forward(inputs,num_clips,imgs)
297
+
298
+ B, T, N, _ = tracks.shape
299
+
300
+ _c2 = self._c2
301
+ _c_further = self._c_further
302
+
303
+ # skip and head
304
+ _c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
305
+ _c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
306
+
307
+ outframe = self.ffm2(_c_further, _c2)
308
+
309
+ tracks_uv = tracks_uvd[...,:2].clone()
310
+ track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
311
+ # visualize track_feature as video
312
+ # import cv2
313
+ # import imageio
314
+ # import os
315
+ # BT, C, H, W = outframe.shape
316
+ # track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
317
+ # track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
318
+ # track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
319
+ # track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
320
+ # imgs =(imgs.detach() + 1) * 127.5
321
+ # vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
322
+ # for b in range(B):
323
+ # frames = []
324
+ # for t in range(T):
325
+ # frame = track_feature_vis[b,t]
326
+ # frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
327
+ # frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
328
+ # frames.append(frame)
329
+ # # Save as gif
330
+ # imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
331
+ # import pdb; pdb.set_trace()
332
+ track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
333
+ track_feature = self.proj_track(track_feature)
334
+ outframe = self.ffm1(edge_feat1 + track_feature,outframe)
335
+ outframe = self.ffm0(edge_feat,outframe)
336
+ outframe = self.AO(outframe)
337
+
338
+ return outframe
339
+
340
+ def reset_success(self):
341
+ self._c2 = None
342
+ self._c_further = None
models/SpaTrackV2/models/predictor.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+
10
+ from tqdm import tqdm
11
+ from models.SpaTrackV2.models.SpaTrack import SpaTrack2
12
+ from typing import Literal
13
+ import numpy as np
14
+ from pathlib import Path
15
+ from typing import Union, Optional
16
+ import cv2
17
+ import os
18
+ import decord
19
+
20
+ class Predictor(torch.nn.Module):
21
+ def __init__(self, args=None):
22
+ super().__init__()
23
+ self.args = args
24
+ self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
25
+ self.S_wind = args.Track_cfg.s_wind
26
+ self.overlap = args.Track_cfg.overlap
27
+
28
+ def to(self, device: Union[str, torch.device]):
29
+ self.spatrack.to(device)
30
+ self.spatrack.base_model.to(device)
31
+
32
+ @classmethod
33
+ def from_pretrained(
34
+ cls,
35
+ pretrained_model_name_or_path: Union[str, Path],
36
+ *,
37
+ force_download: bool = False,
38
+ cache_dir: Optional[str] = None,
39
+ device: Optional[Union[str, torch.device]] = None,
40
+ model_cfg: Optional[dict] = None,
41
+ **kwargs,
42
+ ) -> "SpaTrack2":
43
+ """
44
+ Load a pretrained model from a local file or a remote repository.
45
+
46
+ Args:
47
+ pretrained_model_name_or_path (str or Path):
48
+ - Path to a local model file (e.g., `./model.pth`).
49
+ - HuggingFace Hub model ID (e.g., `username/model-name`).
50
+ force_download (bool, optional):
51
+ Whether to force re-download even if cached. Default: False.
52
+ cache_dir (str, optional):
53
+ Custom cache directory. Default: None (use default cache).
54
+ device (str or torch.device, optional):
55
+ Target device (e.g., "cuda", "cpu"). Default: None (keep original).
56
+ **kwargs:
57
+ Additional config overrides.
58
+
59
+ Returns:
60
+ SpaTrack2: Loaded pretrained model.
61
+ """
62
+ # (1) check the path is local or remote
63
+ if isinstance(pretrained_model_name_or_path, Path):
64
+ model_path = str(pretrained_model_name_or_path)
65
+ else:
66
+ model_path = pretrained_model_name_or_path
67
+ # (2) if the path is remote, download it
68
+ if not os.path.exists(model_path):
69
+ raise NotImplementedError("Remote download not implemented yet. Use a local path.")
70
+ # (3) load the model weights
71
+
72
+ state_dict = torch.load(model_path, map_location="cpu")
73
+ # (4) initialize the model (can load config.json if exists)
74
+ config_path = os.path.join(os.path.dirname(model_path), "config.json")
75
+ config = {}
76
+ if os.path.exists(config_path):
77
+ import json
78
+ with open(config_path, "r") as f:
79
+ config.update(json.load(f))
80
+ config.update(kwargs) # allow override the config
81
+ if model_cfg is not None:
82
+ config = model_cfg
83
+ model = cls(config)
84
+ if "model" in state_dict:
85
+ model.spatrack.load_state_dict(state_dict["model"], strict=False)
86
+ else:
87
+ model.spatrack.load_state_dict(state_dict, strict=False)
88
+ # (5) device management
89
+ if device is not None:
90
+ model.to(device)
91
+
92
+ return model
93
+
94
+ def forward(self, video: str|torch.Tensor|np.ndarray,
95
+ depth: str|torch.Tensor|np.ndarray=None,
96
+ unc_metric: str|torch.Tensor|np.ndarray=None,
97
+ intrs: str|torch.Tensor|np.ndarray=None,
98
+ extrs: str|torch.Tensor|np.ndarray=None,
99
+ queries=None, queries_3d=None, iters_track=4,
100
+ full_point=False, fps=30, track2d_gt=None,
101
+ fixed_cam=False, query_no_BA=False, stage=0,
102
+ support_frame=0, replace_ratio=0.6):
103
+ """
104
+ video: this could be a path to a video, a tensor of shape (T, C, H, W) or a numpy array of shape (T, C, H, W)
105
+ queries: (B, N, 2)
106
+ """
107
+
108
+ if isinstance(video, str):
109
+ video = decord.VideoReader(video)
110
+ video = video[::fps].asnumpy() # Convert to numpy array
111
+ video = np.array(video) # Ensure numpy array
112
+ video = torch.from_numpy(video).permute(0, 3, 1, 2).float()
113
+ elif isinstance(video, np.ndarray):
114
+ video = torch.from_numpy(video).float()
115
+
116
+ if isinstance(depth, np.ndarray):
117
+ depth = torch.from_numpy(depth).float()
118
+ if isinstance(intrs, np.ndarray):
119
+ intrs = torch.from_numpy(intrs).float()
120
+ if isinstance(extrs, np.ndarray):
121
+ extrs = torch.from_numpy(extrs).float()
122
+ if isinstance(unc_metric, np.ndarray):
123
+ unc_metric = torch.from_numpy(unc_metric).float()
124
+
125
+ T_, C, H, W = video.shape
126
+ step_slide = self.S_wind - self.overlap
127
+ if T_ > self.S_wind:
128
+
129
+ num_windows = (T_ - self.S_wind + step_slide) // step_slide
130
+ T = num_windows * step_slide + self.S_wind
131
+ pad_len = T - T_
132
+
133
+ video = torch.cat([video, video[-1:].repeat(T-video.shape[0], 1, 1, 1)], dim=0)
134
+ if depth is not None:
135
+ depth = torch.cat([depth, depth[-1:].repeat(T-depth.shape[0], 1, 1)], dim=0)
136
+ if intrs is not None:
137
+ intrs = torch.cat([intrs, intrs[-1:].repeat(T-intrs.shape[0], 1, 1)], dim=0)
138
+ if extrs is not None:
139
+ extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
140
+ if unc_metric is not None:
141
+ unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1, 1)], dim=0)
142
+ with torch.no_grad():
143
+ ret = self.spatrack.forward_stream(video, queries, T_org=T_,
144
+ depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,
145
+ window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
146
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
147
+
148
+
149
+ return ret
150
+
151
+
152
+
153
+
models/SpaTrackV2/models/tracker3D/TrackRefiner.py ADDED
@@ -0,0 +1,1478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys
2
+ import torch
3
+ import torch.amp
4
+ from models.SpaTrackV2.models.tracker3D.co_tracker.cotracker_base import CoTrackerThreeOffline, get_1d_sincos_pos_embed_from_grid
5
+ import torch.nn.functional as F
6
+ from models.SpaTrackV2.utils.visualizer import Visualizer
7
+ from models.SpaTrackV2.utils.model_utils import sample_features5d
8
+ from models.SpaTrackV2.models.blocks import bilinear_sampler
9
+ import torch.nn as nn
10
+ from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
11
+ EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
12
+ sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss, sequence_loss_xyz, balanced_binary_cross_entropy
13
+ )
14
+ from torchvision.io import write_video
15
+ import math
16
+ from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
17
+ Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer, CorrPointformer
18
+ )
19
+ from models.SpaTrackV2.utils.embeddings import get_3d_sincos_pos_embed_from_grid
20
+ from einops import rearrange, repeat
21
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
22
+ EfficientUpdateFormer3D, weighted_procrustes_torch, posenc, key_fr_wprocrustes, get_topo_mask,
23
+ TrackFusion, get_nth_visible_time_index
24
+ )
25
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
26
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
27
+ from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
28
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
29
+ from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
30
+
31
+ class TrackRefiner3D(CoTrackerThreeOffline):
32
+
33
+ def __init__(self, args=None):
34
+ super().__init__(**args.base)
35
+
36
+ """
37
+ This is 3D warpper from cotracker, which load the cotracker pretrain and
38
+ jointly refine the `camera pose`, `3D tracks`, `video depth`, `visibility` and `conf`
39
+ """
40
+ self.updateformer3D = EfficientUpdateFormer3D(self.updateformer)
41
+ self.corr_depth_mlp = Mlp(in_features=256, hidden_features=256, out_features=256)
42
+ self.rel_pos_mlp = Mlp(in_features=75, hidden_features=128, out_features=128)
43
+ self.rel_pos_glob_mlp = Mlp(in_features=75, hidden_features=128, out_features=256)
44
+ self.corr_xyz_mlp = Mlp(in_features=256, hidden_features=128, out_features=128)
45
+ self.xyz_mlp = Mlp(in_features=126, hidden_features=128, out_features=84)
46
+ # self.track_feat_mlp = Mlp(in_features=1110, hidden_features=128, out_features=128)
47
+ self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
48
+ # get the anchor point's embedding, and init the pts refiner
49
+ update_pts = True
50
+ # self.corr_transformer = nn.ModuleList([
51
+ # CorrPointformer(
52
+ # dim=128,
53
+ # num_heads=8,
54
+ # head_dim=128 // 8,
55
+ # mlp_ratio=4.0,
56
+ # )
57
+ # for _ in range(self.corr_levels)
58
+ # ])
59
+ self.corr_transformer = nn.ModuleList([
60
+ CorrPointformer(
61
+ dim=128,
62
+ num_heads=8,
63
+ head_dim=128 // 8,
64
+ mlp_ratio=4.0,
65
+ )
66
+ ]
67
+ )
68
+ self.fnet = BasicEncoder(input_dim=3,
69
+ output_dim=self.latent_dim, stride=self.stride)
70
+ self.corr3d_radius = 3
71
+
72
+ if args.stablizer:
73
+ self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
74
+ self.upsample_kernel_size = 5
75
+ self.residual_embedding = nn.Parameter(torch.randn(
76
+ self.latent_dim, self.model_resolution[0]//16,
77
+ self.model_resolution[1]//16, requires_grad=True))
78
+ self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
79
+ self.upsample_factor = 4
80
+ self.upsample_transformer = UpsampleTransformerAlibi(
81
+ kernel_size=self.upsample_kernel_size, # kernel_size=3, #
82
+ stride=self.stride,
83
+ latent_dim=self.latent_dim,
84
+ num_attn_blocks=2,
85
+ upsample_factor=4,
86
+ )
87
+ else:
88
+ self.update_pointmap = None
89
+
90
+ self.mode = args.mode
91
+ if self.mode == "online":
92
+ self.s_wind = args.s_wind
93
+ self.overlap = args.overlap
94
+
95
+ def upsample_with_mask(
96
+ self, inp: torch.Tensor, mask: torch.Tensor
97
+ ) -> torch.Tensor:
98
+ """Upsample flow field [H/P, W/P, 2] -> [H, W, 2] using convex combination"""
99
+ H, W = inp.shape[-2:]
100
+ up_inp = F.unfold(
101
+ inp, [self.upsample_kernel_size, self.upsample_kernel_size], padding=(self.upsample_kernel_size - 1) // 2
102
+ )
103
+ up_inp = rearrange(up_inp, "b c (h w) -> b c h w", h=H, w=W)
104
+ up_inp = F.interpolate(up_inp, scale_factor=self.upsample_factor, mode="nearest")
105
+ up_inp = rearrange(
106
+ up_inp, "b (c i j) h w -> b c (i j) h w", i=self.upsample_kernel_size, j=self.upsample_kernel_size
107
+ )
108
+
109
+ up_inp = torch.sum(mask * up_inp, dim=2)
110
+ return up_inp
111
+
112
+ def track_from_cam(self, queries, c2w_traj, intrs,
113
+ rgbs=None, visualize=False):
114
+ """
115
+ This function will generate tracks by camera transform
116
+
117
+ Args:
118
+ queries: B T N 4
119
+ c2w_traj: B T 4 4
120
+ intrs: B T 3 3
121
+ """
122
+ B, T, N, _ = queries.shape
123
+ query_t = queries[:,0,:,0].to(torch.int64) # B N
124
+ query_c2w = torch.gather(c2w_traj,
125
+ dim=1, index=query_t[..., None, None].expand(-1, -1, 4, 4)) # B N 4 4
126
+ query_intr = torch.gather(intrs,
127
+ dim=1, index=query_t[..., None, None].expand(-1, -1, 3, 3)) # B N 3 3
128
+ query_pts = queries[:,0,:,1:4].clone() # B N 3
129
+ query_d = queries[:,0,:,3:4] # B N 3
130
+ query_pts[...,2] = 1
131
+
132
+ cam_pts = torch.einsum("bnij,bnj->bni", torch.inverse(query_intr), query_pts)*query_d # B N 3
133
+ # convert to world
134
+ cam_pts_h = torch.zeros(B, N, 4, device=cam_pts.device)
135
+ cam_pts_h[..., :3] = cam_pts
136
+ cam_pts_h[..., 3] = 1
137
+ world_pts = torch.einsum("bnij,bnj->bni", query_c2w, cam_pts_h)
138
+ # convert to other frames
139
+ cam_other_pts_ = torch.einsum("btnij,btnj->btni",
140
+ torch.inverse(c2w_traj[:,:,None].float().repeat(1,1,N,1,1)),
141
+ world_pts[:,None].repeat(1,T,1,1))
142
+ cam_depth = cam_other_pts_[...,2:3]
143
+ cam_other_pts = cam_other_pts_[...,:3] / (cam_other_pts_[...,2:3].abs()+1e-6)
144
+ cam_other_pts = torch.einsum("btnij,btnj->btni", intrs[:,:,None].repeat(1,1,N,1,1), cam_other_pts[...,:3])
145
+ cam_other_pts[..., 2:] = cam_depth
146
+
147
+ if visualize:
148
+ viser = Visualizer(save_dir=".", grayscale=True,
149
+ fps=10, pad_value=50, tracks_leave_trace=0)
150
+ cam_other_pts[..., 0] /= self.factor_x
151
+ cam_other_pts[..., 1] /= self.factor_y
152
+ viser.visualize(video=rgbs, tracks=cam_other_pts[..., :2], filename="test")
153
+
154
+
155
+ init_xyzs = cam_other_pts
156
+
157
+ return init_xyzs, world_pts[..., :3], cam_other_pts_[..., :3]
158
+
159
+ def cam_from_track(self, tracks, intrs,
160
+ dyn_prob=None, metric_unc=None,
161
+ vis_est=None, only_cam_pts=False,
162
+ track_feat_concat=None,
163
+ tracks_xyz=None,
164
+ query_pts=None,
165
+ fixed_cam=False,
166
+ depth_unproj=None,
167
+ cam_gt=None,
168
+ init_pose=False,
169
+ ):
170
+ """
171
+ This function will generate tracks by camera transform
172
+
173
+ Args:
174
+ queries: B T N 3
175
+ scale_est: 1 1
176
+ shift_est: 1 1
177
+ intrs: B T 3 3
178
+ dyn_prob: B T N
179
+ metric_unc: B N 1
180
+ query_pts: B T N 3
181
+ """
182
+ if tracks_xyz is not None:
183
+ B, T, N, _ = tracks.shape
184
+ cam_pts = tracks_xyz
185
+ intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
186
+ else:
187
+ B, T, N, _ = tracks.shape
188
+ # get the pts in cam coordinate
189
+ tracks_xy = tracks[...,:3].clone().detach() # B T N 3
190
+ # tracks_z = 1/(tracks[...,2:] * scale_est + shift_est) # B T N 1
191
+ tracks_z = tracks[...,2:].detach() # B T N 1
192
+ tracks_xy[...,2] = 1
193
+ intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
194
+ cam_pts = torch.einsum("bnij,bnj->bni",
195
+ torch.inverse(intr_repeat.view(B*T,N,3,3)).float(),
196
+ tracks_xy.view(B*T, N, 3))*(tracks_z.view(B*T,N,1).abs()) # B*T N 3
197
+ cam_pts[...,2] *= torch.sign(tracks_z.view(B*T,N))
198
+ # get the normalized cam pts, and pts refiner
199
+ mask_z = (tracks_z.max(dim=1)[0]<200).squeeze()
200
+ cam_pts = cam_pts.view(B, T, N, 3)
201
+
202
+ if only_cam_pts:
203
+ return cam_pts
204
+ dyn_prob = dyn_prob.mean(dim=1)[..., None]
205
+ # B T N 3 -> local frames coordinates. transformer static points B T N 3 -> B T N 3 static (B T N 3) -> same -> dynamic points @ C2T.inverse()
206
+ # get the cam pose
207
+ vis_est_ = vis_est[:,:,None,:]
208
+ graph_matrix = (vis_est_*vis_est_.permute(0, 2,1,3)).detach()
209
+ # find the max connected component
210
+ key_fr_idx = [0]
211
+ weight_final = (metric_unc) # * vis_est
212
+
213
+
214
+ with torch.amp.autocast(enabled=False, device_type='cuda'):
215
+ if fixed_cam:
216
+ c2w_traj_init = self.c2w_est_curr
217
+ c2w_traj_glob = c2w_traj_init
218
+ cam_pts_refine = cam_pts
219
+ intrs_refine = intrs
220
+ xy_refine = query_pts[...,1:3]
221
+ world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
222
+ world_tracks_refined = world_tracks_init
223
+ # extract the stable static points for refine the camera pose
224
+ intrs_dn = intrs.clone()
225
+ intrs_dn[...,0,:] *= self.factor_x
226
+ intrs_dn[...,1,:] *= self.factor_y
227
+ _, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
228
+ world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
229
+ dyn_prob, query_world_pts,
230
+ vis_est, tracks, img_size=self.image_size,
231
+ K=0)
232
+ world_static_refine = world_tracks_static
233
+
234
+ else:
235
+
236
+ if (not self.training):
237
+ # if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
238
+ campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
239
+ # campts_update = cam_pts
240
+ c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
241
+ (weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
242
+ c2w_traj_init = [email protected]_est_curr
243
+ # else:
244
+ # c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
245
+ else:
246
+ # if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
247
+ campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
248
+ # campts_update = cam_pts
249
+ c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
250
+ (weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
251
+ c2w_traj_init = [email protected]_est_curr
252
+ # else:
253
+ # c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
254
+
255
+ intrs_dn = intrs.clone()
256
+ intrs_dn[...,0,:] *= self.factor_x
257
+ intrs_dn[...,1,:] *= self.factor_y
258
+ _, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
259
+ # refine the world tracks
260
+ world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
261
+ world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
262
+ dyn_prob, query_world_pts,
263
+ vis_est, tracks, img_size=self.image_size,
264
+ K=150 if self.training else 1500)
265
+ # calculate the efficient ba
266
+ cam_tracks_static = cam_pts[:,:,mask_static.squeeze(),:][:,:,mask_topk.squeeze(),:]
267
+ cam_tracks_static[...,2] = depth_unproj.view(B, T, N)[:,:,mask_static.squeeze()][:,:,mask_topk.squeeze()]
268
+
269
+ c2w_traj_glob, world_static_refine, intrs_refine = ba_pycolmap(world_tracks_static, intrs,
270
+ c2w_traj_init, vis_mask_static,
271
+ tracks2d_static, self.image_size,
272
+ cam_tracks_static=cam_tracks_static,
273
+ training=self.training, query_pts=query_pts)
274
+ c2w_traj_glob = c2w_traj_glob.view(B, T, 4, 4)
275
+ world_tracks_refined = world_tracks_init
276
+
277
+ #NOTE: merge the index of static points and topk points
278
+ # merge_idx = torch.where(mask_static.squeeze()>0)[0][mask_topk.squeeze()]
279
+ # world_tracks_refined[:,:,merge_idx] = world_static_refine
280
+
281
+ # test the procrustes
282
+ w2c_traj_glob = torch.inverse(c2w_traj_init.detach())
283
+ cam_pts_refine = torch.einsum("btij,btnj->btni", w2c_traj_glob[:,:,:3,:3], world_tracks_refined) + w2c_traj_glob[:,:,None,:3,3]
284
+ # get the xyz_refine
285
+ #TODO: refiner
286
+ cam_pts4_proj = cam_pts_refine.clone()
287
+ cam_pts4_proj[...,2] *= torch.sign(cam_pts4_proj[...,2:3].view(B*T,N))
288
+ xy_refine = torch.einsum("btnij,btnj->btni", intrs_refine.view(B,T,1,3,3).repeat(1,1,N,1,1), cam_pts4_proj/cam_pts4_proj[...,2:3].abs())
289
+ xy_refine[..., 2] = cam_pts4_proj[...,2:3].view(B*T,N)
290
+ # xy_refine = torch.zeros_like(cam_pts_refine)[...,:2]
291
+ return c2w_traj_glob, cam_pts_refine, intrs_refine, xy_refine, world_tracks_init, world_tracks_refined, c2w_traj_init
292
+
293
+ def extract_img_feat(self, video, fmaps_chunk_size=200):
294
+ B, T, C, H, W = video.shape
295
+ dtype = video.dtype
296
+ H4, W4 = H // self.stride, W // self.stride
297
+ # Compute convolutional features for the video or for the current chunk in case of online mode
298
+ if T > fmaps_chunk_size:
299
+ fmaps = []
300
+ for t in range(0, T, fmaps_chunk_size):
301
+ video_chunk = video[:, t : t + fmaps_chunk_size]
302
+ fmaps_chunk = self.fnet(video_chunk.reshape(-1, C, H, W))
303
+ T_chunk = video_chunk.shape[1]
304
+ C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
305
+ fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
306
+ fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
307
+ else:
308
+ fmaps = self.fnet(video.reshape(-1, C, H, W))
309
+ fmaps = fmaps.permute(0, 2, 3, 1)
310
+ fmaps = fmaps / torch.sqrt(
311
+ torch.maximum(
312
+ torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
313
+ torch.tensor(1e-12, device=fmaps.device),
314
+ )
315
+ )
316
+ fmaps = fmaps.permute(0, 3, 1, 2).reshape(
317
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
318
+ )
319
+ fmaps = fmaps.to(dtype)
320
+
321
+ return fmaps
322
+
323
+ def norm_xyz(self, xyz):
324
+ """
325
+ xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
326
+ """
327
+ if xyz.ndim == 3:
328
+ min_pts = self.min_pts
329
+ max_pts = self.max_pts
330
+ return (xyz - min_pts[None,None,:]) / (max_pts - min_pts)[None,None,:] * 2 - 1
331
+ elif xyz.ndim == 4:
332
+ min_pts = self.min_pts
333
+ max_pts = self.max_pts
334
+ return (xyz - min_pts[None,None,None,:]) / (max_pts - min_pts)[None,None,None,:] * 2 - 1
335
+ elif xyz.ndim == 5:
336
+ if xyz.shape[2] == 3:
337
+ min_pts = self.min_pts
338
+ max_pts = self.max_pts
339
+ return (xyz - min_pts[None,None,:,None,None]) / (max_pts - min_pts)[None,None,:,None,None] * 2 - 1
340
+ elif xyz.shape[-1] == 3:
341
+ min_pts = self.min_pts
342
+ max_pts = self.max_pts
343
+ return (xyz - min_pts[None,None,None,None,:]) / (max_pts - min_pts)[None,None,None,None,:] * 2 - 1
344
+
345
+ def denorm_xyz(self, xyz):
346
+ """
347
+ xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
348
+ """
349
+ if xyz.ndim == 3:
350
+ min_pts = self.min_pts
351
+ max_pts = self.max_pts
352
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:] + min_pts[None,None,:]
353
+ elif xyz.ndim == 4:
354
+ min_pts = self.min_pts
355
+ max_pts = self.max_pts
356
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,:] + min_pts[None,None,None,:]
357
+ elif xyz.ndim == 5:
358
+ if xyz.shape[2] == 3:
359
+ min_pts = self.min_pts
360
+ max_pts = self.max_pts
361
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:,None,None] + min_pts[None,None,:,None,None]
362
+ elif xyz.shape[-1] == 3:
363
+ min_pts = self.min_pts
364
+ max_pts = self.max_pts
365
+ return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,None,:] + min_pts[None,None,None,None,:]
366
+
367
+ def forward(
368
+ self,
369
+ video,
370
+ metric_depth,
371
+ metric_unc,
372
+ point_map,
373
+ queries,
374
+ pts_q_3d=None,
375
+ overlap_d=None,
376
+ iters=4,
377
+ add_space_attn=True,
378
+ fmaps_chunk_size=200,
379
+ intrs=None,
380
+ traj3d_gt=None,
381
+ custom_vid=False,
382
+ vis_gt=None,
383
+ prec_fx=None,
384
+ prec_fy=None,
385
+ cam_gt=None,
386
+ init_pose=False,
387
+ support_pts_q=None,
388
+ update_pointmap=True,
389
+ fixed_cam=False,
390
+ query_no_BA=False,
391
+ stage=0,
392
+ cache=None,
393
+ points_map_gt=None,
394
+ valid_only=False,
395
+ replace_ratio=0.6,
396
+ ):
397
+ """Predict tracks
398
+
399
+ Args:
400
+ video (FloatTensor[B, T, 3 H W]): input videos.
401
+ queries (FloatTensor[B, N, 3]): point queries.
402
+ iters (int, optional): number of updates. Defaults to 4.
403
+ vdp_feats_cache: last layer's feature of depth
404
+ tracks_init: B T N 3 the initialization of 3D tracks computed by cam pose
405
+ Returns:
406
+ - coords_predicted (FloatTensor[B, T, N, 2]):
407
+ - vis_predicted (FloatTensor[B, T, N]):
408
+ - train_data: `None` if `is_train` is false, otherwise:
409
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
410
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
411
+ - mask (BoolTensor[B, T, N]):
412
+ """
413
+ self.stage = stage
414
+
415
+ if cam_gt is not None:
416
+ cam_gt = cam_gt.clone()
417
+ cam_gt = torch.inverse(cam_gt[:,:1,...])@cam_gt
418
+ B, T, C, _, _ = video.shape
419
+ _, _, H_, W_ = metric_depth.shape
420
+ _, _, N, __ = queries.shape
421
+ if (vis_gt is not None)&(queries.shape[1] == T):
422
+ aug_visb = True
423
+ if aug_visb:
424
+ number_visible = vis_gt.sum(dim=1)
425
+ ratio_rand = torch.rand(B, N, device=vis_gt.device)
426
+ # first_positive_inds = get_nth_visible_time_index(vis_gt, 1)
427
+ first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
428
+
429
+ assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
430
+ else:
431
+ __, first_positive_inds = torch.max(vis_gt, dim=1)
432
+ first_positive_inds = first_positive_inds.long()
433
+ gather = torch.gather(
434
+ queries, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
435
+ )
436
+ xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
437
+ gather_xyz = torch.gather(
438
+ traj3d_gt, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 3)
439
+ )
440
+ z_gt_query = torch.diagonal(gather_xyz, dim1=1, dim2=2).permute(0, 2, 1)[...,2]
441
+ queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)
442
+ queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
443
+ else:
444
+ # Generate the 768 points randomly in the whole video
445
+ queries = queries.squeeze(1)
446
+ ba_len = queries.shape[1]
447
+ z_gt_query = None
448
+ if support_pts_q is not None:
449
+ queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
450
+
451
+ if (abs(prec_fx-1.0) > 1e-4) & (self.training) & (traj3d_gt is not None):
452
+ traj3d_gt[..., 0] /= prec_fx
453
+ traj3d_gt[..., 1] /= prec_fy
454
+ queries[...,1] /= prec_fx
455
+ queries[...,2] /= prec_fy
456
+
457
+ video_vis = F.interpolate(video.clone().view(B*T, 3, video.shape[-2], video.shape[-1]), (H_, W_), mode="bilinear", align_corners=False).view(B, T, 3, H_, W_)
458
+
459
+ self.image_size = torch.tensor([H_, W_])
460
+ # self.model_resolution = (H_, W_)
461
+ # resize the queries and intrs
462
+ self.factor_x = self.model_resolution[1]/W_
463
+ self.factor_y = self.model_resolution[0]/H_
464
+ queries[...,1] *= self.factor_x
465
+ queries[...,2] *= self.factor_y
466
+ intrs_org = intrs.clone()
467
+ intrs[...,0,:] *= self.factor_x
468
+ intrs[...,1,:] *= self.factor_y
469
+
470
+ # get the fmaps and color features
471
+ video = F.interpolate(video.view(B*T, 3, video.shape[-2], video.shape[-1]),
472
+ (self.model_resolution[0], self.model_resolution[1])).view(B, T, 3, self.model_resolution[0], self.model_resolution[1])
473
+ _, _, _, H, W = video.shape
474
+ if cache is not None:
475
+ T_cache = cache["fmaps"].shape[0]
476
+ fmaps = self.extract_img_feat(video[:,T_cache:], fmaps_chunk_size=fmaps_chunk_size)
477
+ fmaps = torch.cat([cache["fmaps"][None], fmaps], dim=1)
478
+ else:
479
+ fmaps = self.extract_img_feat(video, fmaps_chunk_size=fmaps_chunk_size)
480
+ fmaps_org = fmaps.clone()
481
+
482
+ metric_depth = F.interpolate(metric_depth.view(B*T, 1, H_, W_),
483
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1]).clamp(0.01, 200)
484
+ self.metric_unc_org = metric_unc.clone()
485
+ metric_unc = F.interpolate(metric_unc.view(B*T, 1, H_, W_),
486
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1])
487
+ if (self.stage == 2) & (self.training):
488
+ scale_rand = (torch.rand(B, T, device=video.device) - 0.5) + 1
489
+ point_map = scale_rand.view(B*T,1,1,1) * point_map
490
+
491
+ point_map_org = point_map.permute(0,3,1,2).view(B*T, 3, H_, W_).clone()
492
+ point_map = F.interpolate(point_map_org.clone(),
493
+ (self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 3, self.model_resolution[0], self.model_resolution[1])
494
+ # align the point map
495
+ point_map_org_train = point_map_org.view(B*T, 3, H_, W_).clone()
496
+
497
+ if (stage == 2):
498
+ # align the point map
499
+ try:
500
+ self.pred_points, scale_gt, shift_gt = affine_invariant_global_loss(
501
+ point_map_org_train.permute(0,2,3,1),
502
+ points_map_gt,
503
+ mask=self.metric_unc_org[:,0]>0.5,
504
+ align_resolution=32,
505
+ only_align=True
506
+ )
507
+ except:
508
+ scale_gt, shift_gt = torch.ones(B*T).to(video.device), torch.zeros(B*T,3).to(video.device)
509
+ self.scale_gt, self.shift_gt = scale_gt, shift_gt
510
+ else:
511
+ scale_est, shift_est = None, None
512
+
513
+ # extract the pts features
514
+ device = queries.device
515
+ assert H % self.stride == 0 and W % self.stride == 0
516
+
517
+ B, N, __ = queries.shape
518
+ queries_z = sample_features5d(metric_depth.view(B, T, 1, H, W),
519
+ queries[:,None], interp_mode="nearest").squeeze(1)
520
+ queries_z_unc = sample_features5d(metric_unc.view(B, T, 1, H, W),
521
+ queries[:,None], interp_mode="nearest").squeeze(1)
522
+
523
+ queries_rgb = sample_features5d(video.view(B, T, C, H, W),
524
+ queries[:,None], interp_mode="nearest").squeeze(1)
525
+ queries_point_map = sample_features5d(point_map.view(B, T, 3, H, W),
526
+ queries[:,None], interp_mode="nearest").squeeze(1)
527
+ if ((queries_z > 100)*(queries_z == 0)).sum() > 0:
528
+ import pdb; pdb.set_trace()
529
+
530
+ if overlap_d is not None:
531
+ queries_z[:,:overlap_d.shape[1],:] = overlap_d[...,None]
532
+ queries_point_map[:,:overlap_d.shape[1],2:] = overlap_d[...,None]
533
+
534
+ if pts_q_3d is not None:
535
+ scale_factor = (pts_q_3d[...,-1].permute(0,2,1) / queries_z[:,:pts_q_3d.shape[2],:]).squeeze().median()
536
+ queries_z[:,:pts_q_3d.shape[2],:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
537
+ queries_point_map[:,:pts_q_3d.shape[2],2:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
538
+
539
+ # normalize the points
540
+ self.min_pts, self.max_pts = queries_point_map.mean(dim=(0,1)) - 3*queries_point_map.std(dim=(0,1)), queries_point_map.mean(dim=(0,1)) + 3*queries_point_map.std(dim=(0,1))
541
+ queries_point_map = self.norm_xyz(queries_point_map)
542
+ queries_point_map_ = queries_point_map.reshape(B, 1, N, 3).expand(B, T, N, 3).clone()
543
+ point_map = self.norm_xyz(point_map.view(B, T, 3, H, W)).view(B*T, 3, H, W)
544
+
545
+ if z_gt_query is not None:
546
+ queries_z[:,:z_gt_query.shape[1],:] = z_gt_query[:,:,None]
547
+ mask_traj_gt = ((queries_z[:,:z_gt_query.shape[1],:] - z_gt_query[:,:,None])).abs() < 0.1
548
+ else:
549
+ if traj3d_gt is not None:
550
+ mask_traj_gt = torch.ones_like(queries_z[:, :traj3d_gt.shape[2]]).bool()
551
+ else:
552
+ mask_traj_gt = torch.ones_like(queries_z).bool()
553
+
554
+ queries_xyz = torch.cat([queries, queries_z], dim=-1)[:,None].repeat(1, T, 1, 1)
555
+ if cache is not None:
556
+ cache_T, cache_N = cache["track2d_pred_cache"].shape[0], cache["track2d_pred_cache"].shape[1]
557
+ cachexy = cache["track2d_pred_cache"].clone()
558
+ cachexy[...,0] = cachexy[...,0] * self.factor_x
559
+ cachexy[...,1] = cachexy[...,1] * self.factor_y
560
+ # initialize the 2d points with cache
561
+ queries_xyz[:,:cache_T,:cache_N,1:] = cachexy
562
+ queries_xyz[:,cache_T:,:cache_N,1:] = cachexy[-1:]
563
+ # initialize the 3d points with cache
564
+ queries_point_map_[:,:cache_T,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][None])
565
+ queries_point_map_[:,cache_T:,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][-1:][None])
566
+
567
+ if cam_gt is not None:
568
+ q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
569
+ intrs, rgbs=video_vis, visualize=False)
570
+ q_static_proj[..., 0] /= self.factor_x
571
+ q_static_proj[..., 1] /= self.factor_y
572
+
573
+
574
+ assert T >= 1 # A tracker needs at least two frames to track something
575
+ video = 2 * (video / 255.0) - 1.0
576
+ dtype = video.dtype
577
+ queried_frames = queries[:, :, 0].long()
578
+
579
+ queried_coords = queries[..., 1:3]
580
+ queried_coords = queried_coords / self.stride
581
+
582
+ # We store our predictions here
583
+ (all_coords_predictions, all_coords_xyz_predictions,all_vis_predictions,
584
+ all_confidence_predictions, all_cam_predictions, all_dynamic_prob_predictions,
585
+ all_cam_pts_predictions, all_world_tracks_predictions, all_world_tracks_refined_predictions,
586
+ all_scale_est, all_shift_est) = (
587
+ [],
588
+ [],
589
+ [],
590
+ [],
591
+ [],
592
+ [],
593
+ [],
594
+ [],
595
+ [],
596
+ [],
597
+ []
598
+ )
599
+
600
+ # We compute track features
601
+ fmaps_pyramid = []
602
+ point_map_pyramid = []
603
+ track_feat_pyramid = []
604
+ track_feat_support_pyramid = []
605
+ track_feat3d_pyramid = []
606
+ track_feat_support3d_pyramid = []
607
+ track_depth_support_pyramid = []
608
+ track_point_map_pyramid = []
609
+ track_point_map_support_pyramid = []
610
+ fmaps_pyramid.append(fmaps)
611
+ metric_depth = metric_depth
612
+ point_map = point_map
613
+ metric_depth_align = F.interpolate(metric_depth, scale_factor=0.25, mode='nearest')
614
+ point_map_align = F.interpolate(point_map, scale_factor=0.25, mode='nearest')
615
+ point_map_pyramid.append(point_map_align.view(B, T, 3, point_map_align.shape[-2], point_map_align.shape[-1]))
616
+ for i in range(self.corr_levels - 1):
617
+ fmaps_ = fmaps.reshape(
618
+ B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
619
+ )
620
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
621
+ fmaps = fmaps_.reshape(
622
+ B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
623
+ )
624
+ fmaps_pyramid.append(fmaps)
625
+ # downsample the depth
626
+ metric_depth_ = metric_depth_align.reshape(B*T,1,metric_depth_align.shape[-2],metric_depth_align.shape[-1])
627
+ metric_depth_ = F.interpolate(metric_depth_, scale_factor=0.5, mode='nearest')
628
+ metric_depth_align = metric_depth_.reshape(B,T,1,metric_depth_.shape[-2], metric_depth_.shape[-1])
629
+ # downsample the point map
630
+ point_map_ = point_map_align.reshape(B*T,3,point_map_align.shape[-2],point_map_align.shape[-1])
631
+ point_map_ = F.interpolate(point_map_, scale_factor=0.5, mode='nearest')
632
+ point_map_align = point_map_.reshape(B,T,3,point_map_.shape[-2], point_map_.shape[-1])
633
+ point_map_pyramid.append(point_map_align)
634
+
635
+ for i in range(self.corr_levels):
636
+ if cache is not None:
637
+ cache_N = cache["track_feat_pyramid"][i].shape[2]
638
+ track_feat_cached, track_feat_support_cached = cache["track_feat_pyramid"][i], cache["track_feat_support_pyramid"][i]
639
+ track_feat3d_cached, track_feat_support3d_cached = cache["track_feat3d_pyramid"][i], cache["track_feat_support3d_pyramid"][i]
640
+ track_point_map_cached, track_point_map_support_cached = self.norm_xyz(cache["track_point_map_pyramid"][i]), self.norm_xyz(cache["track_point_map_support_pyramid"][i])
641
+ queried_coords_new = queried_coords[:,cache_N:,:] / 2**i
642
+ queried_frames_new = queried_frames[:,cache_N:]
643
+ else:
644
+ queried_coords_new = queried_coords / 2**i
645
+ queried_frames_new = queried_frames
646
+ track_feat, track_feat_support = self.get_track_feat(
647
+ fmaps_pyramid[i],
648
+ queried_frames_new,
649
+ queried_coords_new,
650
+ support_radius=self.corr_radius,
651
+ )
652
+ # get 3d track feat
653
+ track_point_map, track_point_map_support = self.get_track_feat(
654
+ point_map_pyramid[i],
655
+ queried_frames_new,
656
+ queried_coords_new,
657
+ support_radius=self.corr3d_radius,
658
+ )
659
+ track_feat3d, track_feat_support3d = self.get_track_feat(
660
+ fmaps_pyramid[i],
661
+ queried_frames_new,
662
+ queried_coords_new,
663
+ support_radius=self.corr3d_radius,
664
+ )
665
+ if cache is not None:
666
+ track_feat = torch.cat([track_feat_cached, track_feat], dim=2)
667
+ track_point_map = torch.cat([track_point_map_cached, track_point_map], dim=2)
668
+ track_feat_support = torch.cat([track_feat_support_cached[:,0], track_feat_support], dim=2)
669
+ track_point_map_support = torch.cat([track_point_map_support_cached[:,0], track_point_map_support], dim=2)
670
+ track_feat3d = torch.cat([track_feat3d_cached, track_feat3d], dim=2)
671
+ track_feat_support3d = torch.cat([track_feat_support3d_cached[:,0], track_feat_support3d], dim=2)
672
+ track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
673
+ track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
674
+ track_feat3d_pyramid.append(track_feat3d.repeat(1, T, 1, 1))
675
+ track_feat_support3d_pyramid.append(track_feat_support3d.unsqueeze(1))
676
+ track_point_map_pyramid.append(track_point_map.repeat(1, T, 1, 1))
677
+ track_point_map_support_pyramid.append(track_point_map_support.unsqueeze(1))
678
+
679
+
680
+ D_coords = 2
681
+ (coord_preds, coords_xyz_preds, vis_preds, confidence_preds,
682
+ dynamic_prob_preds, cam_preds, pts3d_cam_pred, world_tracks_pred,
683
+ world_tracks_refined_pred, point_map_preds, scale_ests, shift_ests) = (
684
+ [], [], [], [], [], [], [], [], [], [], [], []
685
+ )
686
+
687
+ c2w_ests = []
688
+ vis = torch.zeros((B, T, N), device=device).float()
689
+ confidence = torch.zeros((B, T, N), device=device).float()
690
+ dynamic_prob = torch.zeros((B, T, N), device=device).float()
691
+ pro_analysis_w = torch.zeros((B, T, N), device=device).float()
692
+
693
+ coords = queries_xyz[...,1:].clone()
694
+ coords[...,:2] /= self.stride
695
+ # coords[...,:2] = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()[...,:2]
696
+ # initialize the 3d points
697
+ coords_xyz = queries_point_map_.clone()
698
+
699
+ # if cache is not None:
700
+ # viser = Visualizer(save_dir=".", grayscale=True,
701
+ # fps=10, pad_value=50, tracks_leave_trace=0)
702
+ # coords_clone = coords.clone()
703
+ # coords_clone[...,:2] *= self.stride
704
+ # coords_clone[..., 0] /= self.factor_x
705
+ # coords_clone[..., 1] /= self.factor_y
706
+ # viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test")
707
+ # import pdb; pdb.set_trace()
708
+
709
+ if init_pose:
710
+ q_init_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
711
+ intrs, rgbs=video_vis, visualize=False)
712
+ q_init_proj[..., 0] /= self.stride
713
+ q_init_proj[..., 1] /= self.stride
714
+
715
+ r = 2 * self.corr_radius + 1
716
+ r_depth = 2 * self.corr3d_radius + 1
717
+ anchor_loss = 0
718
+ # two current states
719
+ self.c2w_est_curr = torch.eye(4, device=device).repeat(B, T , 1, 1)
720
+ coords_proj_curr = coords.view(B * T, N, 3)[...,:2]
721
+ if init_pose:
722
+ self.c2w_est_curr = cam_gt.to(coords_proj_curr.device).to(coords_proj_curr.dtype)
723
+ sync_loss = 0
724
+ if stage == 2:
725
+ extra_sparse_tokens = self.scale_shift_tokens[:,:,None,:].repeat(B, 1, T, 1)
726
+ extra_dense_tokens = self.residual_embedding[None,None].repeat(B, T, 1, 1, 1)
727
+ xyz_pos_enc = posenc(point_map_pyramid[-2].permute(0,1,3,4,2), min_deg=0, max_deg=10).permute(0,1,4,2,3)
728
+ extra_dense_tokens = torch.cat([xyz_pos_enc, extra_dense_tokens, fmaps_pyramid[-2]], dim=2)
729
+ extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) c h w')
730
+ extra_dense_tokens = self.dense_mlp(extra_dense_tokens)
731
+ extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) c h w -> b t c h w', b=B, t=T)
732
+ else:
733
+ extra_sparse_tokens = None
734
+ extra_dense_tokens = None
735
+
736
+ scale_est, shift_est = torch.ones(B, T, 1, 1, device=device), torch.zeros(B, T, 1, 3, device=device)
737
+ residual_point = torch.zeros(B, T, 3, self.model_resolution[0]//self.stride,
738
+ self.model_resolution[1]//self.stride, device=device)
739
+
740
+ for it in range(iters):
741
+ # query points scale and shift
742
+ scale_est_query = torch.gather(scale_est, dim=1, index=queries[:,:,None,:1].long())
743
+ shift_est_query = torch.gather(shift_est, dim=1, index=queries[:,:,None,:1].long().repeat(1, 1, 1, 3))
744
+
745
+ coords = coords.detach() # B T N 3
746
+ coords_xyz = coords_xyz.detach()
747
+ vis = vis.detach()
748
+ confidence = confidence.detach()
749
+ dynamic_prob = dynamic_prob.detach()
750
+ pro_analysis_w = pro_analysis_w.detach()
751
+ coords_init = coords.view(B * T, N, 3)
752
+ coords_xyz_init = coords_xyz.view(B * T, N, 3)
753
+ corr_embs = []
754
+ corr_depth_embs = []
755
+ corr_feats = []
756
+ for i in range(self.corr_levels):
757
+ # K_level = int(32*0.8**(i))
758
+ K_level = 16
759
+ corr_feat = self.get_correlation_feat(
760
+ fmaps_pyramid[i], coords_init[...,:2] / 2**i
761
+ )
762
+ #NOTE: update the point map
763
+ residual_point_i = F.interpolate(residual_point.view(B*T,3,residual_point.shape[-2],residual_point.shape[-1]),
764
+ size=(point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1]), mode='nearest')
765
+ point_map_pyramid_i = (self.denorm_xyz(point_map_pyramid[i]) * scale_est[...,None]
766
+ + shift_est.permute(0,1,3,2)[...,None] + residual_point_i.view(B,T,3,point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1])).clone().detach()
767
+
768
+ corr_point_map = self.get_correlation_feat(
769
+ self.norm_xyz(point_map_pyramid_i), coords_proj_curr / 2**i, radius=self.corr3d_radius
770
+ )
771
+
772
+ corr_point_feat = self.get_correlation_feat(
773
+ fmaps_pyramid[i], coords_proj_curr / 2**i, radius=self.corr3d_radius
774
+ )
775
+ track_feat_support = (
776
+ track_feat_support_pyramid[i]
777
+ .view(B, 1, r, r, N, self.latent_dim)
778
+ .squeeze(1)
779
+ .permute(0, 3, 1, 2, 4)
780
+ )
781
+ track_feat_support3d = (
782
+ track_feat_support3d_pyramid[i]
783
+ .view(B, 1, r_depth, r_depth, N, self.latent_dim)
784
+ .squeeze(1)
785
+ .permute(0, 3, 1, 2, 4)
786
+ )
787
+ #NOTE: update the point map
788
+ track_point_map_support_pyramid_i = (self.denorm_xyz(track_point_map_support_pyramid[i]) * scale_est_query.view(B,1,1,N,1)
789
+ + shift_est_query.view(B,1,1,N,3)).clone().detach()
790
+
791
+ track_point_map_support = (
792
+ self.norm_xyz(track_point_map_support_pyramid_i)
793
+ .view(B, 1, r_depth, r_depth, N, 3)
794
+ .squeeze(1)
795
+ .permute(0, 3, 1, 2, 4)
796
+ )
797
+ corr_volume = torch.einsum(
798
+ "btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
799
+ )
800
+ corr_emb = self.corr_mlp(corr_volume.reshape(B, T, N, r * r * r * r))
801
+
802
+ with torch.no_grad():
803
+ rel_pos_query_ = track_point_map_support - track_point_map_support[:,:,self.corr3d_radius,self.corr3d_radius,:][...,None,None,:]
804
+ rel_pos_target_ = corr_point_map - coords_xyz_init.view(B, T, N, 1, 1, 3)
805
+ # select the top 9 points
806
+ rel_pos_query_idx = rel_pos_query_.norm(dim=-1).view(B, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
807
+ rel_pos_target_idx = rel_pos_target_.norm(dim=-1).view(B, T, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
808
+ rel_pos_query_ = torch.gather(rel_pos_query_.view(B, N, -1, 3), dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 3))
809
+ rel_pos_target_ = torch.gather(rel_pos_target_.view(B, T, N, -1, 3), dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 3))
810
+ rel_pos_query = rel_pos_query_
811
+ rel_pos_target = rel_pos_target_
812
+ rel_pos_query = posenc(rel_pos_query, min_deg=0, max_deg=12)
813
+ rel_pos_target = posenc(rel_pos_target, min_deg=0, max_deg=12)
814
+ rel_pos_target = self.rel_pos_mlp(rel_pos_target)
815
+ rel_pos_query = self.rel_pos_mlp(rel_pos_query)
816
+ with torch.no_grad():
817
+ # integrate with feature
818
+ track_feat_support_ = rearrange(track_feat_support3d, 'b n r k c -> b n (r k) c', r=r_depth, k=r_depth, n=N, b=B)
819
+ track_feat_support_ = torch.gather(track_feat_support_, dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 128))
820
+ queried_feat = torch.cat([rel_pos_query, track_feat_support_], dim=-1)
821
+ corr_feat_ = rearrange(corr_point_feat, 'b t n r k c -> b t n (r k) c', t=T, n=N, b=B)
822
+ corr_feat_ = torch.gather(corr_feat_, dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 128))
823
+ target_feat = torch.cat([rel_pos_target, corr_feat_], dim=-1)
824
+
825
+ # 3d attention
826
+ queried_feat = self.corr_xyz_mlp(queried_feat)
827
+ target_feat = self.corr_xyz_mlp(target_feat)
828
+ queried_feat = repeat(queried_feat, 'b n k c -> b t n k c', k=K_level, t=T, n=N, b=B)
829
+ corr_depth_emb = self.corr_transformer[0](queried_feat.reshape(B*T*N,-1,128),
830
+ target_feat.reshape(B*T*N,-1,128),
831
+ target_rel_pos=rel_pos_target.reshape(B*T*N,-1,128))
832
+ corr_depth_emb = rearrange(corr_depth_emb, '(b t n) 1 c -> b t n c', t=T, n=N, b=B)
833
+ corr_depth_emb = self.corr_depth_mlp(corr_depth_emb)
834
+ valid_mask = self.denorm_xyz(coords_xyz_init).view(B, T, N, -1)[...,2:3] > 0
835
+ corr_depth_embs.append(corr_depth_emb*valid_mask)
836
+
837
+ corr_embs.append(corr_emb)
838
+ corr_embs = torch.cat(corr_embs, dim=-1)
839
+ corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
840
+ corr_depth_embs = torch.cat(corr_depth_embs, dim=-1)
841
+ corr_depth_embs = corr_depth_embs.view(B, T, N, corr_depth_embs.shape[-1])
842
+ transformer_input = [vis[..., None], confidence[..., None], corr_embs]
843
+ transformer_input_depth = [vis[..., None], confidence[..., None], corr_depth_embs]
844
+
845
+ rel_coords_forward = coords[:,:-1,...,:2] - coords[:,1:,...,:2]
846
+ rel_coords_backward = coords[:, 1:,...,:2] - coords[:, :-1,...,:2]
847
+
848
+ rel_xyz_forward = coords_xyz[:,:-1,...,:3] - coords_xyz[:,1:,...,:3]
849
+ rel_xyz_backward = coords_xyz[:, 1:,...,:3] - coords_xyz[:, :-1,...,:3]
850
+
851
+ rel_coords_forward = torch.nn.functional.pad(
852
+ rel_coords_forward, (0, 0, 0, 0, 0, 1)
853
+ )
854
+ rel_coords_backward = torch.nn.functional.pad(
855
+ rel_coords_backward, (0, 0, 0, 0, 1, 0)
856
+ )
857
+ rel_xyz_forward = torch.nn.functional.pad(
858
+ rel_xyz_forward, (0, 0, 0, 0, 0, 1)
859
+ )
860
+ rel_xyz_backward = torch.nn.functional.pad(
861
+ rel_xyz_backward, (0, 0, 0, 0, 1, 0)
862
+ )
863
+
864
+ scale = (
865
+ torch.tensor(
866
+ [self.model_resolution[1], self.model_resolution[0]],
867
+ device=coords.device,
868
+ )
869
+ / self.stride
870
+ )
871
+ rel_coords_forward = rel_coords_forward / scale
872
+ rel_coords_backward = rel_coords_backward / scale
873
+
874
+ rel_pos_emb_input = posenc(
875
+ torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
876
+ min_deg=0,
877
+ max_deg=10,
878
+ ) # batch, num_points, num_frames, 84
879
+ rel_xyz_emb_input = posenc(
880
+ torch.cat([rel_xyz_forward, rel_xyz_backward], dim=-1),
881
+ min_deg=0,
882
+ max_deg=10,
883
+ ) # batch, num_points, num_frames, 126
884
+ rel_xyz_emb_input = self.xyz_mlp(rel_xyz_emb_input)
885
+ transformer_input.append(rel_pos_emb_input)
886
+ transformer_input_depth.append(rel_xyz_emb_input)
887
+ # get the queries world
888
+ with torch.no_grad():
889
+ # update the query points with scale and shift
890
+ queries_xyz_i = queries_xyz.clone().detach()
891
+ queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
892
+ _, _, q_xyz_cam = self.track_from_cam(queries_xyz_i, self.c2w_est_curr,
893
+ intrs, rgbs=None, visualize=False)
894
+ q_xyz_cam = self.norm_xyz(q_xyz_cam)
895
+
896
+ query_t = queries[:,None,:,:1].repeat(B, T, 1, 1)
897
+ q_xyz_cam = torch.cat([query_t/T, q_xyz_cam], dim=-1)
898
+ T_all = torch.arange(T, device=device)[None,:,None,None].repeat(B, 1, N, 1)
899
+ current_xyzt = torch.cat([T_all/T, coords_xyz_init.view(B, T, N, -1)], dim=-1)
900
+ rel_pos_query_glob = q_xyz_cam - current_xyzt
901
+ # embed the confidence and dynamic probability
902
+ confidence_curr = torch.sigmoid(confidence[...,None])
903
+ dynamic_prob_curr = torch.sigmoid(dynamic_prob[...,None]).mean(dim=1, keepdim=True).repeat(1,T,1,1)
904
+ # embed the confidence and dynamic probability
905
+ rel_pos_query_glob = torch.cat([rel_pos_query_glob, confidence_curr, dynamic_prob_curr], dim=-1)
906
+ rel_pos_query_glob = posenc(rel_pos_query_glob, min_deg=0, max_deg=12)
907
+ transformer_input_depth.append(rel_pos_query_glob)
908
+
909
+ x = (
910
+ torch.cat(transformer_input, dim=-1)
911
+ .permute(0, 2, 1, 3)
912
+ .reshape(B * N, T, -1)
913
+ )
914
+ x_depth = (
915
+ torch.cat(transformer_input_depth, dim=-1)
916
+ .permute(0, 2, 1, 3)
917
+ .reshape(B * N, T, -1)
918
+ )
919
+ x_depth = self.proj_xyz_embed(x_depth)
920
+
921
+ x = x + self.interpolate_time_embed(x, T)
922
+ x = x.view(B, N, T, -1) # (B N) T D -> B N T D
923
+ x_depth = x_depth + self.interpolate_time_embed(x_depth, T)
924
+ x_depth = x_depth.view(B, N, T, -1) # (B N) T D -> B N T D
925
+ delta, delta_depth, delta_dynamic_prob, delta_pro_analysis_w, scale_shift_out, dense_res_out = self.updateformer3D(
926
+ x,
927
+ x_depth,
928
+ self.updateformer,
929
+ add_space_attn=add_space_attn,
930
+ extra_sparse_tokens=extra_sparse_tokens,
931
+ extra_dense_tokens=extra_dense_tokens,
932
+ )
933
+ # update the scale and shift
934
+ if scale_shift_out is not None:
935
+ extra_sparse_tokens = extra_sparse_tokens + scale_shift_out[...,:128]
936
+ scale_update = scale_shift_out[:,:1,:,-1].permute(0,2,1)[...,None]
937
+ shift_update = scale_shift_out[:,1:,:,-1].permute(0,2,1)[...,None]
938
+ scale_est = scale_est + scale_update
939
+ shift_est[...,2:] = shift_est[...,2:] + shift_update / 10
940
+ # dense tokens update
941
+ extra_dense_tokens = extra_dense_tokens + dense_res_out[:,:,-128:]
942
+ res_low = dense_res_out[:,:,:3]
943
+ up_mask = self.upsample_transformer(extra_dense_tokens.mean(dim=1), res_low)
944
+ up_mask = repeat(up_mask, "b k h w -> b s k h w", s=T)
945
+ up_mask = rearrange(up_mask, "b s c h w -> (b s) 1 c h w")
946
+ res_up = self.upsample_with_mask(
947
+ rearrange(res_low, 'b t c h w -> (b t) c h w'),
948
+ up_mask,
949
+ )
950
+ res_up = rearrange(res_up, "(b t) c h w -> b t c h w", b=B, t=T)
951
+ # residual_point = residual_point + res_up
952
+
953
+ delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
954
+ delta_vis = delta[..., D_coords].permute(0, 2, 1)
955
+ delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
956
+
957
+ vis = vis + delta_vis
958
+ confidence = confidence + delta_confidence
959
+ dynamic_prob = dynamic_prob + delta_dynamic_prob[...,0].permute(0, 2, 1)
960
+ pro_analysis_w = pro_analysis_w + delta_pro_analysis_w[...,0].permute(0, 2, 1)
961
+ # update the depth
962
+ vis_est = torch.sigmoid(vis.detach())
963
+
964
+ delta_xyz = delta_depth[...,:3].permute(0,2,1,3)
965
+ denorm_delta_depth = (self.denorm_xyz(coords_xyz+delta_xyz)-self.denorm_xyz(coords_xyz))[...,2:3]
966
+
967
+
968
+ delta_depth_ = denorm_delta_depth.detach()
969
+ delta_coords = torch.cat([delta_coords, delta_depth_],dim=-1)
970
+ coords = coords + delta_coords
971
+ coords_append = coords.clone()
972
+ coords_xyz_append = self.denorm_xyz(coords_xyz + delta_xyz).clone()
973
+
974
+ coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
975
+ coords_append[..., 0] /= self.factor_x
976
+ coords_append[..., 1] /= self.factor_y
977
+
978
+ # get the camera pose from tracks
979
+ dynamic_prob_curr = torch.sigmoid(dynamic_prob.detach())*torch.sigmoid(pro_analysis_w)
980
+ mask_out = (coords_append[...,0]<W_)&(coords_append[...,0]>0)&(coords_append[...,1]<H_)&(coords_append[...,1]>0)
981
+ if query_no_BA:
982
+ dynamic_prob_curr[:,:,:ba_len] = torch.ones_like(dynamic_prob_curr[:,:,:ba_len])
983
+ point_map_org_i = scale_est.view(B*T,1,1,1)*point_map_org.clone().detach() + shift_est.view(B*T,3,1,1)
984
+ # depth_unproj = bilinear_sampler(point_map_org_i, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,2,:,0].detach()
985
+
986
+ depth_unproj_neg = self.get_correlation_feat(
987
+ point_map_org_i.view(B,T,3,point_map_org_i.shape[-2], point_map_org_i.shape[-1]),
988
+ coords_append[...,:2].view(B*T, N, 2), radius=self.corr3d_radius
989
+ )[..., 2]
990
+ depth_diff = (depth_unproj_neg.view(B,T,N,-1) - coords_append[...,2:]).abs()
991
+ idx_neg = torch.argmin(depth_diff, dim=-1)
992
+ depth_unproj = depth_unproj_neg.view(B,T,N,-1)[torch.arange(B)[:, None, None, None],
993
+ torch.arange(T)[None, :, None, None],
994
+ torch.arange(N)[None, None, :, None],
995
+ idx_neg.view(B,T,N,1)].view(B*T, N)
996
+
997
+ unc_unproj = bilinear_sampler(self.metric_unc_org, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,0,:,0].detach()
998
+ depth_unproj[unc_unproj<0.5] = 0.0
999
+
1000
+ # replace the depth for visible and solid points
1001
+ conf_est = torch.sigmoid(confidence.detach())
1002
+ replace_mask = (depth_unproj.view(B,T,N)>0.0) * (vis_est>0.5) # * (conf_est>0.5)
1003
+ #NOTE: way1: find the jitter points
1004
+ depth_rel = (depth_unproj.view(B, T, N) - queries_z.permute(0, 2, 1))
1005
+ depth_ddt1 = depth_rel[:, 1:, :] - depth_rel[:, :-1, :]
1006
+ depth_ddt2 = depth_rel[:, 2:, :] - 2 * depth_rel[:, 1:-1, :] + depth_rel[:, :-2, :]
1007
+ jitter_mask = torch.zeros_like(depth_rel, dtype=torch.bool)
1008
+ if depth_ddt2.abs().max()>0:
1009
+ thre2 = torch.quantile(depth_ddt2.abs()[depth_ddt2.abs()>0], replace_ratio)
1010
+ jitter_mask[:, 1:-1, :] = (depth_ddt2.abs() < thre2)
1011
+ thre1 = torch.quantile(depth_ddt1.abs()[depth_ddt1.abs()>0], replace_ratio)
1012
+ jitter_mask[:, :-1, :] *= (depth_ddt1.abs() < thre1)
1013
+ replace_mask = replace_mask * jitter_mask
1014
+
1015
+ #NOTE: way2: top k topological change detection
1016
+ # coords_2d_lift = coords_append.clone()
1017
+ # coords_2d_lift[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
1018
+ # coords_2d_lift = self.cam_from_track(coords_2d_lift.clone(), intrs_org, only_cam_pts=True)
1019
+ # coords_2d_lift[~replace_mask] = coords_xyz_append[~replace_mask]
1020
+ # import pdb; pdb.set_trace()
1021
+ # jitter_mask = get_topo_mask(coords_xyz_append, coords_2d_lift, replace_ratio)
1022
+ # replace_mask = replace_mask * jitter_mask
1023
+
1024
+ # replace the depth
1025
+ if self.training:
1026
+ replace_mask = torch.zeros_like(replace_mask)
1027
+ coords_append[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
1028
+ coords_xyz_unproj = self.cam_from_track(coords_append.clone(), intrs_org, only_cam_pts=True)
1029
+ coords[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
1030
+ # coords_xyz_append[replace_mask] = coords_xyz_unproj[replace_mask]
1031
+ coords_xyz_append_refine = coords_xyz_append.clone()
1032
+ coords_xyz_append_refine[replace_mask] = coords_xyz_unproj[replace_mask]
1033
+
1034
+ c2w_traj_est, cam_pts_est, intrs_refine, coords_refine, world_tracks, world_tracks_refined, c2w_traj_init = self.cam_from_track(coords_append.clone(),
1035
+ intrs_org, dynamic_prob_curr, queries_z_unc, conf_est*vis_est*mask_out.float(),
1036
+ track_feat_concat=x_depth, tracks_xyz=coords_xyz_append_refine, init_pose=init_pose,
1037
+ query_pts=queries_xyz_i, fixed_cam=fixed_cam, depth_unproj=depth_unproj, cam_gt=cam_gt)
1038
+ intrs_org = intrs_refine.view(B, T, 3, 3).to(intrs_org.dtype)
1039
+
1040
+ # get the queries world
1041
+ self.c2w_est_curr = c2w_traj_est.detach()
1042
+
1043
+ # update coords and coords_append
1044
+ coords[..., 2] = (cam_pts_est)[...,2]
1045
+ coords_append[..., 2] = (cam_pts_est)[...,2]
1046
+
1047
+ # update coords_xyz_append
1048
+ # coords_xyz_append = cam_pts_est
1049
+ coords_xyz = self.norm_xyz(cam_pts_est)
1050
+
1051
+
1052
+ # proj
1053
+ coords_xyz_de = coords_xyz_append.clone()
1054
+ coords_xyz_de[coords_xyz_de[...,2].abs()<1e-6] = -1e-4
1055
+ mask_nan = coords_xyz_de[...,2].abs()<1e-2
1056
+ coords_proj = torch.einsum("btij,btnj->btni", intrs_org, coords_xyz_de/coords_xyz_de[...,2:3].abs())[...,:2]
1057
+ coords_proj[...,0] *= self.factor_x
1058
+ coords_proj[...,1] *= self.factor_y
1059
+ coords_proj[...,:2] /= float(self.stride)
1060
+ # make sure it is aligned with 2d tracking
1061
+ coords_proj_curr = coords[...,:2].view(B*T, N, 2).detach()
1062
+ vis_est = (vis_est>0.5).float()
1063
+ sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
1064
+ # coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
1065
+ # if torch.isnan(coords_proj_curr).sum()>0:
1066
+ # import pdb; pdb.set_trace()
1067
+
1068
+ if False:
1069
+ point_map_resize = point_map.clone().view(B, T, 3, H, W)
1070
+ update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
1071
+ coords_append_resize = coords.clone().detach()
1072
+ coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
1073
+ update_track_input = self.norm_xyz(cam_pts_est)*5
1074
+ update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
1075
+ update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
1076
+ update = self.update_pointmap.stablizer(update_input,
1077
+ update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
1078
+ #NOTE: update the point map
1079
+ point_map_resize += update
1080
+ point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
1081
+ size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
1082
+ point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
1083
+ point_map_preds.append(self.denorm_xyz(point_map_refine_out))
1084
+ point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
1085
+
1086
+ # if torch.isnan(coords).sum()>0:
1087
+ # import pdb; pdb.set_trace()
1088
+ #NOTE: the 2d tracking + unproject depth
1089
+ fix_cam_est = coords_append.clone()
1090
+ fix_cam_est[...,2] = depth_unproj
1091
+ fix_cam_pts = self.cam_from_track(
1092
+ fix_cam_est, intrs_org, only_cam_pts=True
1093
+ )
1094
+
1095
+ coord_preds.append(coords_append)
1096
+ coords_xyz_preds.append(coords_xyz_append)
1097
+ vis_preds.append(vis)
1098
+ cam_preds.append(c2w_traj_init)
1099
+ pts3d_cam_pred.append(cam_pts_est)
1100
+ world_tracks_pred.append(world_tracks)
1101
+ world_tracks_refined_pred.append(world_tracks_refined)
1102
+ confidence_preds.append(confidence)
1103
+ dynamic_prob_preds.append(dynamic_prob)
1104
+ scale_ests.append(scale_est)
1105
+ shift_ests.append(shift_est)
1106
+
1107
+ if stage!=0:
1108
+ all_coords_predictions.append([coord for coord in coord_preds])
1109
+ all_coords_xyz_predictions.append([coord_xyz for coord_xyz in coords_xyz_preds])
1110
+ all_vis_predictions.append(vis_preds)
1111
+ all_confidence_predictions.append(confidence_preds)
1112
+ all_dynamic_prob_predictions.append(dynamic_prob_preds)
1113
+ all_cam_predictions.append([cam for cam in cam_preds])
1114
+ all_cam_pts_predictions.append([pts for pts in pts3d_cam_pred])
1115
+ all_world_tracks_predictions.append([world_tracks for world_tracks in world_tracks_pred])
1116
+ all_world_tracks_refined_predictions.append([world_tracks_refined for world_tracks_refined in world_tracks_refined_pred])
1117
+ all_scale_est.append(scale_ests)
1118
+ all_shift_est.append(shift_ests)
1119
+ if stage!=0:
1120
+ train_data = (
1121
+ all_coords_predictions,
1122
+ all_coords_xyz_predictions,
1123
+ all_vis_predictions,
1124
+ all_confidence_predictions,
1125
+ all_dynamic_prob_predictions,
1126
+ all_cam_predictions,
1127
+ all_cam_pts_predictions,
1128
+ all_world_tracks_predictions,
1129
+ all_world_tracks_refined_predictions,
1130
+ all_scale_est,
1131
+ all_shift_est,
1132
+ torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
1133
+ )
1134
+ else:
1135
+ train_data = None
1136
+ # resize back
1137
+ # init the trajectories by camera motion
1138
+
1139
+ # if cache is not None:
1140
+ # viser = Visualizer(save_dir=".", grayscale=True,
1141
+ # fps=10, pad_value=50, tracks_leave_trace=0)
1142
+ # coords_clone = coords.clone()
1143
+ # coords_clone[...,:2] *= self.stride
1144
+ # coords_clone[..., 0] /= self.factor_x
1145
+ # coords_clone[..., 1] /= self.factor_y
1146
+ # viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test_refine")
1147
+ # import pdb; pdb.set_trace()
1148
+
1149
+ if train_data is not None:
1150
+ # get the gt pts in the world coordinate
1151
+ self_supervised = False
1152
+ if (traj3d_gt is not None):
1153
+ if traj3d_gt[...,2].abs().max()>0:
1154
+ gt_cam_pts = self.cam_from_track(
1155
+ traj3d_gt, intrs_org, only_cam_pts=True
1156
+ )
1157
+ else:
1158
+ self_supervised = True
1159
+ else:
1160
+ self_supervised = True
1161
+
1162
+ if self_supervised:
1163
+ gt_cam_pts = self.cam_from_track(
1164
+ coord_preds[-1].detach(), intrs_org, only_cam_pts=True
1165
+ )
1166
+
1167
+ if cam_gt is not None:
1168
+ gt_world_pts = torch.einsum(
1169
+ "btij,btnj->btni",
1170
+ cam_gt[...,:3,:3],
1171
+ gt_cam_pts
1172
+ ) + cam_gt[...,None, :3,3] # B T N 3
1173
+ else:
1174
+ gt_world_pts = torch.einsum(
1175
+ "btij,btnj->btni",
1176
+ self.c2w_est_curr[...,:3,:3],
1177
+ gt_cam_pts
1178
+ ) + self.c2w_est_curr[...,None, :3,3] # B T N 3
1179
+ # update the query points with scale and shift
1180
+ queries_xyz_i = queries_xyz.clone().detach()
1181
+ queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
1182
+ q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz_i,
1183
+ self.c2w_est_curr,
1184
+ intrs, rgbs=video_vis, visualize=False)
1185
+
1186
+ q_static_proj[..., 0] /= self.factor_x
1187
+ q_static_proj[..., 1] /= self.factor_y
1188
+ cam_gt = self.c2w_est_curr[:,:,:3,:]
1189
+
1190
+ if traj3d_gt is not None:
1191
+ ret_loss = self.loss(train_data, traj3d_gt,
1192
+ vis_gt, None, cam_gt, queries_z_unc,
1193
+ q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
1194
+ gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
1195
+ c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
1196
+ shift_est=shift_est, point_map_org_train=point_map_org_train)
1197
+ else:
1198
+ ret_loss = self.loss(train_data, traj3d_gt,
1199
+ vis_gt, None, cam_gt, queries_z_unc,
1200
+ q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
1201
+ gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
1202
+ c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
1203
+ shift_est=shift_est, point_map_org_train=point_map_org_train)
1204
+ if custom_vid:
1205
+ sync_loss = 0*sync_loss
1206
+ if (sync_loss > 50) and (stage==1):
1207
+ ret_loss = (0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss) + (0*sync_loss,)
1208
+ else:
1209
+ ret_loss = ret_loss+(10*sync_loss,)
1210
+
1211
+ else:
1212
+ ret_loss = None
1213
+
1214
+ color_pts = torch.cat([pts3d_cam_pred[-1], queries_rgb[:,None].repeat(1, T, 1, 1)], dim=-1)
1215
+
1216
+ #TODO: For evaluation. We found our model have some bias on invisible points after training. (to be fixed)
1217
+ vis_pred_out = torch.sigmoid(vis_preds[-1]) + 0.2
1218
+
1219
+ ret = {"preds": coord_preds[-1], "vis_pred": vis_pred_out,
1220
+ "conf_pred": torch.sigmoid(confidence_preds[-1]),
1221
+ "cam_pred": self.c2w_est_curr,"loss": ret_loss}
1222
+
1223
+ cache = {
1224
+ "fmaps": fmaps_org[0].detach(),
1225
+ "track_feat_support3d_pyramid": [track_feat_support3d_pyramid[i].detach() for i in range(len(track_feat_support3d_pyramid))],
1226
+ "track_point_map_support_pyramid": [self.denorm_xyz(track_point_map_support_pyramid[i].detach()) for i in range(len(track_point_map_support_pyramid))],
1227
+ "track_feat3d_pyramid": [track_feat3d_pyramid[i].detach() for i in range(len(track_feat3d_pyramid))],
1228
+ "track_point_map_pyramid": [self.denorm_xyz(track_point_map_pyramid[i].detach()) for i in range(len(track_point_map_pyramid))],
1229
+ "track_feat_pyramid": [track_feat_pyramid[i].detach() for i in range(len(track_feat_pyramid))],
1230
+ "track_feat_support_pyramid": [track_feat_support_pyramid[i].detach() for i in range(len(track_feat_support_pyramid))],
1231
+ "track2d_pred_cache": coord_preds[-1][0].clone().detach(),
1232
+ "track3d_pred_cache": pts3d_cam_pred[-1][0].clone().detach(),
1233
+ }
1234
+ #NOTE: update the point map
1235
+ point_map_org = scale_est.view(B*T,1,1,1)*point_map_org + shift_est.view(B*T,3,1,1)
1236
+ point_map_org_refined = point_map_org
1237
+ return ret, torch.sigmoid(dynamic_prob_preds[-1])*queries_z_unc[:,None,:,0], coord_preds[-1], color_pts, intrs_org, point_map_org_refined, cache
1238
+
1239
+ def track_d2_loss(self, tracks3d, stride=[1,2,3], dyn_prob=None, mask=None):
1240
+ """
1241
+ tracks3d: B T N 3
1242
+ dyn_prob: B T N 1
1243
+ """
1244
+ r = 0.8
1245
+ t_diff_total = 0.0
1246
+ for i, s_ in enumerate(stride):
1247
+ w_ = r**i
1248
+ tracks3d_stride = tracks3d[:, ::s_, :, :] # B T//s_ N 3
1249
+ t_diff_tracks3d = (tracks3d_stride[:, 1:, :, :] - tracks3d_stride[:, :-1, :, :])
1250
+ t_diff2 = (t_diff_tracks3d[:, 1:, :, :] - t_diff_tracks3d[:, :-1, :, :])
1251
+ t_diff_total += w_*(t_diff2.norm(dim=-1).mean())
1252
+
1253
+ return 1e2*t_diff_total
1254
+
1255
+ def loss(self, train_data, traj3d_gt=None,
1256
+ vis_gt=None, static_tracks_gt=None, cam_gt=None,
1257
+ z_unc=None, q_xyz_world=None, q_static_proj=None, anchor_loss=0, valid_only=False,
1258
+ gt_world_pts=None, mask_traj_gt=None, intrs=None, c2w_ests=None, custom_vid=False, video_vis=None, stage=0,
1259
+ fix_cam_pts=None, point_map_preds=None, points_map_gt=None, metric_unc=None, scale_est=None, shift_est=None, point_map_org_train=None):
1260
+ """
1261
+ Compute the loss of 3D tracking problem
1262
+
1263
+ """
1264
+
1265
+ (
1266
+ coord_predictions, coords_xyz_predictions, vis_predictions, confidence_predicitons,
1267
+ dynamic_prob_predictions, camera_predictions, cam_pts_predictions, world_tracks_predictions,
1268
+ world_tracks_refined_predictions, scale_ests, shift_ests, valid_mask
1269
+ ) = train_data
1270
+ B, T, _, _ = cam_gt.shape
1271
+ if (stage == 2) and self.training:
1272
+ # get the scale and shift gt
1273
+ self.metric_unc_org[:,0] = self.metric_unc_org[:,0] * (points_map_gt.norm(dim=-1)>0).float() * (self.metric_unc_org[:,0]>0.5).float()
1274
+ if not (self.scale_gt==torch.ones(B*T).to(self.scale_gt.device)).all():
1275
+ scale_gt, shift_gt = self.scale_gt, self.shift_gt
1276
+ scale_re = scale_gt[:4].mean()
1277
+ scale_loss = 0.0
1278
+ shift_loss = 0.0
1279
+ for i_scale in range(len(scale_ests[0])):
1280
+ scale_loss += 0.8**(len(scale_ests[0])-i_scale-1)*10*(scale_gt - scale_re*scale_ests[0][i_scale].view(-1)).abs().mean()
1281
+ shift_loss += 0.8**(len(shift_ests[0])-i_scale-1)*10*(shift_gt - scale_re*shift_ests[0][i_scale].view(-1,3)).abs().mean()
1282
+ else:
1283
+ scale_loss = 0.0 * scale_ests[0][0].mean()
1284
+ shift_loss = 0.0 * shift_ests[0][0].mean()
1285
+ scale_re = 1.0
1286
+ else:
1287
+ scale_loss = 0.0
1288
+ shift_loss = 0.0
1289
+
1290
+ if len(point_map_preds)>0:
1291
+ point_map_loss = 0.0
1292
+ for i in range(len(point_map_preds)):
1293
+ point_map_preds_i = point_map_preds[i]
1294
+ point_map_preds_i = rearrange(point_map_preds_i, 'b t c h w -> (b t) c h w', b=B, t=T)
1295
+ base_loss = ((self.pred_points - points_map_gt).norm(dim=-1) * self.metric_unc_org[:,0]).mean()
1296
+ point_map_loss_i = ((point_map_preds_i - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
1297
+ point_map_loss += point_map_loss_i
1298
+ # point_map_loss += ((point_map_org_train - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
1299
+ if scale_loss == 0.0:
1300
+ point_map_loss = 0*point_map_preds_i.sum()
1301
+ else:
1302
+ point_map_loss = 0.0
1303
+
1304
+ # camera loss
1305
+ cam_loss = 0.0
1306
+ dyn_loss = 0.0
1307
+ N_gt = gt_world_pts.shape[2]
1308
+
1309
+ # self supervised dynamic mask
1310
+ H_org, W_org = self.image_size[0], self.image_size[1]
1311
+ q_static_proj[torch.isnan(q_static_proj)] = -200
1312
+ in_view_mask = (q_static_proj[...,0]>0) & (q_static_proj[...,0]<W_org) & (q_static_proj[...,1]>0) & (q_static_proj[...,1]<H_org)
1313
+ dyn_mask_final = (((coord_predictions[0][-1] - q_static_proj))[...,:2].norm(dim=-1) * in_view_mask)
1314
+ dyn_mask_final = dyn_mask_final.sum(dim=1) / (in_view_mask.sum(dim=1) + 1e-2)
1315
+ dyn_mask_final = dyn_mask_final > 6
1316
+
1317
+ for iter_, cam_pred_i in enumerate(camera_predictions[0]):
1318
+ # points loss
1319
+ pts_i_world = world_tracks_predictions[0][iter_].view(B, T, -1, 3)
1320
+
1321
+ coords_xyz_i_world = coords_xyz_predictions[0][iter_].view(B, T, -1, 3)
1322
+ coords_i = coord_predictions[0][iter_].view(B, T, -1, 3)[..., :2]
1323
+ pts_i_world_refined = torch.einsum(
1324
+ "btij,btnj->btni",
1325
+ cam_gt[...,:3,:3],
1326
+ coords_xyz_i_world
1327
+ ) + cam_gt[...,None, :3,3] # B T N 3
1328
+
1329
+ # pts_i_world_refined = world_tracks_refined_predictions[0][iter_].view(B, T, -1, 3)
1330
+ pts_world = pts_i_world
1331
+ dyn_prob_i_logits = dynamic_prob_predictions[0][iter_].mean(dim=1)
1332
+ dyn_prob_i = torch.sigmoid(dyn_prob_i_logits).detach()
1333
+ mask = pts_world.norm(dim=-1) < 200
1334
+
1335
+ # general
1336
+ vis_i_logits = vis_predictions[0][iter_]
1337
+ vis_i = torch.sigmoid(vis_i_logits).detach()
1338
+ if mask_traj_gt is not None:
1339
+ try:
1340
+ N_gt_mask = mask_traj_gt.shape[1]
1341
+ align_loss = (gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1)[...,:N_gt_mask] * (mask_traj_gt.permute(0,2,1))
1342
+ visb_traj = (align_loss * vis_i[:,:,:N_gt_mask]).sum(dim=1)/vis_i[:,:,:N_gt_mask].sum(dim=1)
1343
+ except:
1344
+ import pdb; pdb.set_trace()
1345
+ else:
1346
+ visb_traj = ((gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1) * vis_i[:,:,:N_gt]).sum(dim=1)/vis_i[:,:,:N_gt].sum(dim=1)
1347
+
1348
+ # pts_loss = ((q_xyz_world[:,None,...] - pts_world)[:,:,:N_gt,:].norm(dim=-1)*(1-dyn_prob_i[:,None,:N_gt])) # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
1349
+ pts_loss = 0
1350
+ static_mask = ~dyn_mask_final # more strict for static points
1351
+ dyn_mask = dyn_mask_final
1352
+ pts_loss_refined = ((q_xyz_world[:,None,...] - pts_i_world_refined).norm(dim=-1)*static_mask[:,None,:]).sum()/static_mask.sum() # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
1353
+ vis_logits_final = vis_predictions[0][-1].detach()
1354
+ vis_final = torch.sigmoid(vis_logits_final)+0.2 > 0.5 # more strict for visible points
1355
+ dyn_vis_mask = dyn_mask*vis_final * (fix_cam_pts[...,2] > 0.1)
1356
+ pts_loss_dynamic = ((fix_cam_pts - coords_xyz_i_world).norm(dim=-1)*dyn_vis_mask[:,None,:]).sum()/dyn_vis_mask.sum()
1357
+
1358
+ # pts_loss_refined = 0
1359
+ if traj3d_gt is not None:
1360
+ tap_traj = (gt_world_pts[:,:-1,...] - gt_world_pts[:,1:,...]).norm(dim=-1).sum(dim=1)[...,:N_gt_mask]
1361
+ mask_dyn = tap_traj>0.5
1362
+ if mask_traj_gt.sum() > 0:
1363
+ dyn_loss_i = 20*balanced_binary_cross_entropy(dyn_prob_i_logits[:,:N_gt_mask][mask_traj_gt.squeeze(-1)],
1364
+ mask_dyn.float()[mask_traj_gt.squeeze(-1)])
1365
+ else:
1366
+ dyn_loss_i = 0
1367
+ else:
1368
+ dyn_loss_i = 10*balanced_binary_cross_entropy(dyn_prob_i_logits, dyn_mask_final.float())
1369
+
1370
+ dyn_loss += dyn_loss_i
1371
+
1372
+ # visible loss for out of view points
1373
+ vis_i_train = torch.sigmoid(vis_i_logits)
1374
+ out_of_view_mask = (coords_i[...,0]<0)|(coords_i[...,0]>self.image_size[1])|(coords_i[...,1]<0)|(coords_i[...,1]>self.image_size[0])
1375
+ vis_loss_out_of_view = vis_i_train[out_of_view_mask].sum() / out_of_view_mask.sum()
1376
+
1377
+
1378
+ if traj3d_gt is not None:
1379
+ world_pts_loss = (((gt_world_pts - pts_i_world_refined[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
1380
+ # world_pts_init_loss = (((gt_world_pts - pts_i_world[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
1381
+ else:
1382
+ world_pts_loss = 0
1383
+
1384
+ # cam regress
1385
+ t_err = (cam_pred_i[...,:3,3] - cam_gt[...,:3,3]).norm(dim=-1).sum()
1386
+
1387
+ # xyz loss
1388
+ in_view_mask_large = (q_static_proj[...,0]>-50) & (q_static_proj[...,0]<W_org+50) & (q_static_proj[...,1]>-50) & (q_static_proj[...,1]<H_org+50)
1389
+ static_vis_mask = (q_static_proj[...,2]>0.05).float() * static_mask[:,None,:] * in_view_mask_large
1390
+ xyz_loss = ((coord_predictions[0][iter_] - q_static_proj)).abs()[...,:2].norm(dim=-1)*static_vis_mask
1391
+ xyz_loss = xyz_loss.sum()/static_vis_mask.sum()
1392
+
1393
+ # visualize the q_static_proj
1394
+ # viser = Visualizer(save_dir=".", grayscale=True,
1395
+ # fps=10, pad_value=50, tracks_leave_trace=0)
1396
+ # video_vis_ = F.interpolate(video_vis.view(B*T,3,video_vis.shape[-2],video_vis.shape[-1]), (H_org, W_org), mode='bilinear', align_corners=False)
1397
+ # viser.visualize(video=video_vis_, tracks=q_static_proj[:,:,dyn_mask_final.squeeze(), :2], filename="test")
1398
+ # viser.visualize(video=video_vis_, tracks=coord_predictions[0][-1][:,:,dyn_mask_final.squeeze(), :2], filename="test_pred")
1399
+ # import pdb; pdb.set_trace()
1400
+
1401
+ # temporal loss
1402
+ t_loss = self.track_d2_loss(pts_i_world_refined, [1,2,3], dyn_prob=dyn_prob_i, mask=mask)
1403
+ R_err = (cam_pred_i[...,:3,:3] - cam_gt[...,:3,:3]).abs().sum(dim=-1).mean()
1404
+ if self.stage == 1:
1405
+ cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 20*pts_loss_refined + 10*xyz_loss + 20*pts_loss_dynamic + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
1406
+ elif self.stage == 3:
1407
+ cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
1408
+ else:
1409
+ cam_loss += 0*vis_loss_out_of_view
1410
+
1411
+ if (cam_loss > 20000)|(torch.isnan(cam_loss)):
1412
+ cam_loss = torch.zeros_like(cam_loss)
1413
+
1414
+
1415
+ if traj3d_gt is None:
1416
+ # ================ Condition 1: The self-supervised signals from the self-consistency ===================
1417
+ return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
1418
+
1419
+
1420
+ # ================ Condition 2: The supervision signal given by the ground truth trajectories ===================
1421
+ if (
1422
+ (torch.isnan(traj3d_gt).any()
1423
+ or traj3d_gt.abs().max() > 2000) and (custom_vid==False)
1424
+ ):
1425
+ return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
1426
+
1427
+
1428
+ vis_gts = [vis_gt.float()]
1429
+ invis_gts = [1-vis_gt.float()]
1430
+ traj_gts = [traj3d_gt]
1431
+ valids_gts = [valid_mask]
1432
+ seq_loss_all = sequence_loss(
1433
+ coord_predictions,
1434
+ traj_gts,
1435
+ valids_gts,
1436
+ vis=vis_gts,
1437
+ gamma=0.8,
1438
+ add_huber_loss=False,
1439
+ loss_only_for_visible=False if custom_vid==False else True,
1440
+ z_unc=z_unc,
1441
+ mask_traj_gt=mask_traj_gt
1442
+ )
1443
+
1444
+ confidence_loss = sequence_prob_loss(
1445
+ coord_predictions, confidence_predicitons, traj_gts, vis_gts
1446
+ )
1447
+
1448
+ seq_loss_xyz = sequence_loss_xyz(
1449
+ coords_xyz_predictions,
1450
+ traj_gts,
1451
+ valids_gts,
1452
+ intrs=intrs,
1453
+ vis=vis_gts,
1454
+ gamma=0.8,
1455
+ add_huber_loss=False,
1456
+ loss_only_for_visible=False,
1457
+ mask_traj_gt=mask_traj_gt
1458
+ )
1459
+
1460
+ # filter the blinking points
1461
+ mask_vis = vis_gts[0].clone() # B T N
1462
+ mask_vis[mask_vis==0] = -1
1463
+ blink_mask = mask_vis[:,:-1,:] * mask_vis[:,1:,:] # first derivative B (T-1) N
1464
+ mask_vis[:,:-1,:], mask_vis[:,-1,:] = (blink_mask == 1), 0
1465
+
1466
+ vis_loss = sequence_BCE_loss(vis_predictions, vis_gts, mask=[mask_vis])
1467
+
1468
+ track_loss_out = (seq_loss_all+2*seq_loss_xyz + cam_loss)
1469
+ if valid_only:
1470
+ vis_loss = 0.0*vis_loss
1471
+ if custom_vid:
1472
+ return seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all, 10*vis_loss, 0.0*seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all
1473
+
1474
+ return track_loss_out, confidence_loss, dyn_loss, 10*vis_loss, point_map_loss, scale_loss, shift_loss
1475
+
1476
+
1477
+
1478
+
models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from models.SpaTrackV2.utils.model_utils import sample_features5d, bilinear_sampler
11
+
12
+ from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
13
+ Mlp, BasicEncoder, EfficientUpdateFormer
14
+ )
15
+
16
+ torch.manual_seed(0)
17
+
18
+
19
+ def get_1d_sincos_pos_embed_from_grid(
20
+ embed_dim: int, pos: torch.Tensor
21
+ ) -> torch.Tensor:
22
+ """
23
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
24
+
25
+ Args:
26
+ - embed_dim: The embedding dimension.
27
+ - pos: The position to generate the embedding from.
28
+
29
+ Returns:
30
+ - emb: The generated 1D positional embedding.
31
+ """
32
+ assert embed_dim % 2 == 0
33
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
34
+ omega /= embed_dim / 2.0
35
+ omega = 1.0 / 10000**omega # (D/2,)
36
+
37
+ pos = pos.reshape(-1) # (M,)
38
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
39
+
40
+ emb_sin = torch.sin(out) # (M, D/2)
41
+ emb_cos = torch.cos(out) # (M, D/2)
42
+
43
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
44
+ return emb[None].float()
45
+
46
+ def posenc(x, min_deg, max_deg):
47
+ """Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
48
+ Instead of computing [sin(x), cos(x)], we use the trig identity
49
+ cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
50
+ Args:
51
+ x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
52
+ min_deg: int, the minimum (inclusive) degree of the encoding.
53
+ max_deg: int, the maximum (exclusive) degree of the encoding.
54
+ legacy_posenc_order: bool, keep the same ordering as the original tf code.
55
+ Returns:
56
+ encoded: torch.Tensor, encoded variables.
57
+ """
58
+ if min_deg == max_deg:
59
+ return x
60
+ scales = torch.tensor(
61
+ [2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
62
+ )
63
+
64
+ xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
65
+ four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
66
+ return torch.cat([x] + [four_feat], dim=-1)
67
+
68
+ class CoTrackerThreeBase(nn.Module):
69
+ def __init__(
70
+ self,
71
+ window_len=8,
72
+ stride=4,
73
+ corr_radius=3,
74
+ corr_levels=4,
75
+ num_virtual_tracks=64,
76
+ model_resolution=(384, 512),
77
+ add_space_attn=True,
78
+ linear_layer_for_vis_conf=True,
79
+ ):
80
+ super(CoTrackerThreeBase, self).__init__()
81
+ self.window_len = window_len
82
+ self.stride = stride
83
+ self.corr_radius = corr_radius
84
+ self.corr_levels = corr_levels
85
+ self.hidden_dim = 256
86
+ self.latent_dim = 128
87
+
88
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
89
+ self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=stride)
90
+
91
+ highres_dim = 128
92
+ lowres_dim = 256
93
+
94
+ self.num_virtual_tracks = num_virtual_tracks
95
+ self.model_resolution = model_resolution
96
+
97
+ self.input_dim = 1110
98
+
99
+ self.updateformer = EfficientUpdateFormer(
100
+ space_depth=3,
101
+ time_depth=3,
102
+ input_dim=self.input_dim,
103
+ hidden_size=384,
104
+ output_dim=4,
105
+ mlp_ratio=4.0,
106
+ num_virtual_tracks=num_virtual_tracks,
107
+ add_space_attn=add_space_attn,
108
+ linear_layer_for_vis_conf=linear_layer_for_vis_conf,
109
+ )
110
+ self.corr_mlp = Mlp(in_features=49 * 49, hidden_features=384, out_features=256)
111
+
112
+ time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
113
+ 1, window_len, 1
114
+ )
115
+
116
+ self.register_buffer(
117
+ "time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
118
+ )
119
+
120
+ def get_support_points(self, coords, r, reshape_back=True):
121
+ B, _, N, _ = coords.shape
122
+ device = coords.device
123
+ centroid_lvl = coords.reshape(B, N, 1, 1, 3)
124
+
125
+ dx = torch.linspace(-r, r, 2 * r + 1, device=device)
126
+ dy = torch.linspace(-r, r, 2 * r + 1, device=device)
127
+
128
+ xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
129
+ zgrid = torch.zeros_like(xgrid, device=device)
130
+ delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
131
+ delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
132
+ coords_lvl = centroid_lvl + delta_lvl
133
+
134
+ if reshape_back:
135
+ return coords_lvl.reshape(B, N, (2 * r + 1) ** 2, 3).permute(0, 2, 1, 3)
136
+ else:
137
+ return coords_lvl
138
+
139
+ def get_track_feat(self, fmaps, queried_frames, queried_coords, support_radius=0):
140
+
141
+ sample_frames = queried_frames[:, None, :, None]
142
+ sample_coords = torch.cat(
143
+ [
144
+ sample_frames,
145
+ queried_coords[:, None],
146
+ ],
147
+ dim=-1,
148
+ )
149
+ support_points = self.get_support_points(sample_coords, support_radius)
150
+ support_track_feats = sample_features5d(fmaps, support_points)
151
+ return (
152
+ support_track_feats[:, None, support_track_feats.shape[1] // 2],
153
+ support_track_feats,
154
+ )
155
+
156
+ def get_correlation_feat(self, fmaps, queried_coords, radius=None, padding_mode="border"):
157
+ B, T, D, H_, W_ = fmaps.shape
158
+ N = queried_coords.shape[1]
159
+ if radius is None:
160
+ r = self.corr_radius
161
+ else:
162
+ r = radius
163
+ sample_coords = torch.cat(
164
+ [torch.zeros_like(queried_coords[..., :1]), queried_coords], dim=-1
165
+ )[:, None]
166
+ support_points = self.get_support_points(sample_coords, r, reshape_back=False)
167
+ correlation_feat = bilinear_sampler(
168
+ fmaps.reshape(B * T, D, 1, H_, W_), support_points, padding_mode=padding_mode
169
+ )
170
+ return correlation_feat.view(B, T, D, N, (2 * r + 1), (2 * r + 1)).permute(
171
+ 0, 1, 3, 4, 5, 2
172
+ )
173
+
174
+ def interpolate_time_embed(self, x, t):
175
+ previous_dtype = x.dtype
176
+ T = self.time_emb.shape[1]
177
+
178
+ if t == T:
179
+ return self.time_emb
180
+
181
+ time_emb = self.time_emb.float()
182
+ time_emb = F.interpolate(
183
+ time_emb.permute(0, 2, 1), size=t, mode="linear"
184
+ ).permute(0, 2, 1)
185
+ return time_emb.to(previous_dtype)
186
+
187
+ class CoTrackerThreeOffline(CoTrackerThreeBase):
188
+ def __init__(self, **args):
189
+ super(CoTrackerThreeOffline, self).__init__(**args)
190
+
191
+ def forward(
192
+ self,
193
+ video,
194
+ queries,
195
+ iters=4,
196
+ is_train=False,
197
+ add_space_attn=True,
198
+ fmaps_chunk_size=200,
199
+ ):
200
+ """Predict tracks
201
+
202
+ Args:
203
+ video (FloatTensor[B, T, 3]): input videos.
204
+ queries (FloatTensor[B, N, 3]): point queries.
205
+ iters (int, optional): number of updates. Defaults to 4.
206
+ is_train (bool, optional): enables training mode. Defaults to False.
207
+ Returns:
208
+ - coords_predicted (FloatTensor[B, T, N, 2]):
209
+ - vis_predicted (FloatTensor[B, T, N]):
210
+ - train_data: `None` if `is_train` is false, otherwise:
211
+ - all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
212
+ - all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
213
+ - mask (BoolTensor[B, T, N]):
214
+ """
215
+
216
+ B, T, C, H, W = video.shape
217
+ device = queries.device
218
+ assert H % self.stride == 0 and W % self.stride == 0
219
+
220
+ B, N, __ = queries.shape
221
+ # B = batch size
222
+ # S_trimmed = actual number of frames in the window
223
+ # N = number of tracks
224
+ # C = color channels (3 for RGB)
225
+ # E = positional embedding size
226
+ # LRR = local receptive field radius
227
+ # D = dimension of the transformer input tokens
228
+
229
+ # video = B T C H W
230
+ # queries = B N 3
231
+ # coords_init = B T N 2
232
+ # vis_init = B T N 1
233
+
234
+ assert T >= 1 # A tracker needs at least two frames to track something
235
+
236
+ video = 2 * (video / 255.0) - 1.0
237
+ dtype = video.dtype
238
+ queried_frames = queries[:, :, 0].long()
239
+
240
+ queried_coords = queries[..., 1:3]
241
+ queried_coords = queried_coords / self.stride
242
+
243
+ # We store our predictions here
244
+ all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
245
+ [],
246
+ [],
247
+ [],
248
+ )
249
+ C_ = C
250
+ H4, W4 = H // self.stride, W // self.stride
251
+ # Compute convolutional features for the video or for the current chunk in case of online mode
252
+
253
+ if T > fmaps_chunk_size:
254
+ fmaps = []
255
+ for t in range(0, T, fmaps_chunk_size):
256
+ video_chunk = video[:, t : t + fmaps_chunk_size]
257
+ fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
258
+ T_chunk = video_chunk.shape[1]
259
+ C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
260
+ fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
261
+ fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
262
+ else:
263
+ fmaps = self.fnet(video.reshape(-1, C_, H, W))
264
+ fmaps = fmaps.permute(0, 2, 3, 1)
265
+ fmaps = fmaps / torch.sqrt(
266
+ torch.maximum(
267
+ torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
268
+ torch.tensor(1e-12, device=fmaps.device),
269
+ )
270
+ )
271
+ fmaps = fmaps.permute(0, 3, 1, 2).reshape(
272
+ B, -1, self.latent_dim, H // self.stride, W // self.stride
273
+ )
274
+ fmaps = fmaps.to(dtype)
275
+
276
+ # We compute track features
277
+ fmaps_pyramid = []
278
+ track_feat_pyramid = []
279
+ track_feat_support_pyramid = []
280
+ fmaps_pyramid.append(fmaps)
281
+ for i in range(self.corr_levels - 1):
282
+ fmaps_ = fmaps.reshape(
283
+ B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
284
+ )
285
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
286
+ fmaps = fmaps_.reshape(
287
+ B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
288
+ )
289
+ fmaps_pyramid.append(fmaps)
290
+
291
+ for i in range(self.corr_levels):
292
+ track_feat, track_feat_support = self.get_track_feat(
293
+ fmaps_pyramid[i],
294
+ queried_frames,
295
+ queried_coords / 2**i,
296
+ support_radius=self.corr_radius,
297
+ )
298
+ track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
299
+ track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
300
+
301
+ D_coords = 2
302
+
303
+ coord_preds, vis_preds, confidence_preds = [], [], []
304
+
305
+ vis = torch.zeros((B, T, N), device=device).float()
306
+ confidence = torch.zeros((B, T, N), device=device).float()
307
+ coords = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()
308
+
309
+ r = 2 * self.corr_radius + 1
310
+
311
+ for it in range(iters):
312
+ coords = coords.detach() # B T N 2
313
+ coords_init = coords.view(B * T, N, 2)
314
+ corr_embs = []
315
+ corr_feats = []
316
+ for i in range(self.corr_levels):
317
+ corr_feat = self.get_correlation_feat(
318
+ fmaps_pyramid[i], coords_init / 2**i
319
+ )
320
+ track_feat_support = (
321
+ track_feat_support_pyramid[i]
322
+ .view(B, 1, r, r, N, self.latent_dim)
323
+ .squeeze(1)
324
+ .permute(0, 3, 1, 2, 4)
325
+ )
326
+ corr_volume = torch.einsum(
327
+ "btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
328
+ )
329
+ corr_emb = self.corr_mlp(corr_volume.reshape(B * T * N, r * r * r * r))
330
+ corr_embs.append(corr_emb)
331
+ corr_embs = torch.cat(corr_embs, dim=-1)
332
+ corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
333
+
334
+ transformer_input = [vis[..., None], confidence[..., None], corr_embs]
335
+
336
+ rel_coords_forward = coords[:, :-1] - coords[:, 1:]
337
+ rel_coords_backward = coords[:, 1:] - coords[:, :-1]
338
+
339
+ rel_coords_forward = torch.nn.functional.pad(
340
+ rel_coords_forward, (0, 0, 0, 0, 0, 1)
341
+ )
342
+ rel_coords_backward = torch.nn.functional.pad(
343
+ rel_coords_backward, (0, 0, 0, 0, 1, 0)
344
+ )
345
+ scale = (
346
+ torch.tensor(
347
+ [self.model_resolution[1], self.model_resolution[0]],
348
+ device=coords.device,
349
+ )
350
+ / self.stride
351
+ )
352
+ rel_coords_forward = rel_coords_forward / scale
353
+ rel_coords_backward = rel_coords_backward / scale
354
+
355
+ rel_pos_emb_input = posenc(
356
+ torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
357
+ min_deg=0,
358
+ max_deg=10,
359
+ ) # batch, num_points, num_frames, 84
360
+ transformer_input.append(rel_pos_emb_input)
361
+
362
+ x = (
363
+ torch.cat(transformer_input, dim=-1)
364
+ .permute(0, 2, 1, 3)
365
+ .reshape(B * N, T, -1)
366
+ )
367
+
368
+ x = x + self.interpolate_time_embed(x, T)
369
+ x = x.view(B, N, T, -1) # (B N) T D -> B N T D
370
+
371
+ delta = self.updateformer(
372
+ x,
373
+ add_space_attn=add_space_attn,
374
+ )
375
+
376
+ delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
377
+ delta_vis = delta[..., D_coords].permute(0, 2, 1)
378
+ delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
379
+
380
+ vis = vis + delta_vis
381
+ confidence = confidence + delta_confidence
382
+
383
+ coords = coords + delta_coords
384
+ coords_append = coords.clone()
385
+ coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
386
+ coord_preds.append(coords_append)
387
+ vis_preds.append(torch.sigmoid(vis))
388
+ confidence_preds.append(torch.sigmoid(confidence))
389
+
390
+ if is_train:
391
+ all_coords_predictions.append([coord[..., :2] for coord in coord_preds])
392
+ all_vis_predictions.append(vis_preds)
393
+ all_confidence_predictions.append(confidence_preds)
394
+
395
+ if is_train:
396
+ train_data = (
397
+ all_coords_predictions,
398
+ all_vis_predictions,
399
+ all_confidence_predictions,
400
+ torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
401
+ )
402
+ else:
403
+ train_data = None
404
+
405
+ return coord_preds[-1][..., :2], vis_preds[-1], confidence_preds[-1], train_data
406
+
407
+
408
+ if __name__ == "__main__":
409
+ cotrack_cktp = "/data0/xyx/scaled_offline.pth"
410
+ cotracker = CoTrackerThreeOffline(
411
+ stride=4, corr_radius=3, window_len=60
412
+ )
413
+ with open(cotrack_cktp, "rb") as f:
414
+ state_dict = torch.load(f, map_location="cpu")
415
+ if "model" in state_dict:
416
+ state_dict = state_dict["model"]
417
+ cotracker.load_state_dict(state_dict)
418
+ import pdb; pdb.set_trace()
models/SpaTrackV2/models/tracker3D/co_tracker/utils.py ADDED
@@ -0,0 +1,929 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from functools import partial
5
+ from typing import Callable, List
6
+ import collections
7
+ from torch import Tensor
8
+ from itertools import repeat
9
+ from models.SpaTrackV2.utils.model_utils import bilinear_sampler
10
+ from models.SpaTrackV2.models.blocks import CrossAttnBlock as CrossAttnBlock_F
11
+ from torch.nn.functional import scaled_dot_product_attention
12
+ from torch.nn.attention import sdpa_kernel, SDPBackend
13
+ # import flash_attn
14
+ EPS = 1e-6
15
+
16
+
17
+ class ResidualBlock(nn.Module):
18
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1):
19
+ super(ResidualBlock, self).__init__()
20
+
21
+ self.conv1 = nn.Conv2d(
22
+ in_planes,
23
+ planes,
24
+ kernel_size=3,
25
+ padding=1,
26
+ stride=stride,
27
+ padding_mode="zeros",
28
+ )
29
+ self.conv2 = nn.Conv2d(
30
+ planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
31
+ )
32
+ self.relu = nn.ReLU(inplace=True)
33
+
34
+ num_groups = planes // 8
35
+
36
+ if norm_fn == "group":
37
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
38
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
39
+ if not stride == 1:
40
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
41
+
42
+ elif norm_fn == "batch":
43
+ self.norm1 = nn.BatchNorm2d(planes)
44
+ self.norm2 = nn.BatchNorm2d(planes)
45
+ if not stride == 1:
46
+ self.norm3 = nn.BatchNorm2d(planes)
47
+
48
+ elif norm_fn == "instance":
49
+ self.norm1 = nn.InstanceNorm2d(planes)
50
+ self.norm2 = nn.InstanceNorm2d(planes)
51
+ if not stride == 1:
52
+ self.norm3 = nn.InstanceNorm2d(planes)
53
+
54
+ elif norm_fn == "none":
55
+ self.norm1 = nn.Sequential()
56
+ self.norm2 = nn.Sequential()
57
+ if not stride == 1:
58
+ self.norm3 = nn.Sequential()
59
+
60
+ if stride == 1:
61
+ self.downsample = None
62
+
63
+ else:
64
+ self.downsample = nn.Sequential(
65
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
66
+ )
67
+
68
+ def forward(self, x):
69
+ y = x
70
+ y = self.relu(self.norm1(self.conv1(y)))
71
+ y = self.relu(self.norm2(self.conv2(y)))
72
+
73
+ if self.downsample is not None:
74
+ x = self.downsample(x)
75
+
76
+ return self.relu(x + y)
77
+
78
+ def reduce_masked_mean(input, mask, dim=None, keepdim=False):
79
+ r"""Masked mean
80
+
81
+ `reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
82
+ over a mask :attr:`mask`, returning
83
+
84
+ .. math::
85
+ \text{output} =
86
+ \frac
87
+ {\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
88
+ {\epsilon + \sum_{i=1}^N \text{mask}_i}
89
+
90
+ where :math:`N` is the number of elements in :attr:`input` and
91
+ :attr:`mask`, and :math:`\epsilon` is a small constant to avoid
92
+ division by zero.
93
+
94
+ `reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
95
+ :attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
96
+ Optionally, the dimension can be kept in the output by setting
97
+ :attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
98
+ the same dimension as :attr:`input`.
99
+
100
+ The interface is similar to `torch.mean()`.
101
+
102
+ Args:
103
+ inout (Tensor): input tensor.
104
+ mask (Tensor): mask.
105
+ dim (int, optional): Dimension to sum over. Defaults to None.
106
+ keepdim (bool, optional): Keep the summed dimension. Defaults to False.
107
+
108
+ Returns:
109
+ Tensor: mean tensor.
110
+ """
111
+
112
+ mask = mask.expand_as(input)
113
+
114
+ prod = input * mask
115
+
116
+ if dim is None:
117
+ numer = torch.sum(prod)
118
+ denom = torch.sum(mask)
119
+ else:
120
+ numer = torch.sum(prod, dim=dim, keepdim=keepdim)
121
+ denom = torch.sum(mask, dim=dim, keepdim=keepdim)
122
+
123
+ mean = numer / (EPS + denom)
124
+ return mean
125
+
126
+ class GeometryEncoder(nn.Module):
127
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
128
+ super(GeometryEncoder, self).__init__()
129
+ self.stride = stride
130
+ self.norm_fn = "instance"
131
+ self.in_planes = output_dim // 2
132
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
133
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
134
+ self.conv1 = nn.Conv2d(
135
+ input_dim,
136
+ self.in_planes,
137
+ kernel_size=7,
138
+ stride=2,
139
+ padding=3,
140
+ padding_mode="zeros",
141
+ )
142
+ self.relu1 = nn.ReLU(inplace=True)
143
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
144
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
145
+
146
+ self.conv2 = nn.Conv2d(
147
+ output_dim * 5 // 4,
148
+ output_dim,
149
+ kernel_size=3,
150
+ padding=1,
151
+ padding_mode="zeros",
152
+ )
153
+ self.relu2 = nn.ReLU(inplace=True)
154
+ self.conv3 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
155
+ for m in self.modules():
156
+ if isinstance(m, nn.Conv2d):
157
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
158
+ elif isinstance(m, (nn.InstanceNorm2d)):
159
+ if m.weight is not None:
160
+ nn.init.constant_(m.weight, 1)
161
+ if m.bias is not None:
162
+ nn.init.constant_(m.bias, 0)
163
+
164
+ def _make_layer(self, dim, stride=1):
165
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
166
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
167
+ layers = (layer1, layer2)
168
+
169
+ self.in_planes = dim
170
+ return nn.Sequential(*layers)
171
+
172
+ def forward(self, x):
173
+ _, _, H, W = x.shape
174
+ x = self.conv1(x)
175
+ x = self.norm1(x)
176
+ x = self.relu1(x)
177
+ a = self.layer1(x)
178
+ b = self.layer2(a)
179
+ def _bilinear_intepolate(x):
180
+ return F.interpolate(
181
+ x,
182
+ (H // self.stride, W // self.stride),
183
+ mode="bilinear",
184
+ align_corners=True,
185
+ )
186
+ a = _bilinear_intepolate(a)
187
+ b = _bilinear_intepolate(b)
188
+ x = self.conv2(torch.cat([a, b], dim=1))
189
+ x = self.norm2(x)
190
+ x = self.relu2(x)
191
+ x = self.conv3(x)
192
+ return x
193
+
194
+ class BasicEncoder(nn.Module):
195
+ def __init__(self, input_dim=3, output_dim=128, stride=4):
196
+ super(BasicEncoder, self).__init__()
197
+ self.stride = stride
198
+ self.norm_fn = "instance"
199
+ self.in_planes = output_dim // 2
200
+ self.norm1 = nn.InstanceNorm2d(self.in_planes)
201
+ self.norm2 = nn.InstanceNorm2d(output_dim * 2)
202
+
203
+ self.conv1 = nn.Conv2d(
204
+ input_dim,
205
+ self.in_planes,
206
+ kernel_size=7,
207
+ stride=2,
208
+ padding=3,
209
+ padding_mode="zeros",
210
+ )
211
+ self.relu1 = nn.ReLU(inplace=True)
212
+ self.layer1 = self._make_layer(output_dim // 2, stride=1)
213
+ self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
214
+ self.layer3 = self._make_layer(output_dim, stride=2)
215
+ self.layer4 = self._make_layer(output_dim, stride=2)
216
+
217
+ self.conv2 = nn.Conv2d(
218
+ output_dim * 3 + output_dim // 4,
219
+ output_dim * 2,
220
+ kernel_size=3,
221
+ padding=1,
222
+ padding_mode="zeros",
223
+ )
224
+ self.relu2 = nn.ReLU(inplace=True)
225
+ self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
226
+ for m in self.modules():
227
+ if isinstance(m, nn.Conv2d):
228
+ nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
229
+ elif isinstance(m, (nn.InstanceNorm2d)):
230
+ if m.weight is not None:
231
+ nn.init.constant_(m.weight, 1)
232
+ if m.bias is not None:
233
+ nn.init.constant_(m.bias, 0)
234
+
235
+ def _make_layer(self, dim, stride=1):
236
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
238
+ layers = (layer1, layer2)
239
+
240
+ self.in_planes = dim
241
+ return nn.Sequential(*layers)
242
+
243
+ def forward(self, x):
244
+ _, _, H, W = x.shape
245
+
246
+ x = self.conv1(x)
247
+ x = self.norm1(x)
248
+ x = self.relu1(x)
249
+
250
+ a = self.layer1(x)
251
+ b = self.layer2(a)
252
+ c = self.layer3(b)
253
+ d = self.layer4(c)
254
+
255
+ def _bilinear_intepolate(x):
256
+ return F.interpolate(
257
+ x,
258
+ (H // self.stride, W // self.stride),
259
+ mode="bilinear",
260
+ align_corners=True,
261
+ )
262
+
263
+ a = _bilinear_intepolate(a)
264
+ b = _bilinear_intepolate(b)
265
+ c = _bilinear_intepolate(c)
266
+ d = _bilinear_intepolate(d)
267
+
268
+ x = self.conv2(torch.cat([a, b, c, d], dim=1))
269
+ x = self.norm2(x)
270
+ x = self.relu2(x)
271
+ x = self.conv3(x)
272
+ return x
273
+
274
+ # From PyTorch internals
275
+ def _ntuple(n):
276
+ def parse(x):
277
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
278
+ return tuple(x)
279
+ return tuple(repeat(x, n))
280
+
281
+ return parse
282
+
283
+
284
+ def exists(val):
285
+ return val is not None
286
+
287
+
288
+ def default(val, d):
289
+ return val if exists(val) else d
290
+
291
+
292
+ to_2tuple = _ntuple(2)
293
+
294
+
295
+ class Mlp(nn.Module):
296
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
297
+
298
+ def __init__(
299
+ self,
300
+ in_features,
301
+ hidden_features=None,
302
+ out_features=None,
303
+ act_layer=nn.GELU,
304
+ norm_layer=None,
305
+ bias=True,
306
+ drop=0.0,
307
+ use_conv=False,
308
+ ):
309
+ super().__init__()
310
+ out_features = out_features or in_features
311
+ hidden_features = hidden_features or in_features
312
+ bias = to_2tuple(bias)
313
+ drop_probs = to_2tuple(drop)
314
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
315
+
316
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
317
+ self.act = act_layer()
318
+ self.drop1 = nn.Dropout(drop_probs[0])
319
+ self.norm = (
320
+ norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
321
+ )
322
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
323
+ self.drop2 = nn.Dropout(drop_probs[1])
324
+
325
+ def forward(self, x):
326
+ x = self.fc1(x)
327
+ x = self.act(x)
328
+ x = self.drop1(x)
329
+ x = self.fc2(x)
330
+ x = self.drop2(x)
331
+ return x
332
+
333
+
334
+ class Attention(nn.Module):
335
+ def __init__(
336
+ self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
337
+ ):
338
+ super().__init__()
339
+ inner_dim = dim_head * num_heads
340
+ self.inner_dim = inner_dim
341
+ context_dim = default(context_dim, query_dim)
342
+ self.scale = dim_head**-0.5
343
+ self.heads = num_heads
344
+
345
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
346
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
347
+ self.to_out = nn.Linear(inner_dim, query_dim)
348
+
349
+ def forward(self, x, context=None, attn_bias=None, flash=True):
350
+ B, N1, C = x.shape
351
+ h = self.heads
352
+
353
+ q = self.to_q(x).reshape(B, N1, h, self.inner_dim // h).permute(0, 2, 1, 3)
354
+ context = default(context, x)
355
+ k, v = self.to_kv(context).chunk(2, dim=-1)
356
+
357
+ N2 = context.shape[1]
358
+ k = k.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
359
+ v = v.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
360
+
361
+ if (
362
+ (N1 < 64 and N2 < 64) or
363
+ (B > 1e4) or
364
+ (q.shape[1] != k.shape[1]) or
365
+ (q.shape[1] % k.shape[1] != 0)
366
+ ):
367
+ flash = False
368
+
369
+
370
+ if flash == False:
371
+ sim = (q @ k.transpose(-2, -1)) * self.scale
372
+ if attn_bias is not None:
373
+ sim = sim + attn_bias
374
+ if sim.abs().max() > 1e2:
375
+ import pdb; pdb.set_trace()
376
+ attn = sim.softmax(dim=-1)
377
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, self.inner_dim)
378
+ else:
379
+
380
+ input_args = [x.contiguous() for x in [q, k, v]]
381
+ try:
382
+ # print(f"q.shape: {q.shape}, dtype: {q.dtype}, device: {q.device}")
383
+ # print(f"Flash SDP available: {torch.backends.cuda.flash_sdp_enabled()}")
384
+ # print(f"Flash SDP allowed: {torch.backends.cuda.enable_flash_sdp}")
385
+ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
386
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
387
+ except Exception as e:
388
+ print(e)
389
+
390
+ if self.to_out.bias.dtype != x.dtype:
391
+ x = x.to(self.to_out.bias.dtype)
392
+
393
+ return self.to_out(x)
394
+
395
+ class CrossAttnBlock(nn.Module):
396
+ def __init__(
397
+ self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
398
+ ):
399
+ super().__init__()
400
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
401
+ self.norm_context = nn.LayerNorm(context_dim)
402
+ self.cross_attn = Attention(
403
+ hidden_size,
404
+ context_dim=context_dim,
405
+ num_heads=num_heads,
406
+ qkv_bias=True,
407
+ **block_kwargs
408
+ )
409
+
410
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
411
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
412
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
413
+ self.mlp = Mlp(
414
+ in_features=hidden_size,
415
+ hidden_features=mlp_hidden_dim,
416
+ act_layer=approx_gelu,
417
+ drop=0,
418
+ )
419
+
420
+ def forward(self, x, context, mask=None):
421
+ attn_bias = None
422
+ if mask is not None:
423
+ if mask.shape[1] == x.shape[1]:
424
+ mask = mask[:, None, :, None].expand(
425
+ -1, self.cross_attn.heads, -1, context.shape[1]
426
+ )
427
+ else:
428
+ mask = mask[:, None, None].expand(
429
+ -1, self.cross_attn.heads, x.shape[1], -1
430
+ )
431
+
432
+ max_neg_value = -torch.finfo(x.dtype).max
433
+ attn_bias = (~mask) * max_neg_value
434
+ x = x + self.cross_attn(
435
+ self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
436
+ )
437
+ x = x + self.mlp(self.norm2(x))
438
+ return x
439
+
440
+ class AttnBlock(nn.Module):
441
+ def __init__(
442
+ self,
443
+ hidden_size,
444
+ num_heads,
445
+ attn_class: Callable[..., nn.Module] = Attention,
446
+ mlp_ratio=4.0,
447
+ **block_kwargs
448
+ ):
449
+ super().__init__()
450
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
451
+ self.attn = attn_class(
452
+ hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
453
+ )
454
+
455
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
456
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
457
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
458
+ self.mlp = Mlp(
459
+ in_features=hidden_size,
460
+ hidden_features=mlp_hidden_dim,
461
+ act_layer=approx_gelu,
462
+ drop=0,
463
+ )
464
+
465
+ def forward(self, x, mask=None):
466
+ attn_bias = mask
467
+ if mask is not None:
468
+ mask = (
469
+ (mask[:, None] * mask[:, :, None])
470
+ .unsqueeze(1)
471
+ .expand(-1, self.attn.num_heads, -1, -1)
472
+ )
473
+ max_neg_value = -torch.finfo(x.dtype).max
474
+ attn_bias = (~mask) * max_neg_value
475
+ x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
476
+ x = x + self.mlp(self.norm2(x))
477
+ return x
478
+
479
+ class EfficientUpdateFormer(nn.Module):
480
+ """
481
+ Transformer model that updates track estimates.
482
+ """
483
+
484
+ def __init__(
485
+ self,
486
+ space_depth=6,
487
+ time_depth=6,
488
+ input_dim=320,
489
+ hidden_size=384,
490
+ num_heads=8,
491
+ output_dim=130,
492
+ mlp_ratio=4.0,
493
+ num_virtual_tracks=64,
494
+ add_space_attn=True,
495
+ linear_layer_for_vis_conf=False,
496
+ patch_feat=False,
497
+ patch_dim=128,
498
+ ):
499
+ super().__init__()
500
+ self.out_channels = 2
501
+ self.num_heads = num_heads
502
+ self.hidden_size = hidden_size
503
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
504
+ if linear_layer_for_vis_conf:
505
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
506
+ self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
507
+ else:
508
+ self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
509
+
510
+ if patch_feat==False:
511
+ self.virual_tracks = nn.Parameter(
512
+ torch.randn(1, num_virtual_tracks, 1, hidden_size)
513
+ )
514
+ self.num_virtual_tracks = num_virtual_tracks
515
+ else:
516
+ self.patch_proj = nn.Linear(patch_dim, hidden_size, bias=True)
517
+
518
+ self.add_space_attn = add_space_attn
519
+ self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
520
+ self.time_blocks = nn.ModuleList(
521
+ [
522
+ AttnBlock(
523
+ hidden_size,
524
+ num_heads,
525
+ mlp_ratio=mlp_ratio,
526
+ attn_class=Attention,
527
+ )
528
+ for _ in range(time_depth)
529
+ ]
530
+ )
531
+
532
+ if add_space_attn:
533
+ self.space_virtual_blocks = nn.ModuleList(
534
+ [
535
+ AttnBlock(
536
+ hidden_size,
537
+ num_heads,
538
+ mlp_ratio=mlp_ratio,
539
+ attn_class=Attention,
540
+ )
541
+ for _ in range(space_depth)
542
+ ]
543
+ )
544
+ self.space_point2virtual_blocks = nn.ModuleList(
545
+ [
546
+ CrossAttnBlock(
547
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
548
+ )
549
+ for _ in range(space_depth)
550
+ ]
551
+ )
552
+ self.space_virtual2point_blocks = nn.ModuleList(
553
+ [
554
+ CrossAttnBlock(
555
+ hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
556
+ )
557
+ for _ in range(space_depth)
558
+ ]
559
+ )
560
+ assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
561
+ self.initialize_weights()
562
+
563
+ def initialize_weights(self):
564
+ def _basic_init(module):
565
+ if isinstance(module, nn.Linear):
566
+ torch.nn.init.xavier_uniform_(module.weight)
567
+ if module.bias is not None:
568
+ nn.init.constant_(module.bias, 0)
569
+ torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
570
+ if self.linear_layer_for_vis_conf:
571
+ torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
572
+
573
+ def _trunc_init(module):
574
+ """ViT weight initialization, original timm impl (for reproducibility)"""
575
+ if isinstance(module, nn.Linear):
576
+ torch.nn.init.trunc_normal_(module.weight, std=0.02)
577
+ if module.bias is not None:
578
+ nn.init.zeros_(module.bias)
579
+
580
+ self.apply(_basic_init)
581
+
582
+ def forward(self, input_tensor, mask=None, add_space_attn=True, patch_feat=None):
583
+ tokens = self.input_transform(input_tensor)
584
+
585
+ B, _, T, _ = tokens.shape
586
+ if patch_feat is None:
587
+ virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
588
+ tokens = torch.cat([tokens, virtual_tokens], dim=1)
589
+ else:
590
+ patch_feat = self.patch_proj(patch_feat.detach())
591
+ tokens = torch.cat([tokens, patch_feat], dim=1)
592
+ self.num_virtual_tracks = patch_feat.shape[1]
593
+
594
+ _, N, _, _ = tokens.shape
595
+ j = 0
596
+ layers = []
597
+ for i in range(len(self.time_blocks)):
598
+ time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
599
+ time_tokens = torch.utils.checkpoint.checkpoint(
600
+ self.time_blocks[i],
601
+ time_tokens
602
+ )
603
+
604
+ tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
605
+ if (
606
+ add_space_attn
607
+ and hasattr(self, "space_virtual_blocks")
608
+ and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
609
+ ):
610
+ space_tokens = (
611
+ tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
612
+ ) # B N T C -> (B T) N C
613
+
614
+ point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
615
+ virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
616
+
617
+ virtual_tokens = torch.utils.checkpoint.checkpoint(
618
+ self.space_virtual2point_blocks[j],
619
+ virtual_tokens, point_tokens, mask
620
+ )
621
+
622
+ virtual_tokens = torch.utils.checkpoint.checkpoint(
623
+ self.space_virtual_blocks[j],
624
+ virtual_tokens
625
+ )
626
+
627
+ point_tokens = torch.utils.checkpoint.checkpoint(
628
+ self.space_point2virtual_blocks[j],
629
+ point_tokens, virtual_tokens, mask
630
+ )
631
+
632
+ space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
633
+ tokens = space_tokens.view(B, T, N, -1).permute(
634
+ 0, 2, 1, 3
635
+ ) # (B T) N C -> B N T C
636
+ j += 1
637
+ tokens = tokens[:, : N - self.num_virtual_tracks]
638
+
639
+ flow = self.flow_head(tokens)
640
+ if self.linear_layer_for_vis_conf:
641
+ vis_conf = self.vis_conf_head(tokens)
642
+ flow = torch.cat([flow, vis_conf], dim=-1)
643
+
644
+ return flow
645
+
646
+ def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
647
+ probs = torch.sigmoid(logits)
648
+ ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
649
+ p_t = probs * targets + (1 - probs) * (1 - targets)
650
+ loss = alpha * (1 - p_t) ** gamma * ce_loss
651
+ return loss.mean()
652
+
653
+ def balanced_binary_cross_entropy(logits, targets, balance_weight=1.0, eps=1e-6, reduction="mean", pos_bias=0.0, mask=None):
654
+ """
655
+ logits: Tensor of arbitrary shape
656
+ targets: same shape as logits
657
+ balance_weight: scaling the loss
658
+ reduction: 'mean', 'sum', or 'none'
659
+ """
660
+ targets = targets.float()
661
+ positive = (targets == 1).float().sum()
662
+ total = targets.numel()
663
+ positive_ratio = positive / (total + eps)
664
+
665
+ pos_weight = (1 - positive_ratio) / (positive_ratio + eps)
666
+ pos_weight = pos_weight.clamp(min=0.1, max=10.0)
667
+ loss = F.binary_cross_entropy_with_logits(
668
+ logits,
669
+ targets,
670
+ pos_weight=pos_weight+pos_bias,
671
+ reduction=reduction
672
+ )
673
+ if mask is not None:
674
+ loss = (loss * mask).sum() / (mask.sum() + eps)
675
+ return balance_weight * loss
676
+
677
+ def sequence_loss(
678
+ flow_preds,
679
+ flow_gt,
680
+ valids,
681
+ vis=None,
682
+ gamma=0.8,
683
+ add_huber_loss=False,
684
+ loss_only_for_visible=False,
685
+ depth_sample=None,
686
+ z_unc=None,
687
+ mask_traj_gt=None
688
+ ):
689
+ """Loss function defined over sequence of flow predictions"""
690
+ total_flow_loss = 0.0
691
+ for j in range(len(flow_gt)):
692
+ B, S, N, D = flow_gt[j].shape
693
+ B, S2, N = valids[j].shape
694
+ assert S == S2
695
+ n_predictions = len(flow_preds[j])
696
+ flow_loss = 0.0
697
+ for i in range(n_predictions):
698
+ i_weight = gamma ** (n_predictions - i - 1)
699
+ flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
700
+ if flow_pred.shape[-1] == 3:
701
+ flow_pred[...,2] = flow_pred[...,2]
702
+ if add_huber_loss:
703
+ i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
704
+ else:
705
+ if flow_gt[j][...,2].abs().max() != 0:
706
+ track_z_loss = (flow_pred- flow_gt[j])[...,2].abs().mean()
707
+ if mask_traj_gt is not None:
708
+ track_z_loss = ((flow_pred- flow_gt[j])[...,2].abs() * mask_traj_gt.permute(0,2,1)).sum() / (mask_traj_gt.sum(dim=1)+1e-6)
709
+ else:
710
+ track_z_loss = 0
711
+ i_loss = (flow_pred[...,:2] - flow_gt[j][...,:2]).abs() # B, S, N, 2
712
+ # print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
713
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
714
+ valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
715
+ valid_ = valid_ * (flow_gt[j][...,:2].norm(dim=-1) > 0).float()
716
+ if loss_only_for_visible:
717
+ valid_ = valid_ * vis[j]
718
+ # print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
719
+ flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_) + track_z_loss + 10*reduce_masked_mean(i_loss, valid_* vis[j]))
720
+ # if flow_loss > 5e2:
721
+ # import pdb; pdb.set_trace()
722
+ flow_loss = flow_loss / n_predictions
723
+ total_flow_loss += flow_loss
724
+ return total_flow_loss / len(flow_gt)
725
+
726
+ def sequence_loss_xyz(
727
+ flow_preds,
728
+ flow_gt,
729
+ valids,
730
+ intrs,
731
+ vis=None,
732
+ gamma=0.8,
733
+ add_huber_loss=False,
734
+ loss_only_for_visible=False,
735
+ mask_traj_gt=None
736
+ ):
737
+ """Loss function defined over sequence of flow predictions"""
738
+ total_flow_loss = 0.0
739
+ for j in range(len(flow_gt)):
740
+ B, S, N, D = flow_gt[j].shape
741
+ B, S2, N = valids[j].shape
742
+ assert S == S2
743
+ n_predictions = len(flow_preds[j])
744
+ flow_loss = 0.0
745
+ for i in range(n_predictions):
746
+ i_weight = gamma ** (n_predictions - i - 1)
747
+ flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
748
+ flow_gt_ = flow_gt[j]
749
+ flow_gt_one = torch.cat([flow_gt_[...,:2], torch.ones_like(flow_gt_[:,:,:,:1])], dim=-1)
750
+ flow_gt_cam = torch.einsum('btsc,btnc->btns', torch.inverse(intrs), flow_gt_one)
751
+ flow_gt_cam *= flow_gt_[...,2:3].abs()
752
+ flow_gt_cam[...,2] *= torch.sign(flow_gt_cam[...,2])
753
+
754
+ if add_huber_loss:
755
+ i_loss = huber_loss(flow_pred, flow_gt_cam, delta=6.0)
756
+ else:
757
+ i_loss = (flow_pred- flow_gt_cam).norm(dim=-1,keepdim=True) # B, S, N, 2
758
+
759
+ # print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
760
+ i_loss = torch.mean(i_loss, dim=3) # B, S, N
761
+ valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
762
+ if loss_only_for_visible:
763
+ valid_ = valid_ * vis[j]
764
+ # print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
765
+ flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_)) * 1000
766
+ # if flow_loss > 5e2:
767
+ # import pdb; pdb.set_trace()
768
+ flow_loss = flow_loss / n_predictions
769
+ total_flow_loss += flow_loss
770
+ return total_flow_loss / len(flow_gt)
771
+
772
+ def huber_loss(x, y, delta=1.0):
773
+ """Calculate element-wise Huber loss between x and y"""
774
+ diff = x - y
775
+ abs_diff = diff.abs()
776
+ flag = (abs_diff <= delta).float()
777
+ return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
778
+
779
+
780
+ def sequence_BCE_loss(vis_preds, vis_gts, mask=None):
781
+ total_bce_loss = 0.0
782
+ for j in range(len(vis_preds)):
783
+ n_predictions = len(vis_preds[j])
784
+ bce_loss = 0.0
785
+ for i in range(n_predictions):
786
+ N_gt = vis_gts[j].shape[-1]
787
+ if mask is not None:
788
+ vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j], mask=mask[j], reduction="none")
789
+ else:
790
+ vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j]) + focal_loss(vis_preds[j][i][...,:N_gt], vis_gts[j])
791
+ # print(vis_loss, ((torch.sigmoid(vis_preds[j][i][...,:N_gt])>0.5).float() - vis_gts[j]).abs().sum())
792
+ bce_loss += vis_loss
793
+ bce_loss = bce_loss / n_predictions
794
+ total_bce_loss += bce_loss
795
+ return total_bce_loss / len(vis_preds)
796
+
797
+
798
+ def sequence_prob_loss(
799
+ tracks: torch.Tensor,
800
+ confidence: torch.Tensor,
801
+ target_points: torch.Tensor,
802
+ visibility: torch.Tensor,
803
+ expected_dist_thresh: float = 12.0,
804
+ ):
805
+ """Loss for classifying if a point is within pixel threshold of its target."""
806
+ # Points with an error larger than 12 pixels are likely to be useless; marking
807
+ # them as occluded will actually improve Jaccard metrics and give
808
+ # qualitatively better results.
809
+ total_logprob_loss = 0.0
810
+ for j in range(len(tracks)):
811
+ n_predictions = len(tracks[j])
812
+ logprob_loss = 0.0
813
+ for i in range(n_predictions):
814
+ N_gt = target_points[j].shape[2]
815
+ err = torch.sum((tracks[j][i].detach()[:,:,:N_gt,:2] - target_points[j][...,:2]) ** 2, dim=-1)
816
+ valid = (err <= expected_dist_thresh**2).float()
817
+ logprob = balanced_binary_cross_entropy(confidence[j][i][...,:N_gt], valid, reduction="none")
818
+ logprob *= visibility[j]
819
+ logprob = torch.mean(logprob, dim=[1, 2])
820
+ logprob_loss += logprob
821
+ logprob_loss = logprob_loss / n_predictions
822
+ total_logprob_loss += logprob_loss
823
+ return total_logprob_loss / len(tracks)
824
+
825
+
826
+ def sequence_dyn_prob_loss(
827
+ tracks: torch.Tensor,
828
+ confidence: torch.Tensor,
829
+ target_points: torch.Tensor,
830
+ visibility: torch.Tensor,
831
+ expected_dist_thresh: float = 6.0,
832
+ ):
833
+ """Loss for classifying if a point is within pixel threshold of its target."""
834
+ # Points with an error larger than 12 pixels are likely to be useless; marking
835
+ # them as occluded will actually improve Jaccard metrics and give
836
+ # qualitatively better results.
837
+ total_logprob_loss = 0.0
838
+ for j in range(len(tracks)):
839
+ n_predictions = len(tracks[j])
840
+ logprob_loss = 0.0
841
+ for i in range(n_predictions):
842
+ err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
843
+ valid = (err <= expected_dist_thresh**2).float()
844
+ valid = (valid.sum(dim=1) > 0).float()
845
+ logprob = balanced_binary_cross_entropy(confidence[j][i].mean(dim=1), valid, reduction="none")
846
+ # logprob *= visibility[j]
847
+ logprob = torch.mean(logprob, dim=[0, 1])
848
+ logprob_loss += logprob
849
+ logprob_loss = logprob_loss / n_predictions
850
+ total_logprob_loss += logprob_loss
851
+ return total_logprob_loss / len(tracks)
852
+
853
+
854
+ def masked_mean(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
855
+ if mask is None:
856
+ return data.mean(dim=dim, keepdim=True)
857
+ mask = mask.float()
858
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
859
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
860
+ mask_sum, min=1.0
861
+ )
862
+ return mask_mean
863
+
864
+
865
+ def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
866
+ if mask is None:
867
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
868
+ mask = mask.float()
869
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
870
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
871
+ mask_sum, min=1.0
872
+ )
873
+ mask_var = torch.sum(
874
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
875
+ ) / torch.clamp(mask_sum, min=1.0)
876
+ return mask_mean.squeeze(dim), mask_var.squeeze(dim)
877
+
878
+ class NeighborTransformer(nn.Module):
879
+ def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
880
+ super().__init__()
881
+ self.dim = dim
882
+ self.output_token_1 = nn.Parameter(torch.randn(1, dim))
883
+ self.output_token_2 = nn.Parameter(torch.randn(1, dim))
884
+ self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
885
+ self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
886
+ self.aggr1 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
887
+ self.aggr2 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
888
+
889
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
890
+ from einops import rearrange, repeat
891
+ import torch.utils.checkpoint as checkpoint
892
+
893
+ assert len (x.shape) == 3, "x should be of shape (B, N, D)"
894
+ assert len (y.shape) == 3, "y should be of shape (B, N, D)"
895
+
896
+ # not work so well ...
897
+
898
+ def forward_chunk(x, y):
899
+ new_x = self.xblock1_2(x, y)
900
+ new_y = self.xblock2_1(y, x)
901
+ out1 = self.aggr1(repeat(self.output_token_1, 'n d -> b n d', b=x.shape[0]), context=new_x)
902
+ out2 = self.aggr2(repeat(self.output_token_2, 'n d -> b n d', b=x.shape[0]), context=new_y)
903
+ return out1 + out2
904
+
905
+ return checkpoint.checkpoint(forward_chunk, x, y)
906
+
907
+
908
+ class CorrPointformer(nn.Module):
909
+ def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
910
+ super().__init__()
911
+ self.dim = dim
912
+ self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
913
+ # self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
914
+ self.aggr = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
915
+ self.out_proj = nn.Linear(dim, 2*dim)
916
+
917
+ def forward(self, query: torch.Tensor, target: torch.Tensor, target_rel_pos: torch.Tensor) -> torch.Tensor:
918
+ from einops import rearrange, repeat
919
+ import torch.utils.checkpoint as checkpoint
920
+
921
+ def forward_chunk(query, target, target_rel_pos):
922
+ new_query = self.xblock1_2(query, target).mean(dim=1, keepdim=True)
923
+ # new_target = self.xblock2_1(target, query).mean(dim=1, keepdim=True)
924
+ # new_aggr = new_query + new_target
925
+ out = self.aggr(new_query, target+target_rel_pos) # (potential delta xyz) (target - center)
926
+ out = self.out_proj(out)
927
+ return out
928
+
929
+ return checkpoint.checkpoint(forward_chunk, query, target, target_rel_pos)