Spaces:
Running
on
Zero
Running
on
Zero
xiaoyuxi
commited on
Commit
·
c8d9d42
0
Parent(s):
Cleaned history, reset to current state
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +36 -0
- .gitignore +69 -0
- README.md +14 -0
- _viz/viz_template.html +1778 -0
- app.py +1118 -0
- app_3rd/README.md +12 -0
- app_3rd/sam_utils/hf_sam_predictor.py +129 -0
- app_3rd/sam_utils/inference.py +123 -0
- app_3rd/spatrack_utils/infer_track.py +194 -0
- config/__init__.py +0 -0
- config/magic_infer_moge.yaml +48 -0
- examples/backpack.mp4 +3 -0
- examples/ball.mp4 +3 -0
- examples/basketball.mp4 +3 -0
- examples/biker.mp4 +3 -0
- examples/cinema_0.mp4 +3 -0
- examples/cinema_1.mp4 +3 -0
- examples/drifting.mp4 +3 -0
- examples/ego_kc1.mp4 +3 -0
- examples/ego_teaser.mp4 +3 -0
- examples/handwave.mp4 +3 -0
- examples/hockey.mp4 +3 -0
- examples/ken_block_0.mp4 +3 -0
- examples/kiss.mp4 +3 -0
- examples/kitchen.mp4 +3 -0
- examples/kitchen_egocentric.mp4 +3 -0
- examples/pillow.mp4 +3 -0
- examples/protein.mp4 +3 -0
- examples/pusht.mp4 +3 -0
- examples/robot1.mp4 +3 -0
- examples/robot2.mp4 +3 -0
- examples/robot_3.mp4 +3 -0
- examples/robot_unitree.mp4 +3 -0
- examples/running.mp4 +3 -0
- examples/teleop2.mp4 +3 -0
- examples/vertical_place.mp4 +3 -0
- models/SpaTrackV2/models/SpaTrack.py +759 -0
- models/SpaTrackV2/models/__init__.py +0 -0
- models/SpaTrackV2/models/blocks.py +519 -0
- models/SpaTrackV2/models/camera_transform.py +248 -0
- models/SpaTrackV2/models/depth_refiner/backbone.py +472 -0
- models/SpaTrackV2/models/depth_refiner/decode_head.py +619 -0
- models/SpaTrackV2/models/depth_refiner/depth_refiner.py +115 -0
- models/SpaTrackV2/models/depth_refiner/network.py +429 -0
- models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +1187 -0
- models/SpaTrackV2/models/depth_refiner/stablizer.py +342 -0
- models/SpaTrackV2/models/predictor.py +153 -0
- models/SpaTrackV2/models/tracker3D/TrackRefiner.py +1478 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py +418 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/utils.py +929 -0
.gitattributes
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ignore the multi media
|
2 |
+
checkpoints
|
3 |
+
**/checkpoints/
|
4 |
+
**/temp/
|
5 |
+
temp
|
6 |
+
assets_dev
|
7 |
+
assets/example0/results
|
8 |
+
assets/example0/snowboard.npz
|
9 |
+
assets/example1/results
|
10 |
+
assets/davis_eval
|
11 |
+
assets/*/results
|
12 |
+
*gradio*
|
13 |
+
#
|
14 |
+
models/monoD/zoeDepth/ckpts/*
|
15 |
+
models/monoD/depth_anything/ckpts/*
|
16 |
+
vis_results
|
17 |
+
dist_encrypted
|
18 |
+
# remove the dependencies
|
19 |
+
deps
|
20 |
+
|
21 |
+
# filter the __pycache__ files
|
22 |
+
__pycache__/
|
23 |
+
/**/**/__pycache__
|
24 |
+
/**/__pycache__
|
25 |
+
|
26 |
+
outputs
|
27 |
+
scripts/lauch_exp/config
|
28 |
+
scripts/lauch_exp/submit_job.log
|
29 |
+
scripts/lauch_exp/hydra_output
|
30 |
+
scripts/lauch_wulan
|
31 |
+
scripts/custom_video
|
32 |
+
# ignore the visualizer
|
33 |
+
viser
|
34 |
+
viser_result
|
35 |
+
benchmark/results
|
36 |
+
benchmark
|
37 |
+
|
38 |
+
ossutil_output
|
39 |
+
|
40 |
+
prev_version
|
41 |
+
spat_ceres
|
42 |
+
wandb
|
43 |
+
*.log
|
44 |
+
seg_target.py
|
45 |
+
|
46 |
+
eval_davis.py
|
47 |
+
eval_multiple_gpu.py
|
48 |
+
eval_pose_scan.py
|
49 |
+
eval_single_gpu.py
|
50 |
+
|
51 |
+
infer_cam.py
|
52 |
+
infer_stream.py
|
53 |
+
|
54 |
+
*.egg-info/
|
55 |
+
**/*.egg-info
|
56 |
+
|
57 |
+
eval_kinectics.py
|
58 |
+
models/SpaTrackV2/datasets
|
59 |
+
|
60 |
+
scripts
|
61 |
+
config/fix_2d.yaml
|
62 |
+
|
63 |
+
models/SpaTrackV2/datasets
|
64 |
+
scripts/
|
65 |
+
|
66 |
+
models/**/build
|
67 |
+
models/**/dist
|
68 |
+
|
69 |
+
temp_local
|
README.md
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: SpatialTrackerV2
|
3 |
+
emoji: ⚡️
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: red
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.31.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
short_description: Official Space for SpatialTrackerV2
|
12 |
+
---
|
13 |
+
|
14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
_viz/viz_template.html
ADDED
@@ -0,0 +1,1778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>3D Point Cloud Visualizer</title>
|
7 |
+
<style>
|
8 |
+
:root {
|
9 |
+
--primary: #9b59b6; /* Brighter purple for dark mode */
|
10 |
+
--primary-light: #3a2e4a;
|
11 |
+
--secondary: #a86add;
|
12 |
+
--accent: #ff6e6e;
|
13 |
+
--bg: #1a1a1a;
|
14 |
+
--surface: #2c2c2c;
|
15 |
+
--text: #e0e0e0;
|
16 |
+
--text-secondary: #a0a0a0;
|
17 |
+
--border: #444444;
|
18 |
+
--shadow: rgba(0, 0, 0, 0.2);
|
19 |
+
--shadow-hover: rgba(0, 0, 0, 0.3);
|
20 |
+
|
21 |
+
--space-sm: 16px;
|
22 |
+
--space-md: 24px;
|
23 |
+
--space-lg: 32px;
|
24 |
+
}
|
25 |
+
|
26 |
+
body {
|
27 |
+
margin: 0;
|
28 |
+
overflow: hidden;
|
29 |
+
background: var(--bg);
|
30 |
+
color: var(--text);
|
31 |
+
font-family: 'Inter', sans-serif;
|
32 |
+
-webkit-font-smoothing: antialiased;
|
33 |
+
}
|
34 |
+
|
35 |
+
#canvas-container {
|
36 |
+
position: absolute;
|
37 |
+
width: 100%;
|
38 |
+
height: 100%;
|
39 |
+
}
|
40 |
+
|
41 |
+
#ui-container {
|
42 |
+
position: absolute;
|
43 |
+
top: 0;
|
44 |
+
left: 0;
|
45 |
+
width: 100%;
|
46 |
+
height: 100%;
|
47 |
+
pointer-events: none;
|
48 |
+
z-index: 10;
|
49 |
+
}
|
50 |
+
|
51 |
+
#status-bar {
|
52 |
+
position: absolute;
|
53 |
+
top: 16px;
|
54 |
+
left: 16px;
|
55 |
+
background: rgba(30, 30, 30, 0.9);
|
56 |
+
padding: 8px 16px;
|
57 |
+
border-radius: 8px;
|
58 |
+
pointer-events: auto;
|
59 |
+
box-shadow: 0 4px 6px var(--shadow);
|
60 |
+
backdrop-filter: blur(4px);
|
61 |
+
border: 1px solid var(--border);
|
62 |
+
color: var(--text);
|
63 |
+
transition: opacity 0.5s ease, transform 0.5s ease;
|
64 |
+
font-weight: 500;
|
65 |
+
}
|
66 |
+
|
67 |
+
#status-bar.hidden {
|
68 |
+
opacity: 0;
|
69 |
+
transform: translateY(-20px);
|
70 |
+
pointer-events: none;
|
71 |
+
}
|
72 |
+
|
73 |
+
#control-panel {
|
74 |
+
position: absolute;
|
75 |
+
bottom: 16px;
|
76 |
+
left: 50%;
|
77 |
+
transform: translateX(-50%);
|
78 |
+
background: rgba(44, 44, 44, 0.95);
|
79 |
+
padding: 6px 8px;
|
80 |
+
border-radius: 6px;
|
81 |
+
display: flex;
|
82 |
+
gap: 8px;
|
83 |
+
align-items: center;
|
84 |
+
justify-content: space-between;
|
85 |
+
pointer-events: auto;
|
86 |
+
box-shadow: 0 4px 10px var(--shadow);
|
87 |
+
backdrop-filter: blur(4px);
|
88 |
+
border: 1px solid var(--border);
|
89 |
+
}
|
90 |
+
|
91 |
+
#timeline {
|
92 |
+
width: 150px;
|
93 |
+
height: 4px;
|
94 |
+
background: rgba(255, 255, 255, 0.1);
|
95 |
+
border-radius: 2px;
|
96 |
+
position: relative;
|
97 |
+
cursor: pointer;
|
98 |
+
}
|
99 |
+
|
100 |
+
#progress {
|
101 |
+
position: absolute;
|
102 |
+
height: 100%;
|
103 |
+
background: var(--primary);
|
104 |
+
border-radius: 2px;
|
105 |
+
width: 0%;
|
106 |
+
}
|
107 |
+
|
108 |
+
#playback-controls {
|
109 |
+
display: flex;
|
110 |
+
gap: 4px;
|
111 |
+
align-items: center;
|
112 |
+
}
|
113 |
+
|
114 |
+
button {
|
115 |
+
background: rgba(255, 255, 255, 0.08);
|
116 |
+
border: 1px solid var(--border);
|
117 |
+
color: var(--text);
|
118 |
+
padding: 4px 6px;
|
119 |
+
border-radius: 3px;
|
120 |
+
cursor: pointer;
|
121 |
+
display: flex;
|
122 |
+
align-items: center;
|
123 |
+
justify-content: center;
|
124 |
+
transition: background 0.2s, transform 0.2s;
|
125 |
+
font-family: 'Inter', sans-serif;
|
126 |
+
font-weight: 500;
|
127 |
+
font-size: 6px;
|
128 |
+
}
|
129 |
+
|
130 |
+
button:hover {
|
131 |
+
background: rgba(255, 255, 255, 0.15);
|
132 |
+
transform: translateY(-1px);
|
133 |
+
}
|
134 |
+
|
135 |
+
button.active {
|
136 |
+
background: var(--primary);
|
137 |
+
color: white;
|
138 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
139 |
+
}
|
140 |
+
|
141 |
+
select, input {
|
142 |
+
background: rgba(255, 255, 255, 0.08);
|
143 |
+
border: 1px solid var(--border);
|
144 |
+
color: var(--text);
|
145 |
+
padding: 4px 6px;
|
146 |
+
border-radius: 3px;
|
147 |
+
cursor: pointer;
|
148 |
+
font-family: 'Inter', sans-serif;
|
149 |
+
font-size: 6px;
|
150 |
+
}
|
151 |
+
|
152 |
+
.icon {
|
153 |
+
width: 10px;
|
154 |
+
height: 10px;
|
155 |
+
fill: currentColor;
|
156 |
+
}
|
157 |
+
|
158 |
+
.tooltip {
|
159 |
+
position: absolute;
|
160 |
+
bottom: 100%;
|
161 |
+
left: 50%;
|
162 |
+
transform: translateX(-50%);
|
163 |
+
background: var(--surface);
|
164 |
+
color: var(--text);
|
165 |
+
padding: 3px 6px;
|
166 |
+
border-radius: 3px;
|
167 |
+
font-size: 7px;
|
168 |
+
white-space: nowrap;
|
169 |
+
margin-bottom: 4px;
|
170 |
+
opacity: 0;
|
171 |
+
transition: opacity 0.2s;
|
172 |
+
pointer-events: none;
|
173 |
+
box-shadow: 0 2px 4px var(--shadow);
|
174 |
+
border: 1px solid var(--border);
|
175 |
+
}
|
176 |
+
|
177 |
+
button:hover .tooltip {
|
178 |
+
opacity: 1;
|
179 |
+
}
|
180 |
+
|
181 |
+
#settings-panel {
|
182 |
+
position: absolute;
|
183 |
+
top: 16px;
|
184 |
+
right: 16px;
|
185 |
+
background: rgba(44, 44, 44, 0.98);
|
186 |
+
padding: 10px;
|
187 |
+
border-radius: 6px;
|
188 |
+
width: 195px;
|
189 |
+
max-height: calc(100vh - 40px);
|
190 |
+
overflow-y: auto;
|
191 |
+
pointer-events: auto;
|
192 |
+
box-shadow: 0 4px 15px var(--shadow);
|
193 |
+
backdrop-filter: blur(4px);
|
194 |
+
border: 1px solid var(--border);
|
195 |
+
display: block;
|
196 |
+
opacity: 1;
|
197 |
+
scrollbar-width: thin;
|
198 |
+
scrollbar-color: var(--primary-light) transparent;
|
199 |
+
transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
|
200 |
+
}
|
201 |
+
|
202 |
+
#settings-panel.is-hidden {
|
203 |
+
transform: translateX(calc(100% + 20px));
|
204 |
+
opacity: 0;
|
205 |
+
pointer-events: none;
|
206 |
+
}
|
207 |
+
|
208 |
+
#settings-panel::-webkit-scrollbar {
|
209 |
+
width: 3px;
|
210 |
+
}
|
211 |
+
|
212 |
+
#settings-panel::-webkit-scrollbar-track {
|
213 |
+
background: transparent;
|
214 |
+
}
|
215 |
+
|
216 |
+
#settings-panel::-webkit-scrollbar-thumb {
|
217 |
+
background-color: var(--primary-light);
|
218 |
+
border-radius: 3px;
|
219 |
+
}
|
220 |
+
|
221 |
+
@media (max-height: 700px) {
|
222 |
+
#settings-panel {
|
223 |
+
max-height: calc(100vh - 40px);
|
224 |
+
}
|
225 |
+
}
|
226 |
+
|
227 |
+
@media (max-width: 768px) {
|
228 |
+
#control-panel {
|
229 |
+
width: 90%;
|
230 |
+
flex-wrap: wrap;
|
231 |
+
justify-content: center;
|
232 |
+
}
|
233 |
+
|
234 |
+
#timeline {
|
235 |
+
width: 100%;
|
236 |
+
order: 3;
|
237 |
+
margin-top: 10px;
|
238 |
+
}
|
239 |
+
|
240 |
+
#settings-panel {
|
241 |
+
width: 140px;
|
242 |
+
right: 10px;
|
243 |
+
top: 10px;
|
244 |
+
max-height: calc(100vh - 20px);
|
245 |
+
}
|
246 |
+
}
|
247 |
+
|
248 |
+
.settings-group {
|
249 |
+
margin-bottom: 8px;
|
250 |
+
}
|
251 |
+
|
252 |
+
.settings-group h3 {
|
253 |
+
margin: 0 0 6px 0;
|
254 |
+
font-size: 10px;
|
255 |
+
font-weight: 500;
|
256 |
+
color: var(--text-secondary);
|
257 |
+
}
|
258 |
+
|
259 |
+
.slider-container {
|
260 |
+
display: flex;
|
261 |
+
align-items: center;
|
262 |
+
gap: 6px;
|
263 |
+
width: 100%;
|
264 |
+
}
|
265 |
+
|
266 |
+
.slider-container label {
|
267 |
+
min-width: 60px;
|
268 |
+
font-size: 10px;
|
269 |
+
flex-shrink: 0;
|
270 |
+
}
|
271 |
+
|
272 |
+
input[type="range"] {
|
273 |
+
flex: 1;
|
274 |
+
height: 2px;
|
275 |
+
-webkit-appearance: none;
|
276 |
+
background: rgba(255, 255, 255, 0.1);
|
277 |
+
border-radius: 1px;
|
278 |
+
min-width: 0;
|
279 |
+
}
|
280 |
+
|
281 |
+
input[type="range"]::-webkit-slider-thumb {
|
282 |
+
-webkit-appearance: none;
|
283 |
+
width: 8px;
|
284 |
+
height: 8px;
|
285 |
+
border-radius: 50%;
|
286 |
+
background: var(--primary);
|
287 |
+
cursor: pointer;
|
288 |
+
}
|
289 |
+
|
290 |
+
.toggle-switch {
|
291 |
+
position: relative;
|
292 |
+
display: inline-block;
|
293 |
+
width: 20px;
|
294 |
+
height: 10px;
|
295 |
+
}
|
296 |
+
|
297 |
+
.toggle-switch input {
|
298 |
+
opacity: 0;
|
299 |
+
width: 0;
|
300 |
+
height: 0;
|
301 |
+
}
|
302 |
+
|
303 |
+
.toggle-slider {
|
304 |
+
position: absolute;
|
305 |
+
cursor: pointer;
|
306 |
+
top: 0;
|
307 |
+
left: 0;
|
308 |
+
right: 0;
|
309 |
+
bottom: 0;
|
310 |
+
background: rgba(255, 255, 255, 0.1);
|
311 |
+
transition: .4s;
|
312 |
+
border-radius: 10px;
|
313 |
+
}
|
314 |
+
|
315 |
+
.toggle-slider:before {
|
316 |
+
position: absolute;
|
317 |
+
content: "";
|
318 |
+
height: 8px;
|
319 |
+
width: 8px;
|
320 |
+
left: 1px;
|
321 |
+
bottom: 1px;
|
322 |
+
background: var(--surface);
|
323 |
+
border: 1px solid var(--border);
|
324 |
+
transition: .4s;
|
325 |
+
border-radius: 50%;
|
326 |
+
}
|
327 |
+
|
328 |
+
input:checked + .toggle-slider {
|
329 |
+
background: var(--primary);
|
330 |
+
}
|
331 |
+
|
332 |
+
input:checked + .toggle-slider:before {
|
333 |
+
transform: translateX(10px);
|
334 |
+
}
|
335 |
+
|
336 |
+
.checkbox-container {
|
337 |
+
display: flex;
|
338 |
+
align-items: center;
|
339 |
+
gap: 4px;
|
340 |
+
margin-bottom: 4px;
|
341 |
+
}
|
342 |
+
|
343 |
+
.checkbox-container label {
|
344 |
+
font-size: 10px;
|
345 |
+
cursor: pointer;
|
346 |
+
}
|
347 |
+
|
348 |
+
#loading-overlay {
|
349 |
+
position: absolute;
|
350 |
+
top: 0;
|
351 |
+
left: 0;
|
352 |
+
width: 100%;
|
353 |
+
height: 100%;
|
354 |
+
background: var(--bg);
|
355 |
+
display: flex;
|
356 |
+
flex-direction: column;
|
357 |
+
align-items: center;
|
358 |
+
justify-content: center;
|
359 |
+
z-index: 100;
|
360 |
+
transition: opacity 0.5s;
|
361 |
+
}
|
362 |
+
|
363 |
+
#loading-overlay.fade-out {
|
364 |
+
opacity: 0;
|
365 |
+
pointer-events: none;
|
366 |
+
}
|
367 |
+
|
368 |
+
.spinner {
|
369 |
+
width: 50px;
|
370 |
+
height: 50px;
|
371 |
+
border: 5px solid rgba(155, 89, 182, 0.2);
|
372 |
+
border-radius: 50%;
|
373 |
+
border-top-color: var(--primary);
|
374 |
+
animation: spin 1s ease-in-out infinite;
|
375 |
+
margin-bottom: 16px;
|
376 |
+
}
|
377 |
+
|
378 |
+
@keyframes spin {
|
379 |
+
to { transform: rotate(360deg); }
|
380 |
+
}
|
381 |
+
|
382 |
+
#loading-text {
|
383 |
+
margin-top: 16px;
|
384 |
+
font-size: 18px;
|
385 |
+
color: var(--text);
|
386 |
+
font-weight: 500;
|
387 |
+
}
|
388 |
+
|
389 |
+
#frame-counter {
|
390 |
+
color: var(--text-secondary);
|
391 |
+
font-size: 7px;
|
392 |
+
font-weight: 500;
|
393 |
+
min-width: 60px;
|
394 |
+
text-align: center;
|
395 |
+
padding: 0 4px;
|
396 |
+
}
|
397 |
+
|
398 |
+
.control-btn {
|
399 |
+
background: rgba(255, 255, 255, 0.08);
|
400 |
+
border: 1px solid var(--border);
|
401 |
+
padding: 4px 6px;
|
402 |
+
border-radius: 3px;
|
403 |
+
cursor: pointer;
|
404 |
+
display: flex;
|
405 |
+
align-items: center;
|
406 |
+
justify-content: center;
|
407 |
+
transition: all 0.2s ease;
|
408 |
+
font-size: 6px;
|
409 |
+
}
|
410 |
+
|
411 |
+
.control-btn:hover {
|
412 |
+
background: rgba(255, 255, 255, 0.15);
|
413 |
+
transform: translateY(-1px);
|
414 |
+
}
|
415 |
+
|
416 |
+
.control-btn.active {
|
417 |
+
background: var(--primary);
|
418 |
+
color: white;
|
419 |
+
}
|
420 |
+
|
421 |
+
.control-btn.active:hover {
|
422 |
+
background: var(--primary);
|
423 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
424 |
+
}
|
425 |
+
|
426 |
+
#settings-toggle-btn {
|
427 |
+
position: relative;
|
428 |
+
border-radius: 6px;
|
429 |
+
z-index: 20;
|
430 |
+
}
|
431 |
+
|
432 |
+
#settings-toggle-btn.active {
|
433 |
+
background: var(--primary);
|
434 |
+
color: white;
|
435 |
+
}
|
436 |
+
|
437 |
+
#status-bar,
|
438 |
+
#control-panel,
|
439 |
+
#settings-panel,
|
440 |
+
button,
|
441 |
+
input,
|
442 |
+
select,
|
443 |
+
.toggle-switch {
|
444 |
+
pointer-events: auto;
|
445 |
+
}
|
446 |
+
|
447 |
+
h2 {
|
448 |
+
font-size: 0.9rem;
|
449 |
+
font-weight: 600;
|
450 |
+
margin-top: 0;
|
451 |
+
margin-bottom: 12px;
|
452 |
+
color: var(--primary);
|
453 |
+
cursor: move;
|
454 |
+
user-select: none;
|
455 |
+
display: flex;
|
456 |
+
align-items: center;
|
457 |
+
}
|
458 |
+
|
459 |
+
.drag-handle {
|
460 |
+
font-size: 10px;
|
461 |
+
margin-right: 4px;
|
462 |
+
opacity: 0.6;
|
463 |
+
}
|
464 |
+
|
465 |
+
h2:hover .drag-handle {
|
466 |
+
opacity: 1;
|
467 |
+
}
|
468 |
+
|
469 |
+
.loading-subtitle {
|
470 |
+
font-size: 7px;
|
471 |
+
color: var(--text-secondary);
|
472 |
+
margin-top: 4px;
|
473 |
+
}
|
474 |
+
|
475 |
+
#reset-view-btn {
|
476 |
+
background: var(--primary-light);
|
477 |
+
color: var(--primary);
|
478 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
479 |
+
font-weight: 600;
|
480 |
+
font-size: 9px;
|
481 |
+
padding: 4px 6px;
|
482 |
+
transition: all 0.2s;
|
483 |
+
}
|
484 |
+
|
485 |
+
#reset-view-btn:hover {
|
486 |
+
background: var(--primary);
|
487 |
+
color: white;
|
488 |
+
transform: translateY(-2px);
|
489 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
490 |
+
}
|
491 |
+
|
492 |
+
#show-settings-btn {
|
493 |
+
position: absolute;
|
494 |
+
top: 16px;
|
495 |
+
right: 16px;
|
496 |
+
z-index: 15;
|
497 |
+
display: none;
|
498 |
+
}
|
499 |
+
|
500 |
+
#settings-panel.visible {
|
501 |
+
display: block;
|
502 |
+
opacity: 1;
|
503 |
+
animation: slideIn 0.3s ease forwards;
|
504 |
+
}
|
505 |
+
|
506 |
+
@keyframes slideIn {
|
507 |
+
from {
|
508 |
+
transform: translateY(20px);
|
509 |
+
opacity: 0;
|
510 |
+
}
|
511 |
+
to {
|
512 |
+
transform: translateY(0);
|
513 |
+
opacity: 1;
|
514 |
+
}
|
515 |
+
}
|
516 |
+
|
517 |
+
.dragging {
|
518 |
+
opacity: 0.9;
|
519 |
+
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
|
520 |
+
transition: none !important;
|
521 |
+
}
|
522 |
+
|
523 |
+
/* Tooltip for draggable element */
|
524 |
+
.tooltip-drag {
|
525 |
+
position: absolute;
|
526 |
+
left: 50%;
|
527 |
+
transform: translateX(-50%);
|
528 |
+
background: var(--primary);
|
529 |
+
color: white;
|
530 |
+
font-size: 9px;
|
531 |
+
padding: 2px 4px;
|
532 |
+
border-radius: 2px;
|
533 |
+
opacity: 0;
|
534 |
+
pointer-events: none;
|
535 |
+
transition: opacity 0.3s;
|
536 |
+
white-space: nowrap;
|
537 |
+
bottom: 100%;
|
538 |
+
margin-bottom: 4px;
|
539 |
+
}
|
540 |
+
|
541 |
+
h2:hover .tooltip-drag {
|
542 |
+
opacity: 1;
|
543 |
+
}
|
544 |
+
|
545 |
+
.btn-group {
|
546 |
+
display: flex;
|
547 |
+
margin-top: 8px;
|
548 |
+
}
|
549 |
+
|
550 |
+
#reset-settings-btn {
|
551 |
+
background: var(--primary-light);
|
552 |
+
color: var(--primary);
|
553 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
554 |
+
font-weight: 600;
|
555 |
+
font-size: 9px;
|
556 |
+
padding: 4px 6px;
|
557 |
+
transition: all 0.2s;
|
558 |
+
}
|
559 |
+
|
560 |
+
#reset-settings-btn:hover {
|
561 |
+
background: var(--primary);
|
562 |
+
color: white;
|
563 |
+
transform: translateY(-2px);
|
564 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
565 |
+
}
|
566 |
+
</style>
|
567 |
+
</head>
|
568 |
+
<body>
|
569 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
570 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
571 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
572 |
+
|
573 |
+
<div id="canvas-container"></div>
|
574 |
+
|
575 |
+
<div id="ui-container">
|
576 |
+
<div id="status-bar">Initializing...</div>
|
577 |
+
|
578 |
+
<div id="control-panel">
|
579 |
+
<button id="play-pause-btn" class="control-btn">
|
580 |
+
<svg class="icon" viewBox="0 0 24 24">
|
581 |
+
<path id="play-icon" d="M8 5v14l11-7z"/>
|
582 |
+
<path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
|
583 |
+
</svg>
|
584 |
+
<span class="tooltip">Play/Pause</span>
|
585 |
+
</button>
|
586 |
+
|
587 |
+
<div id="timeline">
|
588 |
+
<div id="progress"></div>
|
589 |
+
</div>
|
590 |
+
|
591 |
+
<div id="frame-counter">Frame 0 / 0</div>
|
592 |
+
|
593 |
+
<div id="playback-controls">
|
594 |
+
<button id="speed-btn" class="control-btn">1x</button>
|
595 |
+
</div>
|
596 |
+
</div>
|
597 |
+
|
598 |
+
<div id="settings-panel">
|
599 |
+
<h2>
|
600 |
+
<span class="drag-handle">☰</span>
|
601 |
+
Visualization Settings
|
602 |
+
<button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
|
603 |
+
<svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
|
604 |
+
<path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
|
605 |
+
</svg>
|
606 |
+
</button>
|
607 |
+
</h2>
|
608 |
+
|
609 |
+
<div class="settings-group">
|
610 |
+
<h3>Point Cloud</h3>
|
611 |
+
<div class="slider-container">
|
612 |
+
<label for="point-size">Size</label>
|
613 |
+
<input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
|
614 |
+
</div>
|
615 |
+
<div class="slider-container">
|
616 |
+
<label for="point-opacity">Opacity</label>
|
617 |
+
<input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
|
618 |
+
</div>
|
619 |
+
<div class="slider-container">
|
620 |
+
<label for="max-depth">Max Depth</label>
|
621 |
+
<input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
|
622 |
+
</div>
|
623 |
+
</div>
|
624 |
+
|
625 |
+
<div class="settings-group">
|
626 |
+
<h3>Trajectory</h3>
|
627 |
+
<div class="checkbox-container">
|
628 |
+
<label class="toggle-switch">
|
629 |
+
<input type="checkbox" id="show-trajectory" checked>
|
630 |
+
<span class="toggle-slider"></span>
|
631 |
+
</label>
|
632 |
+
<label for="show-trajectory">Show Trajectory</label>
|
633 |
+
</div>
|
634 |
+
<div class="checkbox-container">
|
635 |
+
<label class="toggle-switch">
|
636 |
+
<input type="checkbox" id="enable-rich-trail">
|
637 |
+
<span class="toggle-slider"></span>
|
638 |
+
</label>
|
639 |
+
<label for="enable-rich-trail">Visual-Rich Trail</label>
|
640 |
+
</div>
|
641 |
+
<div class="slider-container">
|
642 |
+
<label for="trajectory-line-width">Line Width</label>
|
643 |
+
<input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
|
644 |
+
</div>
|
645 |
+
<div class="slider-container">
|
646 |
+
<label for="trajectory-ball-size">Ball Size</label>
|
647 |
+
<input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
|
648 |
+
</div>
|
649 |
+
<div class="slider-container">
|
650 |
+
<label for="trajectory-history">History Frames</label>
|
651 |
+
<input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
|
652 |
+
</div>
|
653 |
+
<div class="slider-container" id="tail-opacity-container" style="display: none;">
|
654 |
+
<label for="trajectory-fade">Tail Opacity</label>
|
655 |
+
<input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
|
656 |
+
</div>
|
657 |
+
</div>
|
658 |
+
|
659 |
+
<div class="settings-group">
|
660 |
+
<h3>Camera</h3>
|
661 |
+
<div class="checkbox-container">
|
662 |
+
<label class="toggle-switch">
|
663 |
+
<input type="checkbox" id="show-camera-frustum" checked>
|
664 |
+
<span class="toggle-slider"></span>
|
665 |
+
</label>
|
666 |
+
<label for="show-camera-frustum">Show Camera Frustum</label>
|
667 |
+
</div>
|
668 |
+
<div class="slider-container">
|
669 |
+
<label for="frustum-size">Size</label>
|
670 |
+
<input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
|
671 |
+
</div>
|
672 |
+
</div>
|
673 |
+
|
674 |
+
<div class="settings-group">
|
675 |
+
<div class="btn-group">
|
676 |
+
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
677 |
+
<button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
|
678 |
+
</div>
|
679 |
+
</div>
|
680 |
+
</div>
|
681 |
+
|
682 |
+
<button id="show-settings-btn" class="control-btn" title="Show Settings">
|
683 |
+
<svg class="icon" viewBox="0 0 24 24">
|
684 |
+
<path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
|
685 |
+
</svg>
|
686 |
+
</button>
|
687 |
+
</div>
|
688 |
+
|
689 |
+
<div id="loading-overlay">
|
690 |
+
<!-- <div class="spinner"></div> -->
|
691 |
+
<div id="loading-text"></div>
|
692 |
+
<div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
|
693 |
+
</div>
|
694 |
+
|
695 |
+
<!-- Libraries -->
|
696 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
|
697 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/build/three.min.js"></script>
|
698 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/controls/OrbitControls.js"></script>
|
699 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/build/dat.gui.min.js"></script>
|
700 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegmentsGeometry.js"></script>
|
701 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineGeometry.js"></script>
|
702 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineMaterial.js"></script>
|
703 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/LineSegments2.js"></script>
|
704 |
+
<script src="https://cdn.jsdelivr.net/npm/[email protected]/examples/js/lines/Line2.js"></script>
|
705 |
+
|
706 |
+
<script>
|
707 |
+
class PointCloudVisualizer {
|
708 |
+
constructor() {
|
709 |
+
this.data = null;
|
710 |
+
this.config = {};
|
711 |
+
this.currentFrame = 0;
|
712 |
+
this.isPlaying = false;
|
713 |
+
this.playbackSpeed = 1;
|
714 |
+
this.lastFrameTime = 0;
|
715 |
+
this.defaultSettings = null;
|
716 |
+
|
717 |
+
this.ui = {
|
718 |
+
statusBar: document.getElementById('status-bar'),
|
719 |
+
playPauseBtn: document.getElementById('play-pause-btn'),
|
720 |
+
speedBtn: document.getElementById('speed-btn'),
|
721 |
+
timeline: document.getElementById('timeline'),
|
722 |
+
progress: document.getElementById('progress'),
|
723 |
+
settingsPanel: document.getElementById('settings-panel'),
|
724 |
+
loadingOverlay: document.getElementById('loading-overlay'),
|
725 |
+
loadingText: document.getElementById('loading-text'),
|
726 |
+
settingsToggleBtn: document.getElementById('settings-toggle-btn'),
|
727 |
+
frameCounter: document.getElementById('frame-counter'),
|
728 |
+
pointSize: document.getElementById('point-size'),
|
729 |
+
pointOpacity: document.getElementById('point-opacity'),
|
730 |
+
maxDepth: document.getElementById('max-depth'),
|
731 |
+
showTrajectory: document.getElementById('show-trajectory'),
|
732 |
+
enableRichTrail: document.getElementById('enable-rich-trail'),
|
733 |
+
trajectoryLineWidth: document.getElementById('trajectory-line-width'),
|
734 |
+
trajectoryBallSize: document.getElementById('trajectory-ball-size'),
|
735 |
+
trajectoryHistory: document.getElementById('trajectory-history'),
|
736 |
+
trajectoryFade: document.getElementById('trajectory-fade'),
|
737 |
+
tailOpacityContainer: document.getElementById('tail-opacity-container'),
|
738 |
+
resetViewBtn: document.getElementById('reset-view-btn'),
|
739 |
+
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
740 |
+
frustumSize: document.getElementById('frustum-size'),
|
741 |
+
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
742 |
+
showSettingsBtn: document.getElementById('show-settings-btn')
|
743 |
+
};
|
744 |
+
|
745 |
+
this.scene = null;
|
746 |
+
this.camera = null;
|
747 |
+
this.renderer = null;
|
748 |
+
this.controls = null;
|
749 |
+
this.pointCloud = null;
|
750 |
+
this.trajectories = [];
|
751 |
+
this.cameraFrustum = null;
|
752 |
+
|
753 |
+
this.initThreeJS();
|
754 |
+
this.loadDefaultSettings().then(() => {
|
755 |
+
this.initEventListeners();
|
756 |
+
this.loadData();
|
757 |
+
});
|
758 |
+
}
|
759 |
+
|
760 |
+
async loadDefaultSettings() {
|
761 |
+
try {
|
762 |
+
const urlParams = new URLSearchParams(window.location.search);
|
763 |
+
const dataPath = urlParams.get('data') || '';
|
764 |
+
|
765 |
+
const defaultSettings = {
|
766 |
+
pointSize: 0.03,
|
767 |
+
pointOpacity: 1.0,
|
768 |
+
showTrajectory: true,
|
769 |
+
trajectoryLineWidth: 2.5,
|
770 |
+
trajectoryBallSize: 0.015,
|
771 |
+
trajectoryHistory: 0,
|
772 |
+
showCameraFrustum: true,
|
773 |
+
frustumSize: 0.2
|
774 |
+
};
|
775 |
+
|
776 |
+
if (!dataPath) {
|
777 |
+
this.defaultSettings = defaultSettings;
|
778 |
+
this.applyDefaultSettings();
|
779 |
+
return;
|
780 |
+
}
|
781 |
+
|
782 |
+
// Try to extract dataset and videoId from the data path
|
783 |
+
// Expected format: demos/datasetname/videoid.bin
|
784 |
+
const pathParts = dataPath.split('/');
|
785 |
+
if (pathParts.length < 3) {
|
786 |
+
this.defaultSettings = defaultSettings;
|
787 |
+
this.applyDefaultSettings();
|
788 |
+
return;
|
789 |
+
}
|
790 |
+
|
791 |
+
const datasetName = pathParts[pathParts.length - 2];
|
792 |
+
let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
|
793 |
+
|
794 |
+
// Load settings from data.json
|
795 |
+
const response = await fetch('./data.json');
|
796 |
+
if (!response.ok) {
|
797 |
+
this.defaultSettings = defaultSettings;
|
798 |
+
this.applyDefaultSettings();
|
799 |
+
return;
|
800 |
+
}
|
801 |
+
|
802 |
+
const settingsData = await response.json();
|
803 |
+
|
804 |
+
// Check if this dataset and video exist
|
805 |
+
if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
|
806 |
+
this.defaultSettings = settingsData[datasetName][videoId];
|
807 |
+
} else {
|
808 |
+
this.defaultSettings = defaultSettings;
|
809 |
+
}
|
810 |
+
|
811 |
+
this.applyDefaultSettings();
|
812 |
+
} catch (error) {
|
813 |
+
console.error("Error loading default settings:", error);
|
814 |
+
|
815 |
+
this.defaultSettings = {
|
816 |
+
pointSize: 0.03,
|
817 |
+
pointOpacity: 1.0,
|
818 |
+
showTrajectory: true,
|
819 |
+
trajectoryLineWidth: 2.5,
|
820 |
+
trajectoryBallSize: 0.015,
|
821 |
+
trajectoryHistory: 0,
|
822 |
+
showCameraFrustum: true,
|
823 |
+
frustumSize: 0.2
|
824 |
+
};
|
825 |
+
|
826 |
+
this.applyDefaultSettings();
|
827 |
+
}
|
828 |
+
}
|
829 |
+
|
830 |
+
applyDefaultSettings() {
|
831 |
+
if (!this.defaultSettings) return;
|
832 |
+
|
833 |
+
if (this.ui.pointSize) {
|
834 |
+
this.ui.pointSize.value = this.defaultSettings.pointSize;
|
835 |
+
}
|
836 |
+
|
837 |
+
if (this.ui.pointOpacity) {
|
838 |
+
this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
|
839 |
+
}
|
840 |
+
|
841 |
+
if (this.ui.maxDepth) {
|
842 |
+
this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
|
843 |
+
}
|
844 |
+
|
845 |
+
if (this.ui.showTrajectory) {
|
846 |
+
this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
|
847 |
+
}
|
848 |
+
|
849 |
+
if (this.ui.trajectoryLineWidth) {
|
850 |
+
this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
|
851 |
+
}
|
852 |
+
|
853 |
+
if (this.ui.trajectoryBallSize) {
|
854 |
+
this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
|
855 |
+
}
|
856 |
+
|
857 |
+
if (this.ui.trajectoryHistory) {
|
858 |
+
this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
|
859 |
+
}
|
860 |
+
|
861 |
+
if (this.ui.showCameraFrustum) {
|
862 |
+
this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
|
863 |
+
}
|
864 |
+
|
865 |
+
if (this.ui.frustumSize) {
|
866 |
+
this.ui.frustumSize.value = this.defaultSettings.frustumSize;
|
867 |
+
}
|
868 |
+
}
|
869 |
+
|
870 |
+
initThreeJS() {
|
871 |
+
this.scene = new THREE.Scene();
|
872 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
873 |
+
|
874 |
+
this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
|
875 |
+
this.camera.position.set(0, 0, 0);
|
876 |
+
|
877 |
+
this.renderer = new THREE.WebGLRenderer({ antialias: true });
|
878 |
+
this.renderer.setPixelRatio(window.devicePixelRatio);
|
879 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
880 |
+
document.getElementById('canvas-container').appendChild(this.renderer.domElement);
|
881 |
+
|
882 |
+
this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
|
883 |
+
this.controls.enableDamping = true;
|
884 |
+
this.controls.dampingFactor = 0.05;
|
885 |
+
this.controls.target.set(0, 0, 0);
|
886 |
+
this.controls.minDistance = 0.1;
|
887 |
+
this.controls.maxDistance = 1000;
|
888 |
+
this.controls.update();
|
889 |
+
|
890 |
+
const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
|
891 |
+
this.scene.add(ambientLight);
|
892 |
+
|
893 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
|
894 |
+
directionalLight.position.set(1, 1, 1);
|
895 |
+
this.scene.add(directionalLight);
|
896 |
+
}
|
897 |
+
|
898 |
+
initEventListeners() {
|
899 |
+
window.addEventListener('resize', () => this.onWindowResize());
|
900 |
+
|
901 |
+
this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
|
902 |
+
|
903 |
+
this.ui.timeline.addEventListener('click', (e) => {
|
904 |
+
const rect = this.ui.timeline.getBoundingClientRect();
|
905 |
+
const pos = (e.clientX - rect.left) / rect.width;
|
906 |
+
this.seekTo(pos);
|
907 |
+
});
|
908 |
+
|
909 |
+
this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
|
910 |
+
|
911 |
+
this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
|
912 |
+
this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
|
913 |
+
this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
|
914 |
+
this.ui.showTrajectory.addEventListener('change', () => {
|
915 |
+
this.trajectories.forEach(trajectory => {
|
916 |
+
trajectory.visible = this.ui.showTrajectory.checked;
|
917 |
+
});
|
918 |
+
});
|
919 |
+
|
920 |
+
this.ui.enableRichTrail.addEventListener('change', () => {
|
921 |
+
this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
|
922 |
+
this.updateTrajectories(this.currentFrame);
|
923 |
+
});
|
924 |
+
|
925 |
+
this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
|
926 |
+
this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
|
927 |
+
this.ui.trajectoryHistory.addEventListener('input', () => {
|
928 |
+
this.updateTrajectories(this.currentFrame);
|
929 |
+
});
|
930 |
+
this.ui.trajectoryFade.addEventListener('input', () => {
|
931 |
+
this.updateTrajectories(this.currentFrame);
|
932 |
+
});
|
933 |
+
|
934 |
+
this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
|
935 |
+
|
936 |
+
const resetSettingsBtn = document.getElementById('reset-settings-btn');
|
937 |
+
if (resetSettingsBtn) {
|
938 |
+
resetSettingsBtn.addEventListener('click', () => this.resetSettings());
|
939 |
+
}
|
940 |
+
|
941 |
+
document.addEventListener('keydown', (e) => {
|
942 |
+
if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
|
943 |
+
this.ui.settingsPanel.classList.remove('visible');
|
944 |
+
this.ui.settingsToggleBtn.classList.remove('active');
|
945 |
+
}
|
946 |
+
});
|
947 |
+
|
948 |
+
if (this.ui.settingsToggleBtn) {
|
949 |
+
this.ui.settingsToggleBtn.addEventListener('click', () => {
|
950 |
+
const isVisible = this.ui.settingsPanel.classList.toggle('visible');
|
951 |
+
this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
|
952 |
+
|
953 |
+
if (isVisible) {
|
954 |
+
const panelRect = this.ui.settingsPanel.getBoundingClientRect();
|
955 |
+
const viewportHeight = window.innerHeight;
|
956 |
+
|
957 |
+
if (panelRect.bottom > viewportHeight) {
|
958 |
+
this.ui.settingsPanel.style.bottom = 'auto';
|
959 |
+
this.ui.settingsPanel.style.top = '80px';
|
960 |
+
}
|
961 |
+
}
|
962 |
+
});
|
963 |
+
}
|
964 |
+
|
965 |
+
if (this.ui.frustumSize) {
|
966 |
+
this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
|
967 |
+
}
|
968 |
+
|
969 |
+
if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
|
970 |
+
this.ui.hideSettingsBtn.addEventListener('click', () => {
|
971 |
+
this.ui.settingsPanel.classList.add('is-hidden');
|
972 |
+
this.ui.showSettingsBtn.style.display = 'flex';
|
973 |
+
});
|
974 |
+
|
975 |
+
this.ui.showSettingsBtn.addEventListener('click', () => {
|
976 |
+
this.ui.settingsPanel.classList.remove('is-hidden');
|
977 |
+
this.ui.showSettingsBtn.style.display = 'none';
|
978 |
+
});
|
979 |
+
}
|
980 |
+
}
|
981 |
+
|
982 |
+
makeElementDraggable(element) {
|
983 |
+
let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
|
984 |
+
|
985 |
+
const dragHandle = element.querySelector('h2');
|
986 |
+
|
987 |
+
if (dragHandle) {
|
988 |
+
dragHandle.onmousedown = dragMouseDown;
|
989 |
+
dragHandle.title = "Drag to move panel";
|
990 |
+
} else {
|
991 |
+
element.onmousedown = dragMouseDown;
|
992 |
+
}
|
993 |
+
|
994 |
+
function dragMouseDown(e) {
|
995 |
+
e = e || window.event;
|
996 |
+
e.preventDefault();
|
997 |
+
pos3 = e.clientX;
|
998 |
+
pos4 = e.clientY;
|
999 |
+
document.onmouseup = closeDragElement;
|
1000 |
+
document.onmousemove = elementDrag;
|
1001 |
+
|
1002 |
+
element.classList.add('dragging');
|
1003 |
+
}
|
1004 |
+
|
1005 |
+
function elementDrag(e) {
|
1006 |
+
e = e || window.event;
|
1007 |
+
e.preventDefault();
|
1008 |
+
pos1 = pos3 - e.clientX;
|
1009 |
+
pos2 = pos4 - e.clientY;
|
1010 |
+
pos3 = e.clientX;
|
1011 |
+
pos4 = e.clientY;
|
1012 |
+
|
1013 |
+
const newTop = element.offsetTop - pos2;
|
1014 |
+
const newLeft = element.offsetLeft - pos1;
|
1015 |
+
|
1016 |
+
const viewportWidth = window.innerWidth;
|
1017 |
+
const viewportHeight = window.innerHeight;
|
1018 |
+
|
1019 |
+
const panelRect = element.getBoundingClientRect();
|
1020 |
+
|
1021 |
+
const maxTop = viewportHeight - 50;
|
1022 |
+
const maxLeft = viewportWidth - 50;
|
1023 |
+
|
1024 |
+
element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
|
1025 |
+
element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
|
1026 |
+
|
1027 |
+
// Remove bottom/right settings when dragging
|
1028 |
+
element.style.bottom = 'auto';
|
1029 |
+
element.style.right = 'auto';
|
1030 |
+
}
|
1031 |
+
|
1032 |
+
function closeDragElement() {
|
1033 |
+
document.onmouseup = null;
|
1034 |
+
document.onmousemove = null;
|
1035 |
+
|
1036 |
+
element.classList.remove('dragging');
|
1037 |
+
}
|
1038 |
+
}
|
1039 |
+
|
1040 |
+
async loadData() {
|
1041 |
+
try {
|
1042 |
+
// this.ui.loadingText.textContent = "Loading binary data...";
|
1043 |
+
|
1044 |
+
let arrayBuffer;
|
1045 |
+
|
1046 |
+
if (window.embeddedBase64) {
|
1047 |
+
// Base64 embedded path
|
1048 |
+
const binaryString = atob(window.embeddedBase64);
|
1049 |
+
const len = binaryString.length;
|
1050 |
+
const bytes = new Uint8Array(len);
|
1051 |
+
for (let i = 0; i < len; i++) {
|
1052 |
+
bytes[i] = binaryString.charCodeAt(i);
|
1053 |
+
}
|
1054 |
+
arrayBuffer = bytes.buffer;
|
1055 |
+
} else {
|
1056 |
+
// Default fetch path (fallback)
|
1057 |
+
const urlParams = new URLSearchParams(window.location.search);
|
1058 |
+
const dataPath = urlParams.get('data') || 'data.bin';
|
1059 |
+
|
1060 |
+
const response = await fetch(dataPath);
|
1061 |
+
if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
|
1062 |
+
arrayBuffer = await response.arrayBuffer();
|
1063 |
+
}
|
1064 |
+
|
1065 |
+
const dataView = new DataView(arrayBuffer);
|
1066 |
+
const headerLen = dataView.getUint32(0, true);
|
1067 |
+
|
1068 |
+
const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
|
1069 |
+
const header = JSON.parse(headerText);
|
1070 |
+
|
1071 |
+
const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
|
1072 |
+
const decompressed = pako.inflate(compressedBlob).buffer;
|
1073 |
+
|
1074 |
+
const arrays = {};
|
1075 |
+
for (const key in header) {
|
1076 |
+
if (key === "meta") continue;
|
1077 |
+
|
1078 |
+
const meta = header[key];
|
1079 |
+
const { dtype, shape, offset, length } = meta;
|
1080 |
+
const slice = decompressed.slice(offset, offset + length);
|
1081 |
+
|
1082 |
+
let typedArray;
|
1083 |
+
switch (dtype) {
|
1084 |
+
case "uint8": typedArray = new Uint8Array(slice); break;
|
1085 |
+
case "uint16": typedArray = new Uint16Array(slice); break;
|
1086 |
+
case "float32": typedArray = new Float32Array(slice); break;
|
1087 |
+
case "float64": typedArray = new Float64Array(slice); break;
|
1088 |
+
default: throw new Error(`Unknown dtype: ${dtype}`);
|
1089 |
+
}
|
1090 |
+
|
1091 |
+
arrays[key] = { data: typedArray, shape: shape };
|
1092 |
+
}
|
1093 |
+
|
1094 |
+
this.data = arrays;
|
1095 |
+
this.config = header.meta;
|
1096 |
+
|
1097 |
+
this.initCameraWithCorrectFOV();
|
1098 |
+
this.ui.loadingText.textContent = "Creating point cloud...";
|
1099 |
+
|
1100 |
+
this.initPointCloud();
|
1101 |
+
this.initTrajectories();
|
1102 |
+
|
1103 |
+
setTimeout(() => {
|
1104 |
+
this.ui.loadingOverlay.classList.add('fade-out');
|
1105 |
+
this.ui.statusBar.classList.add('hidden');
|
1106 |
+
this.startAnimation();
|
1107 |
+
}, 500);
|
1108 |
+
} catch (error) {
|
1109 |
+
console.error("Error loading data:", error);
|
1110 |
+
this.ui.statusBar.textContent = `Error: ${error.message}`;
|
1111 |
+
// this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
|
1112 |
+
}
|
1113 |
+
}
|
1114 |
+
|
1115 |
+
initPointCloud() {
|
1116 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
1117 |
+
const positions = new Float32Array(numPoints * 3);
|
1118 |
+
const colors = new Float32Array(numPoints * 3);
|
1119 |
+
|
1120 |
+
const geometry = new THREE.BufferGeometry();
|
1121 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
|
1122 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
|
1123 |
+
|
1124 |
+
const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
|
1125 |
+
const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
|
1126 |
+
|
1127 |
+
const material = new THREE.PointsMaterial({
|
1128 |
+
size: pointSize,
|
1129 |
+
vertexColors: true,
|
1130 |
+
transparent: true,
|
1131 |
+
opacity: pointOpacity,
|
1132 |
+
sizeAttenuation: true
|
1133 |
+
});
|
1134 |
+
|
1135 |
+
this.pointCloud = new THREE.Points(geometry, material);
|
1136 |
+
this.scene.add(this.pointCloud);
|
1137 |
+
}
|
1138 |
+
|
1139 |
+
initTrajectories() {
|
1140 |
+
if (!this.data.trajectories) return;
|
1141 |
+
|
1142 |
+
this.trajectories.forEach(trajectory => {
|
1143 |
+
if (trajectory.userData.lineSegments) {
|
1144 |
+
trajectory.userData.lineSegments.forEach(segment => {
|
1145 |
+
segment.geometry.dispose();
|
1146 |
+
segment.material.dispose();
|
1147 |
+
});
|
1148 |
+
}
|
1149 |
+
this.scene.remove(trajectory);
|
1150 |
+
});
|
1151 |
+
this.trajectories = [];
|
1152 |
+
|
1153 |
+
const shape = this.data.trajectories.shape;
|
1154 |
+
if (!shape || shape.length < 2) return;
|
1155 |
+
|
1156 |
+
const [totalFrames, numTrajectories] = shape;
|
1157 |
+
const palette = this.createColorPalette(numTrajectories);
|
1158 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1159 |
+
const maxHistory = 500; // Max value of the history slider, for the object pool
|
1160 |
+
|
1161 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1162 |
+
const trajectoryGroup = new THREE.Group();
|
1163 |
+
|
1164 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1165 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1166 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
|
1167 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
1168 |
+
trajectoryGroup.add(positionMarker);
|
1169 |
+
|
1170 |
+
// High-Performance Line (default)
|
1171 |
+
const simpleLineGeometry = new THREE.BufferGeometry();
|
1172 |
+
const simpleLinePositions = new Float32Array(maxHistory * 3);
|
1173 |
+
simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
|
1174 |
+
const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
|
1175 |
+
simpleLine.frustumCulled = false;
|
1176 |
+
trajectoryGroup.add(simpleLine);
|
1177 |
+
|
1178 |
+
// High-Quality Line Segments (for rich trail)
|
1179 |
+
const lineSegments = [];
|
1180 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
1181 |
+
|
1182 |
+
// Create a pool of line segment objects
|
1183 |
+
for (let j = 0; j < maxHistory - 1; j++) {
|
1184 |
+
const lineGeometry = new THREE.LineGeometry();
|
1185 |
+
lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
|
1186 |
+
const lineMaterial = new THREE.LineMaterial({
|
1187 |
+
color: palette[i],
|
1188 |
+
linewidth: lineWidth,
|
1189 |
+
resolution: resolution,
|
1190 |
+
transparent: true,
|
1191 |
+
depthWrite: false, // Correctly handle transparency
|
1192 |
+
opacity: 0
|
1193 |
+
});
|
1194 |
+
const segment = new THREE.Line2(lineGeometry, lineMaterial);
|
1195 |
+
segment.frustumCulled = false;
|
1196 |
+
segment.visible = false; // Start with all segments hidden
|
1197 |
+
trajectoryGroup.add(segment);
|
1198 |
+
lineSegments.push(segment);
|
1199 |
+
}
|
1200 |
+
|
1201 |
+
trajectoryGroup.userData = {
|
1202 |
+
marker: positionMarker,
|
1203 |
+
simpleLine: simpleLine,
|
1204 |
+
lineSegments: lineSegments,
|
1205 |
+
color: palette[i]
|
1206 |
+
};
|
1207 |
+
|
1208 |
+
this.scene.add(trajectoryGroup);
|
1209 |
+
this.trajectories.push(trajectoryGroup);
|
1210 |
+
}
|
1211 |
+
|
1212 |
+
const showTrajectory = this.ui.showTrajectory.checked;
|
1213 |
+
this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
|
1214 |
+
}
|
1215 |
+
|
1216 |
+
createColorPalette(count) {
|
1217 |
+
const colors = [];
|
1218 |
+
const hueStep = 360 / count;
|
1219 |
+
|
1220 |
+
for (let i = 0; i < count; i++) {
|
1221 |
+
const hue = (i * hueStep) % 360;
|
1222 |
+
const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
|
1223 |
+
colors.push(color);
|
1224 |
+
}
|
1225 |
+
|
1226 |
+
return colors;
|
1227 |
+
}
|
1228 |
+
|
1229 |
+
updatePointCloud(frameIndex) {
|
1230 |
+
if (!this.data || !this.pointCloud) return;
|
1231 |
+
|
1232 |
+
const positions = this.pointCloud.geometry.attributes.position.array;
|
1233 |
+
const colors = this.pointCloud.geometry.attributes.color.array;
|
1234 |
+
|
1235 |
+
const rgbVideo = this.data.rgb_video;
|
1236 |
+
const depthsRgb = this.data.depths_rgb;
|
1237 |
+
const intrinsics = this.data.intrinsics;
|
1238 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1239 |
+
|
1240 |
+
const width = this.config.resolution[0];
|
1241 |
+
const height = this.config.resolution[1];
|
1242 |
+
const numPoints = width * height;
|
1243 |
+
|
1244 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
1245 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
1246 |
+
|
1247 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1248 |
+
const transform = this.getTransformElements(invExtrMat);
|
1249 |
+
|
1250 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
1251 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
1252 |
+
|
1253 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
1254 |
+
|
1255 |
+
let validPointCount = 0;
|
1256 |
+
|
1257 |
+
for (let i = 0; i < numPoints; i++) {
|
1258 |
+
const xPix = i % width;
|
1259 |
+
const yPix = Math.floor(i / width);
|
1260 |
+
|
1261 |
+
const d0 = depthFrame[i * 3];
|
1262 |
+
const d1 = depthFrame[i * 3 + 1];
|
1263 |
+
const depthEncoded = d0 | (d1 << 8);
|
1264 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
1265 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
1266 |
+
this.config.depthRange[0];
|
1267 |
+
|
1268 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
1269 |
+
continue;
|
1270 |
+
}
|
1271 |
+
|
1272 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
1273 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
1274 |
+
const Z = depthValue;
|
1275 |
+
|
1276 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
1277 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
1278 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
1279 |
+
|
1280 |
+
const index = validPointCount * 3;
|
1281 |
+
positions[index] = tx;
|
1282 |
+
positions[index + 1] = -ty;
|
1283 |
+
positions[index + 2] = -tz;
|
1284 |
+
|
1285 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
1286 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
1287 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
1288 |
+
|
1289 |
+
validPointCount++;
|
1290 |
+
}
|
1291 |
+
|
1292 |
+
this.pointCloud.geometry.setDrawRange(0, validPointCount);
|
1293 |
+
this.pointCloud.geometry.attributes.position.needsUpdate = true;
|
1294 |
+
this.pointCloud.geometry.attributes.color.needsUpdate = true;
|
1295 |
+
this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
|
1296 |
+
|
1297 |
+
this.updateTrajectories(frameIndex);
|
1298 |
+
|
1299 |
+
const progress = (frameIndex + 1) / this.config.totalFrames;
|
1300 |
+
this.ui.progress.style.width = `${progress * 100}%`;
|
1301 |
+
|
1302 |
+
if (this.ui.frameCounter && this.config.totalFrames) {
|
1303 |
+
this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
|
1304 |
+
}
|
1305 |
+
|
1306 |
+
this.updateCameraFrustum(frameIndex);
|
1307 |
+
}
|
1308 |
+
|
1309 |
+
updateTrajectories(frameIndex) {
|
1310 |
+
if (!this.data.trajectories || this.trajectories.length === 0) return;
|
1311 |
+
|
1312 |
+
const trajectoryData = this.data.trajectories.data;
|
1313 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
1314 |
+
const historyFrames = parseInt(this.ui.trajectoryHistory.value);
|
1315 |
+
const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
|
1316 |
+
|
1317 |
+
const isRichMode = this.ui.enableRichTrail.checked;
|
1318 |
+
|
1319 |
+
for (let i = 0; i < numTrajectories; i++) {
|
1320 |
+
const trajectoryGroup = this.trajectories[i];
|
1321 |
+
const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
|
1322 |
+
|
1323 |
+
const currentPos = new THREE.Vector3();
|
1324 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
1325 |
+
|
1326 |
+
currentPos.x = trajectoryData[currentOffset];
|
1327 |
+
currentPos.y = -trajectoryData[currentOffset + 1];
|
1328 |
+
currentPos.z = -trajectoryData[currentOffset + 2];
|
1329 |
+
|
1330 |
+
marker.position.copy(currentPos);
|
1331 |
+
marker.material.opacity = 1.0;
|
1332 |
+
|
1333 |
+
const historyToShow = Math.min(historyFrames, frameIndex + 1);
|
1334 |
+
|
1335 |
+
if (isRichMode) {
|
1336 |
+
// --- High-Quality Mode ---
|
1337 |
+
simpleLine.visible = false;
|
1338 |
+
|
1339 |
+
for (let j = 0; j < lineSegments.length; j++) {
|
1340 |
+
const segment = lineSegments[j];
|
1341 |
+
if (j < historyToShow - 1) {
|
1342 |
+
const headFrame = frameIndex - j;
|
1343 |
+
const tailFrame = frameIndex - j - 1;
|
1344 |
+
const headOffset = (headFrame * numTrajectories + i) * 3;
|
1345 |
+
const tailOffset = (tailFrame * numTrajectories + i) * 3;
|
1346 |
+
const positions = [
|
1347 |
+
trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
|
1348 |
+
trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
|
1349 |
+
];
|
1350 |
+
segment.geometry.setPositions(positions);
|
1351 |
+
const headOpacity = 1.0;
|
1352 |
+
const normalizedAge = j / Math.max(1, historyToShow - 2);
|
1353 |
+
const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
|
1354 |
+
segment.material.opacity = Math.max(0, alpha);
|
1355 |
+
segment.visible = true;
|
1356 |
+
} else {
|
1357 |
+
segment.visible = false;
|
1358 |
+
}
|
1359 |
+
}
|
1360 |
+
} else {
|
1361 |
+
// --- Performance Mode ---
|
1362 |
+
lineSegments.forEach(s => s.visible = false);
|
1363 |
+
simpleLine.visible = true;
|
1364 |
+
|
1365 |
+
const positions = simpleLine.geometry.attributes.position.array;
|
1366 |
+
for (let j = 0; j < historyToShow; j++) {
|
1367 |
+
const historyFrame = Math.max(0, frameIndex - j);
|
1368 |
+
const offset = (historyFrame * numTrajectories + i) * 3;
|
1369 |
+
positions[j * 3] = trajectoryData[offset];
|
1370 |
+
positions[j * 3 + 1] = -trajectoryData[offset + 1];
|
1371 |
+
positions[j * 3 + 2] = -trajectoryData[offset + 2];
|
1372 |
+
}
|
1373 |
+
simpleLine.geometry.setDrawRange(0, historyToShow);
|
1374 |
+
simpleLine.geometry.attributes.position.needsUpdate = true;
|
1375 |
+
}
|
1376 |
+
}
|
1377 |
+
}
|
1378 |
+
|
1379 |
+
updateTrajectorySettings() {
|
1380 |
+
if (!this.trajectories || this.trajectories.length === 0) return;
|
1381 |
+
|
1382 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
1383 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
1384 |
+
|
1385 |
+
this.trajectories.forEach(trajectoryGroup => {
|
1386 |
+
const { marker, lineSegments } = trajectoryGroup.userData;
|
1387 |
+
|
1388 |
+
marker.geometry.dispose();
|
1389 |
+
marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
1390 |
+
|
1391 |
+
// Line width only affects rich mode
|
1392 |
+
lineSegments.forEach(segment => {
|
1393 |
+
if (segment.material) {
|
1394 |
+
segment.material.linewidth = lineWidth;
|
1395 |
+
}
|
1396 |
+
});
|
1397 |
+
});
|
1398 |
+
|
1399 |
+
this.updateTrajectories(this.currentFrame);
|
1400 |
+
}
|
1401 |
+
|
1402 |
+
getDepthColor(normalizedDepth) {
|
1403 |
+
const hue = (1 - normalizedDepth) * 240 / 360;
|
1404 |
+
const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
|
1405 |
+
return color;
|
1406 |
+
}
|
1407 |
+
|
1408 |
+
getFrame(typedArray, shape, frameIndex) {
|
1409 |
+
const [T, H, W, C] = shape;
|
1410 |
+
const frameSize = H * W * C;
|
1411 |
+
const offset = frameIndex * frameSize;
|
1412 |
+
return typedArray.subarray(offset, offset + frameSize);
|
1413 |
+
}
|
1414 |
+
|
1415 |
+
get3x3Matrix(typedArray, shape, frameIndex) {
|
1416 |
+
const frameSize = 9;
|
1417 |
+
const offset = frameIndex * frameSize;
|
1418 |
+
const K = [];
|
1419 |
+
for (let i = 0; i < 3; i++) {
|
1420 |
+
const row = [];
|
1421 |
+
for (let j = 0; j < 3; j++) {
|
1422 |
+
row.push(typedArray[offset + i * 3 + j]);
|
1423 |
+
}
|
1424 |
+
K.push(row);
|
1425 |
+
}
|
1426 |
+
return K;
|
1427 |
+
}
|
1428 |
+
|
1429 |
+
get4x4Matrix(typedArray, shape, frameIndex) {
|
1430 |
+
const frameSize = 16;
|
1431 |
+
const offset = frameIndex * frameSize;
|
1432 |
+
const M = [];
|
1433 |
+
for (let i = 0; i < 4; i++) {
|
1434 |
+
const row = [];
|
1435 |
+
for (let j = 0; j < 4; j++) {
|
1436 |
+
row.push(typedArray[offset + i * 4 + j]);
|
1437 |
+
}
|
1438 |
+
M.push(row);
|
1439 |
+
}
|
1440 |
+
return M;
|
1441 |
+
}
|
1442 |
+
|
1443 |
+
getTransformElements(matrix) {
|
1444 |
+
return {
|
1445 |
+
m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
|
1446 |
+
m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
|
1447 |
+
m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
|
1448 |
+
};
|
1449 |
+
}
|
1450 |
+
|
1451 |
+
togglePlayback() {
|
1452 |
+
this.isPlaying = !this.isPlaying;
|
1453 |
+
|
1454 |
+
const playIcon = document.getElementById('play-icon');
|
1455 |
+
const pauseIcon = document.getElementById('pause-icon');
|
1456 |
+
|
1457 |
+
if (this.isPlaying) {
|
1458 |
+
playIcon.style.display = 'none';
|
1459 |
+
pauseIcon.style.display = 'block';
|
1460 |
+
this.lastFrameTime = performance.now();
|
1461 |
+
} else {
|
1462 |
+
playIcon.style.display = 'block';
|
1463 |
+
pauseIcon.style.display = 'none';
|
1464 |
+
}
|
1465 |
+
}
|
1466 |
+
|
1467 |
+
cyclePlaybackSpeed() {
|
1468 |
+
const speeds = [0.5, 1, 2, 4, 8];
|
1469 |
+
const speedRates = speeds.map(s => s * this.config.baseFrameRate);
|
1470 |
+
|
1471 |
+
let currentIndex = 0;
|
1472 |
+
const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
|
1473 |
+
|
1474 |
+
for (let i = 0; i < speeds.length; i++) {
|
1475 |
+
if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
|
1476 |
+
currentIndex = i;
|
1477 |
+
}
|
1478 |
+
}
|
1479 |
+
|
1480 |
+
const nextIndex = (currentIndex + 1) % speeds.length;
|
1481 |
+
this.playbackSpeed = speedRates[nextIndex];
|
1482 |
+
this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
|
1483 |
+
|
1484 |
+
if (speeds[nextIndex] === 1) {
|
1485 |
+
this.ui.speedBtn.classList.remove('active');
|
1486 |
+
} else {
|
1487 |
+
this.ui.speedBtn.classList.add('active');
|
1488 |
+
}
|
1489 |
+
}
|
1490 |
+
|
1491 |
+
seekTo(position) {
|
1492 |
+
const frameIndex = Math.floor(position * this.config.totalFrames);
|
1493 |
+
this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
|
1494 |
+
this.updatePointCloud(this.currentFrame);
|
1495 |
+
}
|
1496 |
+
|
1497 |
+
updatePointCloudSettings() {
|
1498 |
+
if (!this.pointCloud) return;
|
1499 |
+
|
1500 |
+
const size = parseFloat(this.ui.pointSize.value);
|
1501 |
+
const opacity = parseFloat(this.ui.pointOpacity.value);
|
1502 |
+
|
1503 |
+
this.pointCloud.material.size = size;
|
1504 |
+
this.pointCloud.material.opacity = opacity;
|
1505 |
+
this.pointCloud.material.needsUpdate = true;
|
1506 |
+
|
1507 |
+
this.updatePointCloud(this.currentFrame);
|
1508 |
+
}
|
1509 |
+
|
1510 |
+
updateControls() {
|
1511 |
+
if (!this.controls) return;
|
1512 |
+
this.controls.update();
|
1513 |
+
}
|
1514 |
+
|
1515 |
+
resetView() {
|
1516 |
+
if (!this.camera || !this.controls) return;
|
1517 |
+
|
1518 |
+
// Reset camera position
|
1519 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
1520 |
+
|
1521 |
+
// Reset controls
|
1522 |
+
this.controls.reset();
|
1523 |
+
|
1524 |
+
// Set target slightly in front of camera
|
1525 |
+
this.controls.target.set(0, 0, -1);
|
1526 |
+
this.controls.update();
|
1527 |
+
|
1528 |
+
// Show status message
|
1529 |
+
this.ui.statusBar.textContent = "View reset";
|
1530 |
+
this.ui.statusBar.classList.remove('hidden');
|
1531 |
+
|
1532 |
+
// Hide status message after a few seconds
|
1533 |
+
setTimeout(() => {
|
1534 |
+
this.ui.statusBar.classList.add('hidden');
|
1535 |
+
}, 3000);
|
1536 |
+
}
|
1537 |
+
|
1538 |
+
onWindowResize() {
|
1539 |
+
if (!this.camera || !this.renderer) return;
|
1540 |
+
|
1541 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
1542 |
+
this.camera.aspect = windowAspect;
|
1543 |
+
this.camera.updateProjectionMatrix();
|
1544 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
1545 |
+
|
1546 |
+
if (this.trajectories && this.trajectories.length > 0) {
|
1547 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1548 |
+
this.trajectories.forEach(trajectory => {
|
1549 |
+
const { lineSegments } = trajectory.userData;
|
1550 |
+
if (lineSegments && lineSegments.length > 0) {
|
1551 |
+
lineSegments.forEach(segment => {
|
1552 |
+
if (segment.material && segment.material.resolution) {
|
1553 |
+
segment.material.resolution.copy(resolution);
|
1554 |
+
}
|
1555 |
+
});
|
1556 |
+
}
|
1557 |
+
});
|
1558 |
+
}
|
1559 |
+
|
1560 |
+
if (this.cameraFrustum) {
|
1561 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1562 |
+
this.cameraFrustum.children.forEach(line => {
|
1563 |
+
if (line.material && line.material.resolution) {
|
1564 |
+
line.material.resolution.copy(resolution);
|
1565 |
+
}
|
1566 |
+
});
|
1567 |
+
}
|
1568 |
+
}
|
1569 |
+
|
1570 |
+
startAnimation() {
|
1571 |
+
this.isPlaying = true;
|
1572 |
+
this.lastFrameTime = performance.now();
|
1573 |
+
|
1574 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
1575 |
+
this.controls.target.set(0, 0, -1);
|
1576 |
+
this.controls.update();
|
1577 |
+
|
1578 |
+
this.playbackSpeed = this.config.baseFrameRate;
|
1579 |
+
|
1580 |
+
document.getElementById('play-icon').style.display = 'none';
|
1581 |
+
document.getElementById('pause-icon').style.display = 'block';
|
1582 |
+
|
1583 |
+
this.animate();
|
1584 |
+
}
|
1585 |
+
|
1586 |
+
animate() {
|
1587 |
+
requestAnimationFrame(() => this.animate());
|
1588 |
+
|
1589 |
+
if (this.controls) {
|
1590 |
+
this.controls.update();
|
1591 |
+
}
|
1592 |
+
|
1593 |
+
if (this.isPlaying && this.data) {
|
1594 |
+
const now = performance.now();
|
1595 |
+
const delta = (now - this.lastFrameTime) / 1000;
|
1596 |
+
|
1597 |
+
const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
|
1598 |
+
if (framesToAdvance > 0) {
|
1599 |
+
this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
|
1600 |
+
this.lastFrameTime = now;
|
1601 |
+
this.updatePointCloud(this.currentFrame);
|
1602 |
+
}
|
1603 |
+
}
|
1604 |
+
|
1605 |
+
if (this.renderer && this.scene && this.camera) {
|
1606 |
+
this.renderer.render(this.scene, this.camera);
|
1607 |
+
}
|
1608 |
+
}
|
1609 |
+
|
1610 |
+
initCameraWithCorrectFOV() {
|
1611 |
+
const fov = this.config.fov || 60;
|
1612 |
+
|
1613 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
1614 |
+
|
1615 |
+
this.camera = new THREE.PerspectiveCamera(
|
1616 |
+
fov,
|
1617 |
+
windowAspect,
|
1618 |
+
0.1,
|
1619 |
+
10000
|
1620 |
+
);
|
1621 |
+
|
1622 |
+
this.controls.object = this.camera;
|
1623 |
+
this.controls.update();
|
1624 |
+
|
1625 |
+
this.initCameraFrustum();
|
1626 |
+
}
|
1627 |
+
|
1628 |
+
initCameraFrustum() {
|
1629 |
+
this.cameraFrustum = new THREE.Group();
|
1630 |
+
|
1631 |
+
this.scene.add(this.cameraFrustum);
|
1632 |
+
|
1633 |
+
this.initCameraFrustumGeometry();
|
1634 |
+
|
1635 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
|
1636 |
+
|
1637 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
1638 |
+
}
|
1639 |
+
|
1640 |
+
initCameraFrustumGeometry() {
|
1641 |
+
const fov = this.config.fov || 60;
|
1642 |
+
const originalAspect = this.config.original_aspect_ratio || 1.33;
|
1643 |
+
|
1644 |
+
const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
|
1645 |
+
|
1646 |
+
const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
|
1647 |
+
const halfWidth = halfHeight * originalAspect;
|
1648 |
+
|
1649 |
+
const vertices = [
|
1650 |
+
new THREE.Vector3(0, 0, 0),
|
1651 |
+
new THREE.Vector3(-halfWidth, -halfHeight, size),
|
1652 |
+
new THREE.Vector3(halfWidth, -halfHeight, size),
|
1653 |
+
new THREE.Vector3(halfWidth, halfHeight, size),
|
1654 |
+
new THREE.Vector3(-halfWidth, halfHeight, size)
|
1655 |
+
];
|
1656 |
+
|
1657 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1658 |
+
|
1659 |
+
const linePairs = [
|
1660 |
+
[1, 2], [2, 3], [3, 4], [4, 1],
|
1661 |
+
[0, 1], [0, 2], [0, 3], [0, 4]
|
1662 |
+
];
|
1663 |
+
|
1664 |
+
const colors = {
|
1665 |
+
edge: new THREE.Color(0x3366ff),
|
1666 |
+
ray: new THREE.Color(0x33cc66)
|
1667 |
+
};
|
1668 |
+
|
1669 |
+
linePairs.forEach((pair, index) => {
|
1670 |
+
const positions = [
|
1671 |
+
vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
|
1672 |
+
vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
|
1673 |
+
];
|
1674 |
+
|
1675 |
+
const lineGeometry = new THREE.LineGeometry();
|
1676 |
+
lineGeometry.setPositions(positions);
|
1677 |
+
|
1678 |
+
let color = index < 4 ? colors.edge : colors.ray;
|
1679 |
+
|
1680 |
+
const lineMaterial = new THREE.LineMaterial({
|
1681 |
+
color: color,
|
1682 |
+
linewidth: 2,
|
1683 |
+
resolution: resolution,
|
1684 |
+
dashed: false
|
1685 |
+
});
|
1686 |
+
|
1687 |
+
const line = new THREE.Line2(lineGeometry, lineMaterial);
|
1688 |
+
this.cameraFrustum.add(line);
|
1689 |
+
});
|
1690 |
+
}
|
1691 |
+
|
1692 |
+
updateCameraFrustum(frameIndex) {
|
1693 |
+
if (!this.cameraFrustum || !this.data) return;
|
1694 |
+
|
1695 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
1696 |
+
if (!invExtrinsics) return;
|
1697 |
+
|
1698 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
1699 |
+
|
1700 |
+
const matrix = new THREE.Matrix4();
|
1701 |
+
matrix.set(
|
1702 |
+
invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
|
1703 |
+
invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
|
1704 |
+
invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
|
1705 |
+
invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
|
1706 |
+
);
|
1707 |
+
|
1708 |
+
const position = new THREE.Vector3();
|
1709 |
+
position.setFromMatrixPosition(matrix);
|
1710 |
+
|
1711 |
+
const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
|
1712 |
+
|
1713 |
+
const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
|
1714 |
+
|
1715 |
+
const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
|
1716 |
+
|
1717 |
+
const quaternion = new THREE.Quaternion();
|
1718 |
+
quaternion.setFromRotationMatrix(finalRotation);
|
1719 |
+
|
1720 |
+
position.y = -position.y;
|
1721 |
+
position.z = -position.z;
|
1722 |
+
|
1723 |
+
this.cameraFrustum.position.copy(position);
|
1724 |
+
this.cameraFrustum.quaternion.copy(quaternion);
|
1725 |
+
|
1726 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
|
1727 |
+
|
1728 |
+
if (this.cameraFrustum.visible !== showCameraFrustum) {
|
1729 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
1730 |
+
}
|
1731 |
+
|
1732 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
1733 |
+
this.cameraFrustum.children.forEach(line => {
|
1734 |
+
if (line.material && line.material.resolution) {
|
1735 |
+
line.material.resolution.copy(resolution);
|
1736 |
+
}
|
1737 |
+
});
|
1738 |
+
}
|
1739 |
+
|
1740 |
+
updateFrustumDimensions() {
|
1741 |
+
if (!this.cameraFrustum) return;
|
1742 |
+
|
1743 |
+
while(this.cameraFrustum.children.length > 0) {
|
1744 |
+
const child = this.cameraFrustum.children[0];
|
1745 |
+
if (child.geometry) child.geometry.dispose();
|
1746 |
+
if (child.material) child.material.dispose();
|
1747 |
+
this.cameraFrustum.remove(child);
|
1748 |
+
}
|
1749 |
+
|
1750 |
+
this.initCameraFrustumGeometry();
|
1751 |
+
|
1752 |
+
this.updateCameraFrustum(this.currentFrame);
|
1753 |
+
}
|
1754 |
+
|
1755 |
+
resetSettings() {
|
1756 |
+
if (!this.defaultSettings) return;
|
1757 |
+
|
1758 |
+
this.applyDefaultSettings();
|
1759 |
+
|
1760 |
+
this.updatePointCloudSettings();
|
1761 |
+
this.updateTrajectorySettings();
|
1762 |
+
this.updateFrustumDimensions();
|
1763 |
+
|
1764 |
+
this.ui.statusBar.textContent = "Settings reset to defaults";
|
1765 |
+
this.ui.statusBar.classList.remove('hidden');
|
1766 |
+
|
1767 |
+
setTimeout(() => {
|
1768 |
+
this.ui.statusBar.classList.add('hidden');
|
1769 |
+
}, 3000);
|
1770 |
+
}
|
1771 |
+
}
|
1772 |
+
|
1773 |
+
window.addEventListener('DOMContentLoaded', () => {
|
1774 |
+
new PointCloudVisualizer();
|
1775 |
+
});
|
1776 |
+
</script>
|
1777 |
+
</body>
|
1778 |
+
</html>
|
app.py
ADDED
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import numpy as np
|
5 |
+
import cv2
|
6 |
+
import base64
|
7 |
+
import time
|
8 |
+
import tempfile
|
9 |
+
import shutil
|
10 |
+
import glob
|
11 |
+
import threading
|
12 |
+
import subprocess
|
13 |
+
import struct
|
14 |
+
import zlib
|
15 |
+
from pathlib import Path
|
16 |
+
from einops import rearrange
|
17 |
+
from typing import List, Tuple, Union
|
18 |
+
try:
|
19 |
+
import spaces
|
20 |
+
except ImportError:
|
21 |
+
# Fallback for local development
|
22 |
+
def spaces(func):
|
23 |
+
return func
|
24 |
+
import torch
|
25 |
+
import logging
|
26 |
+
from concurrent.futures import ThreadPoolExecutor
|
27 |
+
import atexit
|
28 |
+
import uuid
|
29 |
+
|
30 |
+
# Configure logging
|
31 |
+
logging.basicConfig(level=logging.INFO)
|
32 |
+
logger = logging.getLogger(__name__)
|
33 |
+
|
34 |
+
# Import custom modules with error handling
|
35 |
+
try:
|
36 |
+
from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
|
37 |
+
from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
|
38 |
+
except ImportError as e:
|
39 |
+
logger.error(f"Failed to import custom modules: {e}")
|
40 |
+
raise
|
41 |
+
|
42 |
+
# Constants
|
43 |
+
MAX_FRAMES = 80
|
44 |
+
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
45 |
+
MARKERS = [1, 5] # Cross for negative, Star for positive
|
46 |
+
MARKER_SIZE = 8
|
47 |
+
|
48 |
+
# Thread pool for delayed deletion
|
49 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
50 |
+
|
51 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
52 |
+
"""Delete file or directory after specified delay (default 10 minutes)"""
|
53 |
+
def _delete():
|
54 |
+
try:
|
55 |
+
if os.path.isfile(path):
|
56 |
+
os.remove(path)
|
57 |
+
elif os.path.isdir(path):
|
58 |
+
shutil.rmtree(path)
|
59 |
+
except Exception as e:
|
60 |
+
logger.warning(f"Failed to delete {path}: {e}")
|
61 |
+
|
62 |
+
def _wait_and_delete():
|
63 |
+
time.sleep(delay)
|
64 |
+
_delete()
|
65 |
+
|
66 |
+
thread_pool_executor.submit(_wait_and_delete)
|
67 |
+
atexit.register(_delete)
|
68 |
+
|
69 |
+
def create_user_temp_dir():
|
70 |
+
"""Create a unique temporary directory for each user session"""
|
71 |
+
session_id = str(uuid.uuid4())[:8] # Short unique ID
|
72 |
+
temp_dir = os.path.join("temp_local", f"session_{session_id}")
|
73 |
+
os.makedirs(temp_dir, exist_ok=True)
|
74 |
+
|
75 |
+
# Schedule deletion after 10 minutes
|
76 |
+
delete_later(temp_dir, delay=600)
|
77 |
+
|
78 |
+
return temp_dir
|
79 |
+
|
80 |
+
from huggingface_hub import hf_hub_download
|
81 |
+
# init the model
|
82 |
+
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
83 |
+
|
84 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
85 |
+
from models.vggt.vggt.models.vggt_moe import VGGT_MoE
|
86 |
+
from models.vggt.vggt.utils.load_fn import preprocess_image
|
87 |
+
vggt_model = VGGT_MoE()
|
88 |
+
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
|
89 |
+
vggt_model.eval()
|
90 |
+
vggt_model = vggt_model.to("cuda")
|
91 |
+
|
92 |
+
# Global model initialization
|
93 |
+
print("🚀 Initializing local models...")
|
94 |
+
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
|
95 |
+
predictor = get_sam_predictor()
|
96 |
+
print("✅ Models loaded successfully!")
|
97 |
+
|
98 |
+
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
99 |
+
|
100 |
+
@spaces.GPU
|
101 |
+
def gpu_run_inference(predictor_arg, image, points, boxes):
|
102 |
+
"""GPU-accelerated SAM inference"""
|
103 |
+
if predictor_arg is None:
|
104 |
+
print("Initializing SAM predictor inside GPU function...")
|
105 |
+
predictor_arg = get_sam_predictor(predictor=predictor)
|
106 |
+
|
107 |
+
# Ensure predictor is on GPU
|
108 |
+
try:
|
109 |
+
if hasattr(predictor_arg, 'model'):
|
110 |
+
predictor_arg.model = predictor_arg.model.cuda()
|
111 |
+
elif hasattr(predictor_arg, 'sam'):
|
112 |
+
predictor_arg.sam = predictor_arg.sam.cuda()
|
113 |
+
elif hasattr(predictor_arg, 'to'):
|
114 |
+
predictor_arg = predictor_arg.to('cuda')
|
115 |
+
|
116 |
+
if hasattr(image, 'cuda'):
|
117 |
+
image = image.cuda()
|
118 |
+
|
119 |
+
except Exception as e:
|
120 |
+
print(f"Warning: Could not move predictor to GPU: {e}")
|
121 |
+
|
122 |
+
return run_inference(predictor_arg, image, points, boxes)
|
123 |
+
|
124 |
+
@spaces.GPU
|
125 |
+
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
|
126 |
+
"""GPU-accelerated tracking"""
|
127 |
+
import torchvision.transforms as T
|
128 |
+
import decord
|
129 |
+
|
130 |
+
if tracker_model_arg is None or tracker_viser_arg is None:
|
131 |
+
print("Initializing tracker models inside GPU function...")
|
132 |
+
out_dir = os.path.join(temp_dir, "results")
|
133 |
+
os.makedirs(out_dir, exist_ok=True)
|
134 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
|
135 |
+
|
136 |
+
# Setup paths
|
137 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
138 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
139 |
+
out_dir = os.path.join(temp_dir, "results")
|
140 |
+
os.makedirs(out_dir, exist_ok=True)
|
141 |
+
|
142 |
+
# Load video using decord
|
143 |
+
video_reader = decord.VideoReader(video_path)
|
144 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
|
145 |
+
|
146 |
+
# Resize to ensure minimum side is 336
|
147 |
+
h, w = video_tensor.shape[2:]
|
148 |
+
scale = max(224 / h, 224 / w)
|
149 |
+
if scale < 1:
|
150 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
151 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
152 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
153 |
+
|
154 |
+
# Move to GPU
|
155 |
+
video_tensor = video_tensor.cuda()
|
156 |
+
print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
|
157 |
+
|
158 |
+
depth_tensor = None
|
159 |
+
intrs = None
|
160 |
+
extrs = None
|
161 |
+
data_npz_load = {}
|
162 |
+
|
163 |
+
# run vggt
|
164 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
165 |
+
# process the image tensor
|
166 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
167 |
+
with torch.no_grad():
|
168 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
169 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
170 |
+
predictions = vggt_model(video_tensor.cuda()/255)
|
171 |
+
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
172 |
+
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
173 |
+
|
174 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
175 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
176 |
+
extrs = extrinsic.squeeze().cpu().numpy()
|
177 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
178 |
+
video_tensor = video_tensor.squeeze()
|
179 |
+
#NOTE: 20% of the depth is not reliable
|
180 |
+
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
181 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
182 |
+
|
183 |
+
# Load and process mask
|
184 |
+
if os.path.exists(mask_path):
|
185 |
+
mask = cv2.imread(mask_path)
|
186 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
187 |
+
mask = mask.sum(axis=-1)>0
|
188 |
+
else:
|
189 |
+
mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
|
190 |
+
grid_size = 10
|
191 |
+
|
192 |
+
# Get frame dimensions and create grid points
|
193 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
194 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
|
195 |
+
|
196 |
+
# Sample mask values at grid points and filter
|
197 |
+
if os.path.exists(mask_path):
|
198 |
+
grid_pts_int = grid_pts[0].long()
|
199 |
+
mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
|
200 |
+
grid_pts = grid_pts[:, mask_values]
|
201 |
+
|
202 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
203 |
+
print(f"Query points shape: {query_xyt.shape}")
|
204 |
+
|
205 |
+
# Run model inference
|
206 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
207 |
+
(
|
208 |
+
c2w_traj, intrs, point_map, conf_depth,
|
209 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
210 |
+
) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
|
211 |
+
intrs=intrs, extrs=extrs,
|
212 |
+
queries=query_xyt,
|
213 |
+
fps=1, full_point=False, iters_track=4,
|
214 |
+
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
215 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
216 |
+
|
217 |
+
# Resize results to avoid large I/O
|
218 |
+
max_size = 224
|
219 |
+
h, w = video.shape[2:]
|
220 |
+
scale = min(max_size / h, max_size / w)
|
221 |
+
if scale < 1:
|
222 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
223 |
+
video = T.Resize((new_h, new_w))(video)
|
224 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
225 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
226 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
227 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
228 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
229 |
+
|
230 |
+
# Visualize tracks
|
231 |
+
tracker_viser_arg.visualize(video=video[None],
|
232 |
+
tracks=track2d_pred[None][...,:2],
|
233 |
+
visibility=vis_pred[None],filename="test")
|
234 |
+
|
235 |
+
# Save in tapip3d format
|
236 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
237 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
238 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
239 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
240 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
241 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
242 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
243 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
244 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
245 |
+
|
246 |
+
return None
|
247 |
+
|
248 |
+
def compress_and_write(filename, header, blob):
|
249 |
+
header_bytes = json.dumps(header).encode("utf-8")
|
250 |
+
header_len = struct.pack("<I", len(header_bytes))
|
251 |
+
with open(filename, "wb") as f:
|
252 |
+
f.write(header_len)
|
253 |
+
f.write(header_bytes)
|
254 |
+
f.write(blob)
|
255 |
+
|
256 |
+
def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
|
257 |
+
fixed_size = (width, height)
|
258 |
+
|
259 |
+
data = np.load(npz_file)
|
260 |
+
extrinsics = data["extrinsics"]
|
261 |
+
intrinsics = data["intrinsics"]
|
262 |
+
trajs = data["coords"]
|
263 |
+
T, C, H, W = data["video"].shape
|
264 |
+
|
265 |
+
fx = intrinsics[0, 0, 0]
|
266 |
+
fy = intrinsics[0, 1, 1]
|
267 |
+
fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
|
268 |
+
fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
|
269 |
+
original_aspect_ratio = (W / fx) / (H / fy)
|
270 |
+
|
271 |
+
rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
|
272 |
+
rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
|
273 |
+
for frame in rgb_video])
|
274 |
+
|
275 |
+
depth_video = data["depths"].astype(np.float32)
|
276 |
+
if "confs_depth" in data.keys():
|
277 |
+
confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
|
278 |
+
depth_video = depth_video * confs
|
279 |
+
depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
|
280 |
+
for frame in depth_video])
|
281 |
+
|
282 |
+
scale_x = fixed_size[0] / W
|
283 |
+
scale_y = fixed_size[1] / H
|
284 |
+
intrinsics = intrinsics.copy()
|
285 |
+
intrinsics[:, 0, :] *= scale_x
|
286 |
+
intrinsics[:, 1, :] *= scale_y
|
287 |
+
|
288 |
+
min_depth = float(depth_video.min()) * 0.8
|
289 |
+
max_depth = float(depth_video.max()) * 1.5
|
290 |
+
|
291 |
+
depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
|
292 |
+
depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
|
293 |
+
|
294 |
+
depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
|
295 |
+
depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
|
296 |
+
depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
|
297 |
+
|
298 |
+
first_frame_inv = np.linalg.inv(extrinsics[0])
|
299 |
+
normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
|
300 |
+
|
301 |
+
normalized_trajs = np.zeros_like(trajs)
|
302 |
+
for t in range(T):
|
303 |
+
homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
|
304 |
+
transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
|
305 |
+
normalized_trajs[t] = transformed_trajs[:, :3]
|
306 |
+
|
307 |
+
arrays = {
|
308 |
+
"rgb_video": rgb_video,
|
309 |
+
"depths_rgb": depths_rgb,
|
310 |
+
"intrinsics": intrinsics,
|
311 |
+
"extrinsics": normalized_extrinsics,
|
312 |
+
"inv_extrinsics": np.linalg.inv(normalized_extrinsics),
|
313 |
+
"trajectories": normalized_trajs.astype(np.float32),
|
314 |
+
"cameraZ": 0.0
|
315 |
+
}
|
316 |
+
|
317 |
+
header = {}
|
318 |
+
blob_parts = []
|
319 |
+
offset = 0
|
320 |
+
for key, arr in arrays.items():
|
321 |
+
arr = np.ascontiguousarray(arr)
|
322 |
+
arr_bytes = arr.tobytes()
|
323 |
+
header[key] = {
|
324 |
+
"dtype": str(arr.dtype),
|
325 |
+
"shape": arr.shape,
|
326 |
+
"offset": offset,
|
327 |
+
"length": len(arr_bytes)
|
328 |
+
}
|
329 |
+
blob_parts.append(arr_bytes)
|
330 |
+
offset += len(arr_bytes)
|
331 |
+
|
332 |
+
raw_blob = b"".join(blob_parts)
|
333 |
+
compressed_blob = zlib.compress(raw_blob, level=9)
|
334 |
+
|
335 |
+
header["meta"] = {
|
336 |
+
"depthRange": [min_depth, max_depth],
|
337 |
+
"totalFrames": int(T),
|
338 |
+
"resolution": fixed_size,
|
339 |
+
"baseFrameRate": fps,
|
340 |
+
"numTrajectoryPoints": normalized_trajs.shape[1],
|
341 |
+
"fov": float(fov_y),
|
342 |
+
"fov_x": float(fov_x),
|
343 |
+
"original_aspect_ratio": float(original_aspect_ratio),
|
344 |
+
"fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
|
345 |
+
}
|
346 |
+
|
347 |
+
compress_and_write('./_viz/data.bin', header, compressed_blob)
|
348 |
+
with open('./_viz/data.bin', "rb") as f:
|
349 |
+
encoded_blob = base64.b64encode(f.read()).decode("ascii")
|
350 |
+
os.unlink('./_viz/data.bin')
|
351 |
+
|
352 |
+
random_path = f'./_viz/_{time.time()}.html'
|
353 |
+
with open('./_viz/viz_template.html') as f:
|
354 |
+
html_template = f.read()
|
355 |
+
html_out = html_template.replace(
|
356 |
+
"<head>",
|
357 |
+
f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
|
358 |
+
)
|
359 |
+
with open(random_path,'w') as f:
|
360 |
+
f.write(html_out)
|
361 |
+
|
362 |
+
return random_path
|
363 |
+
|
364 |
+
def numpy_to_base64(arr):
|
365 |
+
"""Convert numpy array to base64 string"""
|
366 |
+
return base64.b64encode(arr.tobytes()).decode('utf-8')
|
367 |
+
|
368 |
+
def base64_to_numpy(b64_str, shape, dtype):
|
369 |
+
"""Convert base64 string back to numpy array"""
|
370 |
+
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
371 |
+
|
372 |
+
def get_video_name(video_path):
|
373 |
+
"""Extract video name without extension"""
|
374 |
+
return os.path.splitext(os.path.basename(video_path))[0]
|
375 |
+
|
376 |
+
def extract_first_frame(video_path):
|
377 |
+
"""Extract first frame from video file"""
|
378 |
+
try:
|
379 |
+
cap = cv2.VideoCapture(video_path)
|
380 |
+
ret, frame = cap.read()
|
381 |
+
cap.release()
|
382 |
+
|
383 |
+
if ret:
|
384 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
385 |
+
return frame_rgb
|
386 |
+
else:
|
387 |
+
return None
|
388 |
+
except Exception as e:
|
389 |
+
print(f"Error extracting first frame: {e}")
|
390 |
+
return None
|
391 |
+
|
392 |
+
def handle_video_upload(video):
|
393 |
+
"""Handle video upload and extract first frame"""
|
394 |
+
if video is None:
|
395 |
+
return (None, None, [],
|
396 |
+
gr.update(value=50),
|
397 |
+
gr.update(value=756),
|
398 |
+
gr.update(value=3))
|
399 |
+
|
400 |
+
# Create user-specific temporary directory
|
401 |
+
user_temp_dir = create_user_temp_dir()
|
402 |
+
|
403 |
+
# Get original video name and copy to temp directory
|
404 |
+
if isinstance(video, str):
|
405 |
+
video_name = get_video_name(video)
|
406 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
407 |
+
shutil.copy(video, video_path)
|
408 |
+
else:
|
409 |
+
video_name = get_video_name(video.name)
|
410 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
411 |
+
with open(video_path, 'wb') as f:
|
412 |
+
f.write(video.read())
|
413 |
+
|
414 |
+
print(f"📁 Video saved to: {video_path}")
|
415 |
+
|
416 |
+
# Extract first frame
|
417 |
+
frame = extract_first_frame(video_path)
|
418 |
+
if frame is None:
|
419 |
+
return (None, None, [],
|
420 |
+
gr.update(value=50),
|
421 |
+
gr.update(value=756),
|
422 |
+
gr.update(value=3))
|
423 |
+
|
424 |
+
# Resize frame to have minimum side length of 336
|
425 |
+
h, w = frame.shape[:2]
|
426 |
+
scale = 336 / min(h, w)
|
427 |
+
new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
|
428 |
+
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
429 |
+
|
430 |
+
# Store frame data with temp directory info
|
431 |
+
frame_data = {
|
432 |
+
'data': numpy_to_base64(frame),
|
433 |
+
'shape': frame.shape,
|
434 |
+
'dtype': str(frame.dtype),
|
435 |
+
'temp_dir': user_temp_dir,
|
436 |
+
'video_name': video_name,
|
437 |
+
'video_path': video_path
|
438 |
+
}
|
439 |
+
|
440 |
+
# Get video-specific settings
|
441 |
+
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
442 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
443 |
+
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
444 |
+
|
445 |
+
return (json.dumps(frame_data), frame, [],
|
446 |
+
gr.update(value=grid_size_val),
|
447 |
+
gr.update(value=vo_points_val),
|
448 |
+
gr.update(value=fps_val))
|
449 |
+
|
450 |
+
def save_masks(o_masks, video_name, temp_dir):
|
451 |
+
"""Save binary masks to files in user-specific temp directory"""
|
452 |
+
o_files = []
|
453 |
+
for mask, _ in o_masks:
|
454 |
+
o_mask = np.uint8(mask.squeeze() * 255)
|
455 |
+
o_file = os.path.join(temp_dir, f"{video_name}.png")
|
456 |
+
cv2.imwrite(o_file, o_mask)
|
457 |
+
o_files.append(o_file)
|
458 |
+
return o_files
|
459 |
+
|
460 |
+
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
461 |
+
"""Handle point selection for SAM"""
|
462 |
+
if original_img is None:
|
463 |
+
return None, []
|
464 |
+
|
465 |
+
try:
|
466 |
+
# Convert stored image data back to numpy array
|
467 |
+
frame_data = json.loads(original_img)
|
468 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
469 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
470 |
+
video_name = frame_data.get('video_name', 'video')
|
471 |
+
|
472 |
+
# Create a display image for visualization
|
473 |
+
display_img = original_img_array.copy()
|
474 |
+
new_sel_pix = sel_pix.copy() if sel_pix else []
|
475 |
+
new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
|
476 |
+
|
477 |
+
print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
|
478 |
+
# Run SAM inference
|
479 |
+
o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
|
480 |
+
|
481 |
+
# Draw points on display image
|
482 |
+
for point, label in new_sel_pix:
|
483 |
+
cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
|
484 |
+
|
485 |
+
# Draw mask overlay on display image
|
486 |
+
if o_masks:
|
487 |
+
mask = o_masks[0][0]
|
488 |
+
overlay = display_img.copy()
|
489 |
+
overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
|
490 |
+
display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
|
491 |
+
|
492 |
+
# Save mask for tracking
|
493 |
+
save_masks(o_masks, video_name, temp_dir)
|
494 |
+
print(f"✅ Mask saved for video: {video_name}")
|
495 |
+
|
496 |
+
return display_img, new_sel_pix
|
497 |
+
|
498 |
+
except Exception as e:
|
499 |
+
print(f"❌ Error in select_point: {e}")
|
500 |
+
return None, []
|
501 |
+
|
502 |
+
def reset_points(original_img: str, sel_pix):
|
503 |
+
"""Reset all points and clear the mask"""
|
504 |
+
if original_img is None:
|
505 |
+
return None, []
|
506 |
+
|
507 |
+
try:
|
508 |
+
# Convert stored image data back to numpy array
|
509 |
+
frame_data = json.loads(original_img)
|
510 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
511 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
512 |
+
|
513 |
+
# Create a display image (just the original image)
|
514 |
+
display_img = original_img_array.copy()
|
515 |
+
|
516 |
+
# Clear all points
|
517 |
+
new_sel_pix = []
|
518 |
+
|
519 |
+
# Clear any existing masks
|
520 |
+
for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
|
521 |
+
try:
|
522 |
+
os.remove(mask_file)
|
523 |
+
except Exception as e:
|
524 |
+
logger.warning(f"Failed to remove mask file {mask_file}: {e}")
|
525 |
+
|
526 |
+
print("🔄 Points and masks reset")
|
527 |
+
return display_img, new_sel_pix
|
528 |
+
|
529 |
+
except Exception as e:
|
530 |
+
print(f"❌ Error in reset_points: {e}")
|
531 |
+
return None, []
|
532 |
+
|
533 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
534 |
+
"""Launch visualization with user-specific temp directory"""
|
535 |
+
if original_image_state is None:
|
536 |
+
return None, None, None
|
537 |
+
|
538 |
+
try:
|
539 |
+
# Get user's temp directory from stored frame data
|
540 |
+
frame_data = json.loads(original_image_state)
|
541 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
542 |
+
video_name = frame_data.get('video_name', 'video')
|
543 |
+
|
544 |
+
print(f"🚀 Starting tracking for video: {video_name}")
|
545 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
546 |
+
|
547 |
+
# Check for mask files
|
548 |
+
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
549 |
+
video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
|
550 |
+
|
551 |
+
if not video_files:
|
552 |
+
print("❌ No video file found")
|
553 |
+
return "❌ Error: No video file found", None, None
|
554 |
+
|
555 |
+
video_path = video_files[0]
|
556 |
+
mask_path = mask_files[0] if mask_files else None
|
557 |
+
|
558 |
+
# Run tracker
|
559 |
+
print("🎯 Running tracker...")
|
560 |
+
out_dir = os.path.join(temp_dir, "results")
|
561 |
+
os.makedirs(out_dir, exist_ok=True)
|
562 |
+
|
563 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
|
564 |
+
|
565 |
+
# Process results
|
566 |
+
npz_path = os.path.join(out_dir, "result.npz")
|
567 |
+
track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
|
568 |
+
|
569 |
+
if os.path.exists(npz_path):
|
570 |
+
print("📊 Processing 3D visualization...")
|
571 |
+
html_path = process_point_cloud_data(npz_path)
|
572 |
+
|
573 |
+
# Schedule deletion of generated files
|
574 |
+
delete_later(html_path, delay=600)
|
575 |
+
if os.path.exists(track2d_video):
|
576 |
+
delete_later(track2d_video, delay=600)
|
577 |
+
delete_later(npz_path, delay=600)
|
578 |
+
|
579 |
+
# Create iframe HTML
|
580 |
+
iframe_html = f"""
|
581 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
582 |
+
background: #f8f9ff; height: 650px; width: 100%;
|
583 |
+
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
|
584 |
+
margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
|
585 |
+
<iframe id="viz_iframe" src="/gradio_api/file={html_path}"
|
586 |
+
width="100%" height="650" frameborder="0"
|
587 |
+
style="border: none; display: block; width: 100%; height: 650px;
|
588 |
+
margin: 0; padding: 0; border-radius: 7px;">
|
589 |
+
</iframe>
|
590 |
+
</div>
|
591 |
+
"""
|
592 |
+
|
593 |
+
print("✅ Tracking completed successfully!")
|
594 |
+
return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
|
595 |
+
else:
|
596 |
+
print("❌ Tracking failed - no results generated")
|
597 |
+
return "❌ Error: Tracking failed to generate results", None, None
|
598 |
+
|
599 |
+
except Exception as e:
|
600 |
+
print(f"❌ Error in launch_viz: {e}")
|
601 |
+
return f"❌ Error: {str(e)}", None, None
|
602 |
+
|
603 |
+
def clear_all():
|
604 |
+
"""Clear all buffers and temporary files"""
|
605 |
+
return (None, None, [],
|
606 |
+
gr.update(value=50),
|
607 |
+
gr.update(value=756),
|
608 |
+
gr.update(value=3))
|
609 |
+
|
610 |
+
def clear_all_with_download():
|
611 |
+
"""Clear all buffers including both download components"""
|
612 |
+
return (None, None, [],
|
613 |
+
gr.update(value=50),
|
614 |
+
gr.update(value=756),
|
615 |
+
gr.update(value=3),
|
616 |
+
None, # tracking_video_download
|
617 |
+
None) # HTML download component
|
618 |
+
|
619 |
+
def get_video_settings(video_name):
|
620 |
+
"""Get video-specific settings based on video name"""
|
621 |
+
video_settings = {
|
622 |
+
"running": (50, 512, 2),
|
623 |
+
"backpack": (40, 600, 2),
|
624 |
+
"kitchen": (60, 800, 3),
|
625 |
+
"pillow": (35, 500, 2),
|
626 |
+
"handwave": (35, 500, 8),
|
627 |
+
"hockey": (45, 700, 2),
|
628 |
+
"drifting": (35, 1000, 6),
|
629 |
+
"basketball": (45, 1500, 5),
|
630 |
+
"ken_block_0": (45, 700, 2),
|
631 |
+
"ego_kc1": (45, 500, 4),
|
632 |
+
"vertical_place": (45, 500, 3),
|
633 |
+
"ego_teaser": (45, 1200, 10),
|
634 |
+
"robot_unitree": (45, 500, 4),
|
635 |
+
"robot_3": (35, 400, 5),
|
636 |
+
"teleop2": (45, 256, 7),
|
637 |
+
"pusht": (45, 256, 10),
|
638 |
+
"cinema_0": (45, 356, 5),
|
639 |
+
"cinema_1": (45, 756, 3),
|
640 |
+
"robot1": (45, 600, 2),
|
641 |
+
"robot2": (45, 600, 2),
|
642 |
+
"protein": (45, 600, 2),
|
643 |
+
"kitchen_egocentric": (45, 600, 2),
|
644 |
+
}
|
645 |
+
|
646 |
+
return video_settings.get(video_name, (50, 756, 3))
|
647 |
+
|
648 |
+
# Create the Gradio interface
|
649 |
+
print("🎨 Creating Gradio interface...")
|
650 |
+
|
651 |
+
with gr.Blocks(
|
652 |
+
theme=gr.themes.Soft(),
|
653 |
+
title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
|
654 |
+
css="""
|
655 |
+
.gradio-container {
|
656 |
+
max-width: 1200px !important;
|
657 |
+
margin: auto !important;
|
658 |
+
}
|
659 |
+
.gr-button {
|
660 |
+
margin: 5px;
|
661 |
+
}
|
662 |
+
.gr-form {
|
663 |
+
background: white;
|
664 |
+
border-radius: 10px;
|
665 |
+
padding: 20px;
|
666 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
667 |
+
}
|
668 |
+
/* 移除 gr.Group 的默认灰色背景 */
|
669 |
+
.gr-form {
|
670 |
+
background: transparent !important;
|
671 |
+
border: none !important;
|
672 |
+
box-shadow: none !important;
|
673 |
+
padding: 0 !important;
|
674 |
+
}
|
675 |
+
/* 固定3D可视化器尺寸 */
|
676 |
+
#viz_container {
|
677 |
+
height: 650px !important;
|
678 |
+
min-height: 650px !important;
|
679 |
+
max-height: 650px !important;
|
680 |
+
width: 100% !important;
|
681 |
+
margin: 0 !important;
|
682 |
+
padding: 0 !important;
|
683 |
+
overflow: hidden !important;
|
684 |
+
}
|
685 |
+
#viz_container > div {
|
686 |
+
height: 650px !important;
|
687 |
+
min-height: 650px !important;
|
688 |
+
max-height: 650px !important;
|
689 |
+
width: 100% !important;
|
690 |
+
margin: 0 !important;
|
691 |
+
padding: 0 !important;
|
692 |
+
box-sizing: border-box !important;
|
693 |
+
}
|
694 |
+
#viz_container iframe {
|
695 |
+
height: 650px !important;
|
696 |
+
min-height: 650px !important;
|
697 |
+
max-height: 650px !important;
|
698 |
+
width: 100% !important;
|
699 |
+
border: none !important;
|
700 |
+
display: block !important;
|
701 |
+
margin: 0 !important;
|
702 |
+
padding: 0 !important;
|
703 |
+
box-sizing: border-box !important;
|
704 |
+
}
|
705 |
+
/* 固定视频上传组件高度 */
|
706 |
+
.gr-video {
|
707 |
+
height: 300px !important;
|
708 |
+
min-height: 300px !important;
|
709 |
+
max-height: 300px !important;
|
710 |
+
}
|
711 |
+
.gr-video video {
|
712 |
+
height: 260px !important;
|
713 |
+
max-height: 260px !important;
|
714 |
+
object-fit: contain !important;
|
715 |
+
background: #f8f9fa;
|
716 |
+
}
|
717 |
+
.gr-video .gr-video-player {
|
718 |
+
height: 260px !important;
|
719 |
+
max-height: 260px !important;
|
720 |
+
}
|
721 |
+
/* 强力移除examples的灰色背景 - 使用更通用的选择器 */
|
722 |
+
.horizontal-examples,
|
723 |
+
.horizontal-examples > *,
|
724 |
+
.horizontal-examples * {
|
725 |
+
background: transparent !important;
|
726 |
+
background-color: transparent !important;
|
727 |
+
border: none !important;
|
728 |
+
}
|
729 |
+
|
730 |
+
/* Examples组件水平滚动样式 */
|
731 |
+
.horizontal-examples [data-testid="examples"] {
|
732 |
+
background: transparent !important;
|
733 |
+
background-color: transparent !important;
|
734 |
+
}
|
735 |
+
|
736 |
+
.horizontal-examples [data-testid="examples"] > div {
|
737 |
+
background: transparent !important;
|
738 |
+
background-color: transparent !important;
|
739 |
+
overflow-x: auto !important;
|
740 |
+
overflow-y: hidden !important;
|
741 |
+
scrollbar-width: thin;
|
742 |
+
scrollbar-color: #667eea transparent;
|
743 |
+
padding: 0 !important;
|
744 |
+
margin-top: 10px;
|
745 |
+
border: none !important;
|
746 |
+
}
|
747 |
+
|
748 |
+
.horizontal-examples [data-testid="examples"] table {
|
749 |
+
display: flex !important;
|
750 |
+
flex-wrap: nowrap !important;
|
751 |
+
min-width: max-content !important;
|
752 |
+
gap: 15px !important;
|
753 |
+
padding: 10px 0;
|
754 |
+
background: transparent !important;
|
755 |
+
border: none !important;
|
756 |
+
}
|
757 |
+
|
758 |
+
.horizontal-examples [data-testid="examples"] tbody {
|
759 |
+
display: flex !important;
|
760 |
+
flex-direction: row !important;
|
761 |
+
flex-wrap: nowrap !important;
|
762 |
+
gap: 15px !important;
|
763 |
+
background: transparent !important;
|
764 |
+
}
|
765 |
+
|
766 |
+
.horizontal-examples [data-testid="examples"] tr {
|
767 |
+
display: flex !important;
|
768 |
+
flex-direction: column !important;
|
769 |
+
min-width: 160px !important;
|
770 |
+
max-width: 160px !important;
|
771 |
+
margin: 0 !important;
|
772 |
+
background: white !important;
|
773 |
+
border-radius: 12px;
|
774 |
+
box-shadow: 0 3px 12px rgba(0,0,0,0.12);
|
775 |
+
transition: all 0.3s ease;
|
776 |
+
cursor: pointer;
|
777 |
+
overflow: hidden;
|
778 |
+
border: none !important;
|
779 |
+
}
|
780 |
+
|
781 |
+
.horizontal-examples [data-testid="examples"] tr:hover {
|
782 |
+
transform: translateY(-4px);
|
783 |
+
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
|
784 |
+
}
|
785 |
+
|
786 |
+
.horizontal-examples [data-testid="examples"] td {
|
787 |
+
text-align: center !important;
|
788 |
+
padding: 0 !important;
|
789 |
+
border: none !important;
|
790 |
+
background: transparent !important;
|
791 |
+
}
|
792 |
+
|
793 |
+
.horizontal-examples [data-testid="examples"] td:first-child {
|
794 |
+
padding: 0 !important;
|
795 |
+
background: transparent !important;
|
796 |
+
}
|
797 |
+
|
798 |
+
.horizontal-examples [data-testid="examples"] video {
|
799 |
+
border-radius: 8px 8px 0 0 !important;
|
800 |
+
width: 100% !important;
|
801 |
+
height: 90px !important;
|
802 |
+
object-fit: cover !important;
|
803 |
+
background: #f8f9fa !important;
|
804 |
+
}
|
805 |
+
|
806 |
+
.horizontal-examples [data-testid="examples"] td:last-child {
|
807 |
+
font-size: 11px !important;
|
808 |
+
font-weight: 600 !important;
|
809 |
+
color: #333 !important;
|
810 |
+
padding: 8px 12px !important;
|
811 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
|
812 |
+
border-radius: 0 0 8px 8px;
|
813 |
+
}
|
814 |
+
|
815 |
+
/* 滚动条样式 */
|
816 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
|
817 |
+
height: 8px;
|
818 |
+
}
|
819 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
|
820 |
+
background: transparent;
|
821 |
+
border-radius: 4px;
|
822 |
+
}
|
823 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
|
824 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
825 |
+
border-radius: 4px;
|
826 |
+
}
|
827 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
|
828 |
+
background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
|
829 |
+
}
|
830 |
+
"""
|
831 |
+
) as demo:
|
832 |
+
|
833 |
+
# Add prominent main title
|
834 |
+
|
835 |
+
gr.Markdown("""
|
836 |
+
# ✨ SpatialTrackerV2
|
837 |
+
|
838 |
+
Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
|
839 |
+
|
840 |
+
**⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
|
841 |
+
|
842 |
+
**🔬 Advanced Usage with SAM:**
|
843 |
+
1. Upload a video file or select from examples below
|
844 |
+
2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
|
845 |
+
3. Adjust tracking parameters for optimal performance
|
846 |
+
4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
|
847 |
+
|
848 |
+
""")
|
849 |
+
|
850 |
+
# Status indicator
|
851 |
+
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
852 |
+
|
853 |
+
# Main content area - video upload left, 3D visualization right
|
854 |
+
with gr.Row():
|
855 |
+
with gr.Column(scale=1):
|
856 |
+
# Video upload section
|
857 |
+
gr.Markdown("### 📂 Select Video")
|
858 |
+
|
859 |
+
# Define video_input here so it can be referenced in examples
|
860 |
+
video_input = gr.Video(
|
861 |
+
label="Upload Video or Select Example",
|
862 |
+
format="mp4",
|
863 |
+
height=250 # Matched height with 3D viz
|
864 |
+
)
|
865 |
+
|
866 |
+
|
867 |
+
# Traditional examples but with horizontal scroll styling
|
868 |
+
gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
|
869 |
+
with gr.Row(elem_classes=["horizontal-examples"]):
|
870 |
+
# Horizontal video examples with slider
|
871 |
+
# gr.HTML("<div style='margin-top: 5px;'></div>")
|
872 |
+
gr.Examples(
|
873 |
+
examples=[
|
874 |
+
["./examples/robot1.mp4"],
|
875 |
+
["./examples/robot2.mp4"],
|
876 |
+
["./examples/protein.mp4"],
|
877 |
+
["./examples/kitchen_egocentric.mp4"],
|
878 |
+
["./examples/hockey.mp4"],
|
879 |
+
["./examples/running.mp4"],
|
880 |
+
["./examples/robot_3.mp4"],
|
881 |
+
["./examples/backpack.mp4"],
|
882 |
+
["./examples/kitchen.mp4"],
|
883 |
+
["./examples/pillow.mp4"],
|
884 |
+
["./examples/handwave.mp4"],
|
885 |
+
["./examples/drifting.mp4"],
|
886 |
+
["./examples/basketball.mp4"],
|
887 |
+
["./examples/ken_block_0.mp4"],
|
888 |
+
["./examples/ego_kc1.mp4"],
|
889 |
+
["./examples/vertical_place.mp4"],
|
890 |
+
["./examples/ego_teaser.mp4"],
|
891 |
+
["./examples/robot_unitree.mp4"],
|
892 |
+
["./examples/teleop2.mp4"],
|
893 |
+
["./examples/pusht.mp4"],
|
894 |
+
["./examples/cinema_0.mp4"],
|
895 |
+
["./examples/cinema_1.mp4"],
|
896 |
+
],
|
897 |
+
inputs=[video_input],
|
898 |
+
outputs=[video_input],
|
899 |
+
fn=None,
|
900 |
+
cache_examples=False,
|
901 |
+
label="",
|
902 |
+
examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
|
903 |
+
)
|
904 |
+
|
905 |
+
with gr.Column(scale=2):
|
906 |
+
# 3D Visualization - wider and taller to match left side
|
907 |
+
with gr.Group():
|
908 |
+
gr.Markdown("### 🌐 3D Trajectory Visualization")
|
909 |
+
viz_html = gr.HTML(
|
910 |
+
label="3D Trajectory Visualization",
|
911 |
+
value="""
|
912 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
913 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
914 |
+
text-align: center; height: 650px; display: flex;
|
915 |
+
flex-direction: column; justify-content: center; align-items: center;
|
916 |
+
box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
|
917 |
+
margin: 0; padding: 20px; box-sizing: border-box;'>
|
918 |
+
<div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
|
919 |
+
<h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
|
920 |
+
3D Trajectory Visualization
|
921 |
+
</h3>
|
922 |
+
<p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
|
923 |
+
Track any pixels in 3D space with camera motion
|
924 |
+
</p>
|
925 |
+
<div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
|
926 |
+
padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
|
927 |
+
<span style='color: #667eea; font-weight: 600; font-size: 16px;'>
|
928 |
+
⚡ Powered by SpatialTracker V2
|
929 |
+
</span>
|
930 |
+
</div>
|
931 |
+
</div>
|
932 |
+
""",
|
933 |
+
elem_id="viz_container"
|
934 |
+
)
|
935 |
+
|
936 |
+
# Start button section - below video area
|
937 |
+
with gr.Row():
|
938 |
+
with gr.Column(scale=3):
|
939 |
+
launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
|
940 |
+
with gr.Column(scale=1):
|
941 |
+
clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
|
942 |
+
|
943 |
+
# Tracking parameters section
|
944 |
+
with gr.Row():
|
945 |
+
gr.Markdown("### ⚙️ Tracking Parameters")
|
946 |
+
with gr.Row():
|
947 |
+
grid_size = gr.Slider(
|
948 |
+
minimum=10, maximum=100, step=10, value=50,
|
949 |
+
label="Grid Size", info="Tracking detail level"
|
950 |
+
)
|
951 |
+
vo_points = gr.Slider(
|
952 |
+
minimum=100, maximum=2000, step=50, value=756,
|
953 |
+
label="VO Points", info="Motion accuracy"
|
954 |
+
)
|
955 |
+
fps = gr.Slider(
|
956 |
+
minimum=1, maximum=20, step=1, value=3,
|
957 |
+
label="FPS", info="Processing speed"
|
958 |
+
)
|
959 |
+
|
960 |
+
# Advanced Point Selection with SAM - Collapsed by default
|
961 |
+
with gr.Row():
|
962 |
+
gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
|
963 |
+
with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
|
964 |
+
gr.HTML("""
|
965 |
+
<div style='margin-bottom: 15px;'>
|
966 |
+
<ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
|
967 |
+
<li>Click on target objects in the image for SAM-guided segmentation</li>
|
968 |
+
<li>Positive points: include these areas | Negative points: exclude these areas</li>
|
969 |
+
<li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
|
970 |
+
</ul>
|
971 |
+
</div>
|
972 |
+
""")
|
973 |
+
|
974 |
+
with gr.Row():
|
975 |
+
with gr.Column():
|
976 |
+
interactive_frame = gr.Image(
|
977 |
+
label="Click to select tracking points with SAM guidance",
|
978 |
+
type="numpy",
|
979 |
+
interactive=True,
|
980 |
+
height=300
|
981 |
+
)
|
982 |
+
|
983 |
+
with gr.Row():
|
984 |
+
point_type = gr.Radio(
|
985 |
+
choices=["positive_point", "negative_point"],
|
986 |
+
value="positive_point",
|
987 |
+
label="Point Type",
|
988 |
+
info="Positive: track these areas | Negative: avoid these areas"
|
989 |
+
)
|
990 |
+
|
991 |
+
with gr.Row():
|
992 |
+
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
|
993 |
+
|
994 |
+
# Downloads section - hidden but still functional for local processing
|
995 |
+
with gr.Row(visible=False):
|
996 |
+
with gr.Column(scale=1):
|
997 |
+
tracking_video_download = gr.File(
|
998 |
+
label="📹 Download 2D Tracking Video",
|
999 |
+
interactive=False,
|
1000 |
+
visible=False
|
1001 |
+
)
|
1002 |
+
with gr.Column(scale=1):
|
1003 |
+
html_download = gr.File(
|
1004 |
+
label="📄 Download 3D Visualization HTML",
|
1005 |
+
interactive=False,
|
1006 |
+
visible=False
|
1007 |
+
)
|
1008 |
+
|
1009 |
+
# GitHub Star Section
|
1010 |
+
gr.HTML("""
|
1011 |
+
<div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
|
1012 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
1013 |
+
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
|
1014 |
+
border: 1px solid rgba(102, 126, 234, 0.15);'>
|
1015 |
+
<div style='text-align: center;'>
|
1016 |
+
<h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
1017 |
+
⭐ Love SpatialTracker? Give us a Star! ⭐
|
1018 |
+
</h3>
|
1019 |
+
<p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
1020 |
+
Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
|
1021 |
+
</p>
|
1022 |
+
<a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
|
1023 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
1024 |
+
background: rgba(102, 126, 234, 0.1); color: #4a5568;
|
1025 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
1026 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
|
1027 |
+
transition: all 0.3s ease;'
|
1028 |
+
onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
|
1029 |
+
onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
|
1030 |
+
<span style='font-size: 16px;'>⭐</span>
|
1031 |
+
Star SpatialTracker V2 on GitHub
|
1032 |
+
</a>
|
1033 |
+
</div>
|
1034 |
+
</div>
|
1035 |
+
""")
|
1036 |
+
|
1037 |
+
# Acknowledgments Section
|
1038 |
+
gr.HTML("""
|
1039 |
+
<div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
|
1040 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
1041 |
+
box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
|
1042 |
+
border: 1px solid rgba(255, 193, 7, 0.2);'>
|
1043 |
+
<div style='text-align: center;'>
|
1044 |
+
<h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
1045 |
+
📚 Acknowledgments
|
1046 |
+
</h3>
|
1047 |
+
<p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
1048 |
+
Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
|
1049 |
+
</p>
|
1050 |
+
<a href="https://github.com/zbw001/TAPIP3D" target="_blank"
|
1051 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
1052 |
+
background: rgba(255, 193, 7, 0.15); color: #5d4037;
|
1053 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
1054 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
|
1055 |
+
transition: all 0.3s ease;'
|
1056 |
+
onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
|
1057 |
+
onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
|
1058 |
+
📚 Visit TAPIP3D Repository
|
1059 |
+
</a>
|
1060 |
+
</div>
|
1061 |
+
</div>
|
1062 |
+
""")
|
1063 |
+
|
1064 |
+
# Footer
|
1065 |
+
gr.HTML("""
|
1066 |
+
<div style='text-align: center; margin: 20px 0 10px 0;'>
|
1067 |
+
<span style='font-size: 12px; color: #888; font-style: italic;'>
|
1068 |
+
Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
|
1069 |
+
</span>
|
1070 |
+
</div>
|
1071 |
+
""")
|
1072 |
+
|
1073 |
+
# Hidden state variables
|
1074 |
+
original_image_state = gr.State(None)
|
1075 |
+
selected_points = gr.State([])
|
1076 |
+
|
1077 |
+
# Event handlers
|
1078 |
+
video_input.change(
|
1079 |
+
fn=handle_video_upload,
|
1080 |
+
inputs=[video_input],
|
1081 |
+
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
1082 |
+
)
|
1083 |
+
|
1084 |
+
interactive_frame.select(
|
1085 |
+
fn=select_point,
|
1086 |
+
inputs=[original_image_state, selected_points, point_type],
|
1087 |
+
outputs=[interactive_frame, selected_points]
|
1088 |
+
)
|
1089 |
+
|
1090 |
+
reset_points_btn.click(
|
1091 |
+
fn=reset_points,
|
1092 |
+
inputs=[original_image_state, selected_points],
|
1093 |
+
outputs=[interactive_frame, selected_points]
|
1094 |
+
)
|
1095 |
+
|
1096 |
+
clear_all_btn.click(
|
1097 |
+
fn=clear_all_with_download,
|
1098 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
1099 |
+
)
|
1100 |
+
|
1101 |
+
launch_btn.click(
|
1102 |
+
fn=launch_viz,
|
1103 |
+
inputs=[grid_size, vo_points, fps, original_image_state],
|
1104 |
+
outputs=[viz_html, tracking_video_download, html_download]
|
1105 |
+
)
|
1106 |
+
|
1107 |
+
# Launch the interface
|
1108 |
+
if __name__ == "__main__":
|
1109 |
+
print("🌟 Launching SpatialTracker V2 Local Version...")
|
1110 |
+
print("🔗 Running in Local Processing Mode")
|
1111 |
+
|
1112 |
+
demo.launch(
|
1113 |
+
server_name="0.0.0.0",
|
1114 |
+
server_port=7860,
|
1115 |
+
share=True,
|
1116 |
+
debug=True,
|
1117 |
+
show_error=True
|
1118 |
+
)
|
app_3rd/README.md
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 🌟 SpatialTrackerV2 Integrated with SAM 🌟
|
2 |
+
SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
|
3 |
+
|
4 |
+
## Installation
|
5 |
+
```
|
6 |
+
|
7 |
+
python -m pip install git+https://github.com/facebookresearch/segment-anything.git
|
8 |
+
cd app_3rd/sam_utils
|
9 |
+
mkdir checkpoints
|
10 |
+
cd checkpoints
|
11 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
12 |
+
```
|
app_3rd/sam_utils/hf_sam_predictor.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from typing import Optional, Tuple, List, Union
|
5 |
+
import warnings
|
6 |
+
import cv2
|
7 |
+
try:
|
8 |
+
from transformers import SamModel, SamProcessor
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
HF_AVAILABLE = True
|
11 |
+
except ImportError:
|
12 |
+
HF_AVAILABLE = False
|
13 |
+
warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
|
14 |
+
|
15 |
+
# Hugging Face model mapping
|
16 |
+
HF_MODELS = {
|
17 |
+
'vit_b': 'facebook/sam-vit-base',
|
18 |
+
'vit_l': 'facebook/sam-vit-large',
|
19 |
+
'vit_h': 'facebook/sam-vit-huge'
|
20 |
+
}
|
21 |
+
|
22 |
+
class HFSamPredictor:
|
23 |
+
"""
|
24 |
+
Hugging Face version of SamPredictor that wraps the transformers SAM models.
|
25 |
+
This class provides the same interface as the original SamPredictor for seamless integration.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
|
29 |
+
"""
|
30 |
+
Initialize the HF SAM predictor.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
model: The SAM model from transformers
|
34 |
+
processor: The SAM processor from transformers
|
35 |
+
device: Device to run the model on ('cuda', 'cpu', etc.)
|
36 |
+
"""
|
37 |
+
self.model = model
|
38 |
+
self.processor = processor
|
39 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
40 |
+
self.model.to(self.device)
|
41 |
+
self.model.eval()
|
42 |
+
|
43 |
+
# Store the current image and its features
|
44 |
+
self.original_size = None
|
45 |
+
self.input_size = None
|
46 |
+
self.features = None
|
47 |
+
self.image = None
|
48 |
+
|
49 |
+
@classmethod
|
50 |
+
def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
|
51 |
+
"""
|
52 |
+
Load a SAM model from Hugging Face Hub.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model_name: Model name from HF_MODELS or direct HF model path
|
56 |
+
device: Device to load the model on
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
HFSamPredictor instance
|
60 |
+
"""
|
61 |
+
if not HF_AVAILABLE:
|
62 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
63 |
+
|
64 |
+
# Map model type to HF model name if needed
|
65 |
+
if model_name in HF_MODELS:
|
66 |
+
model_name = HF_MODELS[model_name]
|
67 |
+
|
68 |
+
print(f"Loading SAM model from Hugging Face: {model_name}")
|
69 |
+
|
70 |
+
# Load model and processor
|
71 |
+
model = SamModel.from_pretrained(model_name)
|
72 |
+
processor = SamProcessor.from_pretrained(model_name)
|
73 |
+
return cls(model, processor, device)
|
74 |
+
|
75 |
+
def preprocess(self, image: np.ndarray,
|
76 |
+
input_points: List[List[float]], input_labels: List[int]) -> None:
|
77 |
+
"""
|
78 |
+
Set the image for prediction. This preprocesses the image and extracts features.
|
79 |
+
|
80 |
+
Args:
|
81 |
+
image: Input image as numpy array (H, W, C) in RGB format
|
82 |
+
"""
|
83 |
+
if image.dtype != np.uint8:
|
84 |
+
image = (image * 255).astype(np.uint8)
|
85 |
+
|
86 |
+
self.image = image
|
87 |
+
self.original_size = image.shape[:2]
|
88 |
+
|
89 |
+
# Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
|
90 |
+
inputs = self.processor(
|
91 |
+
images=image,
|
92 |
+
input_points=input_points,
|
93 |
+
input_labels=input_labels,
|
94 |
+
return_tensors="pt"
|
95 |
+
)
|
96 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
97 |
+
|
98 |
+
self.input_size = inputs['pixel_values'].shape[-2:]
|
99 |
+
self.features = inputs
|
100 |
+
return inputs
|
101 |
+
|
102 |
+
|
103 |
+
def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
|
104 |
+
image: Optional[np.ndarray] = None) -> HFSamPredictor:
|
105 |
+
"""
|
106 |
+
Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
|
107 |
+
|
108 |
+
Args:
|
109 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
110 |
+
device: Device to run the model on
|
111 |
+
image: Optional image to set immediately
|
112 |
+
|
113 |
+
Returns:
|
114 |
+
HFSamPredictor instance
|
115 |
+
"""
|
116 |
+
if not HF_AVAILABLE:
|
117 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
118 |
+
|
119 |
+
if device is None:
|
120 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
121 |
+
|
122 |
+
# Load the predictor
|
123 |
+
predictor = HFSamPredictor.from_pretrained(model_type, device)
|
124 |
+
|
125 |
+
# Set image if provided
|
126 |
+
if image is not None:
|
127 |
+
predictor.set_image(image)
|
128 |
+
|
129 |
+
return predictor
|
app_3rd/sam_utils/inference.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from segment_anything import SamPredictor, sam_model_registry
|
6 |
+
|
7 |
+
# Try to import HF SAM support
|
8 |
+
try:
|
9 |
+
from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
|
10 |
+
HF_AVAILABLE = True
|
11 |
+
except ImportError:
|
12 |
+
HF_AVAILABLE = False
|
13 |
+
|
14 |
+
models = {
|
15 |
+
'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
|
16 |
+
'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
|
17 |
+
'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
|
18 |
+
}
|
19 |
+
|
20 |
+
|
21 |
+
def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
|
22 |
+
"""
|
23 |
+
Get SAM predictor with option to use HuggingFace version
|
24 |
+
|
25 |
+
Args:
|
26 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
27 |
+
device: Device to run on
|
28 |
+
image: Optional image to set immediately
|
29 |
+
use_hf: Whether to use HuggingFace SAM instead of original SAM
|
30 |
+
"""
|
31 |
+
if predictor is not None:
|
32 |
+
return predictor
|
33 |
+
if use_hf:
|
34 |
+
if not HF_AVAILABLE:
|
35 |
+
raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
|
36 |
+
return get_hf_sam_predictor(model_type, device, image)
|
37 |
+
|
38 |
+
# Original SAM logic
|
39 |
+
if device is None and torch.cuda.is_available():
|
40 |
+
device = 'cuda'
|
41 |
+
elif device is None:
|
42 |
+
device = 'cpu'
|
43 |
+
# sam model
|
44 |
+
sam = sam_model_registry[model_type](checkpoint=models[model_type])
|
45 |
+
sam = sam.to(device)
|
46 |
+
|
47 |
+
predictor = SamPredictor(sam)
|
48 |
+
if image is not None:
|
49 |
+
predictor.set_image(image)
|
50 |
+
return predictor
|
51 |
+
|
52 |
+
|
53 |
+
def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
|
54 |
+
"""
|
55 |
+
Run inference with either original SAM or HF SAM predictor
|
56 |
+
|
57 |
+
Args:
|
58 |
+
predictor: SamPredictor or HFSamPredictor instance
|
59 |
+
input_x: Input image
|
60 |
+
selected_points: List of (point, label) tuples
|
61 |
+
multi_object: Whether to handle multiple objects
|
62 |
+
"""
|
63 |
+
if len(selected_points) == 0:
|
64 |
+
return []
|
65 |
+
|
66 |
+
# Check if using HF SAM
|
67 |
+
if isinstance(predictor, HFSamPredictor):
|
68 |
+
return _run_hf_inference(predictor, input_x, selected_points, multi_object)
|
69 |
+
else:
|
70 |
+
return _run_original_inference(predictor, input_x, selected_points, multi_object)
|
71 |
+
|
72 |
+
|
73 |
+
def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
|
74 |
+
"""Run inference with original SAM"""
|
75 |
+
points = torch.Tensor(
|
76 |
+
[p for p, _ in selected_points]
|
77 |
+
).to(predictor.device).unsqueeze(1)
|
78 |
+
|
79 |
+
labels = torch.Tensor(
|
80 |
+
[int(l) for _, l in selected_points]
|
81 |
+
).to(predictor.device).unsqueeze(1)
|
82 |
+
|
83 |
+
transformed_points = predictor.transform.apply_coords_torch(
|
84 |
+
points, input_x.shape[:2])
|
85 |
+
|
86 |
+
masks, scores, logits = predictor.predict_torch(
|
87 |
+
point_coords=transformed_points[:,0][None],
|
88 |
+
point_labels=labels[:,0][None],
|
89 |
+
multimask_output=False,
|
90 |
+
)
|
91 |
+
masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
|
92 |
+
|
93 |
+
gc.collect()
|
94 |
+
torch.cuda.empty_cache()
|
95 |
+
|
96 |
+
return [(masks, 'final_mask')]
|
97 |
+
|
98 |
+
|
99 |
+
def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
|
100 |
+
"""Run inference with HF SAM"""
|
101 |
+
# Prepare points and labels for HF SAM
|
102 |
+
select_pts = [[list(p) for p, _ in selected_points]]
|
103 |
+
select_lbls = [[int(l) for _, l in selected_points]]
|
104 |
+
|
105 |
+
# Preprocess inputs
|
106 |
+
inputs = predictor.preprocess(input_x, select_pts, select_lbls)
|
107 |
+
|
108 |
+
# Run inference
|
109 |
+
with torch.no_grad():
|
110 |
+
outputs = predictor.model(**inputs)
|
111 |
+
|
112 |
+
# Post-process masks
|
113 |
+
masks = predictor.processor.image_processor.post_process_masks(
|
114 |
+
outputs.pred_masks.cpu(),
|
115 |
+
inputs["original_sizes"].cpu(),
|
116 |
+
inputs["reshaped_input_sizes"].cpu(),
|
117 |
+
)
|
118 |
+
masks = masks[0][:,:1,...].cpu().numpy()
|
119 |
+
|
120 |
+
gc.collect()
|
121 |
+
torch.cuda.empty_cache()
|
122 |
+
|
123 |
+
return [(masks, 'final_mask')]
|
app_3rd/spatrack_utils/infer_track.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
2 |
+
import yaml
|
3 |
+
import easydict
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
import torchvision.transforms as T
|
9 |
+
from PIL import Image
|
10 |
+
import io
|
11 |
+
import moviepy.editor as mp
|
12 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
13 |
+
import tqdm
|
14 |
+
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
15 |
+
import glob
|
16 |
+
from rich import print
|
17 |
+
import argparse
|
18 |
+
import decord
|
19 |
+
from huggingface_hub import hf_hub_download
|
20 |
+
|
21 |
+
config = {
|
22 |
+
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
|
23 |
+
"cfg_dir": "config/magic_infer_moge.yaml",
|
24 |
+
}
|
25 |
+
|
26 |
+
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
|
27 |
+
"""
|
28 |
+
Initialize and return the tracker predictor and visualizer
|
29 |
+
Args:
|
30 |
+
output_dir: Directory to save visualization results
|
31 |
+
vo_points: Number of points for visual odometry
|
32 |
+
Returns:
|
33 |
+
Tuple of (tracker_predictor, visualizer)
|
34 |
+
"""
|
35 |
+
viz = True
|
36 |
+
os.makedirs(output_dir, exist_ok=True)
|
37 |
+
|
38 |
+
with open(config["cfg_dir"], "r") as f:
|
39 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
40 |
+
cfg = easydict.EasyDict(cfg)
|
41 |
+
cfg.out_dir = output_dir
|
42 |
+
cfg.model.track_num = vo_points
|
43 |
+
|
44 |
+
# Check if it's a local path or HuggingFace repo
|
45 |
+
if tracker_model is not None:
|
46 |
+
model = tracker_model
|
47 |
+
model.spatrack.track_num = vo_points
|
48 |
+
else:
|
49 |
+
if os.path.exists(config["ckpt_dir"]):
|
50 |
+
# Local file
|
51 |
+
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
|
52 |
+
else:
|
53 |
+
# HuggingFace repo - download the model
|
54 |
+
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
|
55 |
+
checkpoint_path = hf_hub_download(
|
56 |
+
repo_id=config["ckpt_dir"],
|
57 |
+
repo_type="model",
|
58 |
+
filename="SpaTrack3_offline.pth"
|
59 |
+
)
|
60 |
+
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
61 |
+
model.eval()
|
62 |
+
model.to("cuda")
|
63 |
+
|
64 |
+
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
65 |
+
fps=10, pad_value=0, tracks_leave_trace=5)
|
66 |
+
|
67 |
+
return model, viser
|
68 |
+
|
69 |
+
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
|
70 |
+
"""
|
71 |
+
Run tracking on a video sequence
|
72 |
+
Args:
|
73 |
+
model: Tracker predictor instance
|
74 |
+
viser: Visualizer instance
|
75 |
+
temp_dir: Directory containing temporary files
|
76 |
+
video_name: Name of the video file (without extension)
|
77 |
+
grid_size: Size of the tracking grid
|
78 |
+
vo_points: Number of points for visual odometry
|
79 |
+
fps: Frames per second for visualization
|
80 |
+
"""
|
81 |
+
# Setup paths
|
82 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
83 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
84 |
+
out_dir = os.path.join(temp_dir, "results")
|
85 |
+
os.makedirs(out_dir, exist_ok=True)
|
86 |
+
|
87 |
+
# Load video using decord
|
88 |
+
video_reader = decord.VideoReader(video_path)
|
89 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
90 |
+
|
91 |
+
# resize make sure the shortest side is 336
|
92 |
+
h, w = video_tensor.shape[2:]
|
93 |
+
scale = max(336 / h, 336 / w)
|
94 |
+
if scale < 1:
|
95 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
96 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
97 |
+
video_tensor = video_tensor[::fps].float()
|
98 |
+
depth_tensor = None
|
99 |
+
intrs = None
|
100 |
+
extrs = None
|
101 |
+
data_npz_load = {}
|
102 |
+
|
103 |
+
# Load and process mask
|
104 |
+
if os.path.exists(mask_path):
|
105 |
+
mask = cv2.imread(mask_path)
|
106 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
107 |
+
mask = mask.sum(axis=-1)>0
|
108 |
+
else:
|
109 |
+
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
110 |
+
|
111 |
+
# Get frame dimensions and create grid points
|
112 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
113 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
114 |
+
|
115 |
+
# Sample mask values at grid points and filter out points where mask=0
|
116 |
+
if os.path.exists(mask_path):
|
117 |
+
grid_pts_int = grid_pts[0].long()
|
118 |
+
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
119 |
+
grid_pts = grid_pts[:, mask_values]
|
120 |
+
|
121 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
122 |
+
|
123 |
+
# run vggt
|
124 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
125 |
+
vggt_model = VGGT()
|
126 |
+
vggt_model.load_state_dict(torch.load(VGGT_DIR))
|
127 |
+
vggt_model.eval()
|
128 |
+
vggt_model = vggt_model.to("cuda")
|
129 |
+
# process the image tensor
|
130 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
131 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
132 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
133 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
134 |
+
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
135 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
136 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
|
137 |
+
# Predict Depth Maps
|
138 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
|
139 |
+
# clear the cache
|
140 |
+
del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
|
141 |
+
torch.cuda.empty_cache()
|
142 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
143 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
144 |
+
extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
|
145 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
146 |
+
video_tensor = video_tensor.squeeze()
|
147 |
+
#NOTE: 20% of the depth is not reliable
|
148 |
+
# threshold = depth_conf.squeeze().view(-1).quantile(0.5)
|
149 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
150 |
+
|
151 |
+
# Run model inference
|
152 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
153 |
+
(
|
154 |
+
c2w_traj, intrs, point_map, conf_depth,
|
155 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
156 |
+
) = model.forward(video_tensor, depth=depth_tensor,
|
157 |
+
intrs=intrs, extrs=extrs,
|
158 |
+
queries=query_xyt,
|
159 |
+
fps=1, full_point=False, iters_track=4,
|
160 |
+
query_no_BA=True, fixed_cam=False, stage=1,
|
161 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
162 |
+
|
163 |
+
# Resize results to avoid too large I/O Burden
|
164 |
+
max_size = 336
|
165 |
+
h, w = video.shape[2:]
|
166 |
+
scale = min(max_size / h, max_size / w)
|
167 |
+
if scale < 1:
|
168 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
169 |
+
video = T.Resize((new_h, new_w))(video)
|
170 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
171 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
172 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
173 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
174 |
+
if depth_tensor is not None:
|
175 |
+
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
176 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
177 |
+
|
178 |
+
# Visualize tracks
|
179 |
+
viser.visualize(video=video[None],
|
180 |
+
tracks=track2d_pred[None][...,:2],
|
181 |
+
visibility=vis_pred[None],filename="test")
|
182 |
+
|
183 |
+
# Save in tapip3d format
|
184 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
185 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
186 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
187 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
188 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
189 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
190 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
191 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
192 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
193 |
+
|
194 |
+
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
config/__init__.py
ADDED
File without changes
|
config/magic_infer_moge.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
seed: 0
|
2 |
+
# config the hydra logger, only in hydra `$` can be decoded as cite
|
3 |
+
data: ./assets/room
|
4 |
+
vis_track: false
|
5 |
+
hydra:
|
6 |
+
run:
|
7 |
+
dir: .
|
8 |
+
output_subdir: null
|
9 |
+
job_logging: {}
|
10 |
+
hydra_logging: {}
|
11 |
+
mixed_precision: bf16
|
12 |
+
visdom:
|
13 |
+
viz_ip: "localhost"
|
14 |
+
port: 6666
|
15 |
+
relax_load: false
|
16 |
+
res_all: 336
|
17 |
+
# config the ckpt path
|
18 |
+
# ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
|
19 |
+
ckpts: "Yuxihenry/SpatialTracker_Files"
|
20 |
+
batch_size: 1
|
21 |
+
input:
|
22 |
+
type: image
|
23 |
+
fps: 1
|
24 |
+
model_wind_size: 32
|
25 |
+
model:
|
26 |
+
backbone_cfg:
|
27 |
+
ckpt_dir: "checkpoints/model.pt"
|
28 |
+
chunk_size: 24 # downsample factor for patchified features
|
29 |
+
ckpt_fwd: true
|
30 |
+
ft_cfg:
|
31 |
+
mode: "fix"
|
32 |
+
paras_name: []
|
33 |
+
resolution: 336
|
34 |
+
max_len: 512
|
35 |
+
Track_cfg:
|
36 |
+
base_ckpt: "checkpoints/scaled_offline.pth"
|
37 |
+
base:
|
38 |
+
stride: 4
|
39 |
+
corr_radius: 3
|
40 |
+
window_len: 60
|
41 |
+
stablizer: True
|
42 |
+
mode: "online"
|
43 |
+
s_wind: 200
|
44 |
+
overlap: 4
|
45 |
+
track_num: 0
|
46 |
+
|
47 |
+
dist_train:
|
48 |
+
num_nodes: 1
|
examples/backpack.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
|
3 |
+
size 1208738
|
examples/ball.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:31f6e3bf875a85284b376c05170b4c08b546b7d5e95106848b1e3818a9d0db91
|
3 |
+
size 3030268
|
examples/basketball.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0df3b429d5fd64c298f2d79b2d818a4044e7341a71d70b957f60b24e313c3760
|
3 |
+
size 2487837
|
examples/biker.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fba880c24bdb8fa3b84b1b491d52f2c1f426fb09e34c3013603e5a549cf3b22b
|
3 |
+
size 249196
|
examples/cinema_0.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a68a5643c14f61c05d48e25a98ddf5cf0344d3ffcda08ad4a0adc989d49d7a9c
|
3 |
+
size 1774022
|
examples/cinema_1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99624e2d0fb2e9f994e46aefb904e884de37a6d78e7f6b6670e286eaa397e515
|
3 |
+
size 2370749
|
examples/drifting.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f3937871117d3cc5d7da3ef31d1edf5626fc8372126b73590f75f05713fe97c
|
3 |
+
size 4695804
|
examples/ego_kc1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22fe64e458e329e8b3c3e20b3725ffd85c3a2e725fd03909cf883d3fd02c80b3
|
3 |
+
size 1365980
|
examples/ego_teaser.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8780b291b48046b1c7dea90712c1c3f59d60c03216df1c489f6f03e8d61fae5c
|
3 |
+
size 7365665
|
examples/handwave.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e6dde7cf4ffa7c66b6861bb5abdedc49dfc4b5b4dd9dd46ee8415dd4953935b6
|
3 |
+
size 2099369
|
examples/hockey.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8c3be095777b442dc401e7d1f489b749611ffade3563a01e4e3d1e511311bd86
|
3 |
+
size 1795810
|
examples/ken_block_0.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7b788faeb4d3206fa604d622a05268f1321ad6a229178fe12319d20c9438deb1
|
3 |
+
size 196343
|
examples/kiss.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f78fffc5108d95d4e5837d7607226f3dd9796615ea3481f2629c69ccd2ccb12f
|
3 |
+
size 1073570
|
examples/kitchen.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3120e942a9b3d7b300928e43113b000fb5ccc209012a2c560ec26b8a04c2d5f9
|
3 |
+
size 543970
|
examples/kitchen_egocentric.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5468ab10d8d39b68b51fa616adc3d099dab7543e38dd221a0a7a20a2401824a2
|
3 |
+
size 2176685
|
examples/pillow.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
|
3 |
+
size 1407147
|
examples/protein.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b2dc9cfceb0984b61ebc62fda4c826135ebe916c8c966a8123dcc3315d43b73f
|
3 |
+
size 2002300
|
examples/pusht.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:996d1923e36811a1069e4d6b5e8c0338d9068c0870ea09c4c04e13e9fbcd207a
|
3 |
+
size 5256495
|
examples/robot1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7a3b9e4449572129fdd96a751938e211241cdd86bcc56ffd33bfd23fc4d6e9c0
|
3 |
+
size 1178671
|
examples/robot2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:188b2d8824ce345c86a603bff210639a6158d72cf6119cc1d3f79d409ac68bb3
|
3 |
+
size 867261
|
examples/robot_3.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:784a0f9c36a316d0da5745075dbc8cefd9ce60c25b067d3d80a1d52830df8a37
|
3 |
+
size 1153015
|
examples/robot_unitree.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99bc274f7613a665c6135085fe01691ebfaa9033101319071f37c550ab21d1ea
|
3 |
+
size 1964268
|
examples/running.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9ceb96b287fefb1c090dcd2f5db7634f808d2079413500beeb7b33023dfae51b
|
3 |
+
size 7307897
|
examples/teleop2.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:59ea006a18227da8cf5db1fa50cd48e71ec7eb66fef48ea2158c325088bd9fee
|
3 |
+
size 1077267
|
examples/vertical_place.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c8061ae449f986113c2ecb17aefc2c13f737aecbcd41d6c057c88e6d41ac3ee
|
3 |
+
size 719810
|
models/SpaTrackV2/models/SpaTrack.py
ADDED
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#python
|
2 |
+
"""
|
3 |
+
SpaTrackerV2, which is an unified model to estimate 'intrinsic',
|
4 |
+
'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
|
5 |
+
|
6 |
+
Contact: DM [email protected]
|
7 |
+
"""
|
8 |
+
|
9 |
+
import os
|
10 |
+
import numpy as np
|
11 |
+
from typing import Literal, Union, List, Tuple, Dict
|
12 |
+
import cv2
|
13 |
+
import torch
|
14 |
+
import torch.nn as nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
# from depth anything v2
|
17 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
18 |
+
from einops import rearrange
|
19 |
+
from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
|
20 |
+
from models.moge.model.v1 import MoGeModel
|
21 |
+
import copy
|
22 |
+
from functools import partial
|
23 |
+
from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
|
24 |
+
import kornia
|
25 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
26 |
+
import utils3d
|
27 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
28 |
+
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
|
29 |
+
import random
|
30 |
+
|
31 |
+
class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
32 |
+
def __init__(
|
33 |
+
self,
|
34 |
+
loggers: list, # include [ viz, logger_tf, logger]
|
35 |
+
backbone_cfg,
|
36 |
+
Track_cfg=None,
|
37 |
+
chunk_size=24,
|
38 |
+
ckpt_fwd: bool = False,
|
39 |
+
ft_cfg=None,
|
40 |
+
resolution=518,
|
41 |
+
max_len=600, # the maximum video length we can preprocess,
|
42 |
+
track_num=768,
|
43 |
+
):
|
44 |
+
|
45 |
+
self.chunk_size = chunk_size
|
46 |
+
self.max_len = max_len
|
47 |
+
self.resolution = resolution
|
48 |
+
# config the T-Lora Dinov2
|
49 |
+
#NOTE: initial the base model
|
50 |
+
base_cfg = copy.deepcopy(backbone_cfg)
|
51 |
+
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
52 |
+
|
53 |
+
super(SpaTrack2, self).__init__()
|
54 |
+
if os.path.exists(backbone_ckpt_dir)==False:
|
55 |
+
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
56 |
+
else:
|
57 |
+
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
58 |
+
base_model = MoGeModel(**checkpoint["model_config"])
|
59 |
+
base_model.load_state_dict(checkpoint['model'])
|
60 |
+
# avoid the base_model is a member of SpaTrack2
|
61 |
+
object.__setattr__(self, 'base_model', base_model)
|
62 |
+
|
63 |
+
# Tracker model
|
64 |
+
self.Track3D = TrackRefiner3D(Track_cfg)
|
65 |
+
track_base_ckpt_dir = Track_cfg.base_ckpt
|
66 |
+
if os.path.exists(track_base_ckpt_dir):
|
67 |
+
track_pretrain = torch.load(track_base_ckpt_dir)
|
68 |
+
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
69 |
+
|
70 |
+
# wrap the function of make lora trainable
|
71 |
+
self.make_paras_trainable = partial(self.make_paras_trainable,
|
72 |
+
mode=ft_cfg.mode,
|
73 |
+
paras_name=ft_cfg.paras_name)
|
74 |
+
self.track_num = track_num
|
75 |
+
|
76 |
+
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
77 |
+
# gradient required for the lora_experts and gate
|
78 |
+
for name, param in self.named_parameters():
|
79 |
+
if any(x in name for x in paras_name):
|
80 |
+
if mode == 'fix':
|
81 |
+
param.requires_grad = False
|
82 |
+
else:
|
83 |
+
param.requires_grad = True
|
84 |
+
else:
|
85 |
+
if mode == 'fix':
|
86 |
+
param.requires_grad = True
|
87 |
+
else:
|
88 |
+
param.requires_grad = False
|
89 |
+
total_params = sum(p.numel() for p in self.parameters())
|
90 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
91 |
+
print(f"Total parameters: {total_params}")
|
92 |
+
print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
|
93 |
+
|
94 |
+
def ProcVid(self,
|
95 |
+
x: torch.Tensor):
|
96 |
+
"""
|
97 |
+
split the video into several overlapped windows.
|
98 |
+
|
99 |
+
args:
|
100 |
+
x: the input video frames. [B, T, C, H, W]
|
101 |
+
outputs:
|
102 |
+
patch_size: the patch size of the video features
|
103 |
+
raises:
|
104 |
+
ValueError: if the input video is longer than `max_len`.
|
105 |
+
|
106 |
+
"""
|
107 |
+
# normalize the input images
|
108 |
+
num_types = x.dtype
|
109 |
+
x = normalize_rgb(x, input_size=self.resolution)
|
110 |
+
x = x.to(num_types)
|
111 |
+
# get the video features
|
112 |
+
B, T, C, H, W = x.size()
|
113 |
+
if T > self.max_len:
|
114 |
+
raise ValueError(f"the video length should no more than {self.max_len}.")
|
115 |
+
# get the video features
|
116 |
+
patch_h, patch_w = H // 14, W // 14
|
117 |
+
patch_size = (patch_h, patch_w)
|
118 |
+
# resize and get the video features
|
119 |
+
x = x.view(B * T, C, H, W)
|
120 |
+
# operate the temporal encoding
|
121 |
+
return patch_size, x
|
122 |
+
|
123 |
+
def forward_stream(
|
124 |
+
self,
|
125 |
+
video: torch.Tensor,
|
126 |
+
queries: torch.Tensor = None,
|
127 |
+
T_org: int = None,
|
128 |
+
depth: torch.Tensor|np.ndarray|str=None,
|
129 |
+
unc_metric_in: torch.Tensor|np.ndarray|str=None,
|
130 |
+
intrs: torch.Tensor|np.ndarray|str=None,
|
131 |
+
extrs: torch.Tensor|np.ndarray|str=None,
|
132 |
+
queries_3d: torch.Tensor = None,
|
133 |
+
window_len: int = 16,
|
134 |
+
overlap_len: int = 4,
|
135 |
+
full_point: bool = False,
|
136 |
+
track2d_gt: torch.Tensor = None,
|
137 |
+
fixed_cam: bool = False,
|
138 |
+
query_no_BA: bool = False,
|
139 |
+
stage: int = 0,
|
140 |
+
support_frame: int = 0,
|
141 |
+
replace_ratio: float = 0.6,
|
142 |
+
annots_train: Dict = None,
|
143 |
+
iters_track=4,
|
144 |
+
**kwargs,
|
145 |
+
):
|
146 |
+
# step 1 allocate the query points on the grid
|
147 |
+
T, C, H, W = video.shape
|
148 |
+
|
149 |
+
if annots_train is not None:
|
150 |
+
vis_gt = annots_train["vis"]
|
151 |
+
_, _, N = vis_gt.shape
|
152 |
+
number_visible = vis_gt.sum(dim=1)
|
153 |
+
ratio_rand = torch.rand(1, N, device=vis_gt.device)
|
154 |
+
first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
|
155 |
+
assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
|
156 |
+
|
157 |
+
first_positive_inds = first_positive_inds.long()
|
158 |
+
gather = torch.gather(
|
159 |
+
annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
160 |
+
)
|
161 |
+
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
162 |
+
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
|
163 |
+
|
164 |
+
|
165 |
+
# Unfold video into segments of window_len with overlap_len
|
166 |
+
step_slide = window_len - overlap_len
|
167 |
+
if T < window_len:
|
168 |
+
video_unf = video.unsqueeze(0)
|
169 |
+
if depth is not None:
|
170 |
+
depth_unf = depth.unsqueeze(0)
|
171 |
+
else:
|
172 |
+
depth_unf = None
|
173 |
+
if unc_metric_in is not None:
|
174 |
+
unc_metric_unf = unc_metric_in.unsqueeze(0)
|
175 |
+
else:
|
176 |
+
unc_metric_unf = None
|
177 |
+
if intrs is not None:
|
178 |
+
intrs_unf = intrs.unsqueeze(0)
|
179 |
+
else:
|
180 |
+
intrs_unf = None
|
181 |
+
if extrs is not None:
|
182 |
+
extrs_unf = extrs.unsqueeze(0)
|
183 |
+
else:
|
184 |
+
extrs_unf = None
|
185 |
+
else:
|
186 |
+
video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
|
187 |
+
if depth is not None:
|
188 |
+
depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
189 |
+
intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
190 |
+
else:
|
191 |
+
depth_unf = None
|
192 |
+
intrs_unf = None
|
193 |
+
if extrs is not None:
|
194 |
+
extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
195 |
+
else:
|
196 |
+
extrs_unf = None
|
197 |
+
if unc_metric_in is not None:
|
198 |
+
unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
199 |
+
else:
|
200 |
+
unc_metric_unf = None
|
201 |
+
|
202 |
+
# parallel
|
203 |
+
# Get number of segments
|
204 |
+
B = video_unf.shape[0]
|
205 |
+
#TODO: Process each segment in parallel using torch.nn.DataParallel
|
206 |
+
c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
|
207 |
+
intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
|
208 |
+
point_map = torch.zeros(T, 3, H, W).cuda()
|
209 |
+
unc_metric = torch.zeros(T, H, W).cuda()
|
210 |
+
# set the queries
|
211 |
+
N, _ = queries.shape
|
212 |
+
track3d_pred = torch.zeros(T, N, 6).cuda()
|
213 |
+
track2d_pred = torch.zeros(T, N, 3).cuda()
|
214 |
+
vis_pred = torch.zeros(T, N, 1).cuda()
|
215 |
+
conf_pred = torch.zeros(T, N, 1).cuda()
|
216 |
+
dyn_preds = torch.zeros(T, N, 1).cuda()
|
217 |
+
# sort the queries by time
|
218 |
+
sorted_indices = np.argsort(queries[...,0])
|
219 |
+
sorted_inv_indices = np.argsort(sorted_indices)
|
220 |
+
sort_query = queries[sorted_indices]
|
221 |
+
sort_query = torch.from_numpy(sort_query).cuda()
|
222 |
+
if queries_3d is not None:
|
223 |
+
sort_query_3d = queries_3d[sorted_indices]
|
224 |
+
sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
|
225 |
+
|
226 |
+
queries_len = 0
|
227 |
+
overlap_d = None
|
228 |
+
cache = None
|
229 |
+
loss = 0.0
|
230 |
+
|
231 |
+
for i in range(B):
|
232 |
+
segment = video_unf[i:i+1].cuda()
|
233 |
+
# Forward pass through model
|
234 |
+
# detect the key points for each frames
|
235 |
+
|
236 |
+
queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
|
237 |
+
if queries_3d is not None:
|
238 |
+
queries_new_3d = sort_query_3d[queries_new_mask]
|
239 |
+
queries_new_3d = queries_new_3d.float()
|
240 |
+
else:
|
241 |
+
queries_new_3d = None
|
242 |
+
queries_new = sort_query[queries_new_mask.bool()]
|
243 |
+
queries_new = queries_new.float()
|
244 |
+
if i > 0:
|
245 |
+
overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
246 |
+
overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
247 |
+
overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
248 |
+
overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
|
249 |
+
overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
|
250 |
+
overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
|
251 |
+
overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
|
252 |
+
queries_new[...,0] -= i*step_slide
|
253 |
+
queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
|
254 |
+
|
255 |
+
if annots_train is None:
|
256 |
+
annots = {}
|
257 |
+
else:
|
258 |
+
annots = copy.deepcopy(annots_train)
|
259 |
+
annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
|
260 |
+
annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
|
261 |
+
annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
|
262 |
+
annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
|
263 |
+
annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
|
264 |
+
annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
|
265 |
+
|
266 |
+
if depth is not None:
|
267 |
+
annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
|
268 |
+
if unc_metric_in is not None:
|
269 |
+
annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
|
270 |
+
if intrs is not None:
|
271 |
+
intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
|
272 |
+
focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
|
273 |
+
pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
|
274 |
+
pose_fake[:, -1] = focal
|
275 |
+
pose_fake[:,3]=1
|
276 |
+
annots["intrs_gt"] = intr_seg
|
277 |
+
if extrs is not None:
|
278 |
+
extrs_unf_norm = extrs_unf[i:i+1][0].clone()
|
279 |
+
extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
|
280 |
+
rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
|
281 |
+
annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
|
282 |
+
annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
|
283 |
+
annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
|
284 |
+
annots["use_extr"] = True
|
285 |
+
|
286 |
+
kwargs.update({"stage": stage})
|
287 |
+
|
288 |
+
#TODO: DEBUG
|
289 |
+
out = self.forward(segment, pts_q=queries_new,
|
290 |
+
pts_q_3d=queries_new_3d, overlap_d=overlap_d,
|
291 |
+
full_point=full_point,
|
292 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA,
|
293 |
+
support_frame=segment.shape[1]-1,
|
294 |
+
cache=cache, replace_ratio=replace_ratio,
|
295 |
+
iters_track=iters_track,
|
296 |
+
**kwargs, annots=annots)
|
297 |
+
if self.training:
|
298 |
+
loss += out["loss"].squeeze()
|
299 |
+
# from models.SpaTrackV2.utils.visualizer import Visualizer
|
300 |
+
# vis_track = Visualizer(grayscale=False,
|
301 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
302 |
+
# vis_track.visualize(video=segment,
|
303 |
+
# tracks=out["traj_est"][...,:2],
|
304 |
+
# visibility=out["vis_est"],
|
305 |
+
# save_video=True)
|
306 |
+
# # visualize 4d
|
307 |
+
# import os, json
|
308 |
+
# import os.path as osp
|
309 |
+
# viser4d_dir = os.path.join("viser_4d_results")
|
310 |
+
# os.makedirs(viser4d_dir, exist_ok=True)
|
311 |
+
# depth_est = annots["depth_gt"][0]
|
312 |
+
# unc_metric = out["unc_metric"]
|
313 |
+
# mask = (unc_metric > 0.5).squeeze(1)
|
314 |
+
# # pose_est = out["poses_pred"].squeeze(0)
|
315 |
+
# pose_est = annots["traj_mat"][0]
|
316 |
+
# rgb_tracks = out["rgb_tracks"].squeeze(0)
|
317 |
+
# intrinsics = out["intrs"].squeeze(0)
|
318 |
+
# for i_k in range(out["depth"].shape[0]):
|
319 |
+
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
|
320 |
+
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
|
321 |
+
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
|
322 |
+
# if stage == 1:
|
323 |
+
# depth = depth_est[i_k].squeeze().cpu().numpy()
|
324 |
+
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
|
325 |
+
# else:
|
326 |
+
# point_map_vis = out["points_map"][i_k].cpu().numpy()
|
327 |
+
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
|
328 |
+
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
|
329 |
+
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
|
330 |
+
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
|
331 |
+
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
|
332 |
+
|
333 |
+
queries_len = len(queries_new)
|
334 |
+
# update the track3d and track2d
|
335 |
+
left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
|
336 |
+
track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
|
337 |
+
track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
|
338 |
+
vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
|
339 |
+
conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
|
340 |
+
dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
|
341 |
+
|
342 |
+
# process the output for each segment
|
343 |
+
seg_c2w = out["poses_pred"][0]
|
344 |
+
seg_intrs = out["intrs"][0]
|
345 |
+
seg_point_map = out["points_map"]
|
346 |
+
seg_conf_depth = out["unc_metric"]
|
347 |
+
|
348 |
+
# cache management
|
349 |
+
cache = out["cache"]
|
350 |
+
for k in cache.keys():
|
351 |
+
if "_pyramid" in k:
|
352 |
+
for j in range(len(cache[k])):
|
353 |
+
if len(cache[k][j].shape) == 5:
|
354 |
+
cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
|
355 |
+
elif len(cache[k][j].shape) == 4:
|
356 |
+
cache[k][j] = cache[k][j][:,:1,:queries_len,:]
|
357 |
+
elif "_pred_cache" in k:
|
358 |
+
cache[k] = cache[k][-overlap_len:,:queries_len,:]
|
359 |
+
else:
|
360 |
+
cache[k] = cache[k][-overlap_len:]
|
361 |
+
|
362 |
+
# update the results
|
363 |
+
idx_glob = i * step_slide
|
364 |
+
# refine part
|
365 |
+
# mask_update = sort_query[..., 0] < i * step_slide + window_len
|
366 |
+
# sort_query_pick = sort_query[mask_update]
|
367 |
+
intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
|
368 |
+
point_map[idx_glob:idx_glob+window_len] = seg_point_map
|
369 |
+
unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
|
370 |
+
# update the camera poses
|
371 |
+
|
372 |
+
# if using the ground truth pose
|
373 |
+
# if extrs_unf is not None:
|
374 |
+
# c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
|
375 |
+
# else:
|
376 |
+
prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
|
377 |
+
c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
|
378 |
+
|
379 |
+
track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
|
380 |
+
track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
|
381 |
+
vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
|
382 |
+
conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
|
383 |
+
dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
|
384 |
+
unc_metric = unc_metric[:T_org,:]
|
385 |
+
point_map = point_map[:T_org,:]
|
386 |
+
intrs_out = intrs_out[:T_org,:]
|
387 |
+
c2w_traj = c2w_traj[:T_org,:]
|
388 |
+
if self.training:
|
389 |
+
ret = {
|
390 |
+
"loss": loss,
|
391 |
+
"depth_loss": 0.0,
|
392 |
+
"ab_loss": 0.0,
|
393 |
+
"vis_loss": out["vis_loss"],
|
394 |
+
"track_loss": out["track_loss"],
|
395 |
+
"conf_loss": out["conf_loss"],
|
396 |
+
"dyn_loss": out["dyn_loss"],
|
397 |
+
"sync_loss": out["sync_loss"],
|
398 |
+
"poses_pred": c2w_traj[None],
|
399 |
+
"intrs": intrs_out[None],
|
400 |
+
"points_map": point_map,
|
401 |
+
"track3d_pred": track3d_pred[None],
|
402 |
+
"rgb_tracks": track3d_pred[None],
|
403 |
+
"track2d_pred": track2d_pred[None],
|
404 |
+
"traj_est": track2d_pred[None],
|
405 |
+
"vis_est": vis_pred[None], "conf_pred": conf_pred[None],
|
406 |
+
"dyn_preds": dyn_preds[None],
|
407 |
+
"imgs_raw": video[None],
|
408 |
+
"unc_metric": unc_metric,
|
409 |
+
}
|
410 |
+
|
411 |
+
return ret
|
412 |
+
else:
|
413 |
+
return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
|
414 |
+
def forward(self,
|
415 |
+
x: torch.Tensor,
|
416 |
+
annots: Dict = {},
|
417 |
+
pts_q: torch.Tensor = None,
|
418 |
+
pts_q_3d: torch.Tensor = None,
|
419 |
+
overlap_d: torch.Tensor = None,
|
420 |
+
full_point = False,
|
421 |
+
fixed_cam = False,
|
422 |
+
support_frame = 0,
|
423 |
+
query_no_BA = False,
|
424 |
+
cache = None,
|
425 |
+
replace_ratio = 0.6,
|
426 |
+
iters_track=4,
|
427 |
+
**kwargs):
|
428 |
+
"""
|
429 |
+
forward the video camera model, which predict (
|
430 |
+
`intr` `camera poses` `video depth`
|
431 |
+
)
|
432 |
+
|
433 |
+
args:
|
434 |
+
x: the input video frames. [B, T, C, H, W]
|
435 |
+
annots: the annotations for video frames.
|
436 |
+
{
|
437 |
+
"poses_gt": the pose encoding for the video frames. [B, T, 7]
|
438 |
+
"depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
|
439 |
+
"metric": bool, whether to calculate the metric for the video frames.
|
440 |
+
}
|
441 |
+
"""
|
442 |
+
self.support_frame = support_frame
|
443 |
+
|
444 |
+
#TODO: to adjust a little bit
|
445 |
+
track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
|
446 |
+
B, T, _, H, W = x.shape
|
447 |
+
imgs_raw = x.clone()
|
448 |
+
# get the video split and features for each segment
|
449 |
+
patch_size, x_resize = self.ProcVid(x)
|
450 |
+
x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
|
451 |
+
H_resize, W_resize = x_resize.shape[-2:]
|
452 |
+
|
453 |
+
prec_fx = W / W_resize
|
454 |
+
prec_fy = H / H_resize
|
455 |
+
# get patch size
|
456 |
+
P_H, P_W = patch_size
|
457 |
+
|
458 |
+
# get the depth, pointmap and mask
|
459 |
+
#TODO: Release DepthAnything Version
|
460 |
+
points_map_gt = None
|
461 |
+
with torch.no_grad():
|
462 |
+
if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
|
463 |
+
if if_gt_depth==False:
|
464 |
+
if cache is not None:
|
465 |
+
T_cache = cache["points_map"].shape[0]
|
466 |
+
T_new = T - T_cache
|
467 |
+
x_resize_new = x_resize[:, T_cache:]
|
468 |
+
else:
|
469 |
+
T_new = T
|
470 |
+
x_resize_new = x_resize
|
471 |
+
# infer with chunk
|
472 |
+
chunk_size = self.chunk_size
|
473 |
+
metric_depth = []
|
474 |
+
intrs = []
|
475 |
+
unc_metric = []
|
476 |
+
mask = []
|
477 |
+
points_map = []
|
478 |
+
normals = []
|
479 |
+
normals_mask = []
|
480 |
+
for i in range(0, B*T_new, chunk_size):
|
481 |
+
output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
|
482 |
+
metric_depth.append(output['depth'])
|
483 |
+
intrs.append(output['intrinsics'])
|
484 |
+
unc_metric.append(output['mask_prob'])
|
485 |
+
mask.append(output['mask'])
|
486 |
+
points_map.append(output['points'])
|
487 |
+
normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
|
488 |
+
normals.append(normals_i)
|
489 |
+
normals_mask.append(normals_mask_i)
|
490 |
+
|
491 |
+
metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
492 |
+
intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
|
493 |
+
intrs[:,:,0,:] *= W_resize
|
494 |
+
intrs[:,:,1,:] *= H_resize
|
495 |
+
# points_map = torch.cat(points_map, dim=0)
|
496 |
+
mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
497 |
+
# cat the normals
|
498 |
+
normals = torch.cat(normals, dim=0)
|
499 |
+
normals_mask = torch.cat(normals_mask, dim=0)
|
500 |
+
|
501 |
+
metric_depth = metric_depth.clone()
|
502 |
+
metric_depth[metric_depth == torch.inf] = 0
|
503 |
+
_depths = metric_depth[metric_depth > 0].reshape(-1)
|
504 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
505 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
506 |
+
iqr = q75 - q25
|
507 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
508 |
+
_depth_roi = torch.tensor(
|
509 |
+
[1e-1, upper_bound.item()],
|
510 |
+
dtype=metric_depth.dtype,
|
511 |
+
device=metric_depth.device
|
512 |
+
)
|
513 |
+
mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
|
514 |
+
mask = mask * mask_roi
|
515 |
+
mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
|
516 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
|
517 |
+
unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
518 |
+
unc_metric *= mask
|
519 |
+
if full_point:
|
520 |
+
unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
|
521 |
+
if cache is not None:
|
522 |
+
assert B==1, "only support batch size 1 right now."
|
523 |
+
unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
|
524 |
+
intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
|
525 |
+
points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
|
526 |
+
metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
|
527 |
+
|
528 |
+
if "poses_gt" in annots.keys():
|
529 |
+
intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
|
530 |
+
H_resize, W_resize, self.resolution)
|
531 |
+
else:
|
532 |
+
c2w_traj_gt = None
|
533 |
+
|
534 |
+
if "intrs_gt" in annots.keys():
|
535 |
+
intrs = annots["intrs_gt"].view(B, T, 3, 3)
|
536 |
+
fx_factor = W_resize / W
|
537 |
+
fy_factor = H_resize / H
|
538 |
+
intrs[:,:,0,:] *= fx_factor
|
539 |
+
intrs[:,:,1,:] *= fy_factor
|
540 |
+
|
541 |
+
if "depth_gt" in annots.keys():
|
542 |
+
|
543 |
+
metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
|
544 |
+
metric_depth_gt = F.interpolate(metric_depth_gt,
|
545 |
+
size=(H_resize, W_resize), mode='nearest')
|
546 |
+
|
547 |
+
_depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
|
548 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
549 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
550 |
+
iqr = q75 - q25
|
551 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
552 |
+
_depth_roi = torch.tensor(
|
553 |
+
[1e-1, upper_bound.item()],
|
554 |
+
dtype=metric_depth_gt.dtype,
|
555 |
+
device=metric_depth_gt.device
|
556 |
+
)
|
557 |
+
mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
|
558 |
+
# if (upper_bound > 200).any():
|
559 |
+
# import pdb; pdb.set_trace()
|
560 |
+
if (kwargs.get('stage', 0) == 2):
|
561 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
|
562 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
563 |
+
else:
|
564 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
|
565 |
+
unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
|
566 |
+
# filter the sky
|
567 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
568 |
+
if "unc_metric" in annots.keys():
|
569 |
+
unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
|
570 |
+
size=(H_resize, W_resize), mode='nearest')
|
571 |
+
unc_metric = unc_metric * unc_metric_
|
572 |
+
if if_gt_depth:
|
573 |
+
points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
574 |
+
metric_depth = metric_depth_gt
|
575 |
+
points_map_gt = points_map
|
576 |
+
else:
|
577 |
+
points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
578 |
+
|
579 |
+
# track the 3d points
|
580 |
+
ret_track = None
|
581 |
+
regular_track = True
|
582 |
+
dyn_preds, final_tracks = None, None
|
583 |
+
|
584 |
+
if "use_extr" in annots.keys():
|
585 |
+
init_pose = True
|
586 |
+
else:
|
587 |
+
init_pose = False
|
588 |
+
# set the custom vid and valid only
|
589 |
+
custom_vid = annots.get("custom_vid", False)
|
590 |
+
valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
|
591 |
+
if self.training:
|
592 |
+
if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
|
593 |
+
traj3d = annots['traj_3d']
|
594 |
+
if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
|
595 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
596 |
+
T, x.device, query_size=self.track_num // 2,
|
597 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
598 |
+
else:
|
599 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
600 |
+
T, x.device, query_size=random.randint(32, 256),
|
601 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
602 |
+
if pts_q is not None:
|
603 |
+
pts_q = pts_q[None,None]
|
604 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
605 |
+
metric_depth,
|
606 |
+
unc_metric.detach(), points_map, pts_q,
|
607 |
+
intrs=intrs.clone(), cache=cache,
|
608 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
609 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
610 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
611 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
612 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
613 |
+
else:
|
614 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
615 |
+
metric_depth,
|
616 |
+
unc_metric.detach(), points_map, traj3d[..., :2],
|
617 |
+
intrs=intrs.clone(), cache=cache,
|
618 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
619 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
620 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
621 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
622 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
623 |
+
regular_track = False
|
624 |
+
|
625 |
+
|
626 |
+
if regular_track:
|
627 |
+
if pts_q is None:
|
628 |
+
pts_q = get_track_points(H_resize, W_resize,
|
629 |
+
T, x.device, query_size=self.track_num,
|
630 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
|
631 |
+
support_pts_q = None
|
632 |
+
else:
|
633 |
+
pts_q = pts_q[None,None]
|
634 |
+
# resize the query points
|
635 |
+
pts_q[...,1] *= W_resize / W
|
636 |
+
pts_q[...,2] *= H_resize / H
|
637 |
+
|
638 |
+
if pts_q_3d is not None:
|
639 |
+
pts_q_3d = pts_q_3d[None,None]
|
640 |
+
# resize the query points
|
641 |
+
pts_q_3d[...,1] *= W_resize / W
|
642 |
+
pts_q_3d[...,2] *= H_resize / H
|
643 |
+
else:
|
644 |
+
# adjust the query with uncertainty
|
645 |
+
if (full_point==False) and (overlap_d is None):
|
646 |
+
pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
|
647 |
+
pts_q = pts_q[:,:,pts_q_unc>0.5,:]
|
648 |
+
if (pts_q_unc<0.5).sum() > 0:
|
649 |
+
# pad the query points
|
650 |
+
pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
|
651 |
+
# pick the random indices
|
652 |
+
indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
|
653 |
+
pad_pts = indices
|
654 |
+
pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
|
655 |
+
|
656 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
657 |
+
T, x.device, query_size=self.track_num,
|
658 |
+
support_frame=self.support_frame,
|
659 |
+
unc_metric=unc_metric, mode="mixed")[None]
|
660 |
+
|
661 |
+
points_map[points_map>1e3] = 0
|
662 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
|
663 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
664 |
+
metric_depth,
|
665 |
+
unc_metric.detach(), points_map, pts_q,
|
666 |
+
pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
|
667 |
+
overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
|
668 |
+
prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
|
669 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
|
670 |
+
stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
|
671 |
+
intrs = intrs_org
|
672 |
+
points_map = point_map_org_refined
|
673 |
+
c2w_traj = ret_track["cam_pred"]
|
674 |
+
|
675 |
+
if ret_track is not None:
|
676 |
+
if ret_track["loss"] is not None:
|
677 |
+
track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
|
678 |
+
|
679 |
+
# update the cache
|
680 |
+
cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
|
681 |
+
# output
|
682 |
+
depth = F.interpolate(metric_depth,
|
683 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
684 |
+
points_map = F.interpolate(points_map,
|
685 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
686 |
+
unc_metric = F.interpolate(unc_metric,
|
687 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
688 |
+
|
689 |
+
if self.training:
|
690 |
+
|
691 |
+
loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
|
692 |
+
ret = {"loss": loss,
|
693 |
+
"depth_loss": point_map_loss,
|
694 |
+
"ab_loss": (scale_loss + shift_loss)*50,
|
695 |
+
"vis_loss": vis_loss, "track_loss": track_loss,
|
696 |
+
"poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
|
697 |
+
"imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
|
698 |
+
"depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
|
699 |
+
"sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
|
700 |
+
}
|
701 |
+
|
702 |
+
else:
|
703 |
+
|
704 |
+
if ret_track is not None:
|
705 |
+
traj_est = ret_track['preds']
|
706 |
+
traj_est[..., 0] *= W / W_resize
|
707 |
+
traj_est[..., 1] *= H / H_resize
|
708 |
+
vis_est = ret_track['vis_pred']
|
709 |
+
else:
|
710 |
+
traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
|
711 |
+
vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
|
712 |
+
|
713 |
+
if intrs is not None:
|
714 |
+
intrs[..., 0, :] *= W / W_resize
|
715 |
+
intrs[..., 1, :] *= H / H_resize
|
716 |
+
ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
|
717 |
+
"depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
|
718 |
+
"rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
|
719 |
+
"conf_pred": ret_track['conf_pred'], "cache": cache,
|
720 |
+
}
|
721 |
+
|
722 |
+
return ret
|
723 |
+
|
724 |
+
|
725 |
+
|
726 |
+
|
727 |
+
# three stages of training
|
728 |
+
|
729 |
+
# stage 1:
|
730 |
+
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
|
731 |
+
# Tracking and Pose as well -> based on gt depth and intrinsics
|
732 |
+
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
|
733 |
+
|
734 |
+
# stage 2: fixed 3D tracking
|
735 |
+
# Joint depth refiner
|
736 |
+
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
|
737 |
+
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
|
738 |
+
# ongoing two days
|
739 |
+
|
740 |
+
# stage 3: train multi windows by propagation
|
741 |
+
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
|
742 |
+
|
743 |
+
# types of scenarioes:
|
744 |
+
# 1. auto driving (waymo open dataset)
|
745 |
+
# 2. robot
|
746 |
+
# 3. internet ego video
|
747 |
+
|
748 |
+
|
749 |
+
|
750 |
+
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
|
751 |
+
# Update Variables:
|
752 |
+
# 1. 3D tracks B T N 3 xyz.
|
753 |
+
# 2. 2D tracks B T N 2 x y.
|
754 |
+
# 3. Dynamic Mask B T H W.
|
755 |
+
# 4. Camera Pose B T 4 4.
|
756 |
+
# 5. Video Depth.
|
757 |
+
|
758 |
+
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
|
759 |
+
# Campatiablity by product.
|
models/SpaTrackV2/models/__init__.py
ADDED
File without changes
|
models/SpaTrackV2/models/blocks.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from torch.cuda.amp import autocast
|
11 |
+
from einops import rearrange
|
12 |
+
import collections
|
13 |
+
from functools import partial
|
14 |
+
from itertools import repeat
|
15 |
+
import torchvision.models as tvm
|
16 |
+
from torch.utils.checkpoint import checkpoint
|
17 |
+
from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
|
18 |
+
from typing import Union, Tuple
|
19 |
+
from torch import Tensor
|
20 |
+
|
21 |
+
# From PyTorch internals
|
22 |
+
def _ntuple(n):
|
23 |
+
def parse(x):
|
24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
25 |
+
return tuple(x)
|
26 |
+
return tuple(repeat(x, n))
|
27 |
+
|
28 |
+
return parse
|
29 |
+
|
30 |
+
|
31 |
+
def exists(val):
|
32 |
+
return val is not None
|
33 |
+
|
34 |
+
|
35 |
+
def default(val, d):
|
36 |
+
return val if exists(val) else d
|
37 |
+
|
38 |
+
|
39 |
+
to_2tuple = _ntuple(2)
|
40 |
+
|
41 |
+
class LayerScale(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
dim: int,
|
45 |
+
init_values: Union[float, Tensor] = 1e-5,
|
46 |
+
inplace: bool = False,
|
47 |
+
) -> None:
|
48 |
+
super().__init__()
|
49 |
+
self.inplace = inplace
|
50 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
51 |
+
|
52 |
+
def forward(self, x: Tensor) -> Tensor:
|
53 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
54 |
+
|
55 |
+
class Mlp(nn.Module):
|
56 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
57 |
+
|
58 |
+
def __init__(
|
59 |
+
self,
|
60 |
+
in_features,
|
61 |
+
hidden_features=None,
|
62 |
+
out_features=None,
|
63 |
+
act_layer=nn.GELU,
|
64 |
+
norm_layer=None,
|
65 |
+
bias=True,
|
66 |
+
drop=0.0,
|
67 |
+
use_conv=False,
|
68 |
+
):
|
69 |
+
super().__init__()
|
70 |
+
out_features = out_features or in_features
|
71 |
+
hidden_features = hidden_features or in_features
|
72 |
+
bias = to_2tuple(bias)
|
73 |
+
drop_probs = to_2tuple(drop)
|
74 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
75 |
+
|
76 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
77 |
+
self.act = act_layer()
|
78 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
79 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
80 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
81 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
82 |
+
|
83 |
+
def forward(self, x):
|
84 |
+
x = self.fc1(x)
|
85 |
+
x = self.act(x)
|
86 |
+
x = self.drop1(x)
|
87 |
+
x = self.fc2(x)
|
88 |
+
x = self.drop2(x)
|
89 |
+
return x
|
90 |
+
|
91 |
+
class Attention(nn.Module):
|
92 |
+
def __init__(self, query_dim, context_dim=None,
|
93 |
+
num_heads=8, dim_head=48, qkv_bias=False, flash=False):
|
94 |
+
super().__init__()
|
95 |
+
inner_dim = self.inner_dim = dim_head * num_heads
|
96 |
+
context_dim = default(context_dim, query_dim)
|
97 |
+
self.scale = dim_head**-0.5
|
98 |
+
self.heads = num_heads
|
99 |
+
self.flash = flash
|
100 |
+
|
101 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
102 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
103 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
104 |
+
|
105 |
+
def forward(self, x, context=None, attn_bias=None):
|
106 |
+
B, N1, _ = x.shape
|
107 |
+
C = self.inner_dim
|
108 |
+
h = self.heads
|
109 |
+
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
110 |
+
context = default(context, x)
|
111 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
112 |
+
|
113 |
+
N2 = context.shape[1]
|
114 |
+
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
115 |
+
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
116 |
+
|
117 |
+
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
118 |
+
if self.flash==False:
|
119 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
120 |
+
if attn_bias is not None:
|
121 |
+
sim = sim + attn_bias
|
122 |
+
if sim.abs().max()>1e2:
|
123 |
+
import pdb; pdb.set_trace()
|
124 |
+
attn = sim.softmax(dim=-1)
|
125 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
126 |
+
else:
|
127 |
+
input_args = [x.contiguous() for x in [q, k, v]]
|
128 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
129 |
+
|
130 |
+
if self.to_out.bias.dtype != x.dtype:
|
131 |
+
x = x.to(self.to_out.bias.dtype)
|
132 |
+
|
133 |
+
return self.to_out(x)
|
134 |
+
|
135 |
+
|
136 |
+
class VGG19(nn.Module):
|
137 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
138 |
+
super().__init__()
|
139 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
140 |
+
self.amp = amp
|
141 |
+
self.amp_dtype = amp_dtype
|
142 |
+
|
143 |
+
def forward(self, x, **kwargs):
|
144 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
145 |
+
feats = {}
|
146 |
+
scale = 1
|
147 |
+
for layer in self.layers:
|
148 |
+
if isinstance(layer, nn.MaxPool2d):
|
149 |
+
feats[scale] = x
|
150 |
+
scale = scale*2
|
151 |
+
x = layer(x)
|
152 |
+
return feats
|
153 |
+
|
154 |
+
class CNNandDinov2(nn.Module):
|
155 |
+
def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
|
156 |
+
super().__init__()
|
157 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
158 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
159 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
160 |
+
|
161 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
162 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
163 |
+
|
164 |
+
|
165 |
+
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
|
166 |
+
self.cnn = VGG19(**cnn_kwargs)
|
167 |
+
self.amp = amp
|
168 |
+
self.amp_dtype = amp_dtype
|
169 |
+
if self.amp:
|
170 |
+
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
171 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
172 |
+
|
173 |
+
|
174 |
+
def train(self, mode: bool = True):
|
175 |
+
return self.cnn.train(mode)
|
176 |
+
|
177 |
+
def forward(self, x, upsample = False):
|
178 |
+
B,C,H,W = x.shape
|
179 |
+
feature_pyramid = self.cnn(x)
|
180 |
+
|
181 |
+
if not upsample:
|
182 |
+
with torch.no_grad():
|
183 |
+
if self.dinov2_vitl14[0].device != x.device:
|
184 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
185 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
186 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
187 |
+
del dinov2_features_16
|
188 |
+
feature_pyramid[16] = features_16
|
189 |
+
return feature_pyramid
|
190 |
+
|
191 |
+
class Dinov2(nn.Module):
|
192 |
+
def __init__(self, amp = True, amp_dtype = torch.float16):
|
193 |
+
super().__init__()
|
194 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
195 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
196 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
197 |
+
|
198 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
199 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
200 |
+
|
201 |
+
self.amp = amp
|
202 |
+
self.amp_dtype = amp_dtype
|
203 |
+
if self.amp:
|
204 |
+
self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
|
205 |
+
|
206 |
+
def forward(self, x, upsample = False):
|
207 |
+
B,C,H,W = x.shape
|
208 |
+
mean_ = torch.tensor([0.485, 0.456, 0.406],
|
209 |
+
device=x.device).view(1, 3, 1, 1)
|
210 |
+
std_ = torch.tensor([0.229, 0.224, 0.225],
|
211 |
+
device=x.device).view(1, 3, 1, 1)
|
212 |
+
x = (x+1)/2
|
213 |
+
x = (x - mean_)/std_
|
214 |
+
h_re, w_re = 560, 560
|
215 |
+
x_resize = F.interpolate(x, size=(h_re, w_re),
|
216 |
+
mode='bilinear', align_corners=True)
|
217 |
+
if not upsample:
|
218 |
+
with torch.no_grad():
|
219 |
+
dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
|
220 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
|
221 |
+
del dinov2_features_16
|
222 |
+
features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
|
223 |
+
return features_16
|
224 |
+
|
225 |
+
class AttnBlock(nn.Module):
|
226 |
+
"""
|
227 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
228 |
+
"""
|
229 |
+
|
230 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
|
231 |
+
flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
|
232 |
+
super().__init__()
|
233 |
+
self.debug=debug
|
234 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
235 |
+
self.flash=flash
|
236 |
+
|
237 |
+
self.attn = Attention(
|
238 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
|
239 |
+
**block_kwargs
|
240 |
+
)
|
241 |
+
self.ls = LayerScale(hidden_size, init_values=0.005)
|
242 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
243 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
244 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
245 |
+
self.mlp = Mlp(
|
246 |
+
in_features=hidden_size,
|
247 |
+
hidden_features=mlp_hidden_dim,
|
248 |
+
act_layer=approx_gelu,
|
249 |
+
)
|
250 |
+
self.ckpt_fwd = ckpt_fwd
|
251 |
+
def forward(self, x):
|
252 |
+
if self.debug:
|
253 |
+
print(x.max(), x.min(), x.mean())
|
254 |
+
if self.ckpt_fwd:
|
255 |
+
x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
|
256 |
+
else:
|
257 |
+
x = x + self.attn(self.norm1(x))
|
258 |
+
|
259 |
+
x = x + self.ls(self.mlp(self.norm2(x)))
|
260 |
+
return x
|
261 |
+
|
262 |
+
class CrossAttnBlock(nn.Module):
|
263 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
|
264 |
+
flash=False, ckpt_fwd=False, **block_kwargs):
|
265 |
+
super().__init__()
|
266 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
267 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
268 |
+
|
269 |
+
self.cross_attn = Attention(
|
270 |
+
hidden_size, context_dim=context_dim, dim_head=head_dim,
|
271 |
+
num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
|
272 |
+
)
|
273 |
+
|
274 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
275 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
276 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
277 |
+
self.mlp = Mlp(
|
278 |
+
in_features=hidden_size,
|
279 |
+
hidden_features=mlp_hidden_dim,
|
280 |
+
act_layer=approx_gelu,
|
281 |
+
drop=0,
|
282 |
+
)
|
283 |
+
self.ckpt_fwd = ckpt_fwd
|
284 |
+
def forward(self, x, context):
|
285 |
+
if self.ckpt_fwd:
|
286 |
+
with autocast():
|
287 |
+
x = x + checkpoint(self.cross_attn,
|
288 |
+
self.norm1(x), self.norm_context(context), use_reentrant=False)
|
289 |
+
else:
|
290 |
+
with autocast():
|
291 |
+
x = x + self.cross_attn(
|
292 |
+
self.norm1(x), self.norm_context(context)
|
293 |
+
)
|
294 |
+
x = x + self.mlp(self.norm2(x))
|
295 |
+
return x
|
296 |
+
|
297 |
+
|
298 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
299 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
300 |
+
H, W = img.shape[-2:]
|
301 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
302 |
+
# go to 0,1 then 0,2 then -1,1
|
303 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
304 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
305 |
+
|
306 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
307 |
+
img = F.grid_sample(img, grid, align_corners=True, mode=mode)
|
308 |
+
|
309 |
+
if mask:
|
310 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
311 |
+
return img, mask.float()
|
312 |
+
|
313 |
+
return img
|
314 |
+
|
315 |
+
|
316 |
+
class CorrBlock:
|
317 |
+
def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
|
318 |
+
B, S, C, H_prev, W_prev = fmaps.shape
|
319 |
+
self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
|
320 |
+
|
321 |
+
self.num_levels = num_levels
|
322 |
+
self.radius = radius
|
323 |
+
self.fmaps_pyramid = []
|
324 |
+
self.depth_pyramid = []
|
325 |
+
self.fmaps_pyramid.append(fmaps)
|
326 |
+
if depths_dnG is not None:
|
327 |
+
self.depth_pyramid.append(depths_dnG)
|
328 |
+
for i in range(self.num_levels - 1):
|
329 |
+
if depths_dnG is not None:
|
330 |
+
depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
|
331 |
+
depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
|
332 |
+
_, _, H, W = depths_dnG_.shape
|
333 |
+
depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
|
334 |
+
self.depth_pyramid.append(depths_dnG)
|
335 |
+
fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
|
336 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
337 |
+
_, _, H, W = fmaps_.shape
|
338 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
339 |
+
H_prev = H
|
340 |
+
W_prev = W
|
341 |
+
self.fmaps_pyramid.append(fmaps)
|
342 |
+
|
343 |
+
def sample(self, coords):
|
344 |
+
r = self.radius
|
345 |
+
B, S, N, D = coords.shape
|
346 |
+
assert D == 2
|
347 |
+
|
348 |
+
H, W = self.H, self.W
|
349 |
+
out_pyramid = []
|
350 |
+
for i in range(self.num_levels):
|
351 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
352 |
+
_, _, _, H, W = corrs.shape
|
353 |
+
|
354 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
355 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
356 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
357 |
+
coords.device
|
358 |
+
)
|
359 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
360 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
361 |
+
coords_lvl = centroid_lvl + delta_lvl
|
362 |
+
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
363 |
+
corrs = corrs.view(B, S, N, -1)
|
364 |
+
out_pyramid.append(corrs)
|
365 |
+
|
366 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
367 |
+
return out.contiguous().float()
|
368 |
+
|
369 |
+
def corr(self, targets):
|
370 |
+
B, S, N, C = targets.shape
|
371 |
+
assert C == self.C
|
372 |
+
assert S == self.S
|
373 |
+
|
374 |
+
fmap1 = targets
|
375 |
+
|
376 |
+
self.corrs_pyramid = []
|
377 |
+
for fmaps in self.fmaps_pyramid:
|
378 |
+
_, _, _, H, W = fmaps.shape
|
379 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
380 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
381 |
+
corrs = corrs.view(B, S, N, H, W)
|
382 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
383 |
+
self.corrs_pyramid.append(corrs)
|
384 |
+
|
385 |
+
def corr_sample(self, targets, coords, coords_dp=None):
|
386 |
+
B, S, N, C = targets.shape
|
387 |
+
r = self.radius
|
388 |
+
Dim_c = (2*r+1)**2
|
389 |
+
assert C == self.C
|
390 |
+
assert S == self.S
|
391 |
+
|
392 |
+
out_pyramid = []
|
393 |
+
out_pyramid_dp = []
|
394 |
+
for i in range(self.num_levels):
|
395 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
396 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
397 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
398 |
+
coords.device
|
399 |
+
)
|
400 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
401 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
402 |
+
coords_lvl = centroid_lvl + delta_lvl
|
403 |
+
fmaps = self.fmaps_pyramid[i]
|
404 |
+
_, _, _, H, W = fmaps.shape
|
405 |
+
fmap2s = fmaps.view(B*S, C, H, W)
|
406 |
+
if len(self.depth_pyramid)>0:
|
407 |
+
depths_dnG_i = self.depth_pyramid[i]
|
408 |
+
depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
|
409 |
+
dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
|
410 |
+
dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
|
411 |
+
out_pyramid_dp.append(dp_corrs)
|
412 |
+
fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
|
413 |
+
fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
|
414 |
+
corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
|
415 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
416 |
+
corrs = corrs.view(B, S, N, -1)
|
417 |
+
out_pyramid.append(corrs)
|
418 |
+
|
419 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
420 |
+
if len(self.depth_pyramid)>0:
|
421 |
+
out_dp = torch.cat(out_pyramid_dp, dim=-1)
|
422 |
+
self.fcorrD = out_dp.contiguous().float()
|
423 |
+
else:
|
424 |
+
self.fcorrD = torch.zeros_like(out).contiguous().float()
|
425 |
+
return out.contiguous().float()
|
426 |
+
|
427 |
+
|
428 |
+
class EUpdateFormer(nn.Module):
|
429 |
+
"""
|
430 |
+
Transformer model that updates track estimates.
|
431 |
+
"""
|
432 |
+
|
433 |
+
def __init__(
|
434 |
+
self,
|
435 |
+
space_depth=12,
|
436 |
+
time_depth=12,
|
437 |
+
input_dim=320,
|
438 |
+
hidden_size=384,
|
439 |
+
num_heads=8,
|
440 |
+
output_dim=130,
|
441 |
+
mlp_ratio=4.0,
|
442 |
+
vq_depth=3,
|
443 |
+
add_space_attn=True,
|
444 |
+
add_time_attn=True,
|
445 |
+
flash=True
|
446 |
+
):
|
447 |
+
super().__init__()
|
448 |
+
self.out_channels = 2
|
449 |
+
self.num_heads = num_heads
|
450 |
+
self.hidden_size = hidden_size
|
451 |
+
self.add_space_attn = add_space_attn
|
452 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
453 |
+
self.flash = flash
|
454 |
+
self.flow_head = nn.Sequential(
|
455 |
+
nn.Linear(hidden_size, output_dim, bias=True),
|
456 |
+
nn.ReLU(inplace=True),
|
457 |
+
nn.Linear(output_dim, output_dim, bias=True),
|
458 |
+
nn.ReLU(inplace=True),
|
459 |
+
nn.Linear(output_dim, output_dim, bias=True)
|
460 |
+
)
|
461 |
+
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
462 |
+
cfg = xLSTMBlockStackConfig(
|
463 |
+
mlstm_block=mLSTMBlockConfig(
|
464 |
+
mlstm=mLSTMLayerConfig(
|
465 |
+
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
|
466 |
+
)
|
467 |
+
),
|
468 |
+
slstm_block=sLSTMBlockConfig(
|
469 |
+
slstm=sLSTMLayerConfig(
|
470 |
+
backend="cuda",
|
471 |
+
num_heads=4,
|
472 |
+
conv1d_kernel_size=4,
|
473 |
+
bias_init="powerlaw_blockdependent",
|
474 |
+
),
|
475 |
+
feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
|
476 |
+
),
|
477 |
+
context_length=50,
|
478 |
+
num_blocks=7,
|
479 |
+
embedding_dim=384,
|
480 |
+
slstm_at=[1],
|
481 |
+
|
482 |
+
)
|
483 |
+
self.xlstm_fwd = xLSTMBlockStack(cfg)
|
484 |
+
self.xlstm_bwd = xLSTMBlockStack(cfg)
|
485 |
+
|
486 |
+
self.initialize_weights()
|
487 |
+
|
488 |
+
def initialize_weights(self):
|
489 |
+
def _basic_init(module):
|
490 |
+
if isinstance(module, nn.Linear):
|
491 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
492 |
+
if module.bias is not None:
|
493 |
+
nn.init.constant_(module.bias, 0)
|
494 |
+
|
495 |
+
self.apply(_basic_init)
|
496 |
+
|
497 |
+
def forward(self,
|
498 |
+
input_tensor,
|
499 |
+
track_mask=None):
|
500 |
+
""" Updating with Transformer
|
501 |
+
|
502 |
+
Args:
|
503 |
+
input_tensor: B, N, T, C
|
504 |
+
arap_embed: B, N, T, C
|
505 |
+
"""
|
506 |
+
B, N, T, C = input_tensor.shape
|
507 |
+
x = self.input_transform(input_tensor)
|
508 |
+
|
509 |
+
track_mask = track_mask.permute(0,2,1,3).float()
|
510 |
+
fwd_x = x*track_mask
|
511 |
+
bwd_x = x.flip(2)*track_mask.flip(2)
|
512 |
+
feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
|
513 |
+
feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
|
514 |
+
feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
|
515 |
+
|
516 |
+
flow = self.flow_head(feat)
|
517 |
+
|
518 |
+
return flow[..., :2], flow[..., 2:]
|
519 |
+
|
models/SpaTrackV2/models/camera_transform.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
# Adapted from https://github.com/amyxlase/relpose-plus-plus
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import numpy as np
|
12 |
+
import math
|
13 |
+
|
14 |
+
|
15 |
+
|
16 |
+
|
17 |
+
def bbox_xyxy_to_xywh(xyxy):
|
18 |
+
wh = xyxy[2:] - xyxy[:2]
|
19 |
+
xywh = np.concatenate([xyxy[:2], wh])
|
20 |
+
return xywh
|
21 |
+
|
22 |
+
|
23 |
+
def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
|
24 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
|
25 |
+
|
26 |
+
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
|
27 |
+
|
28 |
+
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
|
29 |
+
focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
|
30 |
+
)
|
31 |
+
|
32 |
+
return focal_length, principal_point_cropped
|
33 |
+
|
34 |
+
|
35 |
+
def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
|
36 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
|
37 |
+
|
38 |
+
# now scale and convert from pixels to NDC
|
39 |
+
image_size_wh_output = new_size_wh.float()
|
40 |
+
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
|
41 |
+
focal_length_px_scaled = focal_length_px * scale
|
42 |
+
principal_point_px_scaled = principal_point_px * scale
|
43 |
+
|
44 |
+
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
|
45 |
+
focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
|
46 |
+
)
|
47 |
+
return focal_length_scaled, principal_point_scaled
|
48 |
+
|
49 |
+
|
50 |
+
def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
|
51 |
+
half_image_size = image_size_wh / 2
|
52 |
+
rescale = half_image_size.min()
|
53 |
+
principal_point_px = half_image_size - principal_point * rescale
|
54 |
+
focal_length_px = focal_length * rescale
|
55 |
+
return focal_length_px, principal_point_px
|
56 |
+
|
57 |
+
|
58 |
+
def _convert_pixels_to_ndc(
|
59 |
+
focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
|
60 |
+
):
|
61 |
+
half_image_size = image_size_wh / 2
|
62 |
+
rescale = half_image_size.min()
|
63 |
+
principal_point = (half_image_size - principal_point_px) / rescale
|
64 |
+
focal_length = focal_length_px / rescale
|
65 |
+
return focal_length, principal_point
|
66 |
+
|
67 |
+
|
68 |
+
def normalize_cameras(
|
69 |
+
cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
|
70 |
+
pose_mode="C2W"
|
71 |
+
):
|
72 |
+
"""
|
73 |
+
Normalizes cameras such that
|
74 |
+
(1) the optical axes point to the origin and the average distance to the origin is 1
|
75 |
+
(2) the first camera is the origin
|
76 |
+
(3) the translation vector is normalized
|
77 |
+
|
78 |
+
TODO: some transforms overlap with others. no need to do so many transforms
|
79 |
+
Args:
|
80 |
+
cameras (List[camera]).
|
81 |
+
"""
|
82 |
+
# Let distance from first camera to origin be unit
|
83 |
+
new_cameras = cameras.clone()
|
84 |
+
scale = 1.0
|
85 |
+
|
86 |
+
if compute_optical:
|
87 |
+
new_cameras, points = compute_optical_transform(new_cameras, points=points)
|
88 |
+
if first_camera:
|
89 |
+
new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
|
90 |
+
if normalize_trans:
|
91 |
+
new_cameras, points, scale = normalize_translation(new_cameras,
|
92 |
+
points=points, max_norm=max_norm)
|
93 |
+
return new_cameras, points, scale
|
94 |
+
|
95 |
+
|
96 |
+
def compute_optical_transform(new_cameras, points=None):
|
97 |
+
"""
|
98 |
+
adapted from https://github.com/amyxlase/relpose-plus-plus
|
99 |
+
"""
|
100 |
+
|
101 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
102 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
|
103 |
+
t = Translate(p_intersect)
|
104 |
+
scale = dist.squeeze()[0]
|
105 |
+
|
106 |
+
if points is not None:
|
107 |
+
points = t.inverse().transform_points(points)
|
108 |
+
points = points / scale
|
109 |
+
|
110 |
+
# Degenerate case
|
111 |
+
if scale == 0:
|
112 |
+
scale = torch.norm(new_cameras.T, dim=(0, 1))
|
113 |
+
scale = torch.sqrt(scale)
|
114 |
+
new_cameras.T = new_cameras.T / scale
|
115 |
+
else:
|
116 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
117 |
+
new_cameras.R = new_matrix[:, :3, :3]
|
118 |
+
new_cameras.T = new_matrix[:, 3, :3] / scale
|
119 |
+
|
120 |
+
return new_cameras, points
|
121 |
+
|
122 |
+
|
123 |
+
def compute_optical_axis_intersection(cameras):
|
124 |
+
centers = cameras.get_camera_center()
|
125 |
+
principal_points = cameras.principal_point
|
126 |
+
|
127 |
+
one_vec = torch.ones((len(cameras), 1))
|
128 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
129 |
+
|
130 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
131 |
+
|
132 |
+
pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
|
133 |
+
|
134 |
+
directions = pp2 - centers
|
135 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
136 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
137 |
+
|
138 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
|
139 |
+
|
140 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
141 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
142 |
+
|
143 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
144 |
+
|
145 |
+
|
146 |
+
def intersect_skew_line_groups(p, r, mask):
|
147 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
148 |
+
# mask of shape (B, N, n_intersected_lines)
|
149 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
150 |
+
_, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
|
151 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
|
152 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
153 |
+
|
154 |
+
|
155 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
156 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
157 |
+
dim = p.shape[-1]
|
158 |
+
# make sure the heading vectors are l2-normed
|
159 |
+
if mask is None:
|
160 |
+
mask = torch.ones_like(p[..., 0])
|
161 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
162 |
+
|
163 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
164 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
165 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
166 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
167 |
+
|
168 |
+
if torch.any(torch.isnan(p_intersect)):
|
169 |
+
print(p_intersect)
|
170 |
+
raise ValueError(f"p_intersect is NaN")
|
171 |
+
|
172 |
+
return p_intersect, r
|
173 |
+
|
174 |
+
|
175 |
+
def _point_line_distance(p1, r1, p2):
|
176 |
+
df = p2 - p1
|
177 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
178 |
+
line_pt_nearest = p2 - proj_vector
|
179 |
+
d = (proj_vector).norm(dim=-1)
|
180 |
+
return d, line_pt_nearest
|
181 |
+
|
182 |
+
|
183 |
+
def first_camera_transform(cameras, rotation_only=False,
|
184 |
+
points=None, pose_mode="C2W"):
|
185 |
+
"""
|
186 |
+
Transform so that the first camera is the origin
|
187 |
+
"""
|
188 |
+
|
189 |
+
new_cameras = cameras.clone()
|
190 |
+
# new_transform = new_cameras.get_world_to_view_transform()
|
191 |
+
|
192 |
+
R = cameras.R
|
193 |
+
T = cameras.T
|
194 |
+
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
|
195 |
+
Tran_M = torch.cat([Tran_M,
|
196 |
+
torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
|
197 |
+
if pose_mode == "C2W":
|
198 |
+
Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
|
199 |
+
elif pose_mode == "W2C":
|
200 |
+
Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
|
201 |
+
|
202 |
+
if False:
|
203 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
204 |
+
if rotation_only:
|
205 |
+
t = tR.inverse()
|
206 |
+
else:
|
207 |
+
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
208 |
+
t = tR.compose(tT).inverse()
|
209 |
+
|
210 |
+
if points is not None:
|
211 |
+
points = t.inverse().transform_points(points)
|
212 |
+
|
213 |
+
if pose_mode == "C2W":
|
214 |
+
new_matrix = new_transform.compose(t).get_matrix()
|
215 |
+
else:
|
216 |
+
import ipdb; ipdb.set_trace()
|
217 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
218 |
+
|
219 |
+
new_cameras.R = Tran_M_new[:, :3, :3]
|
220 |
+
new_cameras.T = Tran_M_new[:, :3, 3]
|
221 |
+
|
222 |
+
return new_cameras, points
|
223 |
+
|
224 |
+
|
225 |
+
def normalize_translation(new_cameras, points=None, max_norm=False):
|
226 |
+
t_gt = new_cameras.T.clone()
|
227 |
+
t_gt = t_gt[1:, :]
|
228 |
+
|
229 |
+
if max_norm:
|
230 |
+
t_gt_norm = torch.norm(t_gt, dim=(-1))
|
231 |
+
t_gt_scale = t_gt_norm.max()
|
232 |
+
if t_gt_norm.max() < 0.001:
|
233 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
234 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
235 |
+
else:
|
236 |
+
t_gt_norm = torch.norm(t_gt, dim=(0, 1))
|
237 |
+
t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
|
238 |
+
t_gt_scale = t_gt_scale / 2
|
239 |
+
if t_gt_norm.max() < 0.001:
|
240 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
241 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
242 |
+
|
243 |
+
new_cameras.T = new_cameras.T / t_gt_scale
|
244 |
+
|
245 |
+
if points is not None:
|
246 |
+
points = points / t_gt_scale
|
247 |
+
|
248 |
+
return new_cameras, points, t_gt_scale
|
models/SpaTrackV2/models/depth_refiner/backbone.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---------------------------------------------------------------
|
2 |
+
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
3 |
+
#
|
4 |
+
# This work is licensed under the NVIDIA Source Code License
|
5 |
+
# ---------------------------------------------------------------
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from functools import partial
|
10 |
+
|
11 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
12 |
+
from timm.models import register_model
|
13 |
+
from timm.models.vision_transformer import _cfg
|
14 |
+
import math
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
19 |
+
super().__init__()
|
20 |
+
out_features = out_features or in_features
|
21 |
+
hidden_features = hidden_features or in_features
|
22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
23 |
+
self.dwconv = DWConv(hidden_features)
|
24 |
+
self.act = act_layer()
|
25 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
26 |
+
self.drop = nn.Dropout(drop)
|
27 |
+
|
28 |
+
self.apply(self._init_weights)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
elif isinstance(m, nn.LayerNorm):
|
36 |
+
nn.init.constant_(m.bias, 0)
|
37 |
+
nn.init.constant_(m.weight, 1.0)
|
38 |
+
elif isinstance(m, nn.Conv2d):
|
39 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
40 |
+
fan_out //= m.groups
|
41 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
42 |
+
if m.bias is not None:
|
43 |
+
m.bias.data.zero_()
|
44 |
+
|
45 |
+
def forward(self, x, H, W):
|
46 |
+
x = self.fc1(x)
|
47 |
+
x = self.dwconv(x, H, W)
|
48 |
+
x = self.act(x)
|
49 |
+
x = self.drop(x)
|
50 |
+
x = self.fc2(x)
|
51 |
+
x = self.drop(x)
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class Attention(nn.Module):
|
56 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
57 |
+
super().__init__()
|
58 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
59 |
+
|
60 |
+
self.dim = dim
|
61 |
+
self.num_heads = num_heads
|
62 |
+
head_dim = dim // num_heads
|
63 |
+
self.scale = qk_scale or head_dim ** -0.5
|
64 |
+
|
65 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
66 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
67 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
68 |
+
self.proj = nn.Linear(dim, dim)
|
69 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
70 |
+
|
71 |
+
self.sr_ratio = sr_ratio
|
72 |
+
if sr_ratio > 1:
|
73 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
74 |
+
self.norm = nn.LayerNorm(dim)
|
75 |
+
|
76 |
+
self.apply(self._init_weights)
|
77 |
+
|
78 |
+
def _init_weights(self, m):
|
79 |
+
if isinstance(m, nn.Linear):
|
80 |
+
trunc_normal_(m.weight, std=.02)
|
81 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
82 |
+
nn.init.constant_(m.bias, 0)
|
83 |
+
elif isinstance(m, nn.LayerNorm):
|
84 |
+
nn.init.constant_(m.bias, 0)
|
85 |
+
nn.init.constant_(m.weight, 1.0)
|
86 |
+
elif isinstance(m, nn.Conv2d):
|
87 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
88 |
+
fan_out //= m.groups
|
89 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
90 |
+
if m.bias is not None:
|
91 |
+
m.bias.data.zero_()
|
92 |
+
|
93 |
+
def forward(self, x, H, W):
|
94 |
+
B, N, C = x.shape
|
95 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
96 |
+
|
97 |
+
if self.sr_ratio > 1:
|
98 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
99 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
100 |
+
x_ = self.norm(x_)
|
101 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
102 |
+
else:
|
103 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
104 |
+
k, v = kv[0], kv[1]
|
105 |
+
|
106 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
107 |
+
attn = attn.softmax(dim=-1)
|
108 |
+
attn = self.attn_drop(attn)
|
109 |
+
|
110 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
111 |
+
x = self.proj(x)
|
112 |
+
x = self.proj_drop(x)
|
113 |
+
|
114 |
+
return x
|
115 |
+
|
116 |
+
|
117 |
+
class Block(nn.Module):
|
118 |
+
|
119 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
120 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
121 |
+
super().__init__()
|
122 |
+
self.norm1 = norm_layer(dim)
|
123 |
+
self.attn = Attention(
|
124 |
+
dim,
|
125 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
126 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
127 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
128 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
129 |
+
self.norm2 = norm_layer(dim)
|
130 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
131 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
132 |
+
|
133 |
+
self.apply(self._init_weights)
|
134 |
+
|
135 |
+
def _init_weights(self, m):
|
136 |
+
if isinstance(m, nn.Linear):
|
137 |
+
trunc_normal_(m.weight, std=.02)
|
138 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
139 |
+
nn.init.constant_(m.bias, 0)
|
140 |
+
elif isinstance(m, nn.LayerNorm):
|
141 |
+
nn.init.constant_(m.bias, 0)
|
142 |
+
nn.init.constant_(m.weight, 1.0)
|
143 |
+
elif isinstance(m, nn.Conv2d):
|
144 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
145 |
+
fan_out //= m.groups
|
146 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
147 |
+
if m.bias is not None:
|
148 |
+
m.bias.data.zero_()
|
149 |
+
|
150 |
+
def forward(self, x, H, W):
|
151 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
152 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
153 |
+
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
class OverlapPatchEmbed(nn.Module):
|
158 |
+
""" Image to Patch Embedding
|
159 |
+
"""
|
160 |
+
|
161 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
162 |
+
super().__init__()
|
163 |
+
img_size = to_2tuple(img_size)
|
164 |
+
patch_size = to_2tuple(patch_size)
|
165 |
+
|
166 |
+
self.img_size = img_size
|
167 |
+
self.patch_size = patch_size
|
168 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
169 |
+
self.num_patches = self.H * self.W
|
170 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
171 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
172 |
+
self.norm = nn.LayerNorm(embed_dim)
|
173 |
+
|
174 |
+
self.apply(self._init_weights)
|
175 |
+
|
176 |
+
def _init_weights(self, m):
|
177 |
+
if isinstance(m, nn.Linear):
|
178 |
+
trunc_normal_(m.weight, std=.02)
|
179 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
180 |
+
nn.init.constant_(m.bias, 0)
|
181 |
+
elif isinstance(m, nn.LayerNorm):
|
182 |
+
nn.init.constant_(m.bias, 0)
|
183 |
+
nn.init.constant_(m.weight, 1.0)
|
184 |
+
elif isinstance(m, nn.Conv2d):
|
185 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
186 |
+
fan_out //= m.groups
|
187 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
188 |
+
if m.bias is not None:
|
189 |
+
m.bias.data.zero_()
|
190 |
+
|
191 |
+
def forward(self, x):
|
192 |
+
x = self.proj(x)
|
193 |
+
_, _, H, W = x.shape
|
194 |
+
x = x.flatten(2).transpose(1, 2)
|
195 |
+
x = self.norm(x)
|
196 |
+
|
197 |
+
return x, H, W
|
198 |
+
|
199 |
+
|
200 |
+
|
201 |
+
|
202 |
+
class OverlapPatchEmbed43(nn.Module):
|
203 |
+
""" Image to Patch Embedding
|
204 |
+
"""
|
205 |
+
|
206 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
207 |
+
super().__init__()
|
208 |
+
img_size = to_2tuple(img_size)
|
209 |
+
patch_size = to_2tuple(patch_size)
|
210 |
+
|
211 |
+
self.img_size = img_size
|
212 |
+
self.patch_size = patch_size
|
213 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
214 |
+
self.num_patches = self.H * self.W
|
215 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
216 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
217 |
+
self.norm = nn.LayerNorm(embed_dim)
|
218 |
+
|
219 |
+
self.apply(self._init_weights)
|
220 |
+
|
221 |
+
def _init_weights(self, m):
|
222 |
+
if isinstance(m, nn.Linear):
|
223 |
+
trunc_normal_(m.weight, std=.02)
|
224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
225 |
+
nn.init.constant_(m.bias, 0)
|
226 |
+
elif isinstance(m, nn.LayerNorm):
|
227 |
+
nn.init.constant_(m.bias, 0)
|
228 |
+
nn.init.constant_(m.weight, 1.0)
|
229 |
+
elif isinstance(m, nn.Conv2d):
|
230 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
231 |
+
fan_out //= m.groups
|
232 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
233 |
+
if m.bias is not None:
|
234 |
+
m.bias.data.zero_()
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
if x.shape[1]==4:
|
238 |
+
x = self.proj_4c(x)
|
239 |
+
else:
|
240 |
+
x = self.proj(x)
|
241 |
+
_, _, H, W = x.shape
|
242 |
+
x = x.flatten(2).transpose(1, 2)
|
243 |
+
x = self.norm(x)
|
244 |
+
|
245 |
+
return x, H, W
|
246 |
+
|
247 |
+
class MixVisionTransformer(nn.Module):
|
248 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
249 |
+
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
250 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
251 |
+
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
252 |
+
super().__init__()
|
253 |
+
self.num_classes = num_classes
|
254 |
+
self.depths = depths
|
255 |
+
|
256 |
+
# patch_embed 43
|
257 |
+
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
258 |
+
embed_dim=embed_dims[0])
|
259 |
+
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
260 |
+
embed_dim=embed_dims[1])
|
261 |
+
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
262 |
+
embed_dim=embed_dims[2])
|
263 |
+
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
264 |
+
embed_dim=embed_dims[3])
|
265 |
+
|
266 |
+
# transformer encoder
|
267 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
268 |
+
cur = 0
|
269 |
+
self.block1 = nn.ModuleList([Block(
|
270 |
+
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
271 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
272 |
+
sr_ratio=sr_ratios[0])
|
273 |
+
for i in range(depths[0])])
|
274 |
+
self.norm1 = norm_layer(embed_dims[0])
|
275 |
+
|
276 |
+
cur += depths[0]
|
277 |
+
self.block2 = nn.ModuleList([Block(
|
278 |
+
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
279 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
280 |
+
sr_ratio=sr_ratios[1])
|
281 |
+
for i in range(depths[1])])
|
282 |
+
self.norm2 = norm_layer(embed_dims[1])
|
283 |
+
|
284 |
+
cur += depths[1]
|
285 |
+
self.block3 = nn.ModuleList([Block(
|
286 |
+
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
287 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
288 |
+
sr_ratio=sr_ratios[2])
|
289 |
+
for i in range(depths[2])])
|
290 |
+
self.norm3 = norm_layer(embed_dims[2])
|
291 |
+
|
292 |
+
cur += depths[2]
|
293 |
+
self.block4 = nn.ModuleList([Block(
|
294 |
+
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
295 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
296 |
+
sr_ratio=sr_ratios[3])
|
297 |
+
for i in range(depths[3])])
|
298 |
+
self.norm4 = norm_layer(embed_dims[3])
|
299 |
+
|
300 |
+
# classification head
|
301 |
+
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
302 |
+
|
303 |
+
self.apply(self._init_weights)
|
304 |
+
|
305 |
+
def _init_weights(self, m):
|
306 |
+
if isinstance(m, nn.Linear):
|
307 |
+
trunc_normal_(m.weight, std=.02)
|
308 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
309 |
+
nn.init.constant_(m.bias, 0)
|
310 |
+
elif isinstance(m, nn.LayerNorm):
|
311 |
+
nn.init.constant_(m.bias, 0)
|
312 |
+
nn.init.constant_(m.weight, 1.0)
|
313 |
+
elif isinstance(m, nn.Conv2d):
|
314 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
315 |
+
fan_out //= m.groups
|
316 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
317 |
+
if m.bias is not None:
|
318 |
+
m.bias.data.zero_()
|
319 |
+
|
320 |
+
def init_weights(self, pretrained=None):
|
321 |
+
if isinstance(pretrained, str):
|
322 |
+
logger = get_root_logger()
|
323 |
+
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
324 |
+
|
325 |
+
def reset_drop_path(self, drop_path_rate):
|
326 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
327 |
+
cur = 0
|
328 |
+
for i in range(self.depths[0]):
|
329 |
+
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
330 |
+
|
331 |
+
cur += self.depths[0]
|
332 |
+
for i in range(self.depths[1]):
|
333 |
+
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
334 |
+
|
335 |
+
cur += self.depths[1]
|
336 |
+
for i in range(self.depths[2]):
|
337 |
+
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
338 |
+
|
339 |
+
cur += self.depths[2]
|
340 |
+
for i in range(self.depths[3]):
|
341 |
+
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
342 |
+
|
343 |
+
def freeze_patch_emb(self):
|
344 |
+
self.patch_embed1.requires_grad = False
|
345 |
+
|
346 |
+
@torch.jit.ignore
|
347 |
+
def no_weight_decay(self):
|
348 |
+
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
349 |
+
|
350 |
+
def get_classifier(self):
|
351 |
+
return self.head
|
352 |
+
|
353 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
354 |
+
self.num_classes = num_classes
|
355 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
356 |
+
|
357 |
+
def forward_features(self, x):
|
358 |
+
B = x.shape[0]
|
359 |
+
outs = []
|
360 |
+
|
361 |
+
# stage 1
|
362 |
+
x, H, W = self.patch_embed1(x)
|
363 |
+
for i, blk in enumerate(self.block1):
|
364 |
+
x = blk(x, H, W)
|
365 |
+
x = self.norm1(x)
|
366 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
367 |
+
outs.append(x)
|
368 |
+
|
369 |
+
# stage 2
|
370 |
+
x, H, W = self.patch_embed2(x)
|
371 |
+
for i, blk in enumerate(self.block2):
|
372 |
+
x = blk(x, H, W)
|
373 |
+
x = self.norm2(x)
|
374 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
375 |
+
outs.append(x)
|
376 |
+
|
377 |
+
# stage 3
|
378 |
+
x, H, W = self.patch_embed3(x)
|
379 |
+
for i, blk in enumerate(self.block3):
|
380 |
+
x = blk(x, H, W)
|
381 |
+
x = self.norm3(x)
|
382 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
383 |
+
outs.append(x)
|
384 |
+
|
385 |
+
# stage 4
|
386 |
+
x, H, W = self.patch_embed4(x)
|
387 |
+
for i, blk in enumerate(self.block4):
|
388 |
+
x = blk(x, H, W)
|
389 |
+
x = self.norm4(x)
|
390 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
391 |
+
outs.append(x)
|
392 |
+
|
393 |
+
return outs
|
394 |
+
|
395 |
+
def forward(self, x):
|
396 |
+
if x.dim() == 5:
|
397 |
+
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
|
398 |
+
x = self.forward_features(x)
|
399 |
+
# x = self.head(x)
|
400 |
+
|
401 |
+
return x
|
402 |
+
|
403 |
+
|
404 |
+
class DWConv(nn.Module):
|
405 |
+
def __init__(self, dim=768):
|
406 |
+
super(DWConv, self).__init__()
|
407 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
408 |
+
|
409 |
+
def forward(self, x, H, W):
|
410 |
+
B, N, C = x.shape
|
411 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
412 |
+
x = self.dwconv(x)
|
413 |
+
x = x.flatten(2).transpose(1, 2)
|
414 |
+
|
415 |
+
return x
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
#@BACKBONES.register_module()
|
420 |
+
class mit_b0(MixVisionTransformer):
|
421 |
+
def __init__(self, **kwargs):
|
422 |
+
super(mit_b0, self).__init__(
|
423 |
+
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
424 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
425 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
426 |
+
|
427 |
+
|
428 |
+
#@BACKBONES.register_module()
|
429 |
+
class mit_b1(MixVisionTransformer):
|
430 |
+
def __init__(self, **kwargs):
|
431 |
+
super(mit_b1, self).__init__(
|
432 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
433 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
434 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
435 |
+
|
436 |
+
|
437 |
+
#@BACKBONES.register_module()
|
438 |
+
class mit_b2(MixVisionTransformer):
|
439 |
+
def __init__(self, **kwargs):
|
440 |
+
super(mit_b2, self).__init__(
|
441 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
442 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
443 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
444 |
+
|
445 |
+
|
446 |
+
#@BACKBONES.register_module()
|
447 |
+
class mit_b3(MixVisionTransformer):
|
448 |
+
def __init__(self, **kwargs):
|
449 |
+
super(mit_b3, self).__init__(
|
450 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
451 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
452 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
453 |
+
|
454 |
+
|
455 |
+
#@BACKBONES.register_module()
|
456 |
+
class mit_b4(MixVisionTransformer):
|
457 |
+
def __init__(self, **kwargs):
|
458 |
+
super(mit_b4, self).__init__(
|
459 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
460 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
461 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
462 |
+
|
463 |
+
|
464 |
+
#@BACKBONES.register_module()
|
465 |
+
class mit_b5(MixVisionTransformer):
|
466 |
+
def __init__(self, **kwargs):
|
467 |
+
super(mit_b5, self).__init__(
|
468 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
469 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
470 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
471 |
+
|
472 |
+
|
models/SpaTrackV2/models/depth_refiner/decode_head.py
ADDED
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
# from mmcv.cnn import normal_init
|
7 |
+
# from mmcv.runner import auto_fp16, force_fp32
|
8 |
+
|
9 |
+
# from mmseg.core import build_pixel_sampler
|
10 |
+
# from mmseg.ops import resize
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
14 |
+
"""Base class for BaseDecodeHead.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
in_channels (int|Sequence[int]): Input channels.
|
18 |
+
channels (int): Channels after modules, before conv_seg.
|
19 |
+
num_classes (int): Number of classes.
|
20 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
21 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
22 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
23 |
+
act_cfg (dict): Config of activation layers.
|
24 |
+
Default: dict(type='ReLU')
|
25 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
26 |
+
input_transform (str|None): Transformation type of input features.
|
27 |
+
Options: 'resize_concat', 'multiple_select', None.
|
28 |
+
'resize_concat': Multiple feature maps will be resize to the
|
29 |
+
same size as first one and than concat together.
|
30 |
+
Usually used in FCN head of HRNet.
|
31 |
+
'multiple_select': Multiple feature maps will be bundle into
|
32 |
+
a list and passed into decode head.
|
33 |
+
None: Only one select feature map is allowed.
|
34 |
+
Default: None.
|
35 |
+
loss_decode (dict): Config of decode loss.
|
36 |
+
Default: dict(type='CrossEntropyLoss').
|
37 |
+
ignore_index (int | None): The label index to be ignored. When using
|
38 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
39 |
+
sampler (dict|None): The config of segmentation map sampler.
|
40 |
+
Default: None.
|
41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
42 |
+
Default: False.
|
43 |
+
"""
|
44 |
+
|
45 |
+
def __init__(self,
|
46 |
+
in_channels,
|
47 |
+
channels,
|
48 |
+
*,
|
49 |
+
num_classes,
|
50 |
+
dropout_ratio=0.1,
|
51 |
+
conv_cfg=None,
|
52 |
+
norm_cfg=None,
|
53 |
+
act_cfg=dict(type='ReLU'),
|
54 |
+
in_index=-1,
|
55 |
+
input_transform=None,
|
56 |
+
loss_decode=dict(
|
57 |
+
type='CrossEntropyLoss',
|
58 |
+
use_sigmoid=False,
|
59 |
+
loss_weight=1.0),
|
60 |
+
decoder_params=None,
|
61 |
+
ignore_index=255,
|
62 |
+
sampler=None,
|
63 |
+
align_corners=False):
|
64 |
+
super(BaseDecodeHead, self).__init__()
|
65 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
66 |
+
self.channels = channels
|
67 |
+
self.num_classes = num_classes
|
68 |
+
self.dropout_ratio = dropout_ratio
|
69 |
+
self.conv_cfg = conv_cfg
|
70 |
+
self.norm_cfg = norm_cfg
|
71 |
+
self.act_cfg = act_cfg
|
72 |
+
self.in_index = in_index
|
73 |
+
self.ignore_index = ignore_index
|
74 |
+
self.align_corners = align_corners
|
75 |
+
|
76 |
+
if sampler is not None:
|
77 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
78 |
+
else:
|
79 |
+
self.sampler = None
|
80 |
+
|
81 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
82 |
+
if dropout_ratio > 0:
|
83 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
84 |
+
else:
|
85 |
+
self.dropout = None
|
86 |
+
self.fp16_enabled = False
|
87 |
+
|
88 |
+
def extra_repr(self):
|
89 |
+
"""Extra repr."""
|
90 |
+
s = f'input_transform={self.input_transform}, ' \
|
91 |
+
f'ignore_index={self.ignore_index}, ' \
|
92 |
+
f'align_corners={self.align_corners}'
|
93 |
+
return s
|
94 |
+
|
95 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
96 |
+
"""Check and initialize input transforms.
|
97 |
+
|
98 |
+
The in_channels, in_index and input_transform must match.
|
99 |
+
Specifically, when input_transform is None, only single feature map
|
100 |
+
will be selected. So in_channels and in_index must be of type int.
|
101 |
+
When input_transform
|
102 |
+
|
103 |
+
Args:
|
104 |
+
in_channels (int|Sequence[int]): Input channels.
|
105 |
+
in_index (int|Sequence[int]): Input feature index.
|
106 |
+
input_transform (str|None): Transformation type of input features.
|
107 |
+
Options: 'resize_concat', 'multiple_select', None.
|
108 |
+
'resize_concat': Multiple feature maps will be resize to the
|
109 |
+
same size as first one and than concat together.
|
110 |
+
Usually used in FCN head of HRNet.
|
111 |
+
'multiple_select': Multiple feature maps will be bundle into
|
112 |
+
a list and passed into decode head.
|
113 |
+
None: Only one select feature map is allowed.
|
114 |
+
"""
|
115 |
+
|
116 |
+
if input_transform is not None:
|
117 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
118 |
+
self.input_transform = input_transform
|
119 |
+
self.in_index = in_index
|
120 |
+
if input_transform is not None:
|
121 |
+
assert isinstance(in_channels, (list, tuple))
|
122 |
+
assert isinstance(in_index, (list, tuple))
|
123 |
+
assert len(in_channels) == len(in_index)
|
124 |
+
if input_transform == 'resize_concat':
|
125 |
+
self.in_channels = sum(in_channels)
|
126 |
+
else:
|
127 |
+
self.in_channels = in_channels
|
128 |
+
else:
|
129 |
+
assert isinstance(in_channels, int)
|
130 |
+
assert isinstance(in_index, int)
|
131 |
+
self.in_channels = in_channels
|
132 |
+
|
133 |
+
def init_weights(self):
|
134 |
+
"""Initialize weights of classification layer."""
|
135 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
136 |
+
|
137 |
+
def _transform_inputs(self, inputs):
|
138 |
+
"""Transform inputs for decoder.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
inputs (list[Tensor]): List of multi-level img features.
|
142 |
+
|
143 |
+
Returns:
|
144 |
+
Tensor: The transformed inputs
|
145 |
+
"""
|
146 |
+
|
147 |
+
if self.input_transform == 'resize_concat':
|
148 |
+
inputs = [inputs[i] for i in self.in_index]
|
149 |
+
upsampled_inputs = [
|
150 |
+
resize(
|
151 |
+
input=x,
|
152 |
+
size=inputs[0].shape[2:],
|
153 |
+
mode='bilinear',
|
154 |
+
align_corners=self.align_corners) for x in inputs
|
155 |
+
]
|
156 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
157 |
+
elif self.input_transform == 'multiple_select':
|
158 |
+
inputs = [inputs[i] for i in self.in_index]
|
159 |
+
else:
|
160 |
+
inputs = inputs[self.in_index]
|
161 |
+
|
162 |
+
return inputs
|
163 |
+
|
164 |
+
# @auto_fp16()
|
165 |
+
@abstractmethod
|
166 |
+
def forward(self, inputs):
|
167 |
+
"""Placeholder of forward function."""
|
168 |
+
pass
|
169 |
+
|
170 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
171 |
+
"""Forward function for training.
|
172 |
+
Args:
|
173 |
+
inputs (list[Tensor]): List of multi-level img features.
|
174 |
+
img_metas (list[dict]): List of image info dict where each dict
|
175 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
176 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
177 |
+
For details on the values of these keys see
|
178 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
179 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
180 |
+
used if the architecture supports semantic segmentation task.
|
181 |
+
train_cfg (dict): The training config.
|
182 |
+
|
183 |
+
Returns:
|
184 |
+
dict[str, Tensor]: a dictionary of loss components
|
185 |
+
"""
|
186 |
+
seg_logits = self.forward(inputs)
|
187 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
188 |
+
return losses
|
189 |
+
|
190 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
191 |
+
"""Forward function for testing.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
inputs (list[Tensor]): List of multi-level img features.
|
195 |
+
img_metas (list[dict]): List of image info dict where each dict
|
196 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
197 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
198 |
+
For details on the values of these keys see
|
199 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
200 |
+
test_cfg (dict): The testing config.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
Tensor: Output segmentation map.
|
204 |
+
"""
|
205 |
+
return self.forward(inputs)
|
206 |
+
|
207 |
+
def cls_seg(self, feat):
|
208 |
+
"""Classify each pixel."""
|
209 |
+
if self.dropout is not None:
|
210 |
+
feat = self.dropout(feat)
|
211 |
+
output = self.conv_seg(feat)
|
212 |
+
return output
|
213 |
+
|
214 |
+
|
215 |
+
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
|
216 |
+
"""Base class for BaseDecodeHead_clips.
|
217 |
+
|
218 |
+
Args:
|
219 |
+
in_channels (int|Sequence[int]): Input channels.
|
220 |
+
channels (int): Channels after modules, before conv_seg.
|
221 |
+
num_classes (int): Number of classes.
|
222 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
223 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
224 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
225 |
+
act_cfg (dict): Config of activation layers.
|
226 |
+
Default: dict(type='ReLU')
|
227 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
228 |
+
input_transform (str|None): Transformation type of input features.
|
229 |
+
Options: 'resize_concat', 'multiple_select', None.
|
230 |
+
'resize_concat': Multiple feature maps will be resize to the
|
231 |
+
same size as first one and than concat together.
|
232 |
+
Usually used in FCN head of HRNet.
|
233 |
+
'multiple_select': Multiple feature maps will be bundle into
|
234 |
+
a list and passed into decode head.
|
235 |
+
None: Only one select feature map is allowed.
|
236 |
+
Default: None.
|
237 |
+
loss_decode (dict): Config of decode loss.
|
238 |
+
Default: dict(type='CrossEntropyLoss').
|
239 |
+
ignore_index (int | None): The label index to be ignored. When using
|
240 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
241 |
+
sampler (dict|None): The config of segmentation map sampler.
|
242 |
+
Default: None.
|
243 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
244 |
+
Default: False.
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self,
|
248 |
+
in_channels,
|
249 |
+
channels,
|
250 |
+
*,
|
251 |
+
num_classes,
|
252 |
+
dropout_ratio=0.1,
|
253 |
+
conv_cfg=None,
|
254 |
+
norm_cfg=None,
|
255 |
+
act_cfg=dict(type='ReLU'),
|
256 |
+
in_index=-1,
|
257 |
+
input_transform=None,
|
258 |
+
loss_decode=dict(
|
259 |
+
type='CrossEntropyLoss',
|
260 |
+
use_sigmoid=False,
|
261 |
+
loss_weight=1.0),
|
262 |
+
decoder_params=None,
|
263 |
+
ignore_index=255,
|
264 |
+
sampler=None,
|
265 |
+
align_corners=False,
|
266 |
+
num_clips=5):
|
267 |
+
super(BaseDecodeHead_clips, self).__init__()
|
268 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
269 |
+
self.channels = channels
|
270 |
+
self.num_classes = num_classes
|
271 |
+
self.dropout_ratio = dropout_ratio
|
272 |
+
self.conv_cfg = conv_cfg
|
273 |
+
self.norm_cfg = norm_cfg
|
274 |
+
self.act_cfg = act_cfg
|
275 |
+
self.in_index = in_index
|
276 |
+
self.ignore_index = ignore_index
|
277 |
+
self.align_corners = align_corners
|
278 |
+
self.num_clips=num_clips
|
279 |
+
|
280 |
+
if sampler is not None:
|
281 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
282 |
+
else:
|
283 |
+
self.sampler = None
|
284 |
+
|
285 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
286 |
+
if dropout_ratio > 0:
|
287 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
288 |
+
else:
|
289 |
+
self.dropout = None
|
290 |
+
self.fp16_enabled = False
|
291 |
+
|
292 |
+
def extra_repr(self):
|
293 |
+
"""Extra repr."""
|
294 |
+
s = f'input_transform={self.input_transform}, ' \
|
295 |
+
f'ignore_index={self.ignore_index}, ' \
|
296 |
+
f'align_corners={self.align_corners}'
|
297 |
+
return s
|
298 |
+
|
299 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
300 |
+
"""Check and initialize input transforms.
|
301 |
+
|
302 |
+
The in_channels, in_index and input_transform must match.
|
303 |
+
Specifically, when input_transform is None, only single feature map
|
304 |
+
will be selected. So in_channels and in_index must be of type int.
|
305 |
+
When input_transform
|
306 |
+
|
307 |
+
Args:
|
308 |
+
in_channels (int|Sequence[int]): Input channels.
|
309 |
+
in_index (int|Sequence[int]): Input feature index.
|
310 |
+
input_transform (str|None): Transformation type of input features.
|
311 |
+
Options: 'resize_concat', 'multiple_select', None.
|
312 |
+
'resize_concat': Multiple feature maps will be resize to the
|
313 |
+
same size as first one and than concat together.
|
314 |
+
Usually used in FCN head of HRNet.
|
315 |
+
'multiple_select': Multiple feature maps will be bundle into
|
316 |
+
a list and passed into decode head.
|
317 |
+
None: Only one select feature map is allowed.
|
318 |
+
"""
|
319 |
+
|
320 |
+
if input_transform is not None:
|
321 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
322 |
+
self.input_transform = input_transform
|
323 |
+
self.in_index = in_index
|
324 |
+
if input_transform is not None:
|
325 |
+
assert isinstance(in_channels, (list, tuple))
|
326 |
+
assert isinstance(in_index, (list, tuple))
|
327 |
+
assert len(in_channels) == len(in_index)
|
328 |
+
if input_transform == 'resize_concat':
|
329 |
+
self.in_channels = sum(in_channels)
|
330 |
+
else:
|
331 |
+
self.in_channels = in_channels
|
332 |
+
else:
|
333 |
+
assert isinstance(in_channels, int)
|
334 |
+
assert isinstance(in_index, int)
|
335 |
+
self.in_channels = in_channels
|
336 |
+
|
337 |
+
def init_weights(self):
|
338 |
+
"""Initialize weights of classification layer."""
|
339 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
340 |
+
|
341 |
+
def _transform_inputs(self, inputs):
|
342 |
+
"""Transform inputs for decoder.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
inputs (list[Tensor]): List of multi-level img features.
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
Tensor: The transformed inputs
|
349 |
+
"""
|
350 |
+
|
351 |
+
if self.input_transform == 'resize_concat':
|
352 |
+
inputs = [inputs[i] for i in self.in_index]
|
353 |
+
upsampled_inputs = [
|
354 |
+
resize(
|
355 |
+
input=x,
|
356 |
+
size=inputs[0].shape[2:],
|
357 |
+
mode='bilinear',
|
358 |
+
align_corners=self.align_corners) for x in inputs
|
359 |
+
]
|
360 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
361 |
+
elif self.input_transform == 'multiple_select':
|
362 |
+
inputs = [inputs[i] for i in self.in_index]
|
363 |
+
else:
|
364 |
+
inputs = inputs[self.in_index]
|
365 |
+
|
366 |
+
return inputs
|
367 |
+
|
368 |
+
# @auto_fp16()
|
369 |
+
@abstractmethod
|
370 |
+
def forward(self, inputs):
|
371 |
+
"""Placeholder of forward function."""
|
372 |
+
pass
|
373 |
+
|
374 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
|
375 |
+
"""Forward function for training.
|
376 |
+
Args:
|
377 |
+
inputs (list[Tensor]): List of multi-level img features.
|
378 |
+
img_metas (list[dict]): List of image info dict where each dict
|
379 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
380 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
381 |
+
For details on the values of these keys see
|
382 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
383 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
384 |
+
used if the architecture supports semantic segmentation task.
|
385 |
+
train_cfg (dict): The training config.
|
386 |
+
|
387 |
+
Returns:
|
388 |
+
dict[str, Tensor]: a dictionary of loss components
|
389 |
+
"""
|
390 |
+
seg_logits = self.forward(inputs,batch_size, num_clips)
|
391 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
392 |
+
return losses
|
393 |
+
|
394 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
|
395 |
+
"""Forward function for testing.
|
396 |
+
|
397 |
+
Args:
|
398 |
+
inputs (list[Tensor]): List of multi-level img features.
|
399 |
+
img_metas (list[dict]): List of image info dict where each dict
|
400 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
401 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
402 |
+
For details on the values of these keys see
|
403 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
404 |
+
test_cfg (dict): The testing config.
|
405 |
+
|
406 |
+
Returns:
|
407 |
+
Tensor: Output segmentation map.
|
408 |
+
"""
|
409 |
+
return self.forward(inputs, batch_size, num_clips)
|
410 |
+
|
411 |
+
def cls_seg(self, feat):
|
412 |
+
"""Classify each pixel."""
|
413 |
+
if self.dropout is not None:
|
414 |
+
feat = self.dropout(feat)
|
415 |
+
output = self.conv_seg(feat)
|
416 |
+
return output
|
417 |
+
|
418 |
+
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
|
419 |
+
"""Base class for BaseDecodeHead_clips_flow.
|
420 |
+
|
421 |
+
Args:
|
422 |
+
in_channels (int|Sequence[int]): Input channels.
|
423 |
+
channels (int): Channels after modules, before conv_seg.
|
424 |
+
num_classes (int): Number of classes.
|
425 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
426 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
427 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
428 |
+
act_cfg (dict): Config of activation layers.
|
429 |
+
Default: dict(type='ReLU')
|
430 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
431 |
+
input_transform (str|None): Transformation type of input features.
|
432 |
+
Options: 'resize_concat', 'multiple_select', None.
|
433 |
+
'resize_concat': Multiple feature maps will be resize to the
|
434 |
+
same size as first one and than concat together.
|
435 |
+
Usually used in FCN head of HRNet.
|
436 |
+
'multiple_select': Multiple feature maps will be bundle into
|
437 |
+
a list and passed into decode head.
|
438 |
+
None: Only one select feature map is allowed.
|
439 |
+
Default: None.
|
440 |
+
loss_decode (dict): Config of decode loss.
|
441 |
+
Default: dict(type='CrossEntropyLoss').
|
442 |
+
ignore_index (int | None): The label index to be ignored. When using
|
443 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
444 |
+
sampler (dict|None): The config of segmentation map sampler.
|
445 |
+
Default: None.
|
446 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
447 |
+
Default: False.
|
448 |
+
"""
|
449 |
+
|
450 |
+
def __init__(self,
|
451 |
+
in_channels,
|
452 |
+
channels,
|
453 |
+
*,
|
454 |
+
num_classes,
|
455 |
+
dropout_ratio=0.1,
|
456 |
+
conv_cfg=None,
|
457 |
+
norm_cfg=None,
|
458 |
+
act_cfg=dict(type='ReLU'),
|
459 |
+
in_index=-1,
|
460 |
+
input_transform=None,
|
461 |
+
loss_decode=dict(
|
462 |
+
type='CrossEntropyLoss',
|
463 |
+
use_sigmoid=False,
|
464 |
+
loss_weight=1.0),
|
465 |
+
decoder_params=None,
|
466 |
+
ignore_index=255,
|
467 |
+
sampler=None,
|
468 |
+
align_corners=False,
|
469 |
+
num_clips=5):
|
470 |
+
super(BaseDecodeHead_clips_flow, self).__init__()
|
471 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
472 |
+
self.channels = channels
|
473 |
+
self.num_classes = num_classes
|
474 |
+
self.dropout_ratio = dropout_ratio
|
475 |
+
self.conv_cfg = conv_cfg
|
476 |
+
self.norm_cfg = norm_cfg
|
477 |
+
self.act_cfg = act_cfg
|
478 |
+
self.in_index = in_index
|
479 |
+
self.ignore_index = ignore_index
|
480 |
+
self.align_corners = align_corners
|
481 |
+
self.num_clips=num_clips
|
482 |
+
|
483 |
+
if sampler is not None:
|
484 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
485 |
+
else:
|
486 |
+
self.sampler = None
|
487 |
+
|
488 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
489 |
+
if dropout_ratio > 0:
|
490 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
491 |
+
else:
|
492 |
+
self.dropout = None
|
493 |
+
self.fp16_enabled = False
|
494 |
+
|
495 |
+
def extra_repr(self):
|
496 |
+
"""Extra repr."""
|
497 |
+
s = f'input_transform={self.input_transform}, ' \
|
498 |
+
f'ignore_index={self.ignore_index}, ' \
|
499 |
+
f'align_corners={self.align_corners}'
|
500 |
+
return s
|
501 |
+
|
502 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
503 |
+
"""Check and initialize input transforms.
|
504 |
+
|
505 |
+
The in_channels, in_index and input_transform must match.
|
506 |
+
Specifically, when input_transform is None, only single feature map
|
507 |
+
will be selected. So in_channels and in_index must be of type int.
|
508 |
+
When input_transform
|
509 |
+
|
510 |
+
Args:
|
511 |
+
in_channels (int|Sequence[int]): Input channels.
|
512 |
+
in_index (int|Sequence[int]): Input feature index.
|
513 |
+
input_transform (str|None): Transformation type of input features.
|
514 |
+
Options: 'resize_concat', 'multiple_select', None.
|
515 |
+
'resize_concat': Multiple feature maps will be resize to the
|
516 |
+
same size as first one and than concat together.
|
517 |
+
Usually used in FCN head of HRNet.
|
518 |
+
'multiple_select': Multiple feature maps will be bundle into
|
519 |
+
a list and passed into decode head.
|
520 |
+
None: Only one select feature map is allowed.
|
521 |
+
"""
|
522 |
+
|
523 |
+
if input_transform is not None:
|
524 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
525 |
+
self.input_transform = input_transform
|
526 |
+
self.in_index = in_index
|
527 |
+
if input_transform is not None:
|
528 |
+
assert isinstance(in_channels, (list, tuple))
|
529 |
+
assert isinstance(in_index, (list, tuple))
|
530 |
+
assert len(in_channels) == len(in_index)
|
531 |
+
if input_transform == 'resize_concat':
|
532 |
+
self.in_channels = sum(in_channels)
|
533 |
+
else:
|
534 |
+
self.in_channels = in_channels
|
535 |
+
else:
|
536 |
+
assert isinstance(in_channels, int)
|
537 |
+
assert isinstance(in_index, int)
|
538 |
+
self.in_channels = in_channels
|
539 |
+
|
540 |
+
def init_weights(self):
|
541 |
+
"""Initialize weights of classification layer."""
|
542 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
543 |
+
|
544 |
+
def _transform_inputs(self, inputs):
|
545 |
+
"""Transform inputs for decoder.
|
546 |
+
|
547 |
+
Args:
|
548 |
+
inputs (list[Tensor]): List of multi-level img features.
|
549 |
+
|
550 |
+
Returns:
|
551 |
+
Tensor: The transformed inputs
|
552 |
+
"""
|
553 |
+
|
554 |
+
if self.input_transform == 'resize_concat':
|
555 |
+
inputs = [inputs[i] for i in self.in_index]
|
556 |
+
upsampled_inputs = [
|
557 |
+
resize(
|
558 |
+
input=x,
|
559 |
+
size=inputs[0].shape[2:],
|
560 |
+
mode='bilinear',
|
561 |
+
align_corners=self.align_corners) for x in inputs
|
562 |
+
]
|
563 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
564 |
+
elif self.input_transform == 'multiple_select':
|
565 |
+
inputs = [inputs[i] for i in self.in_index]
|
566 |
+
else:
|
567 |
+
inputs = inputs[self.in_index]
|
568 |
+
|
569 |
+
return inputs
|
570 |
+
|
571 |
+
# @auto_fp16()
|
572 |
+
@abstractmethod
|
573 |
+
def forward(self, inputs):
|
574 |
+
"""Placeholder of forward function."""
|
575 |
+
pass
|
576 |
+
|
577 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
|
578 |
+
"""Forward function for training.
|
579 |
+
Args:
|
580 |
+
inputs (list[Tensor]): List of multi-level img features.
|
581 |
+
img_metas (list[dict]): List of image info dict where each dict
|
582 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
583 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
584 |
+
For details on the values of these keys see
|
585 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
586 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
587 |
+
used if the architecture supports semantic segmentation task.
|
588 |
+
train_cfg (dict): The training config.
|
589 |
+
|
590 |
+
Returns:
|
591 |
+
dict[str, Tensor]: a dictionary of loss components
|
592 |
+
"""
|
593 |
+
seg_logits = self.forward(inputs,batch_size, num_clips,img)
|
594 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
595 |
+
return losses
|
596 |
+
|
597 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
|
598 |
+
"""Forward function for testing.
|
599 |
+
|
600 |
+
Args:
|
601 |
+
inputs (list[Tensor]): List of multi-level img features.
|
602 |
+
img_metas (list[dict]): List of image info dict where each dict
|
603 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
604 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
605 |
+
For details on the values of these keys see
|
606 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
607 |
+
test_cfg (dict): The testing config.
|
608 |
+
|
609 |
+
Returns:
|
610 |
+
Tensor: Output segmentation map.
|
611 |
+
"""
|
612 |
+
return self.forward(inputs, batch_size, num_clips,img)
|
613 |
+
|
614 |
+
def cls_seg(self, feat):
|
615 |
+
"""Classify each pixel."""
|
616 |
+
if self.dropout is not None:
|
617 |
+
feat = self.dropout(feat)
|
618 |
+
output = self.conv_seg(feat)
|
619 |
+
return output
|
models/SpaTrackV2/models/depth_refiner/depth_refiner.py
ADDED
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
|
5 |
+
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
|
6 |
+
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
|
7 |
+
from einops import rearrange
|
8 |
+
class TrackStablizer(nn.Module):
|
9 |
+
def __init__(self):
|
10 |
+
super().__init__()
|
11 |
+
|
12 |
+
self.backbone = mit_b3()
|
13 |
+
|
14 |
+
old_conv = self.backbone.patch_embed1.proj
|
15 |
+
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
|
16 |
+
|
17 |
+
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
|
18 |
+
self.backbone.patch_embed1.proj = new_conv
|
19 |
+
|
20 |
+
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
|
21 |
+
in_index=[0, 1, 2, 3],
|
22 |
+
feature_strides=[4, 8, 16, 32],
|
23 |
+
channels=128,
|
24 |
+
dropout_ratio=0.1,
|
25 |
+
num_classes=1,
|
26 |
+
align_corners=False,
|
27 |
+
decoder_params=dict(embed_dim=256, depths=4),
|
28 |
+
num_clips=16,
|
29 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True))
|
30 |
+
|
31 |
+
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
|
32 |
+
nn.ReLU(inplace=True))
|
33 |
+
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
|
34 |
+
nn.ReLU(inplace=True))
|
35 |
+
self.success = False
|
36 |
+
self.x = None
|
37 |
+
|
38 |
+
def buffer_forward(self, inputs, num_clips=16):
|
39 |
+
"""
|
40 |
+
buffer forward for getting the pointmap and image features
|
41 |
+
"""
|
42 |
+
B, T, C, H, W = inputs.shape
|
43 |
+
self.x = self.backbone(inputs)
|
44 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
|
45 |
+
self.success = True
|
46 |
+
return scale, shift
|
47 |
+
|
48 |
+
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
|
49 |
+
|
50 |
+
"""
|
51 |
+
Args:
|
52 |
+
inputs: [B, T, C, H, W], RGB + PointMap + Mask
|
53 |
+
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
|
54 |
+
num_clips: int, number of clips to use
|
55 |
+
"""
|
56 |
+
B, T, C, H, W = inputs.shape
|
57 |
+
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
|
58 |
+
edge_feat1 = self.edge_conv1(edge_feat)
|
59 |
+
|
60 |
+
if not self.success:
|
61 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
|
62 |
+
self.success = True
|
63 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
64 |
+
else:
|
65 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
66 |
+
|
67 |
+
return update
|
68 |
+
|
69 |
+
def reset_success(self):
|
70 |
+
self.success = False
|
71 |
+
self.x = None
|
72 |
+
self.Track_Stabilizer.reset_success()
|
73 |
+
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
# Create test input tensors
|
77 |
+
batch_size = 1
|
78 |
+
seq_len = 16
|
79 |
+
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
|
80 |
+
height = 384
|
81 |
+
width = 512
|
82 |
+
|
83 |
+
# Create random input tensor with shape [B, T, C, H, W]
|
84 |
+
inputs = torch.randn(batch_size, seq_len, channels, height, width)
|
85 |
+
|
86 |
+
# Create random tracks
|
87 |
+
tracks = torch.randn(batch_size, seq_len, 1024, 4)
|
88 |
+
|
89 |
+
# Create random test images
|
90 |
+
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
|
91 |
+
|
92 |
+
# Initialize model and move to GPU
|
93 |
+
model = TrackStablizer().cuda()
|
94 |
+
|
95 |
+
# Move inputs to GPU and run forward pass
|
96 |
+
inputs = inputs.cuda()
|
97 |
+
tracks = tracks.cuda()
|
98 |
+
outputs = model.buffer_forward(inputs, num_clips=seq_len)
|
99 |
+
import time
|
100 |
+
start_time = time.time()
|
101 |
+
outputs = model(inputs, tracks, num_clips=seq_len)
|
102 |
+
end_time = time.time()
|
103 |
+
print(f"Time taken: {end_time - start_time} seconds")
|
104 |
+
import pdb; pdb.set_trace()
|
105 |
+
# # Print shapes for verification
|
106 |
+
# print(f"Input shape: {inputs.shape}")
|
107 |
+
# print(f"Output shape: {outputs.shape}")
|
108 |
+
|
109 |
+
# # Basic tests
|
110 |
+
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
|
111 |
+
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
|
112 |
+
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
|
113 |
+
|
114 |
+
# print("All tests passed!")
|
115 |
+
|
models/SpaTrackV2/models/depth_refiner/network.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
'''
|
4 |
+
Author: Ke Xian
|
5 |
+
Email: [email protected]
|
6 |
+
Date: 2020/07/20
|
7 |
+
'''
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.init as init
|
12 |
+
|
13 |
+
# ==============================================================================================================
|
14 |
+
|
15 |
+
class FTB(nn.Module):
|
16 |
+
def __init__(self, inchannels, midchannels=512):
|
17 |
+
super(FTB, self).__init__()
|
18 |
+
self.in1 = inchannels
|
19 |
+
self.mid = midchannels
|
20 |
+
|
21 |
+
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
|
22 |
+
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
|
23 |
+
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
|
24 |
+
#nn.BatchNorm2d(num_features=self.mid),\
|
25 |
+
nn.ReLU(inplace=True),\
|
26 |
+
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
|
29 |
+
self.init_params()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.conv1(x)
|
33 |
+
x = x + self.conv_branch(x)
|
34 |
+
x = self.relu(x)
|
35 |
+
|
36 |
+
return x
|
37 |
+
|
38 |
+
def init_params(self):
|
39 |
+
for m in self.modules():
|
40 |
+
if isinstance(m, nn.Conv2d):
|
41 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
42 |
+
init.normal_(m.weight, std=0.01)
|
43 |
+
# init.xavier_normal_(m.weight)
|
44 |
+
if m.bias is not None:
|
45 |
+
init.constant_(m.bias, 0)
|
46 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
47 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
48 |
+
init.normal_(m.weight, std=0.01)
|
49 |
+
# init.xavier_normal_(m.weight)
|
50 |
+
if m.bias is not None:
|
51 |
+
init.constant_(m.bias, 0)
|
52 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
53 |
+
init.constant_(m.weight, 1)
|
54 |
+
init.constant_(m.bias, 0)
|
55 |
+
elif isinstance(m, nn.Linear):
|
56 |
+
init.normal_(m.weight, std=0.01)
|
57 |
+
if m.bias is not None:
|
58 |
+
init.constant_(m.bias, 0)
|
59 |
+
|
60 |
+
class ATA(nn.Module):
|
61 |
+
def __init__(self, inchannels, reduction = 8):
|
62 |
+
super(ATA, self).__init__()
|
63 |
+
self.inchannels = inchannels
|
64 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
65 |
+
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
|
66 |
+
nn.ReLU(inplace=True),
|
67 |
+
nn.Linear(self.inchannels // reduction, self.inchannels),
|
68 |
+
nn.Sigmoid())
|
69 |
+
self.init_params()
|
70 |
+
|
71 |
+
def forward(self, low_x, high_x):
|
72 |
+
n, c, _, _ = low_x.size()
|
73 |
+
x = torch.cat([low_x, high_x], 1)
|
74 |
+
x = self.avg_pool(x)
|
75 |
+
x = x.view(n, -1)
|
76 |
+
x = self.fc(x).view(n,c,1,1)
|
77 |
+
x = low_x * x + high_x
|
78 |
+
|
79 |
+
return x
|
80 |
+
|
81 |
+
def init_params(self):
|
82 |
+
for m in self.modules():
|
83 |
+
if isinstance(m, nn.Conv2d):
|
84 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
85 |
+
#init.normal(m.weight, std=0.01)
|
86 |
+
init.xavier_normal_(m.weight)
|
87 |
+
if m.bias is not None:
|
88 |
+
init.constant_(m.bias, 0)
|
89 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
90 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
91 |
+
#init.normal_(m.weight, std=0.01)
|
92 |
+
init.xavier_normal_(m.weight)
|
93 |
+
if m.bias is not None:
|
94 |
+
init.constant_(m.bias, 0)
|
95 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
96 |
+
init.constant_(m.weight, 1)
|
97 |
+
init.constant_(m.bias, 0)
|
98 |
+
elif isinstance(m, nn.Linear):
|
99 |
+
init.normal_(m.weight, std=0.01)
|
100 |
+
if m.bias is not None:
|
101 |
+
init.constant_(m.bias, 0)
|
102 |
+
|
103 |
+
|
104 |
+
class FFM(nn.Module):
|
105 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
106 |
+
super(FFM, self).__init__()
|
107 |
+
self.inchannels = inchannels
|
108 |
+
self.midchannels = midchannels
|
109 |
+
self.outchannels = outchannels
|
110 |
+
self.upfactor = upfactor
|
111 |
+
|
112 |
+
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
113 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
114 |
+
|
115 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
116 |
+
|
117 |
+
self.init_params()
|
118 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
119 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
120 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
121 |
+
|
122 |
+
def forward(self, low_x, high_x):
|
123 |
+
x = self.ftb1(low_x)
|
124 |
+
|
125 |
+
'''
|
126 |
+
x = torch.cat((x,high_x),1)
|
127 |
+
if x.shape[2] == 12:
|
128 |
+
x = self.p1(x)
|
129 |
+
elif x.shape[2] == 24:
|
130 |
+
x = self.p2(x)
|
131 |
+
elif x.shape[2] == 48:
|
132 |
+
x = self.p3(x)
|
133 |
+
'''
|
134 |
+
x = x + high_x ###high_x
|
135 |
+
x = self.ftb2(x)
|
136 |
+
x = self.upsample(x)
|
137 |
+
|
138 |
+
return x
|
139 |
+
|
140 |
+
def init_params(self):
|
141 |
+
for m in self.modules():
|
142 |
+
if isinstance(m, nn.Conv2d):
|
143 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
144 |
+
init.normal_(m.weight, std=0.01)
|
145 |
+
#init.xavier_normal_(m.weight)
|
146 |
+
if m.bias is not None:
|
147 |
+
init.constant_(m.bias, 0)
|
148 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
149 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
150 |
+
init.normal_(m.weight, std=0.01)
|
151 |
+
#init.xavier_normal_(m.weight)
|
152 |
+
if m.bias is not None:
|
153 |
+
init.constant_(m.bias, 0)
|
154 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
155 |
+
init.constant_(m.weight, 1)
|
156 |
+
init.constant_(m.bias, 0)
|
157 |
+
elif isinstance(m, nn.Linear):
|
158 |
+
init.normal_(m.weight, std=0.01)
|
159 |
+
if m.bias is not None:
|
160 |
+
init.constant_(m.bias, 0)
|
161 |
+
|
162 |
+
|
163 |
+
|
164 |
+
class noFFM(nn.Module):
|
165 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
166 |
+
super(noFFM, self).__init__()
|
167 |
+
self.inchannels = inchannels
|
168 |
+
self.midchannels = midchannels
|
169 |
+
self.outchannels = outchannels
|
170 |
+
self.upfactor = upfactor
|
171 |
+
|
172 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
173 |
+
|
174 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
175 |
+
|
176 |
+
self.init_params()
|
177 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
178 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
179 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
180 |
+
|
181 |
+
def forward(self, low_x, high_x):
|
182 |
+
|
183 |
+
#x = self.ftb1(low_x)
|
184 |
+
x = high_x ###high_x
|
185 |
+
x = self.ftb2(x)
|
186 |
+
x = self.upsample(x)
|
187 |
+
|
188 |
+
return x
|
189 |
+
|
190 |
+
def init_params(self):
|
191 |
+
for m in self.modules():
|
192 |
+
if isinstance(m, nn.Conv2d):
|
193 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
194 |
+
init.normal_(m.weight, std=0.01)
|
195 |
+
#init.xavier_normal_(m.weight)
|
196 |
+
if m.bias is not None:
|
197 |
+
init.constant_(m.bias, 0)
|
198 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
199 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
200 |
+
init.normal_(m.weight, std=0.01)
|
201 |
+
#init.xavier_normal_(m.weight)
|
202 |
+
if m.bias is not None:
|
203 |
+
init.constant_(m.bias, 0)
|
204 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
205 |
+
init.constant_(m.weight, 1)
|
206 |
+
init.constant_(m.bias, 0)
|
207 |
+
elif isinstance(m, nn.Linear):
|
208 |
+
init.normal_(m.weight, std=0.01)
|
209 |
+
if m.bias is not None:
|
210 |
+
init.constant_(m.bias, 0)
|
211 |
+
|
212 |
+
|
213 |
+
|
214 |
+
|
215 |
+
class AO(nn.Module):
|
216 |
+
# Adaptive output module
|
217 |
+
def __init__(self, inchannels, outchannels, upfactor=2):
|
218 |
+
super(AO, self).__init__()
|
219 |
+
self.inchannels = inchannels
|
220 |
+
self.outchannels = outchannels
|
221 |
+
self.upfactor = upfactor
|
222 |
+
|
223 |
+
"""
|
224 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
225 |
+
nn.BatchNorm2d(num_features=self.inchannels//2),\
|
226 |
+
nn.ReLU(inplace=True),\
|
227 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
|
228 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
|
229 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
230 |
+
"""
|
231 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
232 |
+
#nn.BatchNorm2d(num_features=self.inchannels//2),\
|
233 |
+
nn.ReLU(inplace=True),\
|
234 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
|
235 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
|
236 |
+
|
237 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
238 |
+
|
239 |
+
self.init_params()
|
240 |
+
|
241 |
+
def forward(self, x):
|
242 |
+
x = self.adapt_conv(x)
|
243 |
+
return x
|
244 |
+
|
245 |
+
def init_params(self):
|
246 |
+
for m in self.modules():
|
247 |
+
if isinstance(m, nn.Conv2d):
|
248 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
249 |
+
init.normal_(m.weight, std=0.01)
|
250 |
+
#init.xavier_normal_(m.weight)
|
251 |
+
if m.bias is not None:
|
252 |
+
init.constant_(m.bias, 0)
|
253 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
254 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
255 |
+
init.normal_(m.weight, std=0.01)
|
256 |
+
#init.xavier_normal_(m.weight)
|
257 |
+
if m.bias is not None:
|
258 |
+
init.constant_(m.bias, 0)
|
259 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
260 |
+
init.constant_(m.weight, 1)
|
261 |
+
init.constant_(m.bias, 0)
|
262 |
+
elif isinstance(m, nn.Linear):
|
263 |
+
init.normal_(m.weight, std=0.01)
|
264 |
+
if m.bias is not None:
|
265 |
+
init.constant_(m.bias, 0)
|
266 |
+
|
267 |
+
class ASPP(nn.Module):
|
268 |
+
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
|
269 |
+
super(ASPP, self).__init__()
|
270 |
+
self.inchannels = inchannels
|
271 |
+
self.planes = planes
|
272 |
+
self.rates = rates
|
273 |
+
self.kernel_sizes = []
|
274 |
+
self.paddings = []
|
275 |
+
for rate in self.rates:
|
276 |
+
if rate == 1:
|
277 |
+
self.kernel_sizes.append(1)
|
278 |
+
self.paddings.append(0)
|
279 |
+
else:
|
280 |
+
self.kernel_sizes.append(3)
|
281 |
+
self.paddings.append(rate)
|
282 |
+
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
|
283 |
+
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
|
284 |
+
nn.ReLU(inplace=True),
|
285 |
+
nn.BatchNorm2d(num_features=self.planes)
|
286 |
+
)
|
287 |
+
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
|
288 |
+
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
|
289 |
+
nn.ReLU(inplace=True),
|
290 |
+
nn.BatchNorm2d(num_features=self.planes),
|
291 |
+
)
|
292 |
+
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
|
293 |
+
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
|
294 |
+
nn.ReLU(inplace=True),
|
295 |
+
nn.BatchNorm2d(num_features=self.planes),
|
296 |
+
)
|
297 |
+
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
|
298 |
+
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
|
299 |
+
nn.ReLU(inplace=True),
|
300 |
+
nn.BatchNorm2d(num_features=self.planes),
|
301 |
+
)
|
302 |
+
|
303 |
+
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
|
304 |
+
def forward(self, x):
|
305 |
+
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
|
306 |
+
#x = self.conv(x)
|
307 |
+
|
308 |
+
return x
|
309 |
+
|
310 |
+
# ==============================================================================================================
|
311 |
+
|
312 |
+
|
313 |
+
class ResidualConv(nn.Module):
|
314 |
+
def __init__(self, inchannels):
|
315 |
+
super(ResidualConv, self).__init__()
|
316 |
+
#nn.BatchNorm2d
|
317 |
+
self.conv = nn.Sequential(
|
318 |
+
#nn.BatchNorm2d(num_features=inchannels),
|
319 |
+
nn.ReLU(inplace=False),
|
320 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
321 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
322 |
+
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
|
323 |
+
nn.BatchNorm2d(num_features=inchannels//2),
|
324 |
+
nn.ReLU(inplace=False),
|
325 |
+
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
|
326 |
+
)
|
327 |
+
self.init_params()
|
328 |
+
|
329 |
+
def forward(self, x):
|
330 |
+
x = self.conv(x)+x
|
331 |
+
return x
|
332 |
+
|
333 |
+
def init_params(self):
|
334 |
+
for m in self.modules():
|
335 |
+
if isinstance(m, nn.Conv2d):
|
336 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
337 |
+
init.normal_(m.weight, std=0.01)
|
338 |
+
#init.xavier_normal_(m.weight)
|
339 |
+
if m.bias is not None:
|
340 |
+
init.constant_(m.bias, 0)
|
341 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
342 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
343 |
+
init.normal_(m.weight, std=0.01)
|
344 |
+
#init.xavier_normal_(m.weight)
|
345 |
+
if m.bias is not None:
|
346 |
+
init.constant_(m.bias, 0)
|
347 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
348 |
+
init.constant_(m.weight, 1)
|
349 |
+
init.constant_(m.bias, 0)
|
350 |
+
elif isinstance(m, nn.Linear):
|
351 |
+
init.normal_(m.weight, std=0.01)
|
352 |
+
if m.bias is not None:
|
353 |
+
init.constant_(m.bias, 0)
|
354 |
+
|
355 |
+
|
356 |
+
class FeatureFusion(nn.Module):
|
357 |
+
def __init__(self, inchannels, outchannels):
|
358 |
+
super(FeatureFusion, self).__init__()
|
359 |
+
self.conv = ResidualConv(inchannels=inchannels)
|
360 |
+
#nn.BatchNorm2d
|
361 |
+
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
362 |
+
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
|
363 |
+
nn.BatchNorm2d(num_features=outchannels),
|
364 |
+
nn.ReLU(inplace=True))
|
365 |
+
|
366 |
+
def forward(self, lowfeat, highfeat):
|
367 |
+
return self.up(highfeat + self.conv(lowfeat))
|
368 |
+
|
369 |
+
def init_params(self):
|
370 |
+
for m in self.modules():
|
371 |
+
if isinstance(m, nn.Conv2d):
|
372 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
373 |
+
init.normal_(m.weight, std=0.01)
|
374 |
+
#init.xavier_normal_(m.weight)
|
375 |
+
if m.bias is not None:
|
376 |
+
init.constant_(m.bias, 0)
|
377 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
378 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
379 |
+
init.normal_(m.weight, std=0.01)
|
380 |
+
#init.xavier_normal_(m.weight)
|
381 |
+
if m.bias is not None:
|
382 |
+
init.constant_(m.bias, 0)
|
383 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
384 |
+
init.constant_(m.weight, 1)
|
385 |
+
init.constant_(m.bias, 0)
|
386 |
+
elif isinstance(m, nn.Linear):
|
387 |
+
init.normal_(m.weight, std=0.01)
|
388 |
+
if m.bias is not None:
|
389 |
+
init.constant_(m.bias, 0)
|
390 |
+
|
391 |
+
|
392 |
+
class SenceUnderstand(nn.Module):
|
393 |
+
def __init__(self, channels):
|
394 |
+
super(SenceUnderstand, self).__init__()
|
395 |
+
self.channels = channels
|
396 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
397 |
+
nn.ReLU(inplace = True))
|
398 |
+
self.pool = nn.AdaptiveAvgPool2d(8)
|
399 |
+
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
|
400 |
+
nn.ReLU(inplace = True))
|
401 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
402 |
+
nn.ReLU(inplace=True))
|
403 |
+
self.initial_params()
|
404 |
+
|
405 |
+
def forward(self, x):
|
406 |
+
n,c,h,w = x.size()
|
407 |
+
x = self.conv1(x)
|
408 |
+
x = self.pool(x)
|
409 |
+
x = x.view(n,-1)
|
410 |
+
x = self.fc(x)
|
411 |
+
x = x.view(n, self.channels, 1, 1)
|
412 |
+
x = self.conv2(x)
|
413 |
+
x = x.repeat(1,1,h,w)
|
414 |
+
return x
|
415 |
+
|
416 |
+
def initial_params(self, dev=0.01):
|
417 |
+
for m in self.modules():
|
418 |
+
if isinstance(m, nn.Conv2d):
|
419 |
+
#print torch.sum(m.weight)
|
420 |
+
m.weight.data.normal_(0, dev)
|
421 |
+
if m.bias is not None:
|
422 |
+
m.bias.data.fill_(0)
|
423 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
424 |
+
#print torch.sum(m.weight)
|
425 |
+
m.weight.data.normal_(0, dev)
|
426 |
+
if m.bias is not None:
|
427 |
+
m.bias.data.fill_(0)
|
428 |
+
elif isinstance(m, nn.Linear):
|
429 |
+
m.weight.data.normal_(0, dev)
|
models/SpaTrackV2/models/depth_refiner/stablilization_attention.py
ADDED
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import time
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch.utils.checkpoint as checkpoint
|
7 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
8 |
+
from einops import rearrange
|
9 |
+
|
10 |
+
class Mlp(nn.Module):
|
11 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
12 |
+
super().__init__()
|
13 |
+
out_features = out_features or in_features
|
14 |
+
hidden_features = hidden_features or in_features
|
15 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
16 |
+
self.act = act_layer()
|
17 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
18 |
+
self.drop = nn.Dropout(drop)
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
x = self.fc1(x)
|
22 |
+
x = self.act(x)
|
23 |
+
x = self.drop(x)
|
24 |
+
x = self.fc2(x)
|
25 |
+
x = self.drop(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def window_partition(x, window_size):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
x: (B, H, W, C)
|
33 |
+
window_size (int): window size
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
windows: (num_windows*B, window_size, window_size, C)
|
37 |
+
"""
|
38 |
+
B, H, W, C = x.shape
|
39 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
40 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
|
41 |
+
return windows
|
42 |
+
|
43 |
+
def window_partition_noreshape(x, window_size):
|
44 |
+
"""
|
45 |
+
Args:
|
46 |
+
x: (B, H, W, C)
|
47 |
+
window_size (int): window size
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
windows: (B, num_windows_h, num_windows_w, window_size, window_size, C)
|
51 |
+
"""
|
52 |
+
B, H, W, C = x.shape
|
53 |
+
x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
|
54 |
+
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
|
55 |
+
return windows
|
56 |
+
|
57 |
+
def window_reverse(windows, window_size, H, W):
|
58 |
+
"""
|
59 |
+
Args:
|
60 |
+
windows: (num_windows*B, window_size, window_size, C)
|
61 |
+
window_size (int): Window size
|
62 |
+
H (int): Height of image
|
63 |
+
W (int): Width of image
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
x: (B, H, W, C)
|
67 |
+
"""
|
68 |
+
B = int(windows.shape[0] / (H * W / window_size / window_size))
|
69 |
+
x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
|
70 |
+
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
|
71 |
+
return x
|
72 |
+
|
73 |
+
def get_roll_masks(H, W, window_size, shift_size):
|
74 |
+
#####################################
|
75 |
+
# move to top-left
|
76 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
77 |
+
h_slices = (slice(0, H-window_size),
|
78 |
+
slice(H-window_size, H-shift_size),
|
79 |
+
slice(H-shift_size, H))
|
80 |
+
w_slices = (slice(0, W-window_size),
|
81 |
+
slice(W-window_size, W-shift_size),
|
82 |
+
slice(W-shift_size, W))
|
83 |
+
cnt = 0
|
84 |
+
for h in h_slices:
|
85 |
+
for w in w_slices:
|
86 |
+
img_mask[:, h, w, :] = cnt
|
87 |
+
cnt += 1
|
88 |
+
|
89 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
90 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
91 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
92 |
+
attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
93 |
+
|
94 |
+
####################################
|
95 |
+
# move to top right
|
96 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
97 |
+
h_slices = (slice(0, H-window_size),
|
98 |
+
slice(H-window_size, H-shift_size),
|
99 |
+
slice(H-shift_size, H))
|
100 |
+
w_slices = (slice(0, shift_size),
|
101 |
+
slice(shift_size, window_size),
|
102 |
+
slice(window_size, W))
|
103 |
+
cnt = 0
|
104 |
+
for h in h_slices:
|
105 |
+
for w in w_slices:
|
106 |
+
img_mask[:, h, w, :] = cnt
|
107 |
+
cnt += 1
|
108 |
+
|
109 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
110 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
111 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
112 |
+
attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
113 |
+
|
114 |
+
####################################
|
115 |
+
# move to bottom left
|
116 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
117 |
+
h_slices = (slice(0, shift_size),
|
118 |
+
slice(shift_size, window_size),
|
119 |
+
slice(window_size, H))
|
120 |
+
w_slices = (slice(0, W-window_size),
|
121 |
+
slice(W-window_size, W-shift_size),
|
122 |
+
slice(W-shift_size, W))
|
123 |
+
cnt = 0
|
124 |
+
for h in h_slices:
|
125 |
+
for w in w_slices:
|
126 |
+
img_mask[:, h, w, :] = cnt
|
127 |
+
cnt += 1
|
128 |
+
|
129 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
130 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
131 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
132 |
+
attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
133 |
+
|
134 |
+
####################################
|
135 |
+
# move to bottom right
|
136 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
137 |
+
h_slices = (slice(0, shift_size),
|
138 |
+
slice(shift_size, window_size),
|
139 |
+
slice(window_size, H))
|
140 |
+
w_slices = (slice(0, shift_size),
|
141 |
+
slice(shift_size, window_size),
|
142 |
+
slice(window_size, W))
|
143 |
+
cnt = 0
|
144 |
+
for h in h_slices:
|
145 |
+
for w in w_slices:
|
146 |
+
img_mask[:, h, w, :] = cnt
|
147 |
+
cnt += 1
|
148 |
+
|
149 |
+
mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
|
150 |
+
mask_windows = mask_windows.view(-1, window_size * window_size)
|
151 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
152 |
+
attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
153 |
+
|
154 |
+
# append all
|
155 |
+
attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1)
|
156 |
+
return attn_mask_all
|
157 |
+
|
158 |
+
def get_relative_position_index(q_windows, k_windows):
|
159 |
+
"""
|
160 |
+
Args:
|
161 |
+
q_windows: tuple (query_window_height, query_window_width)
|
162 |
+
k_windows: tuple (key_window_height, key_window_width)
|
163 |
+
|
164 |
+
Returns:
|
165 |
+
relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
|
166 |
+
"""
|
167 |
+
# get pair-wise relative position index for each token inside the window
|
168 |
+
coords_h_q = torch.arange(q_windows[0])
|
169 |
+
coords_w_q = torch.arange(q_windows[1])
|
170 |
+
coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
|
171 |
+
|
172 |
+
coords_h_k = torch.arange(k_windows[0])
|
173 |
+
coords_w_k = torch.arange(k_windows[1])
|
174 |
+
coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww
|
175 |
+
|
176 |
+
coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
|
177 |
+
coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
|
178 |
+
|
179 |
+
relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
|
180 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
|
181 |
+
relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0
|
182 |
+
relative_coords[:, :, 1] += k_windows[1] - 1
|
183 |
+
relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1
|
184 |
+
relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
|
185 |
+
return relative_position_index
|
186 |
+
|
187 |
+
def get_relative_position_index3d(q_windows, k_windows, num_clips):
|
188 |
+
"""
|
189 |
+
Args:
|
190 |
+
q_windows: tuple (query_window_height, query_window_width)
|
191 |
+
k_windows: tuple (key_window_height, key_window_width)
|
192 |
+
|
193 |
+
Returns:
|
194 |
+
relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
|
195 |
+
"""
|
196 |
+
# get pair-wise relative position index for each token inside the window
|
197 |
+
coords_d_q = torch.arange(num_clips)
|
198 |
+
coords_h_q = torch.arange(q_windows[0])
|
199 |
+
coords_w_q = torch.arange(q_windows[1])
|
200 |
+
coords_q = torch.stack(torch.meshgrid([coords_d_q, coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
|
201 |
+
|
202 |
+
coords_d_k = torch.arange(num_clips)
|
203 |
+
coords_h_k = torch.arange(k_windows[0])
|
204 |
+
coords_w_k = torch.arange(k_windows[1])
|
205 |
+
coords_k = torch.stack(torch.meshgrid([coords_d_k, coords_h_k, coords_w_k])) # 2, Wh, Ww
|
206 |
+
|
207 |
+
coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
|
208 |
+
coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
|
209 |
+
|
210 |
+
relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
|
211 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
|
212 |
+
relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
|
213 |
+
relative_coords[:, :, 1] += k_windows[0] - 1
|
214 |
+
relative_coords[:, :, 2] += k_windows[1] - 1
|
215 |
+
relative_coords[:, :, 0] *= (q_windows[0] + k_windows[0] - 1)*(q_windows[1] + k_windows[1] - 1)
|
216 |
+
relative_coords[:, :, 1] *= (q_windows[1] + k_windows[1] - 1)
|
217 |
+
relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
|
218 |
+
return relative_position_index
|
219 |
+
|
220 |
+
|
221 |
+
class WindowAttention3d3(nn.Module):
|
222 |
+
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
|
223 |
+
|
224 |
+
Args:
|
225 |
+
dim (int): Number of input channels.
|
226 |
+
expand_size (int): The expand size at focal level 1.
|
227 |
+
window_size (tuple[int]): The height and width of the window.
|
228 |
+
focal_window (int): Focal region size.
|
229 |
+
focal_level (int): Focal attention level.
|
230 |
+
num_heads (int): Number of attention heads.
|
231 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
232 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
|
233 |
+
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
|
234 |
+
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
|
235 |
+
pool_method (str): window pooling method. Default: none
|
236 |
+
"""
|
237 |
+
|
238 |
+
def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads,
|
239 |
+
qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none", focal_l_clips=[7,1,2], focal_kernel_clips=[7,5,3]):
|
240 |
+
|
241 |
+
super().__init__()
|
242 |
+
self.dim = dim
|
243 |
+
self.expand_size = expand_size
|
244 |
+
self.window_size = window_size # Wh, Ww
|
245 |
+
self.pool_method = pool_method
|
246 |
+
self.num_heads = num_heads
|
247 |
+
head_dim = dim // num_heads
|
248 |
+
self.scale = qk_scale or head_dim ** -0.5
|
249 |
+
self.focal_level = focal_level
|
250 |
+
self.focal_window = focal_window
|
251 |
+
|
252 |
+
# define a parameter table of relative position bias for each window
|
253 |
+
self.relative_position_bias_table = nn.Parameter(
|
254 |
+
torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
|
255 |
+
|
256 |
+
# get pair-wise relative position index for each token inside the window
|
257 |
+
coords_h = torch.arange(self.window_size[0])
|
258 |
+
coords_w = torch.arange(self.window_size[1])
|
259 |
+
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
260 |
+
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
261 |
+
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
|
262 |
+
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
|
263 |
+
relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
|
264 |
+
relative_coords[:, :, 1] += self.window_size[1] - 1
|
265 |
+
relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
|
266 |
+
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
267 |
+
self.register_buffer("relative_position_index", relative_position_index)
|
268 |
+
|
269 |
+
num_clips=4
|
270 |
+
# # define a parameter table of relative position bias
|
271 |
+
# self.relative_position_bias_table = nn.Parameter(
|
272 |
+
# torch.zeros((2 * num_clips - 1) * (2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wd-1 * 2*Wh-1 * 2*Ww-1, nH
|
273 |
+
|
274 |
+
# # get pair-wise relative position index for each token inside the window
|
275 |
+
# coords_d = torch.arange(num_clips)
|
276 |
+
# coords_h = torch.arange(self.window_size[0])
|
277 |
+
# coords_w = torch.arange(self.window_size[1])
|
278 |
+
# coords = torch.stack(torch.meshgrid(coords_d, coords_h, coords_w)) # 3, Wd, Wh, Ww
|
279 |
+
# coords_flatten = torch.flatten(coords, 1) # 3, Wd*Wh*Ww
|
280 |
+
# relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 3, Wd*Wh*Ww, Wd*Wh*Ww
|
281 |
+
# relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wd*Wh*Ww, Wd*Wh*Ww, 3
|
282 |
+
# relative_coords[:, :, 0] += num_clips - 1 # shift to start from 0
|
283 |
+
# relative_coords[:, :, 1] += self.window_size[0] - 1
|
284 |
+
# relative_coords[:, :, 2] += self.window_size[1] - 1
|
285 |
+
|
286 |
+
# relative_coords[:, :, 0] *= (2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1)
|
287 |
+
# relative_coords[:, :, 1] *= (2 * self.window_size[1] - 1)
|
288 |
+
# relative_position_index = relative_coords.sum(-1) # Wd*Wh*Ww, Wd*Wh*Ww
|
289 |
+
# self.register_buffer("relative_position_index", relative_position_index)
|
290 |
+
|
291 |
+
|
292 |
+
if self.expand_size > 0 and focal_level > 0:
|
293 |
+
# define a parameter table of position bias between window and its fine-grained surroundings
|
294 |
+
self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \
|
295 |
+
(4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (self.window_size[0] - self.expand_size))
|
296 |
+
self.relative_position_bias_table_to_neighbors = nn.Parameter(
|
297 |
+
torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1], self.window_size_of_key)) # Wh*Ww, nH, nSurrounding
|
298 |
+
trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02)
|
299 |
+
|
300 |
+
# get mask for rolled k and rolled v
|
301 |
+
mask_tl = torch.ones(self.window_size[0], self.window_size[1]); mask_tl[:-self.expand_size, :-self.expand_size] = 0
|
302 |
+
mask_tr = torch.ones(self.window_size[0], self.window_size[1]); mask_tr[:-self.expand_size, self.expand_size:] = 0
|
303 |
+
mask_bl = torch.ones(self.window_size[0], self.window_size[1]); mask_bl[self.expand_size:, :-self.expand_size] = 0
|
304 |
+
mask_br = torch.ones(self.window_size[0], self.window_size[1]); mask_br[self.expand_size:, self.expand_size:] = 0
|
305 |
+
mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
|
306 |
+
self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1))
|
307 |
+
|
308 |
+
if pool_method != "none" and focal_level > 1:
|
309 |
+
#self.relative_position_bias_table_to_windows = nn.ParameterList()
|
310 |
+
#self.relative_position_bias_table_to_windows_clips = nn.ParameterList()
|
311 |
+
#self.register_parameter('relative_position_bias_table_to_windows',[])
|
312 |
+
#self.register_parameter('relative_position_bias_table_to_windows_clips',[])
|
313 |
+
self.unfolds = nn.ModuleList()
|
314 |
+
self.unfolds_clips=nn.ModuleList()
|
315 |
+
|
316 |
+
# build relative position bias between local patch and pooled windows
|
317 |
+
for k in range(focal_level-1):
|
318 |
+
stride = 2**k
|
319 |
+
kernel_size = 2*(self.focal_window // 2) + 2**k + (2**k-1)
|
320 |
+
# define unfolding operations
|
321 |
+
self.unfolds += [nn.Unfold(
|
322 |
+
kernel_size=(kernel_size, kernel_size),
|
323 |
+
stride=stride, padding=kernel_size // 2)
|
324 |
+
]
|
325 |
+
|
326 |
+
# define relative position bias table
|
327 |
+
relative_position_bias_table_to_windows = nn.Parameter(
|
328 |
+
torch.zeros(
|
329 |
+
self.num_heads,
|
330 |
+
(self.window_size[0] + self.focal_window + 2**k - 2) * (self.window_size[1] + self.focal_window + 2**k - 2),
|
331 |
+
)
|
332 |
+
)
|
333 |
+
trunc_normal_(relative_position_bias_table_to_windows, std=.02)
|
334 |
+
#self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows)
|
335 |
+
self.register_parameter('relative_position_bias_table_to_windows_{}'.format(k),relative_position_bias_table_to_windows)
|
336 |
+
|
337 |
+
# define relative position bias index
|
338 |
+
relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(self.focal_window + 2**k - 1))
|
339 |
+
# relative_position_index_k = get_relative_position_index3d(self.window_size, to_2tuple(self.focal_window + 2**k - 1), num_clips)
|
340 |
+
self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k)
|
341 |
+
|
342 |
+
# define unfolding index for focal_level > 0
|
343 |
+
if k > 0:
|
344 |
+
mask = torch.zeros(kernel_size, kernel_size); mask[(2**k)-1:, (2**k)-1:] = 1
|
345 |
+
self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1))
|
346 |
+
|
347 |
+
for k in range(len(focal_l_clips)):
|
348 |
+
# kernel_size=focal_kernel_clips[k]
|
349 |
+
focal_l_big_flag=False
|
350 |
+
if focal_l_clips[k]>self.window_size[0]:
|
351 |
+
stride=1
|
352 |
+
padding=0
|
353 |
+
kernel_size=focal_kernel_clips[k]
|
354 |
+
kernel_size_true=kernel_size
|
355 |
+
focal_l_big_flag=True
|
356 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
357 |
+
# padding=(kernel_size-stride)/2
|
358 |
+
else:
|
359 |
+
stride = focal_l_clips[k]
|
360 |
+
# kernel_size
|
361 |
+
# kernel_size = 2*(focal_kernel_clips[k]// 2) + 2**focal_l_clips[k] + (2**focal_l_clips[k]-1)
|
362 |
+
kernel_size = focal_kernel_clips[k] ## kernel_size must be jishu
|
363 |
+
assert kernel_size%2==1
|
364 |
+
padding=kernel_size // 2
|
365 |
+
# kernel_size_true=focal_kernel_clips[k]+2**focal_l_clips[k]-1
|
366 |
+
kernel_size_true=kernel_size
|
367 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
368 |
+
|
369 |
+
self.unfolds_clips += [nn.Unfold(
|
370 |
+
kernel_size=(kernel_size, kernel_size),
|
371 |
+
stride=stride,
|
372 |
+
padding=padding)
|
373 |
+
]
|
374 |
+
relative_position_bias_table_to_windows = nn.Parameter(
|
375 |
+
torch.zeros(
|
376 |
+
self.num_heads,
|
377 |
+
(self.window_size[0] + kernel_size_true - 1) * (self.window_size[0] + kernel_size_true - 1),
|
378 |
+
)
|
379 |
+
)
|
380 |
+
trunc_normal_(relative_position_bias_table_to_windows, std=.02)
|
381 |
+
#self.relative_position_bias_table_to_windows_clips.append(relative_position_bias_table_to_windows)
|
382 |
+
self.register_parameter('relative_position_bias_table_to_windows_clips_{}'.format(k),relative_position_bias_table_to_windows)
|
383 |
+
relative_position_index_k = get_relative_position_index(self.window_size, to_2tuple(kernel_size_true))
|
384 |
+
self.register_buffer("relative_position_index_clips_{}".format(k), relative_position_index_k)
|
385 |
+
# if (not focal_l_big_flag) and focal_l_clips[k]>0:
|
386 |
+
# mask = torch.zeros(kernel_size, kernel_size); mask[(2**focal_l_clips[k])-1:, (2**focal_l_clips[k])-1:] = 1
|
387 |
+
# self.register_buffer("valid_ind_unfold_clips_{}".format(k), mask.flatten(0).nonzero().view(-1))
|
388 |
+
|
389 |
+
|
390 |
+
|
391 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
392 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
393 |
+
self.proj = nn.Linear(dim, dim)
|
394 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
395 |
+
|
396 |
+
self.softmax = nn.Softmax(dim=-1)
|
397 |
+
self.focal_l_clips=focal_l_clips
|
398 |
+
self.focal_kernel_clips=focal_kernel_clips
|
399 |
+
|
400 |
+
def forward(self, x_all, mask_all=None, batch_size=None, num_clips=None):
|
401 |
+
"""
|
402 |
+
Args:
|
403 |
+
x_all (list[Tensors]): input features at different granularity
|
404 |
+
mask_all (list[Tensors/None]): masks for input features at different granularity
|
405 |
+
"""
|
406 |
+
x = x_all[0][0] #
|
407 |
+
|
408 |
+
B0, nH, nW, C = x.shape
|
409 |
+
# assert B==batch_size*num_clips
|
410 |
+
assert B0==batch_size
|
411 |
+
qkv = self.qkv(x).reshape(B0, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous()
|
412 |
+
q, k, v = qkv[0], qkv[1], qkv[2] # B0, nH, nW, C
|
413 |
+
|
414 |
+
# partition q map
|
415 |
+
# print("x.shape: ", x.shape)
|
416 |
+
# print("q.shape: ", q.shape) # [4, 126, 126, 256]
|
417 |
+
(q_windows, k_windows, v_windows) = map(
|
418 |
+
lambda t: window_partition(t, self.window_size[0]).view(
|
419 |
+
-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
|
420 |
+
).transpose(1, 2),
|
421 |
+
(q, k, v)
|
422 |
+
)
|
423 |
+
|
424 |
+
# q_dim0, q_dim1, q_dim2, q_dim3=q_windows.shape
|
425 |
+
# q_windows=q_windows.view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]), q_dim1, q_dim2, q_dim3)
|
426 |
+
# q_windows=q_windows[:,-1].contiguous().view(-1, q_dim1, q_dim2, q_dim3) # query for the last frame (target frame)
|
427 |
+
|
428 |
+
# k_windows.shape [1296, 8, 49, 32]
|
429 |
+
|
430 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
431 |
+
(k_tl, v_tl) = map(
|
432 |
+
lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
|
433 |
+
)
|
434 |
+
(k_tr, v_tr) = map(
|
435 |
+
lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
|
436 |
+
)
|
437 |
+
(k_bl, v_bl) = map(
|
438 |
+
lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
|
439 |
+
)
|
440 |
+
(k_br, v_br) = map(
|
441 |
+
lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
|
442 |
+
)
|
443 |
+
|
444 |
+
(k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
|
445 |
+
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
|
446 |
+
(k_tl, k_tr, k_bl, k_br)
|
447 |
+
)
|
448 |
+
(v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
|
449 |
+
lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads),
|
450 |
+
(v_tl, v_tr, v_bl, v_br)
|
451 |
+
)
|
452 |
+
k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
|
453 |
+
v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
|
454 |
+
|
455 |
+
# mask out tokens in current window
|
456 |
+
# print("self.valid_ind_rolled.shape: ", self.valid_ind_rolled.shape) # [132]
|
457 |
+
# print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 196, 32]
|
458 |
+
k_rolled = k_rolled[:, :, self.valid_ind_rolled]
|
459 |
+
v_rolled = v_rolled[:, :, self.valid_ind_rolled]
|
460 |
+
k_rolled = torch.cat((k_windows, k_rolled), 2)
|
461 |
+
v_rolled = torch.cat((v_windows, v_rolled), 2)
|
462 |
+
else:
|
463 |
+
k_rolled = k_windows; v_rolled = v_windows;
|
464 |
+
|
465 |
+
# print("k_rolled.shape: ", k_rolled.shape) # [1296, 8, 181, 32]
|
466 |
+
|
467 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
468 |
+
k_pooled = []
|
469 |
+
v_pooled = []
|
470 |
+
for k in range(self.focal_level-1):
|
471 |
+
stride = 2**k
|
472 |
+
x_window_pooled = x_all[0][k+1] # B0, nWh, nWw, C
|
473 |
+
nWh, nWw = x_window_pooled.shape[1:3]
|
474 |
+
|
475 |
+
# generate mask for pooled windows
|
476 |
+
# print("x_window_pooled.shape: ", x_window_pooled.shape)
|
477 |
+
mask = x_window_pooled.new(nWh, nWw).fill_(1)
|
478 |
+
# print("here: ",x_window_pooled.shape, self.unfolds[k].kernel_size, self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).shape)
|
479 |
+
# print(mask.unique())
|
480 |
+
unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
|
481 |
+
1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
482 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
483 |
+
|
484 |
+
if k > 0:
|
485 |
+
valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
|
486 |
+
unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
487 |
+
|
488 |
+
# print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
|
489 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
490 |
+
# print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
|
491 |
+
x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
492 |
+
# print(x_window_masks.shape)
|
493 |
+
mask_all[0][k+1] = x_window_masks
|
494 |
+
|
495 |
+
# generate k and v for pooled windows
|
496 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
497 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
498 |
+
|
499 |
+
|
500 |
+
(k_pooled_k, v_pooled_k) = map(
|
501 |
+
lambda t: self.unfolds[k](t).view(
|
502 |
+
B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
503 |
+
view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
504 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
505 |
+
)
|
506 |
+
|
507 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
508 |
+
# print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
|
509 |
+
|
510 |
+
if k > 0:
|
511 |
+
(k_pooled_k, v_pooled_k) = map(
|
512 |
+
lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
|
513 |
+
)
|
514 |
+
|
515 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
516 |
+
|
517 |
+
k_pooled += [k_pooled_k]
|
518 |
+
v_pooled += [v_pooled_k]
|
519 |
+
|
520 |
+
for k in range(len(self.focal_l_clips)):
|
521 |
+
focal_l_big_flag=False
|
522 |
+
if self.focal_l_clips[k]>self.window_size[0]:
|
523 |
+
stride=1
|
524 |
+
focal_l_big_flag=True
|
525 |
+
else:
|
526 |
+
stride = self.focal_l_clips[k]
|
527 |
+
# if self.window_size>=focal_l_clips[k]:
|
528 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
529 |
+
# # padding=(kernel_size-stride)/2
|
530 |
+
# else:
|
531 |
+
# stride=1
|
532 |
+
# padding=0
|
533 |
+
x_window_pooled = x_all[k+1]
|
534 |
+
nWh, nWw = x_window_pooled.shape[1:3]
|
535 |
+
mask = x_window_pooled.new(nWh, nWw).fill_(1)
|
536 |
+
|
537 |
+
# import pdb; pdb.set_trace()
|
538 |
+
# print(x_window_pooled.shape, self.unfolds_clips[k].kernel_size, self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).shape)
|
539 |
+
|
540 |
+
unfolded_mask = self.unfolds_clips[k](mask.unsqueeze(0).unsqueeze(1)).view(
|
541 |
+
1, 1, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
542 |
+
view(nWh*nWw // stride // stride, -1, 1)
|
543 |
+
|
544 |
+
# if (not focal_l_big_flag) and self.focal_l_clips[k]>0:
|
545 |
+
# valid_ind_unfold_k = getattr(self, "valid_ind_unfold_clips_{}".format(k))
|
546 |
+
# unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
|
547 |
+
|
548 |
+
# print("unfolded_mask.shape: ", unfolded_mask.shape, unfolded_mask.unique())
|
549 |
+
x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
|
550 |
+
# print((x_window_masks == 0).sum(), (x_window_masks > 0).sum(), x_window_masks.unique())
|
551 |
+
x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(x_window_masks > 0, float(0.0))
|
552 |
+
# print(x_window_masks.shape)
|
553 |
+
mask_all[k+1] = x_window_masks
|
554 |
+
|
555 |
+
# generate k and v for pooled windows
|
556 |
+
qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
557 |
+
k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
558 |
+
|
559 |
+
if (not focal_l_big_flag):
|
560 |
+
(k_pooled_k, v_pooled_k) = map(
|
561 |
+
lambda t: self.unfolds_clips[k](t).view(
|
562 |
+
B0, C, self.unfolds_clips[k].kernel_size[0], self.unfolds_clips[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
563 |
+
view(-1, self.unfolds_clips[k].kernel_size[0]*self.unfolds_clips[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
564 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
565 |
+
)
|
566 |
+
else:
|
567 |
+
|
568 |
+
(k_pooled_k, v_pooled_k) = map(
|
569 |
+
lambda t: self.unfolds_clips[k](t),
|
570 |
+
(k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
571 |
+
)
|
572 |
+
LLL=k_pooled_k.size(2)
|
573 |
+
LLL_h=int(LLL**0.5)
|
574 |
+
assert LLL_h**2==LLL
|
575 |
+
k_pooled_k=k_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
|
576 |
+
v_pooled_k=v_pooled_k.reshape(B0, -1, LLL_h, LLL_h)
|
577 |
+
|
578 |
+
|
579 |
+
|
580 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
581 |
+
# print("valid_ind_unfold_k.shape: ", valid_ind_unfold_k.shape)
|
582 |
+
# if (not focal_l_big_flag) and self.focal_l_clips[k]:
|
583 |
+
# (k_pooled_k, v_pooled_k) = map(
|
584 |
+
# lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
|
585 |
+
# )
|
586 |
+
|
587 |
+
# print("k_pooled_k.shape: ", k_pooled_k.shape)
|
588 |
+
|
589 |
+
k_pooled += [k_pooled_k]
|
590 |
+
v_pooled += [v_pooled_k]
|
591 |
+
|
592 |
+
# qkv_pooled = self.qkv(x_window_pooled).reshape(B0, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
|
593 |
+
# k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B0, C, nWh, nWw
|
594 |
+
# (k_pooled_k, v_pooled_k) = map(
|
595 |
+
# lambda t: self.unfolds[k](t).view(
|
596 |
+
# B0, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3, 1).contiguous().\
|
597 |
+
# view(-1, self.unfolds[k].kernel_size[0]*self.unfolds[k].kernel_size[1], self.num_heads, C // self.num_heads).transpose(1, 2),
|
598 |
+
# (k_pooled_k, v_pooled_k) # (B0 x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
|
599 |
+
# )
|
600 |
+
# k_pooled += [k_pooled_k]
|
601 |
+
# v_pooled += [v_pooled_k]
|
602 |
+
|
603 |
+
|
604 |
+
k_all = torch.cat([k_rolled] + k_pooled, 2)
|
605 |
+
v_all = torch.cat([v_rolled] + v_pooled, 2)
|
606 |
+
else:
|
607 |
+
k_all = k_rolled
|
608 |
+
v_all = v_rolled
|
609 |
+
|
610 |
+
N = k_all.shape[-2]
|
611 |
+
q_windows = q_windows * self.scale
|
612 |
+
# print(q_windows.shape, k_all.shape, v_all.shape)
|
613 |
+
# exit()
|
614 |
+
# k_all_dim0, k_all_dim1, k_all_dim2, k_all_dim3=k_all.shape
|
615 |
+
# k_all=k_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
|
616 |
+
# k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
|
617 |
+
# v_all=v_all.contiguous().view(batch_size, num_clips, (nH//self.window_size[0])*(nW//self.window_size[1]),
|
618 |
+
# k_all_dim1, k_all_dim2, k_all_dim3).permute(0,2,3,4,1,5).contiguous().view(-1, k_all_dim1, k_all_dim2*num_clips, k_all_dim3)
|
619 |
+
|
620 |
+
# print(q_windows.shape, k_all.shape, v_all.shape, k_rolled.shape)
|
621 |
+
# exit()
|
622 |
+
attn = (q_windows @ k_all.transpose(-2, -1)) # B0*nW, nHead, window_size*window_size, focal_window_size*focal_window_size
|
623 |
+
|
624 |
+
window_area = self.window_size[0] * self.window_size[1]
|
625 |
+
# window_area_clips= num_clips*self.window_size[0] * self.window_size[1]
|
626 |
+
window_area_rolled = k_rolled.shape[2]
|
627 |
+
|
628 |
+
# add relative position bias for tokens inside window
|
629 |
+
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
|
630 |
+
self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
|
631 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
632 |
+
# print(relative_position_bias.shape, attn.shape)
|
633 |
+
attn[:, :, :window_area, :window_area] = attn[:, :, :window_area, :window_area] + relative_position_bias.unsqueeze(0)
|
634 |
+
|
635 |
+
# relative_position_bias = self.relative_position_bias_table[self.relative_position_index[-window_area:, :window_area_clips].reshape(-1)].view(
|
636 |
+
# window_area, window_area_clips, -1) # Wh*Ww,Wd*Wh*Ww,nH
|
637 |
+
# relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous().view(self.num_heads,window_area,num_clips,window_area
|
638 |
+
# ).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,window_area_clips).contiguous() # nH, Wh*Ww, Wh*Ww*Wd
|
639 |
+
# # attn_dim0, attn_dim1, attn_dim2, attn_dim3=attn.shape
|
640 |
+
# # attn=attn.view(attn_dim0,attn_dim1,attn_dim2,num_clips,-1)
|
641 |
+
# # print(attn.shape, relative_position_bias.shape)
|
642 |
+
# attn[:,:,:window_area, :window_area_clips]=attn[:,:,:window_area, :window_area_clips] + relative_position_bias.unsqueeze(0)
|
643 |
+
# attn = attn + relative_position_bias.unsqueeze(0) # B_, nH, N, N
|
644 |
+
|
645 |
+
# add relative position bias for patches inside a window
|
646 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
647 |
+
attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area, window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors
|
648 |
+
|
649 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
650 |
+
# add relative position bias for different windows in an image
|
651 |
+
offset = window_area_rolled
|
652 |
+
# print(offset)
|
653 |
+
for k in range(self.focal_level-1):
|
654 |
+
# add relative position bias
|
655 |
+
relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
|
656 |
+
relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
|
657 |
+
-1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
|
658 |
+
) # nH, NWh*NWw,focal_region*focal_region
|
659 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
660 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
661 |
+
# add attentional mask
|
662 |
+
if mask_all[0][k+1] is not None:
|
663 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
664 |
+
attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
|
665 |
+
mask_all[0][k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[0][k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[0][k+1].shape[-1])
|
666 |
+
|
667 |
+
offset += (self.focal_window+2**k-1)**2
|
668 |
+
# print(offset)
|
669 |
+
for k in range(len(self.focal_l_clips)):
|
670 |
+
focal_l_big_flag=False
|
671 |
+
if self.focal_l_clips[k]>self.window_size[0]:
|
672 |
+
stride=1
|
673 |
+
padding=0
|
674 |
+
kernel_size=self.focal_kernel_clips[k]
|
675 |
+
kernel_size_true=kernel_size
|
676 |
+
focal_l_big_flag=True
|
677 |
+
# stride=math.ceil(self.window_size/focal_l_clips[k])
|
678 |
+
# padding=(kernel_size-stride)/2
|
679 |
+
else:
|
680 |
+
stride = self.focal_l_clips[k]
|
681 |
+
# kernel_size
|
682 |
+
# kernel_size = 2*(self.focal_kernel_clips[k]// 2) + 2**self.focal_l_clips[k] + (2**self.focal_l_clips[k]-1)
|
683 |
+
kernel_size = self.focal_kernel_clips[k]
|
684 |
+
padding=kernel_size // 2
|
685 |
+
# kernel_size_true=self.focal_kernel_clips[k]+2**self.focal_l_clips[k]-1
|
686 |
+
kernel_size_true=kernel_size
|
687 |
+
relative_position_index_k = getattr(self, 'relative_position_index_clips_{}'.format(k))
|
688 |
+
relative_position_bias_to_windows = getattr(self,'relative_position_bias_table_to_windows_clips_{}'.format(k))[:, relative_position_index_k.view(-1)].view(
|
689 |
+
-1, self.window_size[0] * self.window_size[1], (kernel_size_true)**2,
|
690 |
+
)
|
691 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
|
692 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
693 |
+
if mask_all[k+1] is not None:
|
694 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] = \
|
695 |
+
attn[:, :, :window_area, offset:(offset + (kernel_size_true)**2)] + \
|
696 |
+
mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
697 |
+
offset += (kernel_size_true)**2
|
698 |
+
# print(offset)
|
699 |
+
# relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
|
700 |
+
# # relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k.view(-1)].view(
|
701 |
+
# # -1, self.window_size[0] * self.window_size[1], (self.focal_window+2**k-1)**2,
|
702 |
+
# # ) # nH, NWh*NWw,focal_region*focal_region
|
703 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
704 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
705 |
+
# relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:, relative_position_index_k[-window_area:, :].view(-1)].view(
|
706 |
+
# -1, self.window_size[0] * self.window_size[1], num_clips*(self.focal_window+2**k-1)**2,
|
707 |
+
# ).contiguous() # nH, NWh*NWw, num_clips*focal_region*focal_region
|
708 |
+
# relative_position_bias_to_windows = relative_position_bias_to_windows.view(self.num_heads,
|
709 |
+
# window_area,num_clips,-1).permute(0,1,3,2).contiguous().view(self.num_heads,window_area,-1)
|
710 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
|
711 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + relative_position_bias_to_windows.unsqueeze(0)
|
712 |
+
# # add attentional mask
|
713 |
+
# if mask_all[k+1] is not None:
|
714 |
+
# # print("inside the mask, be careful 1")
|
715 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] = \
|
716 |
+
# # attn[:, :, :window_area, offset:(offset + (self.focal_window+2**k-1)**2)] + \
|
717 |
+
# # mask_all[k+1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1).view(-1, 1, 1, mask_all[k+1].shape[-1])
|
718 |
+
# # print("here: ", mask_all[k+1].shape, mask_all[k+1][:, :, None, None, :].shape)
|
719 |
+
|
720 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] = \
|
721 |
+
# attn[:, :, :window_area, offset:(offset + num_clips*(self.focal_window+2**k-1)**2)] + \
|
722 |
+
# mask_all[k+1][:, :, None, None, :,None].repeat(attn.shape[0] // mask_all[k+1].shape[1], 1, 1, 1, 1, num_clips).view(-1, 1, 1, mask_all[k+1].shape[-1]*num_clips)
|
723 |
+
# # print()
|
724 |
+
|
725 |
+
# offset += (self.focal_window+2**k-1)**2
|
726 |
+
|
727 |
+
# print("mask_all[0]: ", mask_all[0])
|
728 |
+
# exit()
|
729 |
+
if mask_all[0][0] is not None:
|
730 |
+
print("inside the mask, be careful 0")
|
731 |
+
nW = mask_all[0].shape[0]
|
732 |
+
attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N)
|
733 |
+
attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :]
|
734 |
+
attn = attn.view(-1, self.num_heads, window_area, N)
|
735 |
+
attn = self.softmax(attn)
|
736 |
+
else:
|
737 |
+
attn = self.softmax(attn)
|
738 |
+
|
739 |
+
attn = self.attn_drop(attn)
|
740 |
+
|
741 |
+
x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C)
|
742 |
+
x = self.proj(x)
|
743 |
+
x = self.proj_drop(x)
|
744 |
+
# print(x.shape)
|
745 |
+
# x = x.view(B/num_clips, nH, nW, C )
|
746 |
+
# exit()
|
747 |
+
return x
|
748 |
+
|
749 |
+
def extra_repr(self) -> str:
|
750 |
+
return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
|
751 |
+
|
752 |
+
def flops(self, N, window_size, unfold_size):
|
753 |
+
# calculate flops for 1 window with token length of N
|
754 |
+
flops = 0
|
755 |
+
# qkv = self.qkv(x)
|
756 |
+
flops += N * self.dim * 3 * self.dim
|
757 |
+
# attn = (q @ k.transpose(-2, -1))
|
758 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * N
|
759 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
760 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
|
761 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
762 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
|
763 |
+
|
764 |
+
# x = (attn @ v)
|
765 |
+
flops += self.num_heads * N * N * (self.dim // self.num_heads)
|
766 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
767 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
|
768 |
+
if self.expand_size > 0 and self.focal_level > 0:
|
769 |
+
flops += self.num_heads * N * (self.dim // self.num_heads) * ((window_size + 2*self.expand_size)**2-window_size**2)
|
770 |
+
|
771 |
+
# x = self.proj(x)
|
772 |
+
flops += N * self.dim * self.dim
|
773 |
+
return flops
|
774 |
+
|
775 |
+
|
776 |
+
class CffmTransformerBlock3d3(nn.Module):
|
777 |
+
r""" Focal Transformer Block.
|
778 |
+
|
779 |
+
Args:
|
780 |
+
dim (int): Number of input channels.
|
781 |
+
input_resolution (tuple[int]): Input resulotion.
|
782 |
+
num_heads (int): Number of attention heads.
|
783 |
+
window_size (int): Window size.
|
784 |
+
expand_size (int): expand size at first focal level (finest level).
|
785 |
+
shift_size (int): Shift size for SW-MSA.
|
786 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
|
787 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
788 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
789 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
790 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
791 |
+
drop_path (float, optional): Stochastic depth rate. Default: 0.0
|
792 |
+
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
|
793 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
794 |
+
pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
|
795 |
+
focal_level (int): number of focal levels. Default: 1.
|
796 |
+
focal_window (int): region size of focal attention. Default: 1
|
797 |
+
use_layerscale (bool): whether use layer scale for training stability. Default: False
|
798 |
+
layerscale_value (float): scaling value for layer scale. Default: 1e-4
|
799 |
+
"""
|
800 |
+
|
801 |
+
def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0,
|
802 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
|
803 |
+
act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
|
804 |
+
focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[7,2,4], focal_kernel_clips=[7,5,3]):
|
805 |
+
super().__init__()
|
806 |
+
self.dim = dim
|
807 |
+
self.input_resolution = input_resolution
|
808 |
+
self.num_heads = num_heads
|
809 |
+
self.window_size = window_size
|
810 |
+
self.shift_size = shift_size
|
811 |
+
self.expand_size = expand_size
|
812 |
+
self.mlp_ratio = mlp_ratio
|
813 |
+
self.pool_method = pool_method
|
814 |
+
self.focal_level = focal_level
|
815 |
+
self.focal_window = focal_window
|
816 |
+
self.use_layerscale = use_layerscale
|
817 |
+
self.focal_l_clips=focal_l_clips
|
818 |
+
self.focal_kernel_clips=focal_kernel_clips
|
819 |
+
|
820 |
+
if min(self.input_resolution) <= self.window_size:
|
821 |
+
# if window size is larger than input resolution, we don't partition windows
|
822 |
+
self.expand_size = 0
|
823 |
+
self.shift_size = 0
|
824 |
+
self.window_size = min(self.input_resolution)
|
825 |
+
assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
|
826 |
+
|
827 |
+
self.window_size_glo = self.window_size
|
828 |
+
|
829 |
+
self.pool_layers = nn.ModuleList()
|
830 |
+
self.pool_layers_clips = nn.ModuleList()
|
831 |
+
if self.pool_method != "none":
|
832 |
+
for k in range(self.focal_level-1):
|
833 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
834 |
+
if self.pool_method == "fc":
|
835 |
+
self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1))
|
836 |
+
self.pool_layers[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
|
837 |
+
self.pool_layers[-1].bias.data.fill_(0)
|
838 |
+
elif self.pool_method == "conv":
|
839 |
+
self.pool_layers.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
|
840 |
+
for k in range(len(focal_l_clips)):
|
841 |
+
# window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
842 |
+
if focal_l_clips[k]>self.window_size:
|
843 |
+
window_size_glo = focal_l_clips[k]
|
844 |
+
else:
|
845 |
+
window_size_glo = math.floor(self.window_size_glo / (focal_l_clips[k]))
|
846 |
+
# window_size_glo = focal_l_clips[k]
|
847 |
+
if self.pool_method == "fc":
|
848 |
+
self.pool_layers_clips.append(nn.Linear(window_size_glo * window_size_glo, 1))
|
849 |
+
self.pool_layers_clips[-1].weight.data.fill_(1./(window_size_glo * window_size_glo))
|
850 |
+
self.pool_layers_clips[-1].bias.data.fill_(0)
|
851 |
+
elif self.pool_method == "conv":
|
852 |
+
self.pool_layers_clips.append(nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
|
853 |
+
|
854 |
+
self.norm1 = norm_layer(dim)
|
855 |
+
|
856 |
+
self.attn = WindowAttention3d3(
|
857 |
+
dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size),
|
858 |
+
focal_window=focal_window, focal_level=focal_level, num_heads=num_heads,
|
859 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method, focal_l_clips=focal_l_clips, focal_kernel_clips=focal_kernel_clips)
|
860 |
+
|
861 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
862 |
+
self.norm2 = norm_layer(dim)
|
863 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
864 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
865 |
+
|
866 |
+
# print("******self.shift_size: ", self.shift_size)
|
867 |
+
|
868 |
+
if self.shift_size > 0:
|
869 |
+
# calculate attention mask for SW-MSA
|
870 |
+
H, W = self.input_resolution
|
871 |
+
img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
|
872 |
+
h_slices = (slice(0, -self.window_size),
|
873 |
+
slice(-self.window_size, -self.shift_size),
|
874 |
+
slice(-self.shift_size, None))
|
875 |
+
w_slices = (slice(0, -self.window_size),
|
876 |
+
slice(-self.window_size, -self.shift_size),
|
877 |
+
slice(-self.shift_size, None))
|
878 |
+
cnt = 0
|
879 |
+
for h in h_slices:
|
880 |
+
for w in w_slices:
|
881 |
+
img_mask[:, h, w, :] = cnt
|
882 |
+
cnt += 1
|
883 |
+
|
884 |
+
mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
|
885 |
+
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
|
886 |
+
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
|
887 |
+
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
|
888 |
+
else:
|
889 |
+
# print("here mask none")
|
890 |
+
attn_mask = None
|
891 |
+
self.register_buffer("attn_mask", attn_mask)
|
892 |
+
|
893 |
+
if self.use_layerscale:
|
894 |
+
self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
895 |
+
self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
|
896 |
+
|
897 |
+
def forward(self, x):
|
898 |
+
H0, W0 = self.input_resolution
|
899 |
+
# B, L, C = x.shape
|
900 |
+
B0, D0, H0, W0, C = x.shape
|
901 |
+
shortcut = x
|
902 |
+
# assert L == H * W, "input feature has wrong size"
|
903 |
+
x=x.reshape(B0*D0,H0,W0,C).reshape(B0*D0,H0*W0,C)
|
904 |
+
|
905 |
+
|
906 |
+
x = self.norm1(x)
|
907 |
+
x = x.reshape(B0*D0, H0, W0, C)
|
908 |
+
# print("here")
|
909 |
+
# exit()
|
910 |
+
|
911 |
+
# pad feature maps to multiples of window size
|
912 |
+
pad_l = pad_t = 0
|
913 |
+
pad_r = (self.window_size - W0 % self.window_size) % self.window_size
|
914 |
+
pad_b = (self.window_size - H0 % self.window_size) % self.window_size
|
915 |
+
if pad_r > 0 or pad_b > 0:
|
916 |
+
x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
|
917 |
+
|
918 |
+
B, H, W, C = x.shape ## B=B0*D0
|
919 |
+
|
920 |
+
if self.shift_size > 0:
|
921 |
+
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
|
922 |
+
else:
|
923 |
+
shifted_x = x
|
924 |
+
|
925 |
+
# print("shifted_x.shape: ", shifted_x.shape)
|
926 |
+
shifted_x=shifted_x.view(B0,D0,H,W,C)
|
927 |
+
x_windows_all = [shifted_x[:,-1]]
|
928 |
+
x_windows_all_clips=[]
|
929 |
+
x_window_masks_all = [self.attn_mask]
|
930 |
+
x_window_masks_all_clips=[]
|
931 |
+
|
932 |
+
if self.focal_level > 1 and self.pool_method != "none":
|
933 |
+
# if we add coarser granularity and the pool method is not none
|
934 |
+
# pooling_index=0
|
935 |
+
for k in range(self.focal_level-1):
|
936 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
937 |
+
pooled_h = math.ceil(H / self.window_size) * (2 ** k)
|
938 |
+
pooled_w = math.ceil(W / self.window_size) * (2 ** k)
|
939 |
+
H_pool = pooled_h * window_size_glo
|
940 |
+
W_pool = pooled_w * window_size_glo
|
941 |
+
|
942 |
+
x_level_k = shifted_x[:,-1]
|
943 |
+
# trim or pad shifted_x depending on the required size
|
944 |
+
if H > H_pool:
|
945 |
+
trim_t = (H - H_pool) // 2
|
946 |
+
trim_b = H - H_pool - trim_t
|
947 |
+
x_level_k = x_level_k[:, trim_t:-trim_b]
|
948 |
+
elif H < H_pool:
|
949 |
+
pad_t = (H_pool - H) // 2
|
950 |
+
pad_b = H_pool - H - pad_t
|
951 |
+
x_level_k = F.pad(x_level_k, (0,0,0,0,pad_t,pad_b))
|
952 |
+
|
953 |
+
if W > W_pool:
|
954 |
+
trim_l = (W - W_pool) // 2
|
955 |
+
trim_r = W - W_pool - trim_l
|
956 |
+
x_level_k = x_level_k[:, :, trim_l:-trim_r]
|
957 |
+
elif W < W_pool:
|
958 |
+
pad_l = (W_pool - W) // 2
|
959 |
+
pad_r = W_pool - W - pad_l
|
960 |
+
x_level_k = F.pad(x_level_k, (0,0,pad_l,pad_r))
|
961 |
+
|
962 |
+
x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
|
963 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
964 |
+
if self.pool_method == "mean":
|
965 |
+
x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
|
966 |
+
elif self.pool_method == "max":
|
967 |
+
x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
968 |
+
elif self.pool_method == "fc":
|
969 |
+
x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
|
970 |
+
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
|
971 |
+
elif self.pool_method == "conv":
|
972 |
+
x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
|
973 |
+
x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
974 |
+
|
975 |
+
x_windows_all += [x_windows_pooled]
|
976 |
+
# print(x_windows_pooled.shape)
|
977 |
+
x_window_masks_all += [None]
|
978 |
+
# pooling_index=pooling_index+1
|
979 |
+
|
980 |
+
x_windows_all_clips += [x_windows_all]
|
981 |
+
x_window_masks_all_clips += [x_window_masks_all]
|
982 |
+
for k in range(len(self.focal_l_clips)):
|
983 |
+
if self.focal_l_clips[k]>self.window_size:
|
984 |
+
window_size_glo = self.focal_l_clips[k]
|
985 |
+
else:
|
986 |
+
window_size_glo = math.floor(self.window_size_glo / (self.focal_l_clips[k]))
|
987 |
+
|
988 |
+
pooled_h = math.ceil(H / self.window_size) * (self.focal_l_clips[k])
|
989 |
+
pooled_w = math.ceil(W / self.window_size) * (self.focal_l_clips[k])
|
990 |
+
|
991 |
+
H_pool = pooled_h * window_size_glo
|
992 |
+
W_pool = pooled_w * window_size_glo
|
993 |
+
|
994 |
+
x_level_k = shifted_x[:,k]
|
995 |
+
if H!=H_pool or W!=W_pool:
|
996 |
+
x_level_k=F.interpolate(x_level_k.permute(0,3,1,2), size=(H_pool, W_pool), mode='bilinear').permute(0,2,3,1)
|
997 |
+
|
998 |
+
# print(x_level_k.shape)
|
999 |
+
x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(), window_size_glo) # B0, nw, nw, window_size, window_size, C
|
1000 |
+
nWh, nWw = x_windows_noreshape.shape[1:3]
|
1001 |
+
if self.pool_method == "mean":
|
1002 |
+
x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B0, nWh, nWw, C
|
1003 |
+
elif self.pool_method == "max":
|
1004 |
+
x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
1005 |
+
elif self.pool_method == "fc":
|
1006 |
+
x_windows_noreshape = x_windows_noreshape.view(B0, nWh, nWw, window_size_glo*window_size_glo, C).transpose(3, 4) # B0, nWh, nWw, C, wsize**2
|
1007 |
+
x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).flatten(-2) # B0, nWh, nWw, C
|
1008 |
+
elif self.pool_method == "conv":
|
1009 |
+
x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0, 3, 1, 2).contiguous() # B0 * nw * nw, C, wsize, wsize
|
1010 |
+
x_windows_pooled = self.pool_layers_clips[k](x_windows_noreshape).view(B0, nWh, nWw, C) # B0, nWh, nWw, C
|
1011 |
+
|
1012 |
+
x_windows_all_clips += [x_windows_pooled]
|
1013 |
+
# print(x_windows_pooled.shape)
|
1014 |
+
x_window_masks_all_clips += [None]
|
1015 |
+
# pooling_index=pooling_index+1
|
1016 |
+
# exit()
|
1017 |
+
|
1018 |
+
attn_windows = self.attn(x_windows_all_clips, mask_all=x_window_masks_all_clips, batch_size=B0, num_clips=D0) # nW*B0, window_size*window_size, C
|
1019 |
+
|
1020 |
+
attn_windows = attn_windows[:, :self.window_size ** 2]
|
1021 |
+
|
1022 |
+
# merge windows
|
1023 |
+
attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
|
1024 |
+
shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H(padded) W(padded) C
|
1025 |
+
|
1026 |
+
# reverse cyclic shift
|
1027 |
+
if self.shift_size > 0:
|
1028 |
+
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
|
1029 |
+
else:
|
1030 |
+
x = shifted_x
|
1031 |
+
# x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C)
|
1032 |
+
x = x[:, :H0, :W0].contiguous().view(B0, -1, C)
|
1033 |
+
|
1034 |
+
# FFN
|
1035 |
+
# x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
|
1036 |
+
# x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
|
1037 |
+
|
1038 |
+
# print(x.shape, shortcut[:,-1].view(B0, -1, C).shape)
|
1039 |
+
x = shortcut[:,-1].view(B0, -1, C) + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
|
1040 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
|
1041 |
+
|
1042 |
+
# x=torch.cat([shortcut[:,:-1],x.view(B0,self.input_resolution[0],self.input_resolution[1],C).unsqueeze(1)],1)
|
1043 |
+
x=torch.cat([shortcut[:,:-1],x.view(B0,H0,W0,C).unsqueeze(1)],1)
|
1044 |
+
|
1045 |
+
assert x.shape==shortcut.shape
|
1046 |
+
|
1047 |
+
# exit()
|
1048 |
+
|
1049 |
+
return x
|
1050 |
+
|
1051 |
+
def extra_repr(self) -> str:
|
1052 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
|
1053 |
+
f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
|
1054 |
+
|
1055 |
+
def flops(self):
|
1056 |
+
flops = 0
|
1057 |
+
H, W = self.input_resolution
|
1058 |
+
# norm1
|
1059 |
+
flops += self.dim * H * W
|
1060 |
+
|
1061 |
+
# W-MSA/SW-MSA
|
1062 |
+
nW = H * W / self.window_size / self.window_size
|
1063 |
+
flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window)
|
1064 |
+
|
1065 |
+
if self.pool_method != "none" and self.focal_level > 1:
|
1066 |
+
for k in range(self.focal_level-1):
|
1067 |
+
window_size_glo = math.floor(self.window_size_glo / (2 ** k))
|
1068 |
+
nW_glo = nW * (2**k)
|
1069 |
+
# (sub)-window pooling
|
1070 |
+
flops += nW_glo * self.dim * window_size_glo * window_size_glo
|
1071 |
+
# qkv for global levels
|
1072 |
+
# NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer,
|
1073 |
+
# but theoritically, we only need to compute k and v.
|
1074 |
+
flops += nW_glo * self.dim * 3 * self.dim
|
1075 |
+
|
1076 |
+
# mlp
|
1077 |
+
flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
|
1078 |
+
# norm2
|
1079 |
+
flops += self.dim * H * W
|
1080 |
+
return flops
|
1081 |
+
|
1082 |
+
|
1083 |
+
class BasicLayer3d3(nn.Module):
|
1084 |
+
""" A basic Focal Transformer layer for one stage.
|
1085 |
+
|
1086 |
+
Args:
|
1087 |
+
dim (int): Number of input channels.
|
1088 |
+
input_resolution (tuple[int]): Input resolution.
|
1089 |
+
depth (int): Number of blocks.
|
1090 |
+
num_heads (int): Number of attention heads.
|
1091 |
+
window_size (int): Local window size.
|
1092 |
+
expand_size (int): expand size for focal level 1.
|
1093 |
+
expand_layer (str): expand layer. Default: all
|
1094 |
+
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
|
1095 |
+
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
|
1096 |
+
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
|
1097 |
+
drop (float, optional): Dropout rate. Default: 0.0
|
1098 |
+
attn_drop (float, optional): Attention dropout rate. Default: 0.0
|
1099 |
+
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
|
1100 |
+
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
|
1101 |
+
pool_method (str): Window pooling method. Default: none.
|
1102 |
+
focal_level (int): Number of focal levels. Default: 1.
|
1103 |
+
focal_window (int): region size at each focal level. Default: 1.
|
1104 |
+
use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False
|
1105 |
+
use_shift (bool): Whether use window shift as in Swin Transformer. Default: False
|
1106 |
+
use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False
|
1107 |
+
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
|
1108 |
+
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
|
1109 |
+
use_layerscale (bool): Whether use layer scale for stability. Default: False.
|
1110 |
+
layerscale_value (float): Layerscale value. Default: 1e-4.
|
1111 |
+
"""
|
1112 |
+
|
1113 |
+
def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all",
|
1114 |
+
mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
|
1115 |
+
drop_path=0., norm_layer=nn.LayerNorm, pool_method="none",
|
1116 |
+
focal_level=1, focal_window=1, use_conv_embed=False, use_shift=False, use_pre_norm=False,
|
1117 |
+
downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4, focal_l_clips=[16,8,2], focal_kernel_clips=[7,5,3]):
|
1118 |
+
|
1119 |
+
super().__init__()
|
1120 |
+
self.dim = dim
|
1121 |
+
self.input_resolution = input_resolution
|
1122 |
+
self.depth = depth
|
1123 |
+
self.use_checkpoint = use_checkpoint
|
1124 |
+
|
1125 |
+
if expand_layer == "even":
|
1126 |
+
expand_factor = 0
|
1127 |
+
elif expand_layer == "odd":
|
1128 |
+
expand_factor = 1
|
1129 |
+
elif expand_layer == "all":
|
1130 |
+
expand_factor = -1
|
1131 |
+
|
1132 |
+
# build blocks
|
1133 |
+
self.blocks = nn.ModuleList([
|
1134 |
+
CffmTransformerBlock3d3(dim=dim, input_resolution=input_resolution,
|
1135 |
+
num_heads=num_heads, window_size=window_size,
|
1136 |
+
shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0,
|
1137 |
+
expand_size=0 if (i % 2 == expand_factor) else expand_size,
|
1138 |
+
mlp_ratio=mlp_ratio,
|
1139 |
+
qkv_bias=qkv_bias, qk_scale=qk_scale,
|
1140 |
+
drop=drop,
|
1141 |
+
attn_drop=attn_drop,
|
1142 |
+
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
|
1143 |
+
norm_layer=norm_layer,
|
1144 |
+
pool_method=pool_method,
|
1145 |
+
focal_level=focal_level,
|
1146 |
+
focal_window=focal_window,
|
1147 |
+
use_layerscale=use_layerscale,
|
1148 |
+
layerscale_value=layerscale_value,
|
1149 |
+
focal_l_clips=focal_l_clips,
|
1150 |
+
focal_kernel_clips=focal_kernel_clips)
|
1151 |
+
for i in range(depth)])
|
1152 |
+
|
1153 |
+
# patch merging layer
|
1154 |
+
if downsample is not None:
|
1155 |
+
self.downsample = downsample(
|
1156 |
+
img_size=input_resolution, patch_size=2, in_chans=dim, embed_dim=2*dim,
|
1157 |
+
use_conv_embed=use_conv_embed, norm_layer=norm_layer, use_pre_norm=use_pre_norm,
|
1158 |
+
is_stem=False
|
1159 |
+
)
|
1160 |
+
else:
|
1161 |
+
self.downsample = None
|
1162 |
+
|
1163 |
+
def forward(self, x, batch_size=None, num_clips=None, reg_tokens=None):
|
1164 |
+
B, D, C, H, W = x.shape
|
1165 |
+
x = rearrange(x, 'b d c h w -> b d h w c')
|
1166 |
+
for blk in self.blocks:
|
1167 |
+
if self.use_checkpoint:
|
1168 |
+
x = checkpoint.checkpoint(blk, x)
|
1169 |
+
else:
|
1170 |
+
x = blk(x)
|
1171 |
+
|
1172 |
+
if self.downsample is not None:
|
1173 |
+
x = x.view(x.shape[0], self.input_resolution[0], self.input_resolution[1], -1).permute(0, 3, 1, 2).contiguous()
|
1174 |
+
x = self.downsample(x)
|
1175 |
+
x = rearrange(x, 'b d h w c -> b d c h w')
|
1176 |
+
return x
|
1177 |
+
|
1178 |
+
def extra_repr(self) -> str:
|
1179 |
+
return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
|
1180 |
+
|
1181 |
+
def flops(self):
|
1182 |
+
flops = 0
|
1183 |
+
for blk in self.blocks:
|
1184 |
+
flops += blk.flops()
|
1185 |
+
if self.downsample is not None:
|
1186 |
+
flops += self.downsample.flops()
|
1187 |
+
return flops
|
models/SpaTrackV2/models/depth_refiner/stablizer.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
# from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule
|
5 |
+
from collections import OrderedDict
|
6 |
+
# from mmseg.ops import resize
|
7 |
+
from torch.nn.functional import interpolate as resize
|
8 |
+
# from builder import HEADS
|
9 |
+
from models.SpaTrackV2.models.depth_refiner.decode_head import BaseDecodeHead, BaseDecodeHead_clips, BaseDecodeHead_clips_flow
|
10 |
+
# from mmseg.models.utils import *
|
11 |
+
import attr
|
12 |
+
from IPython import embed
|
13 |
+
from models.SpaTrackV2.models.depth_refiner.stablilization_attention import BasicLayer3d3
|
14 |
+
import cv2
|
15 |
+
from models.SpaTrackV2.models.depth_refiner.network import *
|
16 |
+
import warnings
|
17 |
+
# from mmcv.utils import Registry, build_from_cfg
|
18 |
+
from torch import nn
|
19 |
+
from einops import rearrange
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from models.SpaTrackV2.models.blocks import (
|
22 |
+
AttnBlock, CrossAttnBlock, Mlp
|
23 |
+
)
|
24 |
+
|
25 |
+
class MLP(nn.Module):
|
26 |
+
"""
|
27 |
+
Linear Embedding
|
28 |
+
"""
|
29 |
+
def __init__(self, input_dim=2048, embed_dim=768):
|
30 |
+
super().__init__()
|
31 |
+
self.proj = nn.Linear(input_dim, embed_dim)
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = x.flatten(2).transpose(1, 2)
|
35 |
+
x = self.proj(x)
|
36 |
+
return x
|
37 |
+
|
38 |
+
|
39 |
+
def scatter_multiscale_fast(
|
40 |
+
track2d: torch.Tensor,
|
41 |
+
trackfeature: torch.Tensor,
|
42 |
+
H: int,
|
43 |
+
W: int,
|
44 |
+
kernel_sizes = [1]
|
45 |
+
) -> torch.Tensor:
|
46 |
+
"""
|
47 |
+
Scatter sparse track features onto a dense image grid with weighted multi-scale pooling to handle zero-value gaps.
|
48 |
+
|
49 |
+
This function scatters sparse track features into a dense image grid and applies multi-scale average pooling
|
50 |
+
while excluding zero-value holes. The weight mask ensures that only valid feature regions contribute to the pooling,
|
51 |
+
avoiding dilution by empty pixels.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
track2d (torch.Tensor): Float tensor of shape (B, T, N, 2) containing (x, y) pixel coordinates
|
55 |
+
for each track point across batches, frames, and points.
|
56 |
+
trackfeature (torch.Tensor): Float tensor of shape (B, T, N, C) with C-dimensional features
|
57 |
+
for each track point.
|
58 |
+
H (int): Height of the target output image.
|
59 |
+
W (int): Width of the target output image.
|
60 |
+
kernel_sizes (List[int]): List of odd integers for average pooling kernel sizes. Default: [3, 5, 7].
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
torch.Tensor: Multi-scale fused feature map of shape (B, T, C, H, W) with hole-resistant pooling.
|
64 |
+
"""
|
65 |
+
B, T, N, C = trackfeature.shape
|
66 |
+
device = trackfeature.device
|
67 |
+
|
68 |
+
# 1. Flatten coordinates and filter valid points within image bounds
|
69 |
+
coords_flat = track2d.round().long().reshape(-1, 2) # (B*T*N, 2)
|
70 |
+
x = coords_flat[:, 0] # x coordinates
|
71 |
+
y = coords_flat[:, 1] # y coordinates
|
72 |
+
feat_flat = trackfeature.reshape(-1, C) # Flatten features
|
73 |
+
|
74 |
+
valid_mask = (x >= 0) & (x < W) & (y >= 0) & (y < H)
|
75 |
+
x = x[valid_mask]
|
76 |
+
y = y[valid_mask]
|
77 |
+
feat_flat = feat_flat[valid_mask]
|
78 |
+
valid_count = x.shape[0]
|
79 |
+
|
80 |
+
if valid_count == 0:
|
81 |
+
return torch.zeros(B, T, C, H, W, device=device) # Handle no-valid-point case
|
82 |
+
|
83 |
+
# 2. Calculate linear indices and batch-frame indices for scattering
|
84 |
+
lin_idx = y * W + x # Linear index within a single frame (H*W range)
|
85 |
+
|
86 |
+
# Generate batch-frame indices (e.g., 0~B*T-1 for each frame in batch)
|
87 |
+
bt_idx_raw = (
|
88 |
+
torch.arange(B * T, device=device)
|
89 |
+
.view(B, T, 1)
|
90 |
+
.expand(B, T, N)
|
91 |
+
.reshape(-1)
|
92 |
+
)
|
93 |
+
bt_idx = bt_idx_raw[valid_mask] # Indices for valid points across batch and frames
|
94 |
+
|
95 |
+
# 3. Create accumulation buffers for features and weights
|
96 |
+
total_space = B * T * H * W
|
97 |
+
img_accum_flat = torch.zeros(total_space, C, device=device) # Feature accumulator
|
98 |
+
weight_accum_flat = torch.zeros(total_space, 1, device=device) # Weight accumulator (counts)
|
99 |
+
|
100 |
+
# 4. Scatter features and weights into accumulation buffers
|
101 |
+
idx_in_accum = bt_idx * (H * W) + lin_idx # Global index: batch_frame * H*W + pixel_index
|
102 |
+
|
103 |
+
# Add features to corresponding indices (index_add_ is efficient for sparse updates)
|
104 |
+
img_accum_flat.index_add_(0, idx_in_accum, feat_flat)
|
105 |
+
weight_accum_flat.index_add_(0, idx_in_accum, torch.ones((valid_count, 1), device=device))
|
106 |
+
|
107 |
+
# 5. Normalize features by valid weights, keep zeros for invalid regions
|
108 |
+
valid_mask_flat = weight_accum_flat > 0 # Binary mask for valid pixels
|
109 |
+
img_accum_flat = img_accum_flat / (weight_accum_flat + 1e-6) # Avoid division by zero
|
110 |
+
img_accum_flat = img_accum_flat * valid_mask_flat.float() # Mask out invalid regions
|
111 |
+
|
112 |
+
# 6. Reshape to (B, T, C, H, W) for further processing
|
113 |
+
img = (
|
114 |
+
img_accum_flat.view(B, T, H, W, C)
|
115 |
+
.permute(0, 1, 4, 2, 3)
|
116 |
+
.contiguous()
|
117 |
+
) # Shape: (B, T, C, H, W)
|
118 |
+
|
119 |
+
# 7. Multi-scale pooling with weight masking to exclude zero holes
|
120 |
+
blurred_outputs = []
|
121 |
+
for k in kernel_sizes:
|
122 |
+
pad = k // 2
|
123 |
+
img_bt = img.view(B*T, C, H, W) # Flatten batch and time for pooling
|
124 |
+
|
125 |
+
# Create weight mask for valid regions (1 where features exist, 0 otherwise)
|
126 |
+
weight_mask = (
|
127 |
+
weight_accum_flat.view(B, T, 1, H, W) > 0
|
128 |
+
).float().view(B*T, 1, H, W) # Shape: (B*T, 1, H, W)
|
129 |
+
|
130 |
+
# Calculate number of valid neighbors in each pooling window
|
131 |
+
weight_sum = F.conv2d(
|
132 |
+
weight_mask,
|
133 |
+
torch.ones((1, 1, k, k), device=device),
|
134 |
+
stride=1,
|
135 |
+
padding=pad
|
136 |
+
) # Shape: (B*T, 1, H, W)
|
137 |
+
|
138 |
+
# Sum features only in valid regions
|
139 |
+
feat_sum = F.conv2d(
|
140 |
+
img_bt * weight_mask, # Mask out invalid regions before summing
|
141 |
+
torch.ones((1, 1, k, k), device=device).expand(C, 1, k, k),
|
142 |
+
stride=1,
|
143 |
+
padding=pad,
|
144 |
+
groups=C
|
145 |
+
) # Shape: (B*T, C, H, W)
|
146 |
+
|
147 |
+
# Compute average only over valid neighbors
|
148 |
+
feat_avg = feat_sum / (weight_sum + 1e-6)
|
149 |
+
blurred_outputs.append(feat_avg)
|
150 |
+
|
151 |
+
# 8. Fuse multi-scale results by averaging across kernel sizes
|
152 |
+
fused = torch.stack(blurred_outputs).mean(dim=0) # Average over kernel sizes
|
153 |
+
return fused.view(B, T, C, H, W) # Restore original shape
|
154 |
+
|
155 |
+
#@HEADS.register_module()
|
156 |
+
class Stabilization_Network_Cross_Attention(BaseDecodeHead_clips_flow):
|
157 |
+
|
158 |
+
def __init__(self, feature_strides, **kwargs):
|
159 |
+
super(Stabilization_Network_Cross_Attention, self).__init__(input_transform='multiple_select', **kwargs)
|
160 |
+
self.training = False
|
161 |
+
assert len(feature_strides) == len(self.in_channels)
|
162 |
+
assert min(feature_strides) == feature_strides[0]
|
163 |
+
self.feature_strides = feature_strides
|
164 |
+
|
165 |
+
c1_in_channels, c2_in_channels, c3_in_channels, c4_in_channels = self.in_channels
|
166 |
+
|
167 |
+
decoder_params = kwargs['decoder_params']
|
168 |
+
embedding_dim = decoder_params['embed_dim']
|
169 |
+
|
170 |
+
self.linear_c4 = MLP(input_dim=c4_in_channels, embed_dim=embedding_dim)
|
171 |
+
self.linear_c3 = MLP(input_dim=c3_in_channels, embed_dim=embedding_dim)
|
172 |
+
self.linear_c2 = MLP(input_dim=c2_in_channels, embed_dim=embedding_dim)
|
173 |
+
self.linear_c1 = MLP(input_dim=c1_in_channels, embed_dim=embedding_dim)
|
174 |
+
|
175 |
+
self.linear_fuse = nn.Sequential(nn.Conv2d(embedding_dim*4, embedding_dim, kernel_size=(1, 1), stride=(1, 1), bias=False),\
|
176 |
+
nn.ReLU(inplace=True))
|
177 |
+
|
178 |
+
self.proj_track = nn.Conv2d(100, 128, kernel_size=(1, 1), stride=(1, 1), bias=True)
|
179 |
+
|
180 |
+
depths = decoder_params['depths']
|
181 |
+
|
182 |
+
self.reg_tokens = nn.Parameter(torch.zeros(1, 2, embedding_dim))
|
183 |
+
self.global_patch = nn.Conv2d(embedding_dim, embedding_dim, kernel_size=(8, 8), stride=(8, 8), bias=True)
|
184 |
+
|
185 |
+
self.att_temporal = nn.ModuleList(
|
186 |
+
[
|
187 |
+
AttnBlock(embedding_dim, 8,
|
188 |
+
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
189 |
+
for _ in range(8)
|
190 |
+
]
|
191 |
+
)
|
192 |
+
self.att_spatial = nn.ModuleList(
|
193 |
+
[
|
194 |
+
AttnBlock(embedding_dim, 8,
|
195 |
+
mlp_ratio=4, flash=True, ckpt_fwd=True)
|
196 |
+
for _ in range(8)
|
197 |
+
]
|
198 |
+
)
|
199 |
+
self.scale_shift_head = nn.Sequential(nn.Linear(embedding_dim, embedding_dim), nn.GELU(), nn.Linear(embedding_dim, 4))
|
200 |
+
|
201 |
+
|
202 |
+
# Initialize reg tokens
|
203 |
+
nn.init.trunc_normal_(self.reg_tokens, std=0.02)
|
204 |
+
|
205 |
+
self.decoder_focal=BasicLayer3d3(dim=embedding_dim,
|
206 |
+
input_resolution=(96,
|
207 |
+
96),
|
208 |
+
depth=depths,
|
209 |
+
num_heads=8,
|
210 |
+
window_size=7,
|
211 |
+
mlp_ratio=4.,
|
212 |
+
qkv_bias=True,
|
213 |
+
qk_scale=None,
|
214 |
+
drop=0.,
|
215 |
+
attn_drop=0.,
|
216 |
+
drop_path=0.,
|
217 |
+
norm_layer=nn.LayerNorm,
|
218 |
+
pool_method='fc',
|
219 |
+
downsample=None,
|
220 |
+
focal_level=2,
|
221 |
+
focal_window=5,
|
222 |
+
expand_size=3,
|
223 |
+
expand_layer="all",
|
224 |
+
use_conv_embed=False,
|
225 |
+
use_shift=False,
|
226 |
+
use_pre_norm=False,
|
227 |
+
use_checkpoint=False,
|
228 |
+
use_layerscale=False,
|
229 |
+
layerscale_value=1e-4,
|
230 |
+
focal_l_clips=[7,4,2],
|
231 |
+
focal_kernel_clips=[7,5,3])
|
232 |
+
|
233 |
+
self.ffm2 = FFM(inchannels= 256, midchannels= 256, outchannels = 128)
|
234 |
+
self.ffm1 = FFM(inchannels= 128, midchannels= 128, outchannels = 64)
|
235 |
+
self.ffm0 = FFM(inchannels= 64, midchannels= 64, outchannels = 32,upfactor=1)
|
236 |
+
self.AO = AO(32, outchannels=3, upfactor=1)
|
237 |
+
self._c2 = None
|
238 |
+
self._c_further = None
|
239 |
+
|
240 |
+
def buffer_forward(self, inputs, num_clips=None, imgs=None):#,infermode=1):
|
241 |
+
|
242 |
+
# input: B T 7 H W (7 means 3 rgb + 3 pointmap + 1 uncertainty) normalized
|
243 |
+
if self.training:
|
244 |
+
assert self.num_clips==num_clips
|
245 |
+
|
246 |
+
x = self._transform_inputs(inputs) # len=4, 1/4,1/8,1/16,1/32
|
247 |
+
c1, c2, c3, c4 = x
|
248 |
+
|
249 |
+
############## MLP decoder on C1-C4 ###########
|
250 |
+
n, _, h, w = c4.shape
|
251 |
+
batch_size = n // num_clips
|
252 |
+
|
253 |
+
_c4 = self.linear_c4(c4).permute(0,2,1).reshape(n, -1, c4.shape[2], c4.shape[3])
|
254 |
+
_c4 = resize(_c4, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
255 |
+
|
256 |
+
_c3 = self.linear_c3(c3).permute(0,2,1).reshape(n, -1, c3.shape[2], c3.shape[3])
|
257 |
+
_c3 = resize(_c3, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
258 |
+
|
259 |
+
_c2 = self.linear_c2(c2).permute(0,2,1).reshape(n, -1, c2.shape[2], c2.shape[3])
|
260 |
+
_c2 = resize(_c2, size=c1.size()[2:],mode='bilinear',align_corners=False)
|
261 |
+
|
262 |
+
_c1 = self.linear_c1(c1).permute(0,2,1).reshape(n, -1, c1.shape[2], c1.shape[3])
|
263 |
+
_c = self.linear_fuse(torch.cat([_c4, _c3, _c2, _c1], dim=1))
|
264 |
+
|
265 |
+
_, _, h, w=_c.shape
|
266 |
+
_c_further=_c.reshape(batch_size, num_clips, -1, h, w) #h2w2
|
267 |
+
|
268 |
+
# Expand reg_tokens to match batch size
|
269 |
+
reg_tokens = self.reg_tokens.expand(batch_size*num_clips, -1, -1) # [B, 2, C]
|
270 |
+
|
271 |
+
_c2=self.decoder_focal(_c_further, batch_size=batch_size, num_clips=num_clips, reg_tokens=reg_tokens)
|
272 |
+
|
273 |
+
assert _c_further.shape==_c2.shape
|
274 |
+
self._c2 = _c2
|
275 |
+
self._c_further = _c_further
|
276 |
+
|
277 |
+
# compute the scale and shift of the global patch
|
278 |
+
global_patch = self.global_patch(_c2.view(batch_size*num_clips, -1, h, w)).view(batch_size*num_clips, _c2.shape[2], -1).permute(0,2,1)
|
279 |
+
global_patch = torch.cat([global_patch, reg_tokens], dim=1)
|
280 |
+
for i in range(8):
|
281 |
+
global_patch = self.att_temporal[i](global_patch)
|
282 |
+
global_patch = rearrange(global_patch, '(b t) n c -> (b n) t c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
283 |
+
global_patch = self.att_spatial[i](global_patch)
|
284 |
+
global_patch = rearrange(global_patch, '(b n) t c -> (b t) n c', b=batch_size, t=num_clips, c=_c2.shape[2])
|
285 |
+
|
286 |
+
reg_tokens = global_patch[:, -2:, :]
|
287 |
+
s_ = self.scale_shift_head(reg_tokens)
|
288 |
+
scale = 1 + s_[:, 0, :1].view(batch_size, num_clips, 1, 1, 1)
|
289 |
+
shift = s_[:, 1, 1:].view(batch_size, num_clips, 3, 1, 1)
|
290 |
+
shift[:,:,:2,...] = 0
|
291 |
+
return scale, shift
|
292 |
+
|
293 |
+
def forward(self, inputs, edge_feat, edge_feat1, tracks, tracks_uvd, num_clips=None, imgs=None, vis_track=None):#,infermode=1):
|
294 |
+
|
295 |
+
if self._c2 is None:
|
296 |
+
scale, shift = self.buffer_forward(inputs,num_clips,imgs)
|
297 |
+
|
298 |
+
B, T, N, _ = tracks.shape
|
299 |
+
|
300 |
+
_c2 = self._c2
|
301 |
+
_c_further = self._c_further
|
302 |
+
|
303 |
+
# skip and head
|
304 |
+
_c_further = rearrange(_c_further, 'b t c h w -> (b t) c h w', b=B, t=T)
|
305 |
+
_c2 = rearrange(_c2, 'b t c h w -> (b t) c h w', b=B, t=T)
|
306 |
+
|
307 |
+
outframe = self.ffm2(_c_further, _c2)
|
308 |
+
|
309 |
+
tracks_uv = tracks_uvd[...,:2].clone()
|
310 |
+
track_feature = scatter_multiscale_fast(tracks_uv/2, tracks, outframe.shape[-2], outframe.shape[-1], kernel_sizes=[1, 3, 5])
|
311 |
+
# visualize track_feature as video
|
312 |
+
# import cv2
|
313 |
+
# import imageio
|
314 |
+
# import os
|
315 |
+
# BT, C, H, W = outframe.shape
|
316 |
+
# track_feature_vis = track_feature.view(B, T, 3, H, W).float().detach().cpu().numpy()
|
317 |
+
# track_feature_vis = track_feature_vis.transpose(0,1,3,4,2)
|
318 |
+
# track_feature_vis = (track_feature_vis - track_feature_vis.min()) / (track_feature_vis.max() - track_feature_vis.min() + 1e-6)
|
319 |
+
# track_feature_vis = (track_feature_vis * 255).astype(np.uint8)
|
320 |
+
# imgs =(imgs.detach() + 1) * 127.5
|
321 |
+
# vis_track.visualize(video=imgs, tracks=tracks_uv, filename="test")
|
322 |
+
# for b in range(B):
|
323 |
+
# frames = []
|
324 |
+
# for t in range(T):
|
325 |
+
# frame = track_feature_vis[b,t]
|
326 |
+
# frame = cv2.applyColorMap(frame[...,0], cv2.COLORMAP_JET)
|
327 |
+
# frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
328 |
+
# frames.append(frame)
|
329 |
+
# # Save as gif
|
330 |
+
# imageio.mimsave(f'track_feature_b{b}.gif', frames, duration=0.1)
|
331 |
+
# import pdb; pdb.set_trace()
|
332 |
+
track_feature = rearrange(track_feature, 'b t c h w -> (b t) c h w')
|
333 |
+
track_feature = self.proj_track(track_feature)
|
334 |
+
outframe = self.ffm1(edge_feat1 + track_feature,outframe)
|
335 |
+
outframe = self.ffm0(edge_feat,outframe)
|
336 |
+
outframe = self.AO(outframe)
|
337 |
+
|
338 |
+
return outframe
|
339 |
+
|
340 |
+
def reset_success(self):
|
341 |
+
self._c2 = None
|
342 |
+
self._c_further = None
|
models/SpaTrackV2/models/predictor.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from tqdm import tqdm
|
11 |
+
from models.SpaTrackV2.models.SpaTrack import SpaTrack2
|
12 |
+
from typing import Literal
|
13 |
+
import numpy as np
|
14 |
+
from pathlib import Path
|
15 |
+
from typing import Union, Optional
|
16 |
+
import cv2
|
17 |
+
import os
|
18 |
+
import decord
|
19 |
+
|
20 |
+
class Predictor(torch.nn.Module):
|
21 |
+
def __init__(self, args=None):
|
22 |
+
super().__init__()
|
23 |
+
self.args = args
|
24 |
+
self.spatrack = SpaTrack2(loggers=[None, None, None], **args)
|
25 |
+
self.S_wind = args.Track_cfg.s_wind
|
26 |
+
self.overlap = args.Track_cfg.overlap
|
27 |
+
|
28 |
+
def to(self, device: Union[str, torch.device]):
|
29 |
+
self.spatrack.to(device)
|
30 |
+
self.spatrack.base_model.to(device)
|
31 |
+
|
32 |
+
@classmethod
|
33 |
+
def from_pretrained(
|
34 |
+
cls,
|
35 |
+
pretrained_model_name_or_path: Union[str, Path],
|
36 |
+
*,
|
37 |
+
force_download: bool = False,
|
38 |
+
cache_dir: Optional[str] = None,
|
39 |
+
device: Optional[Union[str, torch.device]] = None,
|
40 |
+
model_cfg: Optional[dict] = None,
|
41 |
+
**kwargs,
|
42 |
+
) -> "SpaTrack2":
|
43 |
+
"""
|
44 |
+
Load a pretrained model from a local file or a remote repository.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
pretrained_model_name_or_path (str or Path):
|
48 |
+
- Path to a local model file (e.g., `./model.pth`).
|
49 |
+
- HuggingFace Hub model ID (e.g., `username/model-name`).
|
50 |
+
force_download (bool, optional):
|
51 |
+
Whether to force re-download even if cached. Default: False.
|
52 |
+
cache_dir (str, optional):
|
53 |
+
Custom cache directory. Default: None (use default cache).
|
54 |
+
device (str or torch.device, optional):
|
55 |
+
Target device (e.g., "cuda", "cpu"). Default: None (keep original).
|
56 |
+
**kwargs:
|
57 |
+
Additional config overrides.
|
58 |
+
|
59 |
+
Returns:
|
60 |
+
SpaTrack2: Loaded pretrained model.
|
61 |
+
"""
|
62 |
+
# (1) check the path is local or remote
|
63 |
+
if isinstance(pretrained_model_name_or_path, Path):
|
64 |
+
model_path = str(pretrained_model_name_or_path)
|
65 |
+
else:
|
66 |
+
model_path = pretrained_model_name_or_path
|
67 |
+
# (2) if the path is remote, download it
|
68 |
+
if not os.path.exists(model_path):
|
69 |
+
raise NotImplementedError("Remote download not implemented yet. Use a local path.")
|
70 |
+
# (3) load the model weights
|
71 |
+
|
72 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
73 |
+
# (4) initialize the model (can load config.json if exists)
|
74 |
+
config_path = os.path.join(os.path.dirname(model_path), "config.json")
|
75 |
+
config = {}
|
76 |
+
if os.path.exists(config_path):
|
77 |
+
import json
|
78 |
+
with open(config_path, "r") as f:
|
79 |
+
config.update(json.load(f))
|
80 |
+
config.update(kwargs) # allow override the config
|
81 |
+
if model_cfg is not None:
|
82 |
+
config = model_cfg
|
83 |
+
model = cls(config)
|
84 |
+
if "model" in state_dict:
|
85 |
+
model.spatrack.load_state_dict(state_dict["model"], strict=False)
|
86 |
+
else:
|
87 |
+
model.spatrack.load_state_dict(state_dict, strict=False)
|
88 |
+
# (5) device management
|
89 |
+
if device is not None:
|
90 |
+
model.to(device)
|
91 |
+
|
92 |
+
return model
|
93 |
+
|
94 |
+
def forward(self, video: str|torch.Tensor|np.ndarray,
|
95 |
+
depth: str|torch.Tensor|np.ndarray=None,
|
96 |
+
unc_metric: str|torch.Tensor|np.ndarray=None,
|
97 |
+
intrs: str|torch.Tensor|np.ndarray=None,
|
98 |
+
extrs: str|torch.Tensor|np.ndarray=None,
|
99 |
+
queries=None, queries_3d=None, iters_track=4,
|
100 |
+
full_point=False, fps=30, track2d_gt=None,
|
101 |
+
fixed_cam=False, query_no_BA=False, stage=0,
|
102 |
+
support_frame=0, replace_ratio=0.6):
|
103 |
+
"""
|
104 |
+
video: this could be a path to a video, a tensor of shape (T, C, H, W) or a numpy array of shape (T, C, H, W)
|
105 |
+
queries: (B, N, 2)
|
106 |
+
"""
|
107 |
+
|
108 |
+
if isinstance(video, str):
|
109 |
+
video = decord.VideoReader(video)
|
110 |
+
video = video[::fps].asnumpy() # Convert to numpy array
|
111 |
+
video = np.array(video) # Ensure numpy array
|
112 |
+
video = torch.from_numpy(video).permute(0, 3, 1, 2).float()
|
113 |
+
elif isinstance(video, np.ndarray):
|
114 |
+
video = torch.from_numpy(video).float()
|
115 |
+
|
116 |
+
if isinstance(depth, np.ndarray):
|
117 |
+
depth = torch.from_numpy(depth).float()
|
118 |
+
if isinstance(intrs, np.ndarray):
|
119 |
+
intrs = torch.from_numpy(intrs).float()
|
120 |
+
if isinstance(extrs, np.ndarray):
|
121 |
+
extrs = torch.from_numpy(extrs).float()
|
122 |
+
if isinstance(unc_metric, np.ndarray):
|
123 |
+
unc_metric = torch.from_numpy(unc_metric).float()
|
124 |
+
|
125 |
+
T_, C, H, W = video.shape
|
126 |
+
step_slide = self.S_wind - self.overlap
|
127 |
+
if T_ > self.S_wind:
|
128 |
+
|
129 |
+
num_windows = (T_ - self.S_wind + step_slide) // step_slide
|
130 |
+
T = num_windows * step_slide + self.S_wind
|
131 |
+
pad_len = T - T_
|
132 |
+
|
133 |
+
video = torch.cat([video, video[-1:].repeat(T-video.shape[0], 1, 1, 1)], dim=0)
|
134 |
+
if depth is not None:
|
135 |
+
depth = torch.cat([depth, depth[-1:].repeat(T-depth.shape[0], 1, 1)], dim=0)
|
136 |
+
if intrs is not None:
|
137 |
+
intrs = torch.cat([intrs, intrs[-1:].repeat(T-intrs.shape[0], 1, 1)], dim=0)
|
138 |
+
if extrs is not None:
|
139 |
+
extrs = torch.cat([extrs, extrs[-1:].repeat(T-extrs.shape[0], 1, 1)], dim=0)
|
140 |
+
if unc_metric is not None:
|
141 |
+
unc_metric = torch.cat([unc_metric, unc_metric[-1:].repeat(T-unc_metric.shape[0], 1, 1)], dim=0)
|
142 |
+
with torch.no_grad():
|
143 |
+
ret = self.spatrack.forward_stream(video, queries, T_org=T_,
|
144 |
+
depth=depth, intrs=intrs, unc_metric_in=unc_metric, extrs=extrs, queries_3d=queries_3d,
|
145 |
+
window_len=self.S_wind, overlap_len=self.overlap, track2d_gt=track2d_gt, full_point=full_point, iters_track=iters_track,
|
146 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA, stage=stage, support_frame=support_frame, replace_ratio=replace_ratio) + (video[:T_],)
|
147 |
+
|
148 |
+
|
149 |
+
return ret
|
150 |
+
|
151 |
+
|
152 |
+
|
153 |
+
|
models/SpaTrackV2/models/tracker3D/TrackRefiner.py
ADDED
@@ -0,0 +1,1478 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import torch
|
3 |
+
import torch.amp
|
4 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.cotracker_base import CoTrackerThreeOffline, get_1d_sincos_pos_embed_from_grid
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
7 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
8 |
+
from models.SpaTrackV2.models.blocks import bilinear_sampler
|
9 |
+
import torch.nn as nn
|
10 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
11 |
+
EfficientUpdateFormer, AttnBlock, Attention, CrossAttnBlock,
|
12 |
+
sequence_BCE_loss, sequence_loss, sequence_prob_loss, sequence_dyn_prob_loss, sequence_loss_xyz, balanced_binary_cross_entropy
|
13 |
+
)
|
14 |
+
from torchvision.io import write_video
|
15 |
+
import math
|
16 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
17 |
+
Mlp, BasicEncoder, EfficientUpdateFormer, GeometryEncoder, NeighborTransformer, CorrPointformer
|
18 |
+
)
|
19 |
+
from models.SpaTrackV2.utils.embeddings import get_3d_sincos_pos_embed_from_grid
|
20 |
+
from einops import rearrange, repeat
|
21 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import (
|
22 |
+
EfficientUpdateFormer3D, weighted_procrustes_torch, posenc, key_fr_wprocrustes, get_topo_mask,
|
23 |
+
TrackFusion, get_nth_visible_time_index
|
24 |
+
)
|
25 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.ba import extract_static_from_3DTracks, ba_pycolmap
|
26 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.pointmap_updator import PointMapUpdator
|
27 |
+
from models.SpaTrackV2.models.depth_refiner.depth_refiner import TrackStablizer
|
28 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.alignment import affine_invariant_global_loss
|
29 |
+
from models.SpaTrackV2.models.tracker3D.delta_utils.upsample_transformer import UpsampleTransformerAlibi
|
30 |
+
|
31 |
+
class TrackRefiner3D(CoTrackerThreeOffline):
|
32 |
+
|
33 |
+
def __init__(self, args=None):
|
34 |
+
super().__init__(**args.base)
|
35 |
+
|
36 |
+
"""
|
37 |
+
This is 3D warpper from cotracker, which load the cotracker pretrain and
|
38 |
+
jointly refine the `camera pose`, `3D tracks`, `video depth`, `visibility` and `conf`
|
39 |
+
"""
|
40 |
+
self.updateformer3D = EfficientUpdateFormer3D(self.updateformer)
|
41 |
+
self.corr_depth_mlp = Mlp(in_features=256, hidden_features=256, out_features=256)
|
42 |
+
self.rel_pos_mlp = Mlp(in_features=75, hidden_features=128, out_features=128)
|
43 |
+
self.rel_pos_glob_mlp = Mlp(in_features=75, hidden_features=128, out_features=256)
|
44 |
+
self.corr_xyz_mlp = Mlp(in_features=256, hidden_features=128, out_features=128)
|
45 |
+
self.xyz_mlp = Mlp(in_features=126, hidden_features=128, out_features=84)
|
46 |
+
# self.track_feat_mlp = Mlp(in_features=1110, hidden_features=128, out_features=128)
|
47 |
+
self.proj_xyz_embed = Mlp(in_features=1210+50, hidden_features=1110, out_features=1110)
|
48 |
+
# get the anchor point's embedding, and init the pts refiner
|
49 |
+
update_pts = True
|
50 |
+
# self.corr_transformer = nn.ModuleList([
|
51 |
+
# CorrPointformer(
|
52 |
+
# dim=128,
|
53 |
+
# num_heads=8,
|
54 |
+
# head_dim=128 // 8,
|
55 |
+
# mlp_ratio=4.0,
|
56 |
+
# )
|
57 |
+
# for _ in range(self.corr_levels)
|
58 |
+
# ])
|
59 |
+
self.corr_transformer = nn.ModuleList([
|
60 |
+
CorrPointformer(
|
61 |
+
dim=128,
|
62 |
+
num_heads=8,
|
63 |
+
head_dim=128 // 8,
|
64 |
+
mlp_ratio=4.0,
|
65 |
+
)
|
66 |
+
]
|
67 |
+
)
|
68 |
+
self.fnet = BasicEncoder(input_dim=3,
|
69 |
+
output_dim=self.latent_dim, stride=self.stride)
|
70 |
+
self.corr3d_radius = 3
|
71 |
+
|
72 |
+
if args.stablizer:
|
73 |
+
self.scale_shift_tokens = nn.Parameter(torch.randn(1, 2, self.latent_dim, requires_grad=True))
|
74 |
+
self.upsample_kernel_size = 5
|
75 |
+
self.residual_embedding = nn.Parameter(torch.randn(
|
76 |
+
self.latent_dim, self.model_resolution[0]//16,
|
77 |
+
self.model_resolution[1]//16, requires_grad=True))
|
78 |
+
self.dense_mlp = nn.Conv2d(2*self.latent_dim+63, self.latent_dim, kernel_size=1, stride=1, padding=0)
|
79 |
+
self.upsample_factor = 4
|
80 |
+
self.upsample_transformer = UpsampleTransformerAlibi(
|
81 |
+
kernel_size=self.upsample_kernel_size, # kernel_size=3, #
|
82 |
+
stride=self.stride,
|
83 |
+
latent_dim=self.latent_dim,
|
84 |
+
num_attn_blocks=2,
|
85 |
+
upsample_factor=4,
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
self.update_pointmap = None
|
89 |
+
|
90 |
+
self.mode = args.mode
|
91 |
+
if self.mode == "online":
|
92 |
+
self.s_wind = args.s_wind
|
93 |
+
self.overlap = args.overlap
|
94 |
+
|
95 |
+
def upsample_with_mask(
|
96 |
+
self, inp: torch.Tensor, mask: torch.Tensor
|
97 |
+
) -> torch.Tensor:
|
98 |
+
"""Upsample flow field [H/P, W/P, 2] -> [H, W, 2] using convex combination"""
|
99 |
+
H, W = inp.shape[-2:]
|
100 |
+
up_inp = F.unfold(
|
101 |
+
inp, [self.upsample_kernel_size, self.upsample_kernel_size], padding=(self.upsample_kernel_size - 1) // 2
|
102 |
+
)
|
103 |
+
up_inp = rearrange(up_inp, "b c (h w) -> b c h w", h=H, w=W)
|
104 |
+
up_inp = F.interpolate(up_inp, scale_factor=self.upsample_factor, mode="nearest")
|
105 |
+
up_inp = rearrange(
|
106 |
+
up_inp, "b (c i j) h w -> b c (i j) h w", i=self.upsample_kernel_size, j=self.upsample_kernel_size
|
107 |
+
)
|
108 |
+
|
109 |
+
up_inp = torch.sum(mask * up_inp, dim=2)
|
110 |
+
return up_inp
|
111 |
+
|
112 |
+
def track_from_cam(self, queries, c2w_traj, intrs,
|
113 |
+
rgbs=None, visualize=False):
|
114 |
+
"""
|
115 |
+
This function will generate tracks by camera transform
|
116 |
+
|
117 |
+
Args:
|
118 |
+
queries: B T N 4
|
119 |
+
c2w_traj: B T 4 4
|
120 |
+
intrs: B T 3 3
|
121 |
+
"""
|
122 |
+
B, T, N, _ = queries.shape
|
123 |
+
query_t = queries[:,0,:,0].to(torch.int64) # B N
|
124 |
+
query_c2w = torch.gather(c2w_traj,
|
125 |
+
dim=1, index=query_t[..., None, None].expand(-1, -1, 4, 4)) # B N 4 4
|
126 |
+
query_intr = torch.gather(intrs,
|
127 |
+
dim=1, index=query_t[..., None, None].expand(-1, -1, 3, 3)) # B N 3 3
|
128 |
+
query_pts = queries[:,0,:,1:4].clone() # B N 3
|
129 |
+
query_d = queries[:,0,:,3:4] # B N 3
|
130 |
+
query_pts[...,2] = 1
|
131 |
+
|
132 |
+
cam_pts = torch.einsum("bnij,bnj->bni", torch.inverse(query_intr), query_pts)*query_d # B N 3
|
133 |
+
# convert to world
|
134 |
+
cam_pts_h = torch.zeros(B, N, 4, device=cam_pts.device)
|
135 |
+
cam_pts_h[..., :3] = cam_pts
|
136 |
+
cam_pts_h[..., 3] = 1
|
137 |
+
world_pts = torch.einsum("bnij,bnj->bni", query_c2w, cam_pts_h)
|
138 |
+
# convert to other frames
|
139 |
+
cam_other_pts_ = torch.einsum("btnij,btnj->btni",
|
140 |
+
torch.inverse(c2w_traj[:,:,None].float().repeat(1,1,N,1,1)),
|
141 |
+
world_pts[:,None].repeat(1,T,1,1))
|
142 |
+
cam_depth = cam_other_pts_[...,2:3]
|
143 |
+
cam_other_pts = cam_other_pts_[...,:3] / (cam_other_pts_[...,2:3].abs()+1e-6)
|
144 |
+
cam_other_pts = torch.einsum("btnij,btnj->btni", intrs[:,:,None].repeat(1,1,N,1,1), cam_other_pts[...,:3])
|
145 |
+
cam_other_pts[..., 2:] = cam_depth
|
146 |
+
|
147 |
+
if visualize:
|
148 |
+
viser = Visualizer(save_dir=".", grayscale=True,
|
149 |
+
fps=10, pad_value=50, tracks_leave_trace=0)
|
150 |
+
cam_other_pts[..., 0] /= self.factor_x
|
151 |
+
cam_other_pts[..., 1] /= self.factor_y
|
152 |
+
viser.visualize(video=rgbs, tracks=cam_other_pts[..., :2], filename="test")
|
153 |
+
|
154 |
+
|
155 |
+
init_xyzs = cam_other_pts
|
156 |
+
|
157 |
+
return init_xyzs, world_pts[..., :3], cam_other_pts_[..., :3]
|
158 |
+
|
159 |
+
def cam_from_track(self, tracks, intrs,
|
160 |
+
dyn_prob=None, metric_unc=None,
|
161 |
+
vis_est=None, only_cam_pts=False,
|
162 |
+
track_feat_concat=None,
|
163 |
+
tracks_xyz=None,
|
164 |
+
query_pts=None,
|
165 |
+
fixed_cam=False,
|
166 |
+
depth_unproj=None,
|
167 |
+
cam_gt=None,
|
168 |
+
init_pose=False,
|
169 |
+
):
|
170 |
+
"""
|
171 |
+
This function will generate tracks by camera transform
|
172 |
+
|
173 |
+
Args:
|
174 |
+
queries: B T N 3
|
175 |
+
scale_est: 1 1
|
176 |
+
shift_est: 1 1
|
177 |
+
intrs: B T 3 3
|
178 |
+
dyn_prob: B T N
|
179 |
+
metric_unc: B N 1
|
180 |
+
query_pts: B T N 3
|
181 |
+
"""
|
182 |
+
if tracks_xyz is not None:
|
183 |
+
B, T, N, _ = tracks.shape
|
184 |
+
cam_pts = tracks_xyz
|
185 |
+
intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
|
186 |
+
else:
|
187 |
+
B, T, N, _ = tracks.shape
|
188 |
+
# get the pts in cam coordinate
|
189 |
+
tracks_xy = tracks[...,:3].clone().detach() # B T N 3
|
190 |
+
# tracks_z = 1/(tracks[...,2:] * scale_est + shift_est) # B T N 1
|
191 |
+
tracks_z = tracks[...,2:].detach() # B T N 1
|
192 |
+
tracks_xy[...,2] = 1
|
193 |
+
intr_repeat = intrs[:,:,None].repeat(1,1,N,1,1)
|
194 |
+
cam_pts = torch.einsum("bnij,bnj->bni",
|
195 |
+
torch.inverse(intr_repeat.view(B*T,N,3,3)).float(),
|
196 |
+
tracks_xy.view(B*T, N, 3))*(tracks_z.view(B*T,N,1).abs()) # B*T N 3
|
197 |
+
cam_pts[...,2] *= torch.sign(tracks_z.view(B*T,N))
|
198 |
+
# get the normalized cam pts, and pts refiner
|
199 |
+
mask_z = (tracks_z.max(dim=1)[0]<200).squeeze()
|
200 |
+
cam_pts = cam_pts.view(B, T, N, 3)
|
201 |
+
|
202 |
+
if only_cam_pts:
|
203 |
+
return cam_pts
|
204 |
+
dyn_prob = dyn_prob.mean(dim=1)[..., None]
|
205 |
+
# B T N 3 -> local frames coordinates. transformer static points B T N 3 -> B T N 3 static (B T N 3) -> same -> dynamic points @ C2T.inverse()
|
206 |
+
# get the cam pose
|
207 |
+
vis_est_ = vis_est[:,:,None,:]
|
208 |
+
graph_matrix = (vis_est_*vis_est_.permute(0, 2,1,3)).detach()
|
209 |
+
# find the max connected component
|
210 |
+
key_fr_idx = [0]
|
211 |
+
weight_final = (metric_unc) # * vis_est
|
212 |
+
|
213 |
+
|
214 |
+
with torch.amp.autocast(enabled=False, device_type='cuda'):
|
215 |
+
if fixed_cam:
|
216 |
+
c2w_traj_init = self.c2w_est_curr
|
217 |
+
c2w_traj_glob = c2w_traj_init
|
218 |
+
cam_pts_refine = cam_pts
|
219 |
+
intrs_refine = intrs
|
220 |
+
xy_refine = query_pts[...,1:3]
|
221 |
+
world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
|
222 |
+
world_tracks_refined = world_tracks_init
|
223 |
+
# extract the stable static points for refine the camera pose
|
224 |
+
intrs_dn = intrs.clone()
|
225 |
+
intrs_dn[...,0,:] *= self.factor_x
|
226 |
+
intrs_dn[...,1,:] *= self.factor_y
|
227 |
+
_, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
|
228 |
+
world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
|
229 |
+
dyn_prob, query_world_pts,
|
230 |
+
vis_est, tracks, img_size=self.image_size,
|
231 |
+
K=0)
|
232 |
+
world_static_refine = world_tracks_static
|
233 |
+
|
234 |
+
else:
|
235 |
+
|
236 |
+
if (not self.training):
|
237 |
+
# if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
|
238 |
+
campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
|
239 |
+
# campts_update = cam_pts
|
240 |
+
c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
|
241 |
+
(weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
|
242 |
+
c2w_traj_init = [email protected]_est_curr
|
243 |
+
# else:
|
244 |
+
# c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
|
245 |
+
else:
|
246 |
+
# if (self.c2w_est_curr==torch.eye(4, device=cam_pts.device).repeat(B, T, 1, 1)).all():
|
247 |
+
campts_update = torch.einsum("btij,btnj->btni", self.c2w_est_curr[...,:3,:3], cam_pts) + self.c2w_est_curr[...,None,:3,3]
|
248 |
+
# campts_update = cam_pts
|
249 |
+
c2w_traj_init_update = key_fr_wprocrustes(campts_update, graph_matrix,
|
250 |
+
(weight_final*(1-dyn_prob)).permute(0,2,1), vis_est_.permute(0,1,3,2))
|
251 |
+
c2w_traj_init = [email protected]_est_curr
|
252 |
+
# else:
|
253 |
+
# c2w_traj_init = self.c2w_est_curr # extract the stable static points for refine the camera pose
|
254 |
+
|
255 |
+
intrs_dn = intrs.clone()
|
256 |
+
intrs_dn[...,0,:] *= self.factor_x
|
257 |
+
intrs_dn[...,1,:] *= self.factor_y
|
258 |
+
_, query_world_pts, _ = self.track_from_cam(query_pts, c2w_traj_init, intrs_dn)
|
259 |
+
# refine the world tracks
|
260 |
+
world_tracks_init = torch.einsum("btij,btnj->btni", c2w_traj_init[:,:,:3,:3], cam_pts) + c2w_traj_init[:,:,None,:3,3]
|
261 |
+
world_tracks_static, mask_static, mask_topk, vis_mask_static, tracks2d_static = extract_static_from_3DTracks(world_tracks_init,
|
262 |
+
dyn_prob, query_world_pts,
|
263 |
+
vis_est, tracks, img_size=self.image_size,
|
264 |
+
K=150 if self.training else 1500)
|
265 |
+
# calculate the efficient ba
|
266 |
+
cam_tracks_static = cam_pts[:,:,mask_static.squeeze(),:][:,:,mask_topk.squeeze(),:]
|
267 |
+
cam_tracks_static[...,2] = depth_unproj.view(B, T, N)[:,:,mask_static.squeeze()][:,:,mask_topk.squeeze()]
|
268 |
+
|
269 |
+
c2w_traj_glob, world_static_refine, intrs_refine = ba_pycolmap(world_tracks_static, intrs,
|
270 |
+
c2w_traj_init, vis_mask_static,
|
271 |
+
tracks2d_static, self.image_size,
|
272 |
+
cam_tracks_static=cam_tracks_static,
|
273 |
+
training=self.training, query_pts=query_pts)
|
274 |
+
c2w_traj_glob = c2w_traj_glob.view(B, T, 4, 4)
|
275 |
+
world_tracks_refined = world_tracks_init
|
276 |
+
|
277 |
+
#NOTE: merge the index of static points and topk points
|
278 |
+
# merge_idx = torch.where(mask_static.squeeze()>0)[0][mask_topk.squeeze()]
|
279 |
+
# world_tracks_refined[:,:,merge_idx] = world_static_refine
|
280 |
+
|
281 |
+
# test the procrustes
|
282 |
+
w2c_traj_glob = torch.inverse(c2w_traj_init.detach())
|
283 |
+
cam_pts_refine = torch.einsum("btij,btnj->btni", w2c_traj_glob[:,:,:3,:3], world_tracks_refined) + w2c_traj_glob[:,:,None,:3,3]
|
284 |
+
# get the xyz_refine
|
285 |
+
#TODO: refiner
|
286 |
+
cam_pts4_proj = cam_pts_refine.clone()
|
287 |
+
cam_pts4_proj[...,2] *= torch.sign(cam_pts4_proj[...,2:3].view(B*T,N))
|
288 |
+
xy_refine = torch.einsum("btnij,btnj->btni", intrs_refine.view(B,T,1,3,3).repeat(1,1,N,1,1), cam_pts4_proj/cam_pts4_proj[...,2:3].abs())
|
289 |
+
xy_refine[..., 2] = cam_pts4_proj[...,2:3].view(B*T,N)
|
290 |
+
# xy_refine = torch.zeros_like(cam_pts_refine)[...,:2]
|
291 |
+
return c2w_traj_glob, cam_pts_refine, intrs_refine, xy_refine, world_tracks_init, world_tracks_refined, c2w_traj_init
|
292 |
+
|
293 |
+
def extract_img_feat(self, video, fmaps_chunk_size=200):
|
294 |
+
B, T, C, H, W = video.shape
|
295 |
+
dtype = video.dtype
|
296 |
+
H4, W4 = H // self.stride, W // self.stride
|
297 |
+
# Compute convolutional features for the video or for the current chunk in case of online mode
|
298 |
+
if T > fmaps_chunk_size:
|
299 |
+
fmaps = []
|
300 |
+
for t in range(0, T, fmaps_chunk_size):
|
301 |
+
video_chunk = video[:, t : t + fmaps_chunk_size]
|
302 |
+
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C, H, W))
|
303 |
+
T_chunk = video_chunk.shape[1]
|
304 |
+
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
|
305 |
+
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
|
306 |
+
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
|
307 |
+
else:
|
308 |
+
fmaps = self.fnet(video.reshape(-1, C, H, W))
|
309 |
+
fmaps = fmaps.permute(0, 2, 3, 1)
|
310 |
+
fmaps = fmaps / torch.sqrt(
|
311 |
+
torch.maximum(
|
312 |
+
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
|
313 |
+
torch.tensor(1e-12, device=fmaps.device),
|
314 |
+
)
|
315 |
+
)
|
316 |
+
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
|
317 |
+
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
318 |
+
)
|
319 |
+
fmaps = fmaps.to(dtype)
|
320 |
+
|
321 |
+
return fmaps
|
322 |
+
|
323 |
+
def norm_xyz(self, xyz):
|
324 |
+
"""
|
325 |
+
xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
|
326 |
+
"""
|
327 |
+
if xyz.ndim == 3:
|
328 |
+
min_pts = self.min_pts
|
329 |
+
max_pts = self.max_pts
|
330 |
+
return (xyz - min_pts[None,None,:]) / (max_pts - min_pts)[None,None,:] * 2 - 1
|
331 |
+
elif xyz.ndim == 4:
|
332 |
+
min_pts = self.min_pts
|
333 |
+
max_pts = self.max_pts
|
334 |
+
return (xyz - min_pts[None,None,None,:]) / (max_pts - min_pts)[None,None,None,:] * 2 - 1
|
335 |
+
elif xyz.ndim == 5:
|
336 |
+
if xyz.shape[2] == 3:
|
337 |
+
min_pts = self.min_pts
|
338 |
+
max_pts = self.max_pts
|
339 |
+
return (xyz - min_pts[None,None,:,None,None]) / (max_pts - min_pts)[None,None,:,None,None] * 2 - 1
|
340 |
+
elif xyz.shape[-1] == 3:
|
341 |
+
min_pts = self.min_pts
|
342 |
+
max_pts = self.max_pts
|
343 |
+
return (xyz - min_pts[None,None,None,None,:]) / (max_pts - min_pts)[None,None,None,None,:] * 2 - 1
|
344 |
+
|
345 |
+
def denorm_xyz(self, xyz):
|
346 |
+
"""
|
347 |
+
xyz can be (B T N 3) or (B T 3 H W) or (B N 3)
|
348 |
+
"""
|
349 |
+
if xyz.ndim == 3:
|
350 |
+
min_pts = self.min_pts
|
351 |
+
max_pts = self.max_pts
|
352 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:] + min_pts[None,None,:]
|
353 |
+
elif xyz.ndim == 4:
|
354 |
+
min_pts = self.min_pts
|
355 |
+
max_pts = self.max_pts
|
356 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,:] + min_pts[None,None,None,:]
|
357 |
+
elif xyz.ndim == 5:
|
358 |
+
if xyz.shape[2] == 3:
|
359 |
+
min_pts = self.min_pts
|
360 |
+
max_pts = self.max_pts
|
361 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,:,None,None] + min_pts[None,None,:,None,None]
|
362 |
+
elif xyz.shape[-1] == 3:
|
363 |
+
min_pts = self.min_pts
|
364 |
+
max_pts = self.max_pts
|
365 |
+
return (xyz + 1) / 2 * (max_pts - min_pts)[None,None,None,None,:] + min_pts[None,None,None,None,:]
|
366 |
+
|
367 |
+
def forward(
|
368 |
+
self,
|
369 |
+
video,
|
370 |
+
metric_depth,
|
371 |
+
metric_unc,
|
372 |
+
point_map,
|
373 |
+
queries,
|
374 |
+
pts_q_3d=None,
|
375 |
+
overlap_d=None,
|
376 |
+
iters=4,
|
377 |
+
add_space_attn=True,
|
378 |
+
fmaps_chunk_size=200,
|
379 |
+
intrs=None,
|
380 |
+
traj3d_gt=None,
|
381 |
+
custom_vid=False,
|
382 |
+
vis_gt=None,
|
383 |
+
prec_fx=None,
|
384 |
+
prec_fy=None,
|
385 |
+
cam_gt=None,
|
386 |
+
init_pose=False,
|
387 |
+
support_pts_q=None,
|
388 |
+
update_pointmap=True,
|
389 |
+
fixed_cam=False,
|
390 |
+
query_no_BA=False,
|
391 |
+
stage=0,
|
392 |
+
cache=None,
|
393 |
+
points_map_gt=None,
|
394 |
+
valid_only=False,
|
395 |
+
replace_ratio=0.6,
|
396 |
+
):
|
397 |
+
"""Predict tracks
|
398 |
+
|
399 |
+
Args:
|
400 |
+
video (FloatTensor[B, T, 3 H W]): input videos.
|
401 |
+
queries (FloatTensor[B, N, 3]): point queries.
|
402 |
+
iters (int, optional): number of updates. Defaults to 4.
|
403 |
+
vdp_feats_cache: last layer's feature of depth
|
404 |
+
tracks_init: B T N 3 the initialization of 3D tracks computed by cam pose
|
405 |
+
Returns:
|
406 |
+
- coords_predicted (FloatTensor[B, T, N, 2]):
|
407 |
+
- vis_predicted (FloatTensor[B, T, N]):
|
408 |
+
- train_data: `None` if `is_train` is false, otherwise:
|
409 |
+
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
410 |
+
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
411 |
+
- mask (BoolTensor[B, T, N]):
|
412 |
+
"""
|
413 |
+
self.stage = stage
|
414 |
+
|
415 |
+
if cam_gt is not None:
|
416 |
+
cam_gt = cam_gt.clone()
|
417 |
+
cam_gt = torch.inverse(cam_gt[:,:1,...])@cam_gt
|
418 |
+
B, T, C, _, _ = video.shape
|
419 |
+
_, _, H_, W_ = metric_depth.shape
|
420 |
+
_, _, N, __ = queries.shape
|
421 |
+
if (vis_gt is not None)&(queries.shape[1] == T):
|
422 |
+
aug_visb = True
|
423 |
+
if aug_visb:
|
424 |
+
number_visible = vis_gt.sum(dim=1)
|
425 |
+
ratio_rand = torch.rand(B, N, device=vis_gt.device)
|
426 |
+
# first_positive_inds = get_nth_visible_time_index(vis_gt, 1)
|
427 |
+
first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
|
428 |
+
|
429 |
+
assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
|
430 |
+
else:
|
431 |
+
__, first_positive_inds = torch.max(vis_gt, dim=1)
|
432 |
+
first_positive_inds = first_positive_inds.long()
|
433 |
+
gather = torch.gather(
|
434 |
+
queries, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
435 |
+
)
|
436 |
+
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
437 |
+
gather_xyz = torch.gather(
|
438 |
+
traj3d_gt, 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 3)
|
439 |
+
)
|
440 |
+
z_gt_query = torch.diagonal(gather_xyz, dim1=1, dim2=2).permute(0, 2, 1)[...,2]
|
441 |
+
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)
|
442 |
+
queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
|
443 |
+
else:
|
444 |
+
# Generate the 768 points randomly in the whole video
|
445 |
+
queries = queries.squeeze(1)
|
446 |
+
ba_len = queries.shape[1]
|
447 |
+
z_gt_query = None
|
448 |
+
if support_pts_q is not None:
|
449 |
+
queries = torch.cat([queries, support_pts_q[:,0]], dim=1)
|
450 |
+
|
451 |
+
if (abs(prec_fx-1.0) > 1e-4) & (self.training) & (traj3d_gt is not None):
|
452 |
+
traj3d_gt[..., 0] /= prec_fx
|
453 |
+
traj3d_gt[..., 1] /= prec_fy
|
454 |
+
queries[...,1] /= prec_fx
|
455 |
+
queries[...,2] /= prec_fy
|
456 |
+
|
457 |
+
video_vis = F.interpolate(video.clone().view(B*T, 3, video.shape[-2], video.shape[-1]), (H_, W_), mode="bilinear", align_corners=False).view(B, T, 3, H_, W_)
|
458 |
+
|
459 |
+
self.image_size = torch.tensor([H_, W_])
|
460 |
+
# self.model_resolution = (H_, W_)
|
461 |
+
# resize the queries and intrs
|
462 |
+
self.factor_x = self.model_resolution[1]/W_
|
463 |
+
self.factor_y = self.model_resolution[0]/H_
|
464 |
+
queries[...,1] *= self.factor_x
|
465 |
+
queries[...,2] *= self.factor_y
|
466 |
+
intrs_org = intrs.clone()
|
467 |
+
intrs[...,0,:] *= self.factor_x
|
468 |
+
intrs[...,1,:] *= self.factor_y
|
469 |
+
|
470 |
+
# get the fmaps and color features
|
471 |
+
video = F.interpolate(video.view(B*T, 3, video.shape[-2], video.shape[-1]),
|
472 |
+
(self.model_resolution[0], self.model_resolution[1])).view(B, T, 3, self.model_resolution[0], self.model_resolution[1])
|
473 |
+
_, _, _, H, W = video.shape
|
474 |
+
if cache is not None:
|
475 |
+
T_cache = cache["fmaps"].shape[0]
|
476 |
+
fmaps = self.extract_img_feat(video[:,T_cache:], fmaps_chunk_size=fmaps_chunk_size)
|
477 |
+
fmaps = torch.cat([cache["fmaps"][None], fmaps], dim=1)
|
478 |
+
else:
|
479 |
+
fmaps = self.extract_img_feat(video, fmaps_chunk_size=fmaps_chunk_size)
|
480 |
+
fmaps_org = fmaps.clone()
|
481 |
+
|
482 |
+
metric_depth = F.interpolate(metric_depth.view(B*T, 1, H_, W_),
|
483 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1]).clamp(0.01, 200)
|
484 |
+
self.metric_unc_org = metric_unc.clone()
|
485 |
+
metric_unc = F.interpolate(metric_unc.view(B*T, 1, H_, W_),
|
486 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 1, self.model_resolution[0], self.model_resolution[1])
|
487 |
+
if (self.stage == 2) & (self.training):
|
488 |
+
scale_rand = (torch.rand(B, T, device=video.device) - 0.5) + 1
|
489 |
+
point_map = scale_rand.view(B*T,1,1,1) * point_map
|
490 |
+
|
491 |
+
point_map_org = point_map.permute(0,3,1,2).view(B*T, 3, H_, W_).clone()
|
492 |
+
point_map = F.interpolate(point_map_org.clone(),
|
493 |
+
(self.model_resolution[0], self.model_resolution[1]),mode="nearest").view(B*T, 3, self.model_resolution[0], self.model_resolution[1])
|
494 |
+
# align the point map
|
495 |
+
point_map_org_train = point_map_org.view(B*T, 3, H_, W_).clone()
|
496 |
+
|
497 |
+
if (stage == 2):
|
498 |
+
# align the point map
|
499 |
+
try:
|
500 |
+
self.pred_points, scale_gt, shift_gt = affine_invariant_global_loss(
|
501 |
+
point_map_org_train.permute(0,2,3,1),
|
502 |
+
points_map_gt,
|
503 |
+
mask=self.metric_unc_org[:,0]>0.5,
|
504 |
+
align_resolution=32,
|
505 |
+
only_align=True
|
506 |
+
)
|
507 |
+
except:
|
508 |
+
scale_gt, shift_gt = torch.ones(B*T).to(video.device), torch.zeros(B*T,3).to(video.device)
|
509 |
+
self.scale_gt, self.shift_gt = scale_gt, shift_gt
|
510 |
+
else:
|
511 |
+
scale_est, shift_est = None, None
|
512 |
+
|
513 |
+
# extract the pts features
|
514 |
+
device = queries.device
|
515 |
+
assert H % self.stride == 0 and W % self.stride == 0
|
516 |
+
|
517 |
+
B, N, __ = queries.shape
|
518 |
+
queries_z = sample_features5d(metric_depth.view(B, T, 1, H, W),
|
519 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
520 |
+
queries_z_unc = sample_features5d(metric_unc.view(B, T, 1, H, W),
|
521 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
522 |
+
|
523 |
+
queries_rgb = sample_features5d(video.view(B, T, C, H, W),
|
524 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
525 |
+
queries_point_map = sample_features5d(point_map.view(B, T, 3, H, W),
|
526 |
+
queries[:,None], interp_mode="nearest").squeeze(1)
|
527 |
+
if ((queries_z > 100)*(queries_z == 0)).sum() > 0:
|
528 |
+
import pdb; pdb.set_trace()
|
529 |
+
|
530 |
+
if overlap_d is not None:
|
531 |
+
queries_z[:,:overlap_d.shape[1],:] = overlap_d[...,None]
|
532 |
+
queries_point_map[:,:overlap_d.shape[1],2:] = overlap_d[...,None]
|
533 |
+
|
534 |
+
if pts_q_3d is not None:
|
535 |
+
scale_factor = (pts_q_3d[...,-1].permute(0,2,1) / queries_z[:,:pts_q_3d.shape[2],:]).squeeze().median()
|
536 |
+
queries_z[:,:pts_q_3d.shape[2],:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
|
537 |
+
queries_point_map[:,:pts_q_3d.shape[2],2:] = pts_q_3d[...,-1].permute(0,2,1) / scale_factor
|
538 |
+
|
539 |
+
# normalize the points
|
540 |
+
self.min_pts, self.max_pts = queries_point_map.mean(dim=(0,1)) - 3*queries_point_map.std(dim=(0,1)), queries_point_map.mean(dim=(0,1)) + 3*queries_point_map.std(dim=(0,1))
|
541 |
+
queries_point_map = self.norm_xyz(queries_point_map)
|
542 |
+
queries_point_map_ = queries_point_map.reshape(B, 1, N, 3).expand(B, T, N, 3).clone()
|
543 |
+
point_map = self.norm_xyz(point_map.view(B, T, 3, H, W)).view(B*T, 3, H, W)
|
544 |
+
|
545 |
+
if z_gt_query is not None:
|
546 |
+
queries_z[:,:z_gt_query.shape[1],:] = z_gt_query[:,:,None]
|
547 |
+
mask_traj_gt = ((queries_z[:,:z_gt_query.shape[1],:] - z_gt_query[:,:,None])).abs() < 0.1
|
548 |
+
else:
|
549 |
+
if traj3d_gt is not None:
|
550 |
+
mask_traj_gt = torch.ones_like(queries_z[:, :traj3d_gt.shape[2]]).bool()
|
551 |
+
else:
|
552 |
+
mask_traj_gt = torch.ones_like(queries_z).bool()
|
553 |
+
|
554 |
+
queries_xyz = torch.cat([queries, queries_z], dim=-1)[:,None].repeat(1, T, 1, 1)
|
555 |
+
if cache is not None:
|
556 |
+
cache_T, cache_N = cache["track2d_pred_cache"].shape[0], cache["track2d_pred_cache"].shape[1]
|
557 |
+
cachexy = cache["track2d_pred_cache"].clone()
|
558 |
+
cachexy[...,0] = cachexy[...,0] * self.factor_x
|
559 |
+
cachexy[...,1] = cachexy[...,1] * self.factor_y
|
560 |
+
# initialize the 2d points with cache
|
561 |
+
queries_xyz[:,:cache_T,:cache_N,1:] = cachexy
|
562 |
+
queries_xyz[:,cache_T:,:cache_N,1:] = cachexy[-1:]
|
563 |
+
# initialize the 3d points with cache
|
564 |
+
queries_point_map_[:,:cache_T,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][None])
|
565 |
+
queries_point_map_[:,cache_T:,:cache_N,:] = self.norm_xyz(cache["track3d_pred_cache"][-1:][None])
|
566 |
+
|
567 |
+
if cam_gt is not None:
|
568 |
+
q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
|
569 |
+
intrs, rgbs=video_vis, visualize=False)
|
570 |
+
q_static_proj[..., 0] /= self.factor_x
|
571 |
+
q_static_proj[..., 1] /= self.factor_y
|
572 |
+
|
573 |
+
|
574 |
+
assert T >= 1 # A tracker needs at least two frames to track something
|
575 |
+
video = 2 * (video / 255.0) - 1.0
|
576 |
+
dtype = video.dtype
|
577 |
+
queried_frames = queries[:, :, 0].long()
|
578 |
+
|
579 |
+
queried_coords = queries[..., 1:3]
|
580 |
+
queried_coords = queried_coords / self.stride
|
581 |
+
|
582 |
+
# We store our predictions here
|
583 |
+
(all_coords_predictions, all_coords_xyz_predictions,all_vis_predictions,
|
584 |
+
all_confidence_predictions, all_cam_predictions, all_dynamic_prob_predictions,
|
585 |
+
all_cam_pts_predictions, all_world_tracks_predictions, all_world_tracks_refined_predictions,
|
586 |
+
all_scale_est, all_shift_est) = (
|
587 |
+
[],
|
588 |
+
[],
|
589 |
+
[],
|
590 |
+
[],
|
591 |
+
[],
|
592 |
+
[],
|
593 |
+
[],
|
594 |
+
[],
|
595 |
+
[],
|
596 |
+
[],
|
597 |
+
[]
|
598 |
+
)
|
599 |
+
|
600 |
+
# We compute track features
|
601 |
+
fmaps_pyramid = []
|
602 |
+
point_map_pyramid = []
|
603 |
+
track_feat_pyramid = []
|
604 |
+
track_feat_support_pyramid = []
|
605 |
+
track_feat3d_pyramid = []
|
606 |
+
track_feat_support3d_pyramid = []
|
607 |
+
track_depth_support_pyramid = []
|
608 |
+
track_point_map_pyramid = []
|
609 |
+
track_point_map_support_pyramid = []
|
610 |
+
fmaps_pyramid.append(fmaps)
|
611 |
+
metric_depth = metric_depth
|
612 |
+
point_map = point_map
|
613 |
+
metric_depth_align = F.interpolate(metric_depth, scale_factor=0.25, mode='nearest')
|
614 |
+
point_map_align = F.interpolate(point_map, scale_factor=0.25, mode='nearest')
|
615 |
+
point_map_pyramid.append(point_map_align.view(B, T, 3, point_map_align.shape[-2], point_map_align.shape[-1]))
|
616 |
+
for i in range(self.corr_levels - 1):
|
617 |
+
fmaps_ = fmaps.reshape(
|
618 |
+
B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
|
619 |
+
)
|
620 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
621 |
+
fmaps = fmaps_.reshape(
|
622 |
+
B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
|
623 |
+
)
|
624 |
+
fmaps_pyramid.append(fmaps)
|
625 |
+
# downsample the depth
|
626 |
+
metric_depth_ = metric_depth_align.reshape(B*T,1,metric_depth_align.shape[-2],metric_depth_align.shape[-1])
|
627 |
+
metric_depth_ = F.interpolate(metric_depth_, scale_factor=0.5, mode='nearest')
|
628 |
+
metric_depth_align = metric_depth_.reshape(B,T,1,metric_depth_.shape[-2], metric_depth_.shape[-1])
|
629 |
+
# downsample the point map
|
630 |
+
point_map_ = point_map_align.reshape(B*T,3,point_map_align.shape[-2],point_map_align.shape[-1])
|
631 |
+
point_map_ = F.interpolate(point_map_, scale_factor=0.5, mode='nearest')
|
632 |
+
point_map_align = point_map_.reshape(B,T,3,point_map_.shape[-2], point_map_.shape[-1])
|
633 |
+
point_map_pyramid.append(point_map_align)
|
634 |
+
|
635 |
+
for i in range(self.corr_levels):
|
636 |
+
if cache is not None:
|
637 |
+
cache_N = cache["track_feat_pyramid"][i].shape[2]
|
638 |
+
track_feat_cached, track_feat_support_cached = cache["track_feat_pyramid"][i], cache["track_feat_support_pyramid"][i]
|
639 |
+
track_feat3d_cached, track_feat_support3d_cached = cache["track_feat3d_pyramid"][i], cache["track_feat_support3d_pyramid"][i]
|
640 |
+
track_point_map_cached, track_point_map_support_cached = self.norm_xyz(cache["track_point_map_pyramid"][i]), self.norm_xyz(cache["track_point_map_support_pyramid"][i])
|
641 |
+
queried_coords_new = queried_coords[:,cache_N:,:] / 2**i
|
642 |
+
queried_frames_new = queried_frames[:,cache_N:]
|
643 |
+
else:
|
644 |
+
queried_coords_new = queried_coords / 2**i
|
645 |
+
queried_frames_new = queried_frames
|
646 |
+
track_feat, track_feat_support = self.get_track_feat(
|
647 |
+
fmaps_pyramid[i],
|
648 |
+
queried_frames_new,
|
649 |
+
queried_coords_new,
|
650 |
+
support_radius=self.corr_radius,
|
651 |
+
)
|
652 |
+
# get 3d track feat
|
653 |
+
track_point_map, track_point_map_support = self.get_track_feat(
|
654 |
+
point_map_pyramid[i],
|
655 |
+
queried_frames_new,
|
656 |
+
queried_coords_new,
|
657 |
+
support_radius=self.corr3d_radius,
|
658 |
+
)
|
659 |
+
track_feat3d, track_feat_support3d = self.get_track_feat(
|
660 |
+
fmaps_pyramid[i],
|
661 |
+
queried_frames_new,
|
662 |
+
queried_coords_new,
|
663 |
+
support_radius=self.corr3d_radius,
|
664 |
+
)
|
665 |
+
if cache is not None:
|
666 |
+
track_feat = torch.cat([track_feat_cached, track_feat], dim=2)
|
667 |
+
track_point_map = torch.cat([track_point_map_cached, track_point_map], dim=2)
|
668 |
+
track_feat_support = torch.cat([track_feat_support_cached[:,0], track_feat_support], dim=2)
|
669 |
+
track_point_map_support = torch.cat([track_point_map_support_cached[:,0], track_point_map_support], dim=2)
|
670 |
+
track_feat3d = torch.cat([track_feat3d_cached, track_feat3d], dim=2)
|
671 |
+
track_feat_support3d = torch.cat([track_feat_support3d_cached[:,0], track_feat_support3d], dim=2)
|
672 |
+
track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
|
673 |
+
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
|
674 |
+
track_feat3d_pyramid.append(track_feat3d.repeat(1, T, 1, 1))
|
675 |
+
track_feat_support3d_pyramid.append(track_feat_support3d.unsqueeze(1))
|
676 |
+
track_point_map_pyramid.append(track_point_map.repeat(1, T, 1, 1))
|
677 |
+
track_point_map_support_pyramid.append(track_point_map_support.unsqueeze(1))
|
678 |
+
|
679 |
+
|
680 |
+
D_coords = 2
|
681 |
+
(coord_preds, coords_xyz_preds, vis_preds, confidence_preds,
|
682 |
+
dynamic_prob_preds, cam_preds, pts3d_cam_pred, world_tracks_pred,
|
683 |
+
world_tracks_refined_pred, point_map_preds, scale_ests, shift_ests) = (
|
684 |
+
[], [], [], [], [], [], [], [], [], [], [], []
|
685 |
+
)
|
686 |
+
|
687 |
+
c2w_ests = []
|
688 |
+
vis = torch.zeros((B, T, N), device=device).float()
|
689 |
+
confidence = torch.zeros((B, T, N), device=device).float()
|
690 |
+
dynamic_prob = torch.zeros((B, T, N), device=device).float()
|
691 |
+
pro_analysis_w = torch.zeros((B, T, N), device=device).float()
|
692 |
+
|
693 |
+
coords = queries_xyz[...,1:].clone()
|
694 |
+
coords[...,:2] /= self.stride
|
695 |
+
# coords[...,:2] = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()[...,:2]
|
696 |
+
# initialize the 3d points
|
697 |
+
coords_xyz = queries_point_map_.clone()
|
698 |
+
|
699 |
+
# if cache is not None:
|
700 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
701 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
702 |
+
# coords_clone = coords.clone()
|
703 |
+
# coords_clone[...,:2] *= self.stride
|
704 |
+
# coords_clone[..., 0] /= self.factor_x
|
705 |
+
# coords_clone[..., 1] /= self.factor_y
|
706 |
+
# viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test")
|
707 |
+
# import pdb; pdb.set_trace()
|
708 |
+
|
709 |
+
if init_pose:
|
710 |
+
q_init_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz, cam_gt,
|
711 |
+
intrs, rgbs=video_vis, visualize=False)
|
712 |
+
q_init_proj[..., 0] /= self.stride
|
713 |
+
q_init_proj[..., 1] /= self.stride
|
714 |
+
|
715 |
+
r = 2 * self.corr_radius + 1
|
716 |
+
r_depth = 2 * self.corr3d_radius + 1
|
717 |
+
anchor_loss = 0
|
718 |
+
# two current states
|
719 |
+
self.c2w_est_curr = torch.eye(4, device=device).repeat(B, T , 1, 1)
|
720 |
+
coords_proj_curr = coords.view(B * T, N, 3)[...,:2]
|
721 |
+
if init_pose:
|
722 |
+
self.c2w_est_curr = cam_gt.to(coords_proj_curr.device).to(coords_proj_curr.dtype)
|
723 |
+
sync_loss = 0
|
724 |
+
if stage == 2:
|
725 |
+
extra_sparse_tokens = self.scale_shift_tokens[:,:,None,:].repeat(B, 1, T, 1)
|
726 |
+
extra_dense_tokens = self.residual_embedding[None,None].repeat(B, T, 1, 1, 1)
|
727 |
+
xyz_pos_enc = posenc(point_map_pyramid[-2].permute(0,1,3,4,2), min_deg=0, max_deg=10).permute(0,1,4,2,3)
|
728 |
+
extra_dense_tokens = torch.cat([xyz_pos_enc, extra_dense_tokens, fmaps_pyramid[-2]], dim=2)
|
729 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, 'b t c h w -> (b t) c h w')
|
730 |
+
extra_dense_tokens = self.dense_mlp(extra_dense_tokens)
|
731 |
+
extra_dense_tokens = rearrange(extra_dense_tokens, '(b t) c h w -> b t c h w', b=B, t=T)
|
732 |
+
else:
|
733 |
+
extra_sparse_tokens = None
|
734 |
+
extra_dense_tokens = None
|
735 |
+
|
736 |
+
scale_est, shift_est = torch.ones(B, T, 1, 1, device=device), torch.zeros(B, T, 1, 3, device=device)
|
737 |
+
residual_point = torch.zeros(B, T, 3, self.model_resolution[0]//self.stride,
|
738 |
+
self.model_resolution[1]//self.stride, device=device)
|
739 |
+
|
740 |
+
for it in range(iters):
|
741 |
+
# query points scale and shift
|
742 |
+
scale_est_query = torch.gather(scale_est, dim=1, index=queries[:,:,None,:1].long())
|
743 |
+
shift_est_query = torch.gather(shift_est, dim=1, index=queries[:,:,None,:1].long().repeat(1, 1, 1, 3))
|
744 |
+
|
745 |
+
coords = coords.detach() # B T N 3
|
746 |
+
coords_xyz = coords_xyz.detach()
|
747 |
+
vis = vis.detach()
|
748 |
+
confidence = confidence.detach()
|
749 |
+
dynamic_prob = dynamic_prob.detach()
|
750 |
+
pro_analysis_w = pro_analysis_w.detach()
|
751 |
+
coords_init = coords.view(B * T, N, 3)
|
752 |
+
coords_xyz_init = coords_xyz.view(B * T, N, 3)
|
753 |
+
corr_embs = []
|
754 |
+
corr_depth_embs = []
|
755 |
+
corr_feats = []
|
756 |
+
for i in range(self.corr_levels):
|
757 |
+
# K_level = int(32*0.8**(i))
|
758 |
+
K_level = 16
|
759 |
+
corr_feat = self.get_correlation_feat(
|
760 |
+
fmaps_pyramid[i], coords_init[...,:2] / 2**i
|
761 |
+
)
|
762 |
+
#NOTE: update the point map
|
763 |
+
residual_point_i = F.interpolate(residual_point.view(B*T,3,residual_point.shape[-2],residual_point.shape[-1]),
|
764 |
+
size=(point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1]), mode='nearest')
|
765 |
+
point_map_pyramid_i = (self.denorm_xyz(point_map_pyramid[i]) * scale_est[...,None]
|
766 |
+
+ shift_est.permute(0,1,3,2)[...,None] + residual_point_i.view(B,T,3,point_map_pyramid[i].shape[-2], point_map_pyramid[i].shape[-1])).clone().detach()
|
767 |
+
|
768 |
+
corr_point_map = self.get_correlation_feat(
|
769 |
+
self.norm_xyz(point_map_pyramid_i), coords_proj_curr / 2**i, radius=self.corr3d_radius
|
770 |
+
)
|
771 |
+
|
772 |
+
corr_point_feat = self.get_correlation_feat(
|
773 |
+
fmaps_pyramid[i], coords_proj_curr / 2**i, radius=self.corr3d_radius
|
774 |
+
)
|
775 |
+
track_feat_support = (
|
776 |
+
track_feat_support_pyramid[i]
|
777 |
+
.view(B, 1, r, r, N, self.latent_dim)
|
778 |
+
.squeeze(1)
|
779 |
+
.permute(0, 3, 1, 2, 4)
|
780 |
+
)
|
781 |
+
track_feat_support3d = (
|
782 |
+
track_feat_support3d_pyramid[i]
|
783 |
+
.view(B, 1, r_depth, r_depth, N, self.latent_dim)
|
784 |
+
.squeeze(1)
|
785 |
+
.permute(0, 3, 1, 2, 4)
|
786 |
+
)
|
787 |
+
#NOTE: update the point map
|
788 |
+
track_point_map_support_pyramid_i = (self.denorm_xyz(track_point_map_support_pyramid[i]) * scale_est_query.view(B,1,1,N,1)
|
789 |
+
+ shift_est_query.view(B,1,1,N,3)).clone().detach()
|
790 |
+
|
791 |
+
track_point_map_support = (
|
792 |
+
self.norm_xyz(track_point_map_support_pyramid_i)
|
793 |
+
.view(B, 1, r_depth, r_depth, N, 3)
|
794 |
+
.squeeze(1)
|
795 |
+
.permute(0, 3, 1, 2, 4)
|
796 |
+
)
|
797 |
+
corr_volume = torch.einsum(
|
798 |
+
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
|
799 |
+
)
|
800 |
+
corr_emb = self.corr_mlp(corr_volume.reshape(B, T, N, r * r * r * r))
|
801 |
+
|
802 |
+
with torch.no_grad():
|
803 |
+
rel_pos_query_ = track_point_map_support - track_point_map_support[:,:,self.corr3d_radius,self.corr3d_radius,:][...,None,None,:]
|
804 |
+
rel_pos_target_ = corr_point_map - coords_xyz_init.view(B, T, N, 1, 1, 3)
|
805 |
+
# select the top 9 points
|
806 |
+
rel_pos_query_idx = rel_pos_query_.norm(dim=-1).view(B, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
|
807 |
+
rel_pos_target_idx = rel_pos_target_.norm(dim=-1).view(B, T, N, -1).topk(K_level+1, dim=-1, largest=False)[1][...,1:,None]
|
808 |
+
rel_pos_query_ = torch.gather(rel_pos_query_.view(B, N, -1, 3), dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 3))
|
809 |
+
rel_pos_target_ = torch.gather(rel_pos_target_.view(B, T, N, -1, 3), dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 3))
|
810 |
+
rel_pos_query = rel_pos_query_
|
811 |
+
rel_pos_target = rel_pos_target_
|
812 |
+
rel_pos_query = posenc(rel_pos_query, min_deg=0, max_deg=12)
|
813 |
+
rel_pos_target = posenc(rel_pos_target, min_deg=0, max_deg=12)
|
814 |
+
rel_pos_target = self.rel_pos_mlp(rel_pos_target)
|
815 |
+
rel_pos_query = self.rel_pos_mlp(rel_pos_query)
|
816 |
+
with torch.no_grad():
|
817 |
+
# integrate with feature
|
818 |
+
track_feat_support_ = rearrange(track_feat_support3d, 'b n r k c -> b n (r k) c', r=r_depth, k=r_depth, n=N, b=B)
|
819 |
+
track_feat_support_ = torch.gather(track_feat_support_, dim=-2, index=rel_pos_query_idx.expand(B, N, K_level, 128))
|
820 |
+
queried_feat = torch.cat([rel_pos_query, track_feat_support_], dim=-1)
|
821 |
+
corr_feat_ = rearrange(corr_point_feat, 'b t n r k c -> b t n (r k) c', t=T, n=N, b=B)
|
822 |
+
corr_feat_ = torch.gather(corr_feat_, dim=-2, index=rel_pos_target_idx.expand(B, T, N, K_level, 128))
|
823 |
+
target_feat = torch.cat([rel_pos_target, corr_feat_], dim=-1)
|
824 |
+
|
825 |
+
# 3d attention
|
826 |
+
queried_feat = self.corr_xyz_mlp(queried_feat)
|
827 |
+
target_feat = self.corr_xyz_mlp(target_feat)
|
828 |
+
queried_feat = repeat(queried_feat, 'b n k c -> b t n k c', k=K_level, t=T, n=N, b=B)
|
829 |
+
corr_depth_emb = self.corr_transformer[0](queried_feat.reshape(B*T*N,-1,128),
|
830 |
+
target_feat.reshape(B*T*N,-1,128),
|
831 |
+
target_rel_pos=rel_pos_target.reshape(B*T*N,-1,128))
|
832 |
+
corr_depth_emb = rearrange(corr_depth_emb, '(b t n) 1 c -> b t n c', t=T, n=N, b=B)
|
833 |
+
corr_depth_emb = self.corr_depth_mlp(corr_depth_emb)
|
834 |
+
valid_mask = self.denorm_xyz(coords_xyz_init).view(B, T, N, -1)[...,2:3] > 0
|
835 |
+
corr_depth_embs.append(corr_depth_emb*valid_mask)
|
836 |
+
|
837 |
+
corr_embs.append(corr_emb)
|
838 |
+
corr_embs = torch.cat(corr_embs, dim=-1)
|
839 |
+
corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
|
840 |
+
corr_depth_embs = torch.cat(corr_depth_embs, dim=-1)
|
841 |
+
corr_depth_embs = corr_depth_embs.view(B, T, N, corr_depth_embs.shape[-1])
|
842 |
+
transformer_input = [vis[..., None], confidence[..., None], corr_embs]
|
843 |
+
transformer_input_depth = [vis[..., None], confidence[..., None], corr_depth_embs]
|
844 |
+
|
845 |
+
rel_coords_forward = coords[:,:-1,...,:2] - coords[:,1:,...,:2]
|
846 |
+
rel_coords_backward = coords[:, 1:,...,:2] - coords[:, :-1,...,:2]
|
847 |
+
|
848 |
+
rel_xyz_forward = coords_xyz[:,:-1,...,:3] - coords_xyz[:,1:,...,:3]
|
849 |
+
rel_xyz_backward = coords_xyz[:, 1:,...,:3] - coords_xyz[:, :-1,...,:3]
|
850 |
+
|
851 |
+
rel_coords_forward = torch.nn.functional.pad(
|
852 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1)
|
853 |
+
)
|
854 |
+
rel_coords_backward = torch.nn.functional.pad(
|
855 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0)
|
856 |
+
)
|
857 |
+
rel_xyz_forward = torch.nn.functional.pad(
|
858 |
+
rel_xyz_forward, (0, 0, 0, 0, 0, 1)
|
859 |
+
)
|
860 |
+
rel_xyz_backward = torch.nn.functional.pad(
|
861 |
+
rel_xyz_backward, (0, 0, 0, 0, 1, 0)
|
862 |
+
)
|
863 |
+
|
864 |
+
scale = (
|
865 |
+
torch.tensor(
|
866 |
+
[self.model_resolution[1], self.model_resolution[0]],
|
867 |
+
device=coords.device,
|
868 |
+
)
|
869 |
+
/ self.stride
|
870 |
+
)
|
871 |
+
rel_coords_forward = rel_coords_forward / scale
|
872 |
+
rel_coords_backward = rel_coords_backward / scale
|
873 |
+
|
874 |
+
rel_pos_emb_input = posenc(
|
875 |
+
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
|
876 |
+
min_deg=0,
|
877 |
+
max_deg=10,
|
878 |
+
) # batch, num_points, num_frames, 84
|
879 |
+
rel_xyz_emb_input = posenc(
|
880 |
+
torch.cat([rel_xyz_forward, rel_xyz_backward], dim=-1),
|
881 |
+
min_deg=0,
|
882 |
+
max_deg=10,
|
883 |
+
) # batch, num_points, num_frames, 126
|
884 |
+
rel_xyz_emb_input = self.xyz_mlp(rel_xyz_emb_input)
|
885 |
+
transformer_input.append(rel_pos_emb_input)
|
886 |
+
transformer_input_depth.append(rel_xyz_emb_input)
|
887 |
+
# get the queries world
|
888 |
+
with torch.no_grad():
|
889 |
+
# update the query points with scale and shift
|
890 |
+
queries_xyz_i = queries_xyz.clone().detach()
|
891 |
+
queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
|
892 |
+
_, _, q_xyz_cam = self.track_from_cam(queries_xyz_i, self.c2w_est_curr,
|
893 |
+
intrs, rgbs=None, visualize=False)
|
894 |
+
q_xyz_cam = self.norm_xyz(q_xyz_cam)
|
895 |
+
|
896 |
+
query_t = queries[:,None,:,:1].repeat(B, T, 1, 1)
|
897 |
+
q_xyz_cam = torch.cat([query_t/T, q_xyz_cam], dim=-1)
|
898 |
+
T_all = torch.arange(T, device=device)[None,:,None,None].repeat(B, 1, N, 1)
|
899 |
+
current_xyzt = torch.cat([T_all/T, coords_xyz_init.view(B, T, N, -1)], dim=-1)
|
900 |
+
rel_pos_query_glob = q_xyz_cam - current_xyzt
|
901 |
+
# embed the confidence and dynamic probability
|
902 |
+
confidence_curr = torch.sigmoid(confidence[...,None])
|
903 |
+
dynamic_prob_curr = torch.sigmoid(dynamic_prob[...,None]).mean(dim=1, keepdim=True).repeat(1,T,1,1)
|
904 |
+
# embed the confidence and dynamic probability
|
905 |
+
rel_pos_query_glob = torch.cat([rel_pos_query_glob, confidence_curr, dynamic_prob_curr], dim=-1)
|
906 |
+
rel_pos_query_glob = posenc(rel_pos_query_glob, min_deg=0, max_deg=12)
|
907 |
+
transformer_input_depth.append(rel_pos_query_glob)
|
908 |
+
|
909 |
+
x = (
|
910 |
+
torch.cat(transformer_input, dim=-1)
|
911 |
+
.permute(0, 2, 1, 3)
|
912 |
+
.reshape(B * N, T, -1)
|
913 |
+
)
|
914 |
+
x_depth = (
|
915 |
+
torch.cat(transformer_input_depth, dim=-1)
|
916 |
+
.permute(0, 2, 1, 3)
|
917 |
+
.reshape(B * N, T, -1)
|
918 |
+
)
|
919 |
+
x_depth = self.proj_xyz_embed(x_depth)
|
920 |
+
|
921 |
+
x = x + self.interpolate_time_embed(x, T)
|
922 |
+
x = x.view(B, N, T, -1) # (B N) T D -> B N T D
|
923 |
+
x_depth = x_depth + self.interpolate_time_embed(x_depth, T)
|
924 |
+
x_depth = x_depth.view(B, N, T, -1) # (B N) T D -> B N T D
|
925 |
+
delta, delta_depth, delta_dynamic_prob, delta_pro_analysis_w, scale_shift_out, dense_res_out = self.updateformer3D(
|
926 |
+
x,
|
927 |
+
x_depth,
|
928 |
+
self.updateformer,
|
929 |
+
add_space_attn=add_space_attn,
|
930 |
+
extra_sparse_tokens=extra_sparse_tokens,
|
931 |
+
extra_dense_tokens=extra_dense_tokens,
|
932 |
+
)
|
933 |
+
# update the scale and shift
|
934 |
+
if scale_shift_out is not None:
|
935 |
+
extra_sparse_tokens = extra_sparse_tokens + scale_shift_out[...,:128]
|
936 |
+
scale_update = scale_shift_out[:,:1,:,-1].permute(0,2,1)[...,None]
|
937 |
+
shift_update = scale_shift_out[:,1:,:,-1].permute(0,2,1)[...,None]
|
938 |
+
scale_est = scale_est + scale_update
|
939 |
+
shift_est[...,2:] = shift_est[...,2:] + shift_update / 10
|
940 |
+
# dense tokens update
|
941 |
+
extra_dense_tokens = extra_dense_tokens + dense_res_out[:,:,-128:]
|
942 |
+
res_low = dense_res_out[:,:,:3]
|
943 |
+
up_mask = self.upsample_transformer(extra_dense_tokens.mean(dim=1), res_low)
|
944 |
+
up_mask = repeat(up_mask, "b k h w -> b s k h w", s=T)
|
945 |
+
up_mask = rearrange(up_mask, "b s c h w -> (b s) 1 c h w")
|
946 |
+
res_up = self.upsample_with_mask(
|
947 |
+
rearrange(res_low, 'b t c h w -> (b t) c h w'),
|
948 |
+
up_mask,
|
949 |
+
)
|
950 |
+
res_up = rearrange(res_up, "(b t) c h w -> b t c h w", b=B, t=T)
|
951 |
+
# residual_point = residual_point + res_up
|
952 |
+
|
953 |
+
delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
|
954 |
+
delta_vis = delta[..., D_coords].permute(0, 2, 1)
|
955 |
+
delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
|
956 |
+
|
957 |
+
vis = vis + delta_vis
|
958 |
+
confidence = confidence + delta_confidence
|
959 |
+
dynamic_prob = dynamic_prob + delta_dynamic_prob[...,0].permute(0, 2, 1)
|
960 |
+
pro_analysis_w = pro_analysis_w + delta_pro_analysis_w[...,0].permute(0, 2, 1)
|
961 |
+
# update the depth
|
962 |
+
vis_est = torch.sigmoid(vis.detach())
|
963 |
+
|
964 |
+
delta_xyz = delta_depth[...,:3].permute(0,2,1,3)
|
965 |
+
denorm_delta_depth = (self.denorm_xyz(coords_xyz+delta_xyz)-self.denorm_xyz(coords_xyz))[...,2:3]
|
966 |
+
|
967 |
+
|
968 |
+
delta_depth_ = denorm_delta_depth.detach()
|
969 |
+
delta_coords = torch.cat([delta_coords, delta_depth_],dim=-1)
|
970 |
+
coords = coords + delta_coords
|
971 |
+
coords_append = coords.clone()
|
972 |
+
coords_xyz_append = self.denorm_xyz(coords_xyz + delta_xyz).clone()
|
973 |
+
|
974 |
+
coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
|
975 |
+
coords_append[..., 0] /= self.factor_x
|
976 |
+
coords_append[..., 1] /= self.factor_y
|
977 |
+
|
978 |
+
# get the camera pose from tracks
|
979 |
+
dynamic_prob_curr = torch.sigmoid(dynamic_prob.detach())*torch.sigmoid(pro_analysis_w)
|
980 |
+
mask_out = (coords_append[...,0]<W_)&(coords_append[...,0]>0)&(coords_append[...,1]<H_)&(coords_append[...,1]>0)
|
981 |
+
if query_no_BA:
|
982 |
+
dynamic_prob_curr[:,:,:ba_len] = torch.ones_like(dynamic_prob_curr[:,:,:ba_len])
|
983 |
+
point_map_org_i = scale_est.view(B*T,1,1,1)*point_map_org.clone().detach() + shift_est.view(B*T,3,1,1)
|
984 |
+
# depth_unproj = bilinear_sampler(point_map_org_i, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,2,:,0].detach()
|
985 |
+
|
986 |
+
depth_unproj_neg = self.get_correlation_feat(
|
987 |
+
point_map_org_i.view(B,T,3,point_map_org_i.shape[-2], point_map_org_i.shape[-1]),
|
988 |
+
coords_append[...,:2].view(B*T, N, 2), radius=self.corr3d_radius
|
989 |
+
)[..., 2]
|
990 |
+
depth_diff = (depth_unproj_neg.view(B,T,N,-1) - coords_append[...,2:]).abs()
|
991 |
+
idx_neg = torch.argmin(depth_diff, dim=-1)
|
992 |
+
depth_unproj = depth_unproj_neg.view(B,T,N,-1)[torch.arange(B)[:, None, None, None],
|
993 |
+
torch.arange(T)[None, :, None, None],
|
994 |
+
torch.arange(N)[None, None, :, None],
|
995 |
+
idx_neg.view(B,T,N,1)].view(B*T, N)
|
996 |
+
|
997 |
+
unc_unproj = bilinear_sampler(self.metric_unc_org, coords_append[...,:2].view(B*T, N, 1, 2), mode="nearest")[:,0,:,0].detach()
|
998 |
+
depth_unproj[unc_unproj<0.5] = 0.0
|
999 |
+
|
1000 |
+
# replace the depth for visible and solid points
|
1001 |
+
conf_est = torch.sigmoid(confidence.detach())
|
1002 |
+
replace_mask = (depth_unproj.view(B,T,N)>0.0) * (vis_est>0.5) # * (conf_est>0.5)
|
1003 |
+
#NOTE: way1: find the jitter points
|
1004 |
+
depth_rel = (depth_unproj.view(B, T, N) - queries_z.permute(0, 2, 1))
|
1005 |
+
depth_ddt1 = depth_rel[:, 1:, :] - depth_rel[:, :-1, :]
|
1006 |
+
depth_ddt2 = depth_rel[:, 2:, :] - 2 * depth_rel[:, 1:-1, :] + depth_rel[:, :-2, :]
|
1007 |
+
jitter_mask = torch.zeros_like(depth_rel, dtype=torch.bool)
|
1008 |
+
if depth_ddt2.abs().max()>0:
|
1009 |
+
thre2 = torch.quantile(depth_ddt2.abs()[depth_ddt2.abs()>0], replace_ratio)
|
1010 |
+
jitter_mask[:, 1:-1, :] = (depth_ddt2.abs() < thre2)
|
1011 |
+
thre1 = torch.quantile(depth_ddt1.abs()[depth_ddt1.abs()>0], replace_ratio)
|
1012 |
+
jitter_mask[:, :-1, :] *= (depth_ddt1.abs() < thre1)
|
1013 |
+
replace_mask = replace_mask * jitter_mask
|
1014 |
+
|
1015 |
+
#NOTE: way2: top k topological change detection
|
1016 |
+
# coords_2d_lift = coords_append.clone()
|
1017 |
+
# coords_2d_lift[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
1018 |
+
# coords_2d_lift = self.cam_from_track(coords_2d_lift.clone(), intrs_org, only_cam_pts=True)
|
1019 |
+
# coords_2d_lift[~replace_mask] = coords_xyz_append[~replace_mask]
|
1020 |
+
# import pdb; pdb.set_trace()
|
1021 |
+
# jitter_mask = get_topo_mask(coords_xyz_append, coords_2d_lift, replace_ratio)
|
1022 |
+
# replace_mask = replace_mask * jitter_mask
|
1023 |
+
|
1024 |
+
# replace the depth
|
1025 |
+
if self.training:
|
1026 |
+
replace_mask = torch.zeros_like(replace_mask)
|
1027 |
+
coords_append[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
1028 |
+
coords_xyz_unproj = self.cam_from_track(coords_append.clone(), intrs_org, only_cam_pts=True)
|
1029 |
+
coords[...,2][replace_mask] = depth_unproj.view(B,T,N)[replace_mask]
|
1030 |
+
# coords_xyz_append[replace_mask] = coords_xyz_unproj[replace_mask]
|
1031 |
+
coords_xyz_append_refine = coords_xyz_append.clone()
|
1032 |
+
coords_xyz_append_refine[replace_mask] = coords_xyz_unproj[replace_mask]
|
1033 |
+
|
1034 |
+
c2w_traj_est, cam_pts_est, intrs_refine, coords_refine, world_tracks, world_tracks_refined, c2w_traj_init = self.cam_from_track(coords_append.clone(),
|
1035 |
+
intrs_org, dynamic_prob_curr, queries_z_unc, conf_est*vis_est*mask_out.float(),
|
1036 |
+
track_feat_concat=x_depth, tracks_xyz=coords_xyz_append_refine, init_pose=init_pose,
|
1037 |
+
query_pts=queries_xyz_i, fixed_cam=fixed_cam, depth_unproj=depth_unproj, cam_gt=cam_gt)
|
1038 |
+
intrs_org = intrs_refine.view(B, T, 3, 3).to(intrs_org.dtype)
|
1039 |
+
|
1040 |
+
# get the queries world
|
1041 |
+
self.c2w_est_curr = c2w_traj_est.detach()
|
1042 |
+
|
1043 |
+
# update coords and coords_append
|
1044 |
+
coords[..., 2] = (cam_pts_est)[...,2]
|
1045 |
+
coords_append[..., 2] = (cam_pts_est)[...,2]
|
1046 |
+
|
1047 |
+
# update coords_xyz_append
|
1048 |
+
# coords_xyz_append = cam_pts_est
|
1049 |
+
coords_xyz = self.norm_xyz(cam_pts_est)
|
1050 |
+
|
1051 |
+
|
1052 |
+
# proj
|
1053 |
+
coords_xyz_de = coords_xyz_append.clone()
|
1054 |
+
coords_xyz_de[coords_xyz_de[...,2].abs()<1e-6] = -1e-4
|
1055 |
+
mask_nan = coords_xyz_de[...,2].abs()<1e-2
|
1056 |
+
coords_proj = torch.einsum("btij,btnj->btni", intrs_org, coords_xyz_de/coords_xyz_de[...,2:3].abs())[...,:2]
|
1057 |
+
coords_proj[...,0] *= self.factor_x
|
1058 |
+
coords_proj[...,1] *= self.factor_y
|
1059 |
+
coords_proj[...,:2] /= float(self.stride)
|
1060 |
+
# make sure it is aligned with 2d tracking
|
1061 |
+
coords_proj_curr = coords[...,:2].view(B*T, N, 2).detach()
|
1062 |
+
vis_est = (vis_est>0.5).float()
|
1063 |
+
sync_loss += (vis_est.detach()[...,None]*(coords_proj_curr - coords_proj).norm(dim=-1, keepdim=True)*(1-mask_nan[...,None].float())).mean()
|
1064 |
+
# coords_proj_curr[~mask_nan.view(B*T, N)] = coords_proj.view(B*T, N, 2)[~mask_nan.view(B*T, N)].to(coords_proj_curr.dtype)
|
1065 |
+
# if torch.isnan(coords_proj_curr).sum()>0:
|
1066 |
+
# import pdb; pdb.set_trace()
|
1067 |
+
|
1068 |
+
if False:
|
1069 |
+
point_map_resize = point_map.clone().view(B, T, 3, H, W)
|
1070 |
+
update_input = torch.cat([point_map_resize, metric_unc.view(B,T,1,H,W)], dim=2)
|
1071 |
+
coords_append_resize = coords.clone().detach()
|
1072 |
+
coords_append_resize[..., :2] = coords_append_resize[..., :2] * float(self.stride)
|
1073 |
+
update_track_input = self.norm_xyz(cam_pts_est)*5
|
1074 |
+
update_track_input = torch.cat([update_track_input, vis_est[...,None]], dim=-1)
|
1075 |
+
update_track_input = posenc(update_track_input, min_deg=0, max_deg=12)
|
1076 |
+
update = self.update_pointmap.stablizer(update_input,
|
1077 |
+
update_track_input, coords_append_resize)#, imgs=video, vis_track=viser)
|
1078 |
+
#NOTE: update the point map
|
1079 |
+
point_map_resize += update
|
1080 |
+
point_map_refine_out = F.interpolate(point_map_resize.view(B*T, -1, H, W),
|
1081 |
+
size=(self.image_size[0].item(), self.image_size[1].item()), mode='nearest')
|
1082 |
+
point_map_refine_out = rearrange(point_map_refine_out, '(b t) c h w -> b t c h w', t=T, b=B)
|
1083 |
+
point_map_preds.append(self.denorm_xyz(point_map_refine_out))
|
1084 |
+
point_map_org = self.denorm_xyz(point_map_refine_out).view(B*T, 3, H_, W_)
|
1085 |
+
|
1086 |
+
# if torch.isnan(coords).sum()>0:
|
1087 |
+
# import pdb; pdb.set_trace()
|
1088 |
+
#NOTE: the 2d tracking + unproject depth
|
1089 |
+
fix_cam_est = coords_append.clone()
|
1090 |
+
fix_cam_est[...,2] = depth_unproj
|
1091 |
+
fix_cam_pts = self.cam_from_track(
|
1092 |
+
fix_cam_est, intrs_org, only_cam_pts=True
|
1093 |
+
)
|
1094 |
+
|
1095 |
+
coord_preds.append(coords_append)
|
1096 |
+
coords_xyz_preds.append(coords_xyz_append)
|
1097 |
+
vis_preds.append(vis)
|
1098 |
+
cam_preds.append(c2w_traj_init)
|
1099 |
+
pts3d_cam_pred.append(cam_pts_est)
|
1100 |
+
world_tracks_pred.append(world_tracks)
|
1101 |
+
world_tracks_refined_pred.append(world_tracks_refined)
|
1102 |
+
confidence_preds.append(confidence)
|
1103 |
+
dynamic_prob_preds.append(dynamic_prob)
|
1104 |
+
scale_ests.append(scale_est)
|
1105 |
+
shift_ests.append(shift_est)
|
1106 |
+
|
1107 |
+
if stage!=0:
|
1108 |
+
all_coords_predictions.append([coord for coord in coord_preds])
|
1109 |
+
all_coords_xyz_predictions.append([coord_xyz for coord_xyz in coords_xyz_preds])
|
1110 |
+
all_vis_predictions.append(vis_preds)
|
1111 |
+
all_confidence_predictions.append(confidence_preds)
|
1112 |
+
all_dynamic_prob_predictions.append(dynamic_prob_preds)
|
1113 |
+
all_cam_predictions.append([cam for cam in cam_preds])
|
1114 |
+
all_cam_pts_predictions.append([pts for pts in pts3d_cam_pred])
|
1115 |
+
all_world_tracks_predictions.append([world_tracks for world_tracks in world_tracks_pred])
|
1116 |
+
all_world_tracks_refined_predictions.append([world_tracks_refined for world_tracks_refined in world_tracks_refined_pred])
|
1117 |
+
all_scale_est.append(scale_ests)
|
1118 |
+
all_shift_est.append(shift_ests)
|
1119 |
+
if stage!=0:
|
1120 |
+
train_data = (
|
1121 |
+
all_coords_predictions,
|
1122 |
+
all_coords_xyz_predictions,
|
1123 |
+
all_vis_predictions,
|
1124 |
+
all_confidence_predictions,
|
1125 |
+
all_dynamic_prob_predictions,
|
1126 |
+
all_cam_predictions,
|
1127 |
+
all_cam_pts_predictions,
|
1128 |
+
all_world_tracks_predictions,
|
1129 |
+
all_world_tracks_refined_predictions,
|
1130 |
+
all_scale_est,
|
1131 |
+
all_shift_est,
|
1132 |
+
torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
|
1133 |
+
)
|
1134 |
+
else:
|
1135 |
+
train_data = None
|
1136 |
+
# resize back
|
1137 |
+
# init the trajectories by camera motion
|
1138 |
+
|
1139 |
+
# if cache is not None:
|
1140 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
1141 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
1142 |
+
# coords_clone = coords.clone()
|
1143 |
+
# coords_clone[...,:2] *= self.stride
|
1144 |
+
# coords_clone[..., 0] /= self.factor_x
|
1145 |
+
# coords_clone[..., 1] /= self.factor_y
|
1146 |
+
# viser.visualize(video=video_vis, tracks=coords_clone[..., :2], filename="test_refine")
|
1147 |
+
# import pdb; pdb.set_trace()
|
1148 |
+
|
1149 |
+
if train_data is not None:
|
1150 |
+
# get the gt pts in the world coordinate
|
1151 |
+
self_supervised = False
|
1152 |
+
if (traj3d_gt is not None):
|
1153 |
+
if traj3d_gt[...,2].abs().max()>0:
|
1154 |
+
gt_cam_pts = self.cam_from_track(
|
1155 |
+
traj3d_gt, intrs_org, only_cam_pts=True
|
1156 |
+
)
|
1157 |
+
else:
|
1158 |
+
self_supervised = True
|
1159 |
+
else:
|
1160 |
+
self_supervised = True
|
1161 |
+
|
1162 |
+
if self_supervised:
|
1163 |
+
gt_cam_pts = self.cam_from_track(
|
1164 |
+
coord_preds[-1].detach(), intrs_org, only_cam_pts=True
|
1165 |
+
)
|
1166 |
+
|
1167 |
+
if cam_gt is not None:
|
1168 |
+
gt_world_pts = torch.einsum(
|
1169 |
+
"btij,btnj->btni",
|
1170 |
+
cam_gt[...,:3,:3],
|
1171 |
+
gt_cam_pts
|
1172 |
+
) + cam_gt[...,None, :3,3] # B T N 3
|
1173 |
+
else:
|
1174 |
+
gt_world_pts = torch.einsum(
|
1175 |
+
"btij,btnj->btni",
|
1176 |
+
self.c2w_est_curr[...,:3,:3],
|
1177 |
+
gt_cam_pts
|
1178 |
+
) + self.c2w_est_curr[...,None, :3,3] # B T N 3
|
1179 |
+
# update the query points with scale and shift
|
1180 |
+
queries_xyz_i = queries_xyz.clone().detach()
|
1181 |
+
queries_xyz_i[..., -1] = queries_xyz_i[..., -1] * scale_est_query.view(B,1,N) + shift_est_query.view(B,1,N,3)[...,2]
|
1182 |
+
q_static_proj, q_xyz_world, q_xyz_cam = self.track_from_cam(queries_xyz_i,
|
1183 |
+
self.c2w_est_curr,
|
1184 |
+
intrs, rgbs=video_vis, visualize=False)
|
1185 |
+
|
1186 |
+
q_static_proj[..., 0] /= self.factor_x
|
1187 |
+
q_static_proj[..., 1] /= self.factor_y
|
1188 |
+
cam_gt = self.c2w_est_curr[:,:,:3,:]
|
1189 |
+
|
1190 |
+
if traj3d_gt is not None:
|
1191 |
+
ret_loss = self.loss(train_data, traj3d_gt,
|
1192 |
+
vis_gt, None, cam_gt, queries_z_unc,
|
1193 |
+
q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
|
1194 |
+
gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
|
1195 |
+
c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
|
1196 |
+
shift_est=shift_est, point_map_org_train=point_map_org_train)
|
1197 |
+
else:
|
1198 |
+
ret_loss = self.loss(train_data, traj3d_gt,
|
1199 |
+
vis_gt, None, cam_gt, queries_z_unc,
|
1200 |
+
q_xyz_world, q_static_proj, anchor_loss=anchor_loss, fix_cam_pts=fix_cam_pts, video_vis=video_vis, stage=stage,
|
1201 |
+
gt_world_pts=gt_world_pts, mask_traj_gt=mask_traj_gt, intrs=intrs_org, custom_vid=custom_vid, valid_only=valid_only,
|
1202 |
+
c2w_ests=c2w_ests, point_map_preds=point_map_preds, points_map_gt=points_map_gt, metric_unc=metric_unc, scale_est=scale_est,
|
1203 |
+
shift_est=shift_est, point_map_org_train=point_map_org_train)
|
1204 |
+
if custom_vid:
|
1205 |
+
sync_loss = 0*sync_loss
|
1206 |
+
if (sync_loss > 50) and (stage==1):
|
1207 |
+
ret_loss = (0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss, 0*sync_loss) + (0*sync_loss,)
|
1208 |
+
else:
|
1209 |
+
ret_loss = ret_loss+(10*sync_loss,)
|
1210 |
+
|
1211 |
+
else:
|
1212 |
+
ret_loss = None
|
1213 |
+
|
1214 |
+
color_pts = torch.cat([pts3d_cam_pred[-1], queries_rgb[:,None].repeat(1, T, 1, 1)], dim=-1)
|
1215 |
+
|
1216 |
+
#TODO: For evaluation. We found our model have some bias on invisible points after training. (to be fixed)
|
1217 |
+
vis_pred_out = torch.sigmoid(vis_preds[-1]) + 0.2
|
1218 |
+
|
1219 |
+
ret = {"preds": coord_preds[-1], "vis_pred": vis_pred_out,
|
1220 |
+
"conf_pred": torch.sigmoid(confidence_preds[-1]),
|
1221 |
+
"cam_pred": self.c2w_est_curr,"loss": ret_loss}
|
1222 |
+
|
1223 |
+
cache = {
|
1224 |
+
"fmaps": fmaps_org[0].detach(),
|
1225 |
+
"track_feat_support3d_pyramid": [track_feat_support3d_pyramid[i].detach() for i in range(len(track_feat_support3d_pyramid))],
|
1226 |
+
"track_point_map_support_pyramid": [self.denorm_xyz(track_point_map_support_pyramid[i].detach()) for i in range(len(track_point_map_support_pyramid))],
|
1227 |
+
"track_feat3d_pyramid": [track_feat3d_pyramid[i].detach() for i in range(len(track_feat3d_pyramid))],
|
1228 |
+
"track_point_map_pyramid": [self.denorm_xyz(track_point_map_pyramid[i].detach()) for i in range(len(track_point_map_pyramid))],
|
1229 |
+
"track_feat_pyramid": [track_feat_pyramid[i].detach() for i in range(len(track_feat_pyramid))],
|
1230 |
+
"track_feat_support_pyramid": [track_feat_support_pyramid[i].detach() for i in range(len(track_feat_support_pyramid))],
|
1231 |
+
"track2d_pred_cache": coord_preds[-1][0].clone().detach(),
|
1232 |
+
"track3d_pred_cache": pts3d_cam_pred[-1][0].clone().detach(),
|
1233 |
+
}
|
1234 |
+
#NOTE: update the point map
|
1235 |
+
point_map_org = scale_est.view(B*T,1,1,1)*point_map_org + shift_est.view(B*T,3,1,1)
|
1236 |
+
point_map_org_refined = point_map_org
|
1237 |
+
return ret, torch.sigmoid(dynamic_prob_preds[-1])*queries_z_unc[:,None,:,0], coord_preds[-1], color_pts, intrs_org, point_map_org_refined, cache
|
1238 |
+
|
1239 |
+
def track_d2_loss(self, tracks3d, stride=[1,2,3], dyn_prob=None, mask=None):
|
1240 |
+
"""
|
1241 |
+
tracks3d: B T N 3
|
1242 |
+
dyn_prob: B T N 1
|
1243 |
+
"""
|
1244 |
+
r = 0.8
|
1245 |
+
t_diff_total = 0.0
|
1246 |
+
for i, s_ in enumerate(stride):
|
1247 |
+
w_ = r**i
|
1248 |
+
tracks3d_stride = tracks3d[:, ::s_, :, :] # B T//s_ N 3
|
1249 |
+
t_diff_tracks3d = (tracks3d_stride[:, 1:, :, :] - tracks3d_stride[:, :-1, :, :])
|
1250 |
+
t_diff2 = (t_diff_tracks3d[:, 1:, :, :] - t_diff_tracks3d[:, :-1, :, :])
|
1251 |
+
t_diff_total += w_*(t_diff2.norm(dim=-1).mean())
|
1252 |
+
|
1253 |
+
return 1e2*t_diff_total
|
1254 |
+
|
1255 |
+
def loss(self, train_data, traj3d_gt=None,
|
1256 |
+
vis_gt=None, static_tracks_gt=None, cam_gt=None,
|
1257 |
+
z_unc=None, q_xyz_world=None, q_static_proj=None, anchor_loss=0, valid_only=False,
|
1258 |
+
gt_world_pts=None, mask_traj_gt=None, intrs=None, c2w_ests=None, custom_vid=False, video_vis=None, stage=0,
|
1259 |
+
fix_cam_pts=None, point_map_preds=None, points_map_gt=None, metric_unc=None, scale_est=None, shift_est=None, point_map_org_train=None):
|
1260 |
+
"""
|
1261 |
+
Compute the loss of 3D tracking problem
|
1262 |
+
|
1263 |
+
"""
|
1264 |
+
|
1265 |
+
(
|
1266 |
+
coord_predictions, coords_xyz_predictions, vis_predictions, confidence_predicitons,
|
1267 |
+
dynamic_prob_predictions, camera_predictions, cam_pts_predictions, world_tracks_predictions,
|
1268 |
+
world_tracks_refined_predictions, scale_ests, shift_ests, valid_mask
|
1269 |
+
) = train_data
|
1270 |
+
B, T, _, _ = cam_gt.shape
|
1271 |
+
if (stage == 2) and self.training:
|
1272 |
+
# get the scale and shift gt
|
1273 |
+
self.metric_unc_org[:,0] = self.metric_unc_org[:,0] * (points_map_gt.norm(dim=-1)>0).float() * (self.metric_unc_org[:,0]>0.5).float()
|
1274 |
+
if not (self.scale_gt==torch.ones(B*T).to(self.scale_gt.device)).all():
|
1275 |
+
scale_gt, shift_gt = self.scale_gt, self.shift_gt
|
1276 |
+
scale_re = scale_gt[:4].mean()
|
1277 |
+
scale_loss = 0.0
|
1278 |
+
shift_loss = 0.0
|
1279 |
+
for i_scale in range(len(scale_ests[0])):
|
1280 |
+
scale_loss += 0.8**(len(scale_ests[0])-i_scale-1)*10*(scale_gt - scale_re*scale_ests[0][i_scale].view(-1)).abs().mean()
|
1281 |
+
shift_loss += 0.8**(len(shift_ests[0])-i_scale-1)*10*(shift_gt - scale_re*shift_ests[0][i_scale].view(-1,3)).abs().mean()
|
1282 |
+
else:
|
1283 |
+
scale_loss = 0.0 * scale_ests[0][0].mean()
|
1284 |
+
shift_loss = 0.0 * shift_ests[0][0].mean()
|
1285 |
+
scale_re = 1.0
|
1286 |
+
else:
|
1287 |
+
scale_loss = 0.0
|
1288 |
+
shift_loss = 0.0
|
1289 |
+
|
1290 |
+
if len(point_map_preds)>0:
|
1291 |
+
point_map_loss = 0.0
|
1292 |
+
for i in range(len(point_map_preds)):
|
1293 |
+
point_map_preds_i = point_map_preds[i]
|
1294 |
+
point_map_preds_i = rearrange(point_map_preds_i, 'b t c h w -> (b t) c h w', b=B, t=T)
|
1295 |
+
base_loss = ((self.pred_points - points_map_gt).norm(dim=-1) * self.metric_unc_org[:,0]).mean()
|
1296 |
+
point_map_loss_i = ((point_map_preds_i - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
|
1297 |
+
point_map_loss += point_map_loss_i
|
1298 |
+
# point_map_loss += ((point_map_org_train - points_map_gt.permute(0,3,1,2)).norm(dim=1) * self.metric_unc_org[:,0]).mean()
|
1299 |
+
if scale_loss == 0.0:
|
1300 |
+
point_map_loss = 0*point_map_preds_i.sum()
|
1301 |
+
else:
|
1302 |
+
point_map_loss = 0.0
|
1303 |
+
|
1304 |
+
# camera loss
|
1305 |
+
cam_loss = 0.0
|
1306 |
+
dyn_loss = 0.0
|
1307 |
+
N_gt = gt_world_pts.shape[2]
|
1308 |
+
|
1309 |
+
# self supervised dynamic mask
|
1310 |
+
H_org, W_org = self.image_size[0], self.image_size[1]
|
1311 |
+
q_static_proj[torch.isnan(q_static_proj)] = -200
|
1312 |
+
in_view_mask = (q_static_proj[...,0]>0) & (q_static_proj[...,0]<W_org) & (q_static_proj[...,1]>0) & (q_static_proj[...,1]<H_org)
|
1313 |
+
dyn_mask_final = (((coord_predictions[0][-1] - q_static_proj))[...,:2].norm(dim=-1) * in_view_mask)
|
1314 |
+
dyn_mask_final = dyn_mask_final.sum(dim=1) / (in_view_mask.sum(dim=1) + 1e-2)
|
1315 |
+
dyn_mask_final = dyn_mask_final > 6
|
1316 |
+
|
1317 |
+
for iter_, cam_pred_i in enumerate(camera_predictions[0]):
|
1318 |
+
# points loss
|
1319 |
+
pts_i_world = world_tracks_predictions[0][iter_].view(B, T, -1, 3)
|
1320 |
+
|
1321 |
+
coords_xyz_i_world = coords_xyz_predictions[0][iter_].view(B, T, -1, 3)
|
1322 |
+
coords_i = coord_predictions[0][iter_].view(B, T, -1, 3)[..., :2]
|
1323 |
+
pts_i_world_refined = torch.einsum(
|
1324 |
+
"btij,btnj->btni",
|
1325 |
+
cam_gt[...,:3,:3],
|
1326 |
+
coords_xyz_i_world
|
1327 |
+
) + cam_gt[...,None, :3,3] # B T N 3
|
1328 |
+
|
1329 |
+
# pts_i_world_refined = world_tracks_refined_predictions[0][iter_].view(B, T, -1, 3)
|
1330 |
+
pts_world = pts_i_world
|
1331 |
+
dyn_prob_i_logits = dynamic_prob_predictions[0][iter_].mean(dim=1)
|
1332 |
+
dyn_prob_i = torch.sigmoid(dyn_prob_i_logits).detach()
|
1333 |
+
mask = pts_world.norm(dim=-1) < 200
|
1334 |
+
|
1335 |
+
# general
|
1336 |
+
vis_i_logits = vis_predictions[0][iter_]
|
1337 |
+
vis_i = torch.sigmoid(vis_i_logits).detach()
|
1338 |
+
if mask_traj_gt is not None:
|
1339 |
+
try:
|
1340 |
+
N_gt_mask = mask_traj_gt.shape[1]
|
1341 |
+
align_loss = (gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1)[...,:N_gt_mask] * (mask_traj_gt.permute(0,2,1))
|
1342 |
+
visb_traj = (align_loss * vis_i[:,:,:N_gt_mask]).sum(dim=1)/vis_i[:,:,:N_gt_mask].sum(dim=1)
|
1343 |
+
except:
|
1344 |
+
import pdb; pdb.set_trace()
|
1345 |
+
else:
|
1346 |
+
visb_traj = ((gt_world_pts - q_xyz_world[:,None,:N_gt,:,]).norm(dim=-1) * vis_i[:,:,:N_gt]).sum(dim=1)/vis_i[:,:,:N_gt].sum(dim=1)
|
1347 |
+
|
1348 |
+
# pts_loss = ((q_xyz_world[:,None,...] - pts_world)[:,:,:N_gt,:].norm(dim=-1)*(1-dyn_prob_i[:,None,:N_gt])) # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
|
1349 |
+
pts_loss = 0
|
1350 |
+
static_mask = ~dyn_mask_final # more strict for static points
|
1351 |
+
dyn_mask = dyn_mask_final
|
1352 |
+
pts_loss_refined = ((q_xyz_world[:,None,...] - pts_i_world_refined).norm(dim=-1)*static_mask[:,None,:]).sum()/static_mask.sum() # - 0.1*(1-dyn_prob_i[:,None,:N_gt]).log()
|
1353 |
+
vis_logits_final = vis_predictions[0][-1].detach()
|
1354 |
+
vis_final = torch.sigmoid(vis_logits_final)+0.2 > 0.5 # more strict for visible points
|
1355 |
+
dyn_vis_mask = dyn_mask*vis_final * (fix_cam_pts[...,2] > 0.1)
|
1356 |
+
pts_loss_dynamic = ((fix_cam_pts - coords_xyz_i_world).norm(dim=-1)*dyn_vis_mask[:,None,:]).sum()/dyn_vis_mask.sum()
|
1357 |
+
|
1358 |
+
# pts_loss_refined = 0
|
1359 |
+
if traj3d_gt is not None:
|
1360 |
+
tap_traj = (gt_world_pts[:,:-1,...] - gt_world_pts[:,1:,...]).norm(dim=-1).sum(dim=1)[...,:N_gt_mask]
|
1361 |
+
mask_dyn = tap_traj>0.5
|
1362 |
+
if mask_traj_gt.sum() > 0:
|
1363 |
+
dyn_loss_i = 20*balanced_binary_cross_entropy(dyn_prob_i_logits[:,:N_gt_mask][mask_traj_gt.squeeze(-1)],
|
1364 |
+
mask_dyn.float()[mask_traj_gt.squeeze(-1)])
|
1365 |
+
else:
|
1366 |
+
dyn_loss_i = 0
|
1367 |
+
else:
|
1368 |
+
dyn_loss_i = 10*balanced_binary_cross_entropy(dyn_prob_i_logits, dyn_mask_final.float())
|
1369 |
+
|
1370 |
+
dyn_loss += dyn_loss_i
|
1371 |
+
|
1372 |
+
# visible loss for out of view points
|
1373 |
+
vis_i_train = torch.sigmoid(vis_i_logits)
|
1374 |
+
out_of_view_mask = (coords_i[...,0]<0)|(coords_i[...,0]>self.image_size[1])|(coords_i[...,1]<0)|(coords_i[...,1]>self.image_size[0])
|
1375 |
+
vis_loss_out_of_view = vis_i_train[out_of_view_mask].sum() / out_of_view_mask.sum()
|
1376 |
+
|
1377 |
+
|
1378 |
+
if traj3d_gt is not None:
|
1379 |
+
world_pts_loss = (((gt_world_pts - pts_i_world_refined[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
|
1380 |
+
# world_pts_init_loss = (((gt_world_pts - pts_i_world[:,:,:gt_world_pts.shape[2],...]).norm(dim=-1))[...,:N_gt_mask] * mask_traj_gt.permute(0,2,1)).sum() / mask_traj_gt.sum()
|
1381 |
+
else:
|
1382 |
+
world_pts_loss = 0
|
1383 |
+
|
1384 |
+
# cam regress
|
1385 |
+
t_err = (cam_pred_i[...,:3,3] - cam_gt[...,:3,3]).norm(dim=-1).sum()
|
1386 |
+
|
1387 |
+
# xyz loss
|
1388 |
+
in_view_mask_large = (q_static_proj[...,0]>-50) & (q_static_proj[...,0]<W_org+50) & (q_static_proj[...,1]>-50) & (q_static_proj[...,1]<H_org+50)
|
1389 |
+
static_vis_mask = (q_static_proj[...,2]>0.05).float() * static_mask[:,None,:] * in_view_mask_large
|
1390 |
+
xyz_loss = ((coord_predictions[0][iter_] - q_static_proj)).abs()[...,:2].norm(dim=-1)*static_vis_mask
|
1391 |
+
xyz_loss = xyz_loss.sum()/static_vis_mask.sum()
|
1392 |
+
|
1393 |
+
# visualize the q_static_proj
|
1394 |
+
# viser = Visualizer(save_dir=".", grayscale=True,
|
1395 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
1396 |
+
# video_vis_ = F.interpolate(video_vis.view(B*T,3,video_vis.shape[-2],video_vis.shape[-1]), (H_org, W_org), mode='bilinear', align_corners=False)
|
1397 |
+
# viser.visualize(video=video_vis_, tracks=q_static_proj[:,:,dyn_mask_final.squeeze(), :2], filename="test")
|
1398 |
+
# viser.visualize(video=video_vis_, tracks=coord_predictions[0][-1][:,:,dyn_mask_final.squeeze(), :2], filename="test_pred")
|
1399 |
+
# import pdb; pdb.set_trace()
|
1400 |
+
|
1401 |
+
# temporal loss
|
1402 |
+
t_loss = self.track_d2_loss(pts_i_world_refined, [1,2,3], dyn_prob=dyn_prob_i, mask=mask)
|
1403 |
+
R_err = (cam_pred_i[...,:3,:3] - cam_gt[...,:3,:3]).abs().sum(dim=-1).mean()
|
1404 |
+
if self.stage == 1:
|
1405 |
+
cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 20*pts_loss_refined + 10*xyz_loss + 20*pts_loss_dynamic + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
|
1406 |
+
elif self.stage == 3:
|
1407 |
+
cam_loss += 0.8**(len(camera_predictions[0])-iter_-1)*(10*t_err + 500*R_err + 10*vis_loss_out_of_view) #+ 5*(pts_loss + pts_loss_refined + world_pts_loss) + t_loss)
|
1408 |
+
else:
|
1409 |
+
cam_loss += 0*vis_loss_out_of_view
|
1410 |
+
|
1411 |
+
if (cam_loss > 20000)|(torch.isnan(cam_loss)):
|
1412 |
+
cam_loss = torch.zeros_like(cam_loss)
|
1413 |
+
|
1414 |
+
|
1415 |
+
if traj3d_gt is None:
|
1416 |
+
# ================ Condition 1: The self-supervised signals from the self-consistency ===================
|
1417 |
+
return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
|
1418 |
+
|
1419 |
+
|
1420 |
+
# ================ Condition 2: The supervision signal given by the ground truth trajectories ===================
|
1421 |
+
if (
|
1422 |
+
(torch.isnan(traj3d_gt).any()
|
1423 |
+
or traj3d_gt.abs().max() > 2000) and (custom_vid==False)
|
1424 |
+
):
|
1425 |
+
return cam_loss, train_data[0][0][0].mean()*0, dyn_loss, train_data[0][0][0].mean()*0, point_map_loss, scale_loss, shift_loss
|
1426 |
+
|
1427 |
+
|
1428 |
+
vis_gts = [vis_gt.float()]
|
1429 |
+
invis_gts = [1-vis_gt.float()]
|
1430 |
+
traj_gts = [traj3d_gt]
|
1431 |
+
valids_gts = [valid_mask]
|
1432 |
+
seq_loss_all = sequence_loss(
|
1433 |
+
coord_predictions,
|
1434 |
+
traj_gts,
|
1435 |
+
valids_gts,
|
1436 |
+
vis=vis_gts,
|
1437 |
+
gamma=0.8,
|
1438 |
+
add_huber_loss=False,
|
1439 |
+
loss_only_for_visible=False if custom_vid==False else True,
|
1440 |
+
z_unc=z_unc,
|
1441 |
+
mask_traj_gt=mask_traj_gt
|
1442 |
+
)
|
1443 |
+
|
1444 |
+
confidence_loss = sequence_prob_loss(
|
1445 |
+
coord_predictions, confidence_predicitons, traj_gts, vis_gts
|
1446 |
+
)
|
1447 |
+
|
1448 |
+
seq_loss_xyz = sequence_loss_xyz(
|
1449 |
+
coords_xyz_predictions,
|
1450 |
+
traj_gts,
|
1451 |
+
valids_gts,
|
1452 |
+
intrs=intrs,
|
1453 |
+
vis=vis_gts,
|
1454 |
+
gamma=0.8,
|
1455 |
+
add_huber_loss=False,
|
1456 |
+
loss_only_for_visible=False,
|
1457 |
+
mask_traj_gt=mask_traj_gt
|
1458 |
+
)
|
1459 |
+
|
1460 |
+
# filter the blinking points
|
1461 |
+
mask_vis = vis_gts[0].clone() # B T N
|
1462 |
+
mask_vis[mask_vis==0] = -1
|
1463 |
+
blink_mask = mask_vis[:,:-1,:] * mask_vis[:,1:,:] # first derivative B (T-1) N
|
1464 |
+
mask_vis[:,:-1,:], mask_vis[:,-1,:] = (blink_mask == 1), 0
|
1465 |
+
|
1466 |
+
vis_loss = sequence_BCE_loss(vis_predictions, vis_gts, mask=[mask_vis])
|
1467 |
+
|
1468 |
+
track_loss_out = (seq_loss_all+2*seq_loss_xyz + cam_loss)
|
1469 |
+
if valid_only:
|
1470 |
+
vis_loss = 0.0*vis_loss
|
1471 |
+
if custom_vid:
|
1472 |
+
return seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all, 10*vis_loss, 0.0*seq_loss_all, 0.0*seq_loss_all, 0.0*seq_loss_all
|
1473 |
+
|
1474 |
+
return track_loss_out, confidence_loss, dyn_loss, 10*vis_loss, point_map_loss, scale_loss, shift_loss
|
1475 |
+
|
1476 |
+
|
1477 |
+
|
1478 |
+
|
models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py
ADDED
@@ -0,0 +1,418 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torch.nn.functional as F
|
10 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d, bilinear_sampler
|
11 |
+
|
12 |
+
from models.SpaTrackV2.models.tracker3D.co_tracker.utils import (
|
13 |
+
Mlp, BasicEncoder, EfficientUpdateFormer
|
14 |
+
)
|
15 |
+
|
16 |
+
torch.manual_seed(0)
|
17 |
+
|
18 |
+
|
19 |
+
def get_1d_sincos_pos_embed_from_grid(
|
20 |
+
embed_dim: int, pos: torch.Tensor
|
21 |
+
) -> torch.Tensor:
|
22 |
+
"""
|
23 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions.
|
24 |
+
|
25 |
+
Args:
|
26 |
+
- embed_dim: The embedding dimension.
|
27 |
+
- pos: The position to generate the embedding from.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
- emb: The generated 1D positional embedding.
|
31 |
+
"""
|
32 |
+
assert embed_dim % 2 == 0
|
33 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double)
|
34 |
+
omega /= embed_dim / 2.0
|
35 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
36 |
+
|
37 |
+
pos = pos.reshape(-1) # (M,)
|
38 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
39 |
+
|
40 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
41 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
42 |
+
|
43 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
44 |
+
return emb[None].float()
|
45 |
+
|
46 |
+
def posenc(x, min_deg, max_deg):
|
47 |
+
"""Cat x with a positional encoding of x with scales 2^[min_deg, max_deg-1].
|
48 |
+
Instead of computing [sin(x), cos(x)], we use the trig identity
|
49 |
+
cos(x) = sin(x + pi/2) and do one vectorized call to sin([x, x+pi/2]).
|
50 |
+
Args:
|
51 |
+
x: torch.Tensor, variables to be encoded. Note that x should be in [-pi, pi].
|
52 |
+
min_deg: int, the minimum (inclusive) degree of the encoding.
|
53 |
+
max_deg: int, the maximum (exclusive) degree of the encoding.
|
54 |
+
legacy_posenc_order: bool, keep the same ordering as the original tf code.
|
55 |
+
Returns:
|
56 |
+
encoded: torch.Tensor, encoded variables.
|
57 |
+
"""
|
58 |
+
if min_deg == max_deg:
|
59 |
+
return x
|
60 |
+
scales = torch.tensor(
|
61 |
+
[2**i for i in range(min_deg, max_deg)], dtype=x.dtype, device=x.device
|
62 |
+
)
|
63 |
+
|
64 |
+
xb = (x[..., None, :] * scales[:, None]).reshape(list(x.shape[:-1]) + [-1])
|
65 |
+
four_feat = torch.sin(torch.cat([xb, xb + 0.5 * torch.pi], dim=-1))
|
66 |
+
return torch.cat([x] + [four_feat], dim=-1)
|
67 |
+
|
68 |
+
class CoTrackerThreeBase(nn.Module):
|
69 |
+
def __init__(
|
70 |
+
self,
|
71 |
+
window_len=8,
|
72 |
+
stride=4,
|
73 |
+
corr_radius=3,
|
74 |
+
corr_levels=4,
|
75 |
+
num_virtual_tracks=64,
|
76 |
+
model_resolution=(384, 512),
|
77 |
+
add_space_attn=True,
|
78 |
+
linear_layer_for_vis_conf=True,
|
79 |
+
):
|
80 |
+
super(CoTrackerThreeBase, self).__init__()
|
81 |
+
self.window_len = window_len
|
82 |
+
self.stride = stride
|
83 |
+
self.corr_radius = corr_radius
|
84 |
+
self.corr_levels = corr_levels
|
85 |
+
self.hidden_dim = 256
|
86 |
+
self.latent_dim = 128
|
87 |
+
|
88 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
89 |
+
self.fnet = BasicEncoder(input_dim=3, output_dim=self.latent_dim, stride=stride)
|
90 |
+
|
91 |
+
highres_dim = 128
|
92 |
+
lowres_dim = 256
|
93 |
+
|
94 |
+
self.num_virtual_tracks = num_virtual_tracks
|
95 |
+
self.model_resolution = model_resolution
|
96 |
+
|
97 |
+
self.input_dim = 1110
|
98 |
+
|
99 |
+
self.updateformer = EfficientUpdateFormer(
|
100 |
+
space_depth=3,
|
101 |
+
time_depth=3,
|
102 |
+
input_dim=self.input_dim,
|
103 |
+
hidden_size=384,
|
104 |
+
output_dim=4,
|
105 |
+
mlp_ratio=4.0,
|
106 |
+
num_virtual_tracks=num_virtual_tracks,
|
107 |
+
add_space_attn=add_space_attn,
|
108 |
+
linear_layer_for_vis_conf=linear_layer_for_vis_conf,
|
109 |
+
)
|
110 |
+
self.corr_mlp = Mlp(in_features=49 * 49, hidden_features=384, out_features=256)
|
111 |
+
|
112 |
+
time_grid = torch.linspace(0, window_len - 1, window_len).reshape(
|
113 |
+
1, window_len, 1
|
114 |
+
)
|
115 |
+
|
116 |
+
self.register_buffer(
|
117 |
+
"time_emb", get_1d_sincos_pos_embed_from_grid(self.input_dim, time_grid[0])
|
118 |
+
)
|
119 |
+
|
120 |
+
def get_support_points(self, coords, r, reshape_back=True):
|
121 |
+
B, _, N, _ = coords.shape
|
122 |
+
device = coords.device
|
123 |
+
centroid_lvl = coords.reshape(B, N, 1, 1, 3)
|
124 |
+
|
125 |
+
dx = torch.linspace(-r, r, 2 * r + 1, device=device)
|
126 |
+
dy = torch.linspace(-r, r, 2 * r + 1, device=device)
|
127 |
+
|
128 |
+
xgrid, ygrid = torch.meshgrid(dy, dx, indexing="ij")
|
129 |
+
zgrid = torch.zeros_like(xgrid, device=device)
|
130 |
+
delta = torch.stack([zgrid, xgrid, ygrid], axis=-1)
|
131 |
+
delta_lvl = delta.view(1, 1, 2 * r + 1, 2 * r + 1, 3)
|
132 |
+
coords_lvl = centroid_lvl + delta_lvl
|
133 |
+
|
134 |
+
if reshape_back:
|
135 |
+
return coords_lvl.reshape(B, N, (2 * r + 1) ** 2, 3).permute(0, 2, 1, 3)
|
136 |
+
else:
|
137 |
+
return coords_lvl
|
138 |
+
|
139 |
+
def get_track_feat(self, fmaps, queried_frames, queried_coords, support_radius=0):
|
140 |
+
|
141 |
+
sample_frames = queried_frames[:, None, :, None]
|
142 |
+
sample_coords = torch.cat(
|
143 |
+
[
|
144 |
+
sample_frames,
|
145 |
+
queried_coords[:, None],
|
146 |
+
],
|
147 |
+
dim=-1,
|
148 |
+
)
|
149 |
+
support_points = self.get_support_points(sample_coords, support_radius)
|
150 |
+
support_track_feats = sample_features5d(fmaps, support_points)
|
151 |
+
return (
|
152 |
+
support_track_feats[:, None, support_track_feats.shape[1] // 2],
|
153 |
+
support_track_feats,
|
154 |
+
)
|
155 |
+
|
156 |
+
def get_correlation_feat(self, fmaps, queried_coords, radius=None, padding_mode="border"):
|
157 |
+
B, T, D, H_, W_ = fmaps.shape
|
158 |
+
N = queried_coords.shape[1]
|
159 |
+
if radius is None:
|
160 |
+
r = self.corr_radius
|
161 |
+
else:
|
162 |
+
r = radius
|
163 |
+
sample_coords = torch.cat(
|
164 |
+
[torch.zeros_like(queried_coords[..., :1]), queried_coords], dim=-1
|
165 |
+
)[:, None]
|
166 |
+
support_points = self.get_support_points(sample_coords, r, reshape_back=False)
|
167 |
+
correlation_feat = bilinear_sampler(
|
168 |
+
fmaps.reshape(B * T, D, 1, H_, W_), support_points, padding_mode=padding_mode
|
169 |
+
)
|
170 |
+
return correlation_feat.view(B, T, D, N, (2 * r + 1), (2 * r + 1)).permute(
|
171 |
+
0, 1, 3, 4, 5, 2
|
172 |
+
)
|
173 |
+
|
174 |
+
def interpolate_time_embed(self, x, t):
|
175 |
+
previous_dtype = x.dtype
|
176 |
+
T = self.time_emb.shape[1]
|
177 |
+
|
178 |
+
if t == T:
|
179 |
+
return self.time_emb
|
180 |
+
|
181 |
+
time_emb = self.time_emb.float()
|
182 |
+
time_emb = F.interpolate(
|
183 |
+
time_emb.permute(0, 2, 1), size=t, mode="linear"
|
184 |
+
).permute(0, 2, 1)
|
185 |
+
return time_emb.to(previous_dtype)
|
186 |
+
|
187 |
+
class CoTrackerThreeOffline(CoTrackerThreeBase):
|
188 |
+
def __init__(self, **args):
|
189 |
+
super(CoTrackerThreeOffline, self).__init__(**args)
|
190 |
+
|
191 |
+
def forward(
|
192 |
+
self,
|
193 |
+
video,
|
194 |
+
queries,
|
195 |
+
iters=4,
|
196 |
+
is_train=False,
|
197 |
+
add_space_attn=True,
|
198 |
+
fmaps_chunk_size=200,
|
199 |
+
):
|
200 |
+
"""Predict tracks
|
201 |
+
|
202 |
+
Args:
|
203 |
+
video (FloatTensor[B, T, 3]): input videos.
|
204 |
+
queries (FloatTensor[B, N, 3]): point queries.
|
205 |
+
iters (int, optional): number of updates. Defaults to 4.
|
206 |
+
is_train (bool, optional): enables training mode. Defaults to False.
|
207 |
+
Returns:
|
208 |
+
- coords_predicted (FloatTensor[B, T, N, 2]):
|
209 |
+
- vis_predicted (FloatTensor[B, T, N]):
|
210 |
+
- train_data: `None` if `is_train` is false, otherwise:
|
211 |
+
- all_vis_predictions (List[FloatTensor[B, S, N, 1]]):
|
212 |
+
- all_coords_predictions (List[FloatTensor[B, S, N, 2]]):
|
213 |
+
- mask (BoolTensor[B, T, N]):
|
214 |
+
"""
|
215 |
+
|
216 |
+
B, T, C, H, W = video.shape
|
217 |
+
device = queries.device
|
218 |
+
assert H % self.stride == 0 and W % self.stride == 0
|
219 |
+
|
220 |
+
B, N, __ = queries.shape
|
221 |
+
# B = batch size
|
222 |
+
# S_trimmed = actual number of frames in the window
|
223 |
+
# N = number of tracks
|
224 |
+
# C = color channels (3 for RGB)
|
225 |
+
# E = positional embedding size
|
226 |
+
# LRR = local receptive field radius
|
227 |
+
# D = dimension of the transformer input tokens
|
228 |
+
|
229 |
+
# video = B T C H W
|
230 |
+
# queries = B N 3
|
231 |
+
# coords_init = B T N 2
|
232 |
+
# vis_init = B T N 1
|
233 |
+
|
234 |
+
assert T >= 1 # A tracker needs at least two frames to track something
|
235 |
+
|
236 |
+
video = 2 * (video / 255.0) - 1.0
|
237 |
+
dtype = video.dtype
|
238 |
+
queried_frames = queries[:, :, 0].long()
|
239 |
+
|
240 |
+
queried_coords = queries[..., 1:3]
|
241 |
+
queried_coords = queried_coords / self.stride
|
242 |
+
|
243 |
+
# We store our predictions here
|
244 |
+
all_coords_predictions, all_vis_predictions, all_confidence_predictions = (
|
245 |
+
[],
|
246 |
+
[],
|
247 |
+
[],
|
248 |
+
)
|
249 |
+
C_ = C
|
250 |
+
H4, W4 = H // self.stride, W // self.stride
|
251 |
+
# Compute convolutional features for the video or for the current chunk in case of online mode
|
252 |
+
|
253 |
+
if T > fmaps_chunk_size:
|
254 |
+
fmaps = []
|
255 |
+
for t in range(0, T, fmaps_chunk_size):
|
256 |
+
video_chunk = video[:, t : t + fmaps_chunk_size]
|
257 |
+
fmaps_chunk = self.fnet(video_chunk.reshape(-1, C_, H, W))
|
258 |
+
T_chunk = video_chunk.shape[1]
|
259 |
+
C_chunk, H_chunk, W_chunk = fmaps_chunk.shape[1:]
|
260 |
+
fmaps.append(fmaps_chunk.reshape(B, T_chunk, C_chunk, H_chunk, W_chunk))
|
261 |
+
fmaps = torch.cat(fmaps, dim=1).reshape(-1, C_chunk, H_chunk, W_chunk)
|
262 |
+
else:
|
263 |
+
fmaps = self.fnet(video.reshape(-1, C_, H, W))
|
264 |
+
fmaps = fmaps.permute(0, 2, 3, 1)
|
265 |
+
fmaps = fmaps / torch.sqrt(
|
266 |
+
torch.maximum(
|
267 |
+
torch.sum(torch.square(fmaps), axis=-1, keepdims=True),
|
268 |
+
torch.tensor(1e-12, device=fmaps.device),
|
269 |
+
)
|
270 |
+
)
|
271 |
+
fmaps = fmaps.permute(0, 3, 1, 2).reshape(
|
272 |
+
B, -1, self.latent_dim, H // self.stride, W // self.stride
|
273 |
+
)
|
274 |
+
fmaps = fmaps.to(dtype)
|
275 |
+
|
276 |
+
# We compute track features
|
277 |
+
fmaps_pyramid = []
|
278 |
+
track_feat_pyramid = []
|
279 |
+
track_feat_support_pyramid = []
|
280 |
+
fmaps_pyramid.append(fmaps)
|
281 |
+
for i in range(self.corr_levels - 1):
|
282 |
+
fmaps_ = fmaps.reshape(
|
283 |
+
B * T, self.latent_dim, fmaps.shape[-2], fmaps.shape[-1]
|
284 |
+
)
|
285 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
286 |
+
fmaps = fmaps_.reshape(
|
287 |
+
B, T, self.latent_dim, fmaps_.shape[-2], fmaps_.shape[-1]
|
288 |
+
)
|
289 |
+
fmaps_pyramid.append(fmaps)
|
290 |
+
|
291 |
+
for i in range(self.corr_levels):
|
292 |
+
track_feat, track_feat_support = self.get_track_feat(
|
293 |
+
fmaps_pyramid[i],
|
294 |
+
queried_frames,
|
295 |
+
queried_coords / 2**i,
|
296 |
+
support_radius=self.corr_radius,
|
297 |
+
)
|
298 |
+
track_feat_pyramid.append(track_feat.repeat(1, T, 1, 1))
|
299 |
+
track_feat_support_pyramid.append(track_feat_support.unsqueeze(1))
|
300 |
+
|
301 |
+
D_coords = 2
|
302 |
+
|
303 |
+
coord_preds, vis_preds, confidence_preds = [], [], []
|
304 |
+
|
305 |
+
vis = torch.zeros((B, T, N), device=device).float()
|
306 |
+
confidence = torch.zeros((B, T, N), device=device).float()
|
307 |
+
coords = queried_coords.reshape(B, 1, N, 2).expand(B, T, N, 2).float()
|
308 |
+
|
309 |
+
r = 2 * self.corr_radius + 1
|
310 |
+
|
311 |
+
for it in range(iters):
|
312 |
+
coords = coords.detach() # B T N 2
|
313 |
+
coords_init = coords.view(B * T, N, 2)
|
314 |
+
corr_embs = []
|
315 |
+
corr_feats = []
|
316 |
+
for i in range(self.corr_levels):
|
317 |
+
corr_feat = self.get_correlation_feat(
|
318 |
+
fmaps_pyramid[i], coords_init / 2**i
|
319 |
+
)
|
320 |
+
track_feat_support = (
|
321 |
+
track_feat_support_pyramid[i]
|
322 |
+
.view(B, 1, r, r, N, self.latent_dim)
|
323 |
+
.squeeze(1)
|
324 |
+
.permute(0, 3, 1, 2, 4)
|
325 |
+
)
|
326 |
+
corr_volume = torch.einsum(
|
327 |
+
"btnhwc,bnijc->btnhwij", corr_feat, track_feat_support
|
328 |
+
)
|
329 |
+
corr_emb = self.corr_mlp(corr_volume.reshape(B * T * N, r * r * r * r))
|
330 |
+
corr_embs.append(corr_emb)
|
331 |
+
corr_embs = torch.cat(corr_embs, dim=-1)
|
332 |
+
corr_embs = corr_embs.view(B, T, N, corr_embs.shape[-1])
|
333 |
+
|
334 |
+
transformer_input = [vis[..., None], confidence[..., None], corr_embs]
|
335 |
+
|
336 |
+
rel_coords_forward = coords[:, :-1] - coords[:, 1:]
|
337 |
+
rel_coords_backward = coords[:, 1:] - coords[:, :-1]
|
338 |
+
|
339 |
+
rel_coords_forward = torch.nn.functional.pad(
|
340 |
+
rel_coords_forward, (0, 0, 0, 0, 0, 1)
|
341 |
+
)
|
342 |
+
rel_coords_backward = torch.nn.functional.pad(
|
343 |
+
rel_coords_backward, (0, 0, 0, 0, 1, 0)
|
344 |
+
)
|
345 |
+
scale = (
|
346 |
+
torch.tensor(
|
347 |
+
[self.model_resolution[1], self.model_resolution[0]],
|
348 |
+
device=coords.device,
|
349 |
+
)
|
350 |
+
/ self.stride
|
351 |
+
)
|
352 |
+
rel_coords_forward = rel_coords_forward / scale
|
353 |
+
rel_coords_backward = rel_coords_backward / scale
|
354 |
+
|
355 |
+
rel_pos_emb_input = posenc(
|
356 |
+
torch.cat([rel_coords_forward, rel_coords_backward], dim=-1),
|
357 |
+
min_deg=0,
|
358 |
+
max_deg=10,
|
359 |
+
) # batch, num_points, num_frames, 84
|
360 |
+
transformer_input.append(rel_pos_emb_input)
|
361 |
+
|
362 |
+
x = (
|
363 |
+
torch.cat(transformer_input, dim=-1)
|
364 |
+
.permute(0, 2, 1, 3)
|
365 |
+
.reshape(B * N, T, -1)
|
366 |
+
)
|
367 |
+
|
368 |
+
x = x + self.interpolate_time_embed(x, T)
|
369 |
+
x = x.view(B, N, T, -1) # (B N) T D -> B N T D
|
370 |
+
|
371 |
+
delta = self.updateformer(
|
372 |
+
x,
|
373 |
+
add_space_attn=add_space_attn,
|
374 |
+
)
|
375 |
+
|
376 |
+
delta_coords = delta[..., :D_coords].permute(0, 2, 1, 3)
|
377 |
+
delta_vis = delta[..., D_coords].permute(0, 2, 1)
|
378 |
+
delta_confidence = delta[..., D_coords + 1].permute(0, 2, 1)
|
379 |
+
|
380 |
+
vis = vis + delta_vis
|
381 |
+
confidence = confidence + delta_confidence
|
382 |
+
|
383 |
+
coords = coords + delta_coords
|
384 |
+
coords_append = coords.clone()
|
385 |
+
coords_append[..., :2] = coords_append[..., :2] * float(self.stride)
|
386 |
+
coord_preds.append(coords_append)
|
387 |
+
vis_preds.append(torch.sigmoid(vis))
|
388 |
+
confidence_preds.append(torch.sigmoid(confidence))
|
389 |
+
|
390 |
+
if is_train:
|
391 |
+
all_coords_predictions.append([coord[..., :2] for coord in coord_preds])
|
392 |
+
all_vis_predictions.append(vis_preds)
|
393 |
+
all_confidence_predictions.append(confidence_preds)
|
394 |
+
|
395 |
+
if is_train:
|
396 |
+
train_data = (
|
397 |
+
all_coords_predictions,
|
398 |
+
all_vis_predictions,
|
399 |
+
all_confidence_predictions,
|
400 |
+
torch.ones_like(vis_preds[-1], device=vis_preds[-1].device),
|
401 |
+
)
|
402 |
+
else:
|
403 |
+
train_data = None
|
404 |
+
|
405 |
+
return coord_preds[-1][..., :2], vis_preds[-1], confidence_preds[-1], train_data
|
406 |
+
|
407 |
+
|
408 |
+
if __name__ == "__main__":
|
409 |
+
cotrack_cktp = "/data0/xyx/scaled_offline.pth"
|
410 |
+
cotracker = CoTrackerThreeOffline(
|
411 |
+
stride=4, corr_radius=3, window_len=60
|
412 |
+
)
|
413 |
+
with open(cotrack_cktp, "rb") as f:
|
414 |
+
state_dict = torch.load(f, map_location="cpu")
|
415 |
+
if "model" in state_dict:
|
416 |
+
state_dict = state_dict["model"]
|
417 |
+
cotracker.load_state_dict(state_dict)
|
418 |
+
import pdb; pdb.set_trace()
|
models/SpaTrackV2/models/tracker3D/co_tracker/utils.py
ADDED
@@ -0,0 +1,929 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from functools import partial
|
5 |
+
from typing import Callable, List
|
6 |
+
import collections
|
7 |
+
from torch import Tensor
|
8 |
+
from itertools import repeat
|
9 |
+
from models.SpaTrackV2.utils.model_utils import bilinear_sampler
|
10 |
+
from models.SpaTrackV2.models.blocks import CrossAttnBlock as CrossAttnBlock_F
|
11 |
+
from torch.nn.functional import scaled_dot_product_attention
|
12 |
+
from torch.nn.attention import sdpa_kernel, SDPBackend
|
13 |
+
# import flash_attn
|
14 |
+
EPS = 1e-6
|
15 |
+
|
16 |
+
|
17 |
+
class ResidualBlock(nn.Module):
|
18 |
+
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
19 |
+
super(ResidualBlock, self).__init__()
|
20 |
+
|
21 |
+
self.conv1 = nn.Conv2d(
|
22 |
+
in_planes,
|
23 |
+
planes,
|
24 |
+
kernel_size=3,
|
25 |
+
padding=1,
|
26 |
+
stride=stride,
|
27 |
+
padding_mode="zeros",
|
28 |
+
)
|
29 |
+
self.conv2 = nn.Conv2d(
|
30 |
+
planes, planes, kernel_size=3, padding=1, padding_mode="zeros"
|
31 |
+
)
|
32 |
+
self.relu = nn.ReLU(inplace=True)
|
33 |
+
|
34 |
+
num_groups = planes // 8
|
35 |
+
|
36 |
+
if norm_fn == "group":
|
37 |
+
self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
38 |
+
self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
39 |
+
if not stride == 1:
|
40 |
+
self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
|
41 |
+
|
42 |
+
elif norm_fn == "batch":
|
43 |
+
self.norm1 = nn.BatchNorm2d(planes)
|
44 |
+
self.norm2 = nn.BatchNorm2d(planes)
|
45 |
+
if not stride == 1:
|
46 |
+
self.norm3 = nn.BatchNorm2d(planes)
|
47 |
+
|
48 |
+
elif norm_fn == "instance":
|
49 |
+
self.norm1 = nn.InstanceNorm2d(planes)
|
50 |
+
self.norm2 = nn.InstanceNorm2d(planes)
|
51 |
+
if not stride == 1:
|
52 |
+
self.norm3 = nn.InstanceNorm2d(planes)
|
53 |
+
|
54 |
+
elif norm_fn == "none":
|
55 |
+
self.norm1 = nn.Sequential()
|
56 |
+
self.norm2 = nn.Sequential()
|
57 |
+
if not stride == 1:
|
58 |
+
self.norm3 = nn.Sequential()
|
59 |
+
|
60 |
+
if stride == 1:
|
61 |
+
self.downsample = None
|
62 |
+
|
63 |
+
else:
|
64 |
+
self.downsample = nn.Sequential(
|
65 |
+
nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3
|
66 |
+
)
|
67 |
+
|
68 |
+
def forward(self, x):
|
69 |
+
y = x
|
70 |
+
y = self.relu(self.norm1(self.conv1(y)))
|
71 |
+
y = self.relu(self.norm2(self.conv2(y)))
|
72 |
+
|
73 |
+
if self.downsample is not None:
|
74 |
+
x = self.downsample(x)
|
75 |
+
|
76 |
+
return self.relu(x + y)
|
77 |
+
|
78 |
+
def reduce_masked_mean(input, mask, dim=None, keepdim=False):
|
79 |
+
r"""Masked mean
|
80 |
+
|
81 |
+
`reduce_masked_mean(x, mask)` computes the mean of a tensor :attr:`input`
|
82 |
+
over a mask :attr:`mask`, returning
|
83 |
+
|
84 |
+
.. math::
|
85 |
+
\text{output} =
|
86 |
+
\frac
|
87 |
+
{\sum_{i=1}^N \text{input}_i \cdot \text{mask}_i}
|
88 |
+
{\epsilon + \sum_{i=1}^N \text{mask}_i}
|
89 |
+
|
90 |
+
where :math:`N` is the number of elements in :attr:`input` and
|
91 |
+
:attr:`mask`, and :math:`\epsilon` is a small constant to avoid
|
92 |
+
division by zero.
|
93 |
+
|
94 |
+
`reduced_masked_mean(x, mask, dim)` computes the mean of a tensor
|
95 |
+
:attr:`input` over a mask :attr:`mask` along a dimension :attr:`dim`.
|
96 |
+
Optionally, the dimension can be kept in the output by setting
|
97 |
+
:attr:`keepdim` to `True`. Tensor :attr:`mask` must be broadcastable to
|
98 |
+
the same dimension as :attr:`input`.
|
99 |
+
|
100 |
+
The interface is similar to `torch.mean()`.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
inout (Tensor): input tensor.
|
104 |
+
mask (Tensor): mask.
|
105 |
+
dim (int, optional): Dimension to sum over. Defaults to None.
|
106 |
+
keepdim (bool, optional): Keep the summed dimension. Defaults to False.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Tensor: mean tensor.
|
110 |
+
"""
|
111 |
+
|
112 |
+
mask = mask.expand_as(input)
|
113 |
+
|
114 |
+
prod = input * mask
|
115 |
+
|
116 |
+
if dim is None:
|
117 |
+
numer = torch.sum(prod)
|
118 |
+
denom = torch.sum(mask)
|
119 |
+
else:
|
120 |
+
numer = torch.sum(prod, dim=dim, keepdim=keepdim)
|
121 |
+
denom = torch.sum(mask, dim=dim, keepdim=keepdim)
|
122 |
+
|
123 |
+
mean = numer / (EPS + denom)
|
124 |
+
return mean
|
125 |
+
|
126 |
+
class GeometryEncoder(nn.Module):
|
127 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
128 |
+
super(GeometryEncoder, self).__init__()
|
129 |
+
self.stride = stride
|
130 |
+
self.norm_fn = "instance"
|
131 |
+
self.in_planes = output_dim // 2
|
132 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
133 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
134 |
+
self.conv1 = nn.Conv2d(
|
135 |
+
input_dim,
|
136 |
+
self.in_planes,
|
137 |
+
kernel_size=7,
|
138 |
+
stride=2,
|
139 |
+
padding=3,
|
140 |
+
padding_mode="zeros",
|
141 |
+
)
|
142 |
+
self.relu1 = nn.ReLU(inplace=True)
|
143 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
144 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
145 |
+
|
146 |
+
self.conv2 = nn.Conv2d(
|
147 |
+
output_dim * 5 // 4,
|
148 |
+
output_dim,
|
149 |
+
kernel_size=3,
|
150 |
+
padding=1,
|
151 |
+
padding_mode="zeros",
|
152 |
+
)
|
153 |
+
self.relu2 = nn.ReLU(inplace=True)
|
154 |
+
self.conv3 = nn.Conv2d(output_dim, output_dim, kernel_size=1)
|
155 |
+
for m in self.modules():
|
156 |
+
if isinstance(m, nn.Conv2d):
|
157 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
158 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
159 |
+
if m.weight is not None:
|
160 |
+
nn.init.constant_(m.weight, 1)
|
161 |
+
if m.bias is not None:
|
162 |
+
nn.init.constant_(m.bias, 0)
|
163 |
+
|
164 |
+
def _make_layer(self, dim, stride=1):
|
165 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
166 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
167 |
+
layers = (layer1, layer2)
|
168 |
+
|
169 |
+
self.in_planes = dim
|
170 |
+
return nn.Sequential(*layers)
|
171 |
+
|
172 |
+
def forward(self, x):
|
173 |
+
_, _, H, W = x.shape
|
174 |
+
x = self.conv1(x)
|
175 |
+
x = self.norm1(x)
|
176 |
+
x = self.relu1(x)
|
177 |
+
a = self.layer1(x)
|
178 |
+
b = self.layer2(a)
|
179 |
+
def _bilinear_intepolate(x):
|
180 |
+
return F.interpolate(
|
181 |
+
x,
|
182 |
+
(H // self.stride, W // self.stride),
|
183 |
+
mode="bilinear",
|
184 |
+
align_corners=True,
|
185 |
+
)
|
186 |
+
a = _bilinear_intepolate(a)
|
187 |
+
b = _bilinear_intepolate(b)
|
188 |
+
x = self.conv2(torch.cat([a, b], dim=1))
|
189 |
+
x = self.norm2(x)
|
190 |
+
x = self.relu2(x)
|
191 |
+
x = self.conv3(x)
|
192 |
+
return x
|
193 |
+
|
194 |
+
class BasicEncoder(nn.Module):
|
195 |
+
def __init__(self, input_dim=3, output_dim=128, stride=4):
|
196 |
+
super(BasicEncoder, self).__init__()
|
197 |
+
self.stride = stride
|
198 |
+
self.norm_fn = "instance"
|
199 |
+
self.in_planes = output_dim // 2
|
200 |
+
self.norm1 = nn.InstanceNorm2d(self.in_planes)
|
201 |
+
self.norm2 = nn.InstanceNorm2d(output_dim * 2)
|
202 |
+
|
203 |
+
self.conv1 = nn.Conv2d(
|
204 |
+
input_dim,
|
205 |
+
self.in_planes,
|
206 |
+
kernel_size=7,
|
207 |
+
stride=2,
|
208 |
+
padding=3,
|
209 |
+
padding_mode="zeros",
|
210 |
+
)
|
211 |
+
self.relu1 = nn.ReLU(inplace=True)
|
212 |
+
self.layer1 = self._make_layer(output_dim // 2, stride=1)
|
213 |
+
self.layer2 = self._make_layer(output_dim // 4 * 3, stride=2)
|
214 |
+
self.layer3 = self._make_layer(output_dim, stride=2)
|
215 |
+
self.layer4 = self._make_layer(output_dim, stride=2)
|
216 |
+
|
217 |
+
self.conv2 = nn.Conv2d(
|
218 |
+
output_dim * 3 + output_dim // 4,
|
219 |
+
output_dim * 2,
|
220 |
+
kernel_size=3,
|
221 |
+
padding=1,
|
222 |
+
padding_mode="zeros",
|
223 |
+
)
|
224 |
+
self.relu2 = nn.ReLU(inplace=True)
|
225 |
+
self.conv3 = nn.Conv2d(output_dim * 2, output_dim, kernel_size=1)
|
226 |
+
for m in self.modules():
|
227 |
+
if isinstance(m, nn.Conv2d):
|
228 |
+
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
229 |
+
elif isinstance(m, (nn.InstanceNorm2d)):
|
230 |
+
if m.weight is not None:
|
231 |
+
nn.init.constant_(m.weight, 1)
|
232 |
+
if m.bias is not None:
|
233 |
+
nn.init.constant_(m.bias, 0)
|
234 |
+
|
235 |
+
def _make_layer(self, dim, stride=1):
|
236 |
+
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
237 |
+
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
238 |
+
layers = (layer1, layer2)
|
239 |
+
|
240 |
+
self.in_planes = dim
|
241 |
+
return nn.Sequential(*layers)
|
242 |
+
|
243 |
+
def forward(self, x):
|
244 |
+
_, _, H, W = x.shape
|
245 |
+
|
246 |
+
x = self.conv1(x)
|
247 |
+
x = self.norm1(x)
|
248 |
+
x = self.relu1(x)
|
249 |
+
|
250 |
+
a = self.layer1(x)
|
251 |
+
b = self.layer2(a)
|
252 |
+
c = self.layer3(b)
|
253 |
+
d = self.layer4(c)
|
254 |
+
|
255 |
+
def _bilinear_intepolate(x):
|
256 |
+
return F.interpolate(
|
257 |
+
x,
|
258 |
+
(H // self.stride, W // self.stride),
|
259 |
+
mode="bilinear",
|
260 |
+
align_corners=True,
|
261 |
+
)
|
262 |
+
|
263 |
+
a = _bilinear_intepolate(a)
|
264 |
+
b = _bilinear_intepolate(b)
|
265 |
+
c = _bilinear_intepolate(c)
|
266 |
+
d = _bilinear_intepolate(d)
|
267 |
+
|
268 |
+
x = self.conv2(torch.cat([a, b, c, d], dim=1))
|
269 |
+
x = self.norm2(x)
|
270 |
+
x = self.relu2(x)
|
271 |
+
x = self.conv3(x)
|
272 |
+
return x
|
273 |
+
|
274 |
+
# From PyTorch internals
|
275 |
+
def _ntuple(n):
|
276 |
+
def parse(x):
|
277 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
278 |
+
return tuple(x)
|
279 |
+
return tuple(repeat(x, n))
|
280 |
+
|
281 |
+
return parse
|
282 |
+
|
283 |
+
|
284 |
+
def exists(val):
|
285 |
+
return val is not None
|
286 |
+
|
287 |
+
|
288 |
+
def default(val, d):
|
289 |
+
return val if exists(val) else d
|
290 |
+
|
291 |
+
|
292 |
+
to_2tuple = _ntuple(2)
|
293 |
+
|
294 |
+
|
295 |
+
class Mlp(nn.Module):
|
296 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
297 |
+
|
298 |
+
def __init__(
|
299 |
+
self,
|
300 |
+
in_features,
|
301 |
+
hidden_features=None,
|
302 |
+
out_features=None,
|
303 |
+
act_layer=nn.GELU,
|
304 |
+
norm_layer=None,
|
305 |
+
bias=True,
|
306 |
+
drop=0.0,
|
307 |
+
use_conv=False,
|
308 |
+
):
|
309 |
+
super().__init__()
|
310 |
+
out_features = out_features or in_features
|
311 |
+
hidden_features = hidden_features or in_features
|
312 |
+
bias = to_2tuple(bias)
|
313 |
+
drop_probs = to_2tuple(drop)
|
314 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
315 |
+
|
316 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
317 |
+
self.act = act_layer()
|
318 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
319 |
+
self.norm = (
|
320 |
+
norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
321 |
+
)
|
322 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
323 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
324 |
+
|
325 |
+
def forward(self, x):
|
326 |
+
x = self.fc1(x)
|
327 |
+
x = self.act(x)
|
328 |
+
x = self.drop1(x)
|
329 |
+
x = self.fc2(x)
|
330 |
+
x = self.drop2(x)
|
331 |
+
return x
|
332 |
+
|
333 |
+
|
334 |
+
class Attention(nn.Module):
|
335 |
+
def __init__(
|
336 |
+
self, query_dim, context_dim=None, num_heads=8, dim_head=48, qkv_bias=False
|
337 |
+
):
|
338 |
+
super().__init__()
|
339 |
+
inner_dim = dim_head * num_heads
|
340 |
+
self.inner_dim = inner_dim
|
341 |
+
context_dim = default(context_dim, query_dim)
|
342 |
+
self.scale = dim_head**-0.5
|
343 |
+
self.heads = num_heads
|
344 |
+
|
345 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
346 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
347 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
348 |
+
|
349 |
+
def forward(self, x, context=None, attn_bias=None, flash=True):
|
350 |
+
B, N1, C = x.shape
|
351 |
+
h = self.heads
|
352 |
+
|
353 |
+
q = self.to_q(x).reshape(B, N1, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
354 |
+
context = default(context, x)
|
355 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
356 |
+
|
357 |
+
N2 = context.shape[1]
|
358 |
+
k = k.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
359 |
+
v = v.reshape(B, N2, h, self.inner_dim // h).permute(0, 2, 1, 3)
|
360 |
+
|
361 |
+
if (
|
362 |
+
(N1 < 64 and N2 < 64) or
|
363 |
+
(B > 1e4) or
|
364 |
+
(q.shape[1] != k.shape[1]) or
|
365 |
+
(q.shape[1] % k.shape[1] != 0)
|
366 |
+
):
|
367 |
+
flash = False
|
368 |
+
|
369 |
+
|
370 |
+
if flash == False:
|
371 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
372 |
+
if attn_bias is not None:
|
373 |
+
sim = sim + attn_bias
|
374 |
+
if sim.abs().max() > 1e2:
|
375 |
+
import pdb; pdb.set_trace()
|
376 |
+
attn = sim.softmax(dim=-1)
|
377 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, self.inner_dim)
|
378 |
+
else:
|
379 |
+
|
380 |
+
input_args = [x.contiguous() for x in [q, k, v]]
|
381 |
+
try:
|
382 |
+
# print(f"q.shape: {q.shape}, dtype: {q.dtype}, device: {q.device}")
|
383 |
+
# print(f"Flash SDP available: {torch.backends.cuda.flash_sdp_enabled()}")
|
384 |
+
# print(f"Flash SDP allowed: {torch.backends.cuda.enable_flash_sdp}")
|
385 |
+
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False):
|
386 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
387 |
+
except Exception as e:
|
388 |
+
print(e)
|
389 |
+
|
390 |
+
if self.to_out.bias.dtype != x.dtype:
|
391 |
+
x = x.to(self.to_out.bias.dtype)
|
392 |
+
|
393 |
+
return self.to_out(x)
|
394 |
+
|
395 |
+
class CrossAttnBlock(nn.Module):
|
396 |
+
def __init__(
|
397 |
+
self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs
|
398 |
+
):
|
399 |
+
super().__init__()
|
400 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
401 |
+
self.norm_context = nn.LayerNorm(context_dim)
|
402 |
+
self.cross_attn = Attention(
|
403 |
+
hidden_size,
|
404 |
+
context_dim=context_dim,
|
405 |
+
num_heads=num_heads,
|
406 |
+
qkv_bias=True,
|
407 |
+
**block_kwargs
|
408 |
+
)
|
409 |
+
|
410 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
411 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
412 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
413 |
+
self.mlp = Mlp(
|
414 |
+
in_features=hidden_size,
|
415 |
+
hidden_features=mlp_hidden_dim,
|
416 |
+
act_layer=approx_gelu,
|
417 |
+
drop=0,
|
418 |
+
)
|
419 |
+
|
420 |
+
def forward(self, x, context, mask=None):
|
421 |
+
attn_bias = None
|
422 |
+
if mask is not None:
|
423 |
+
if mask.shape[1] == x.shape[1]:
|
424 |
+
mask = mask[:, None, :, None].expand(
|
425 |
+
-1, self.cross_attn.heads, -1, context.shape[1]
|
426 |
+
)
|
427 |
+
else:
|
428 |
+
mask = mask[:, None, None].expand(
|
429 |
+
-1, self.cross_attn.heads, x.shape[1], -1
|
430 |
+
)
|
431 |
+
|
432 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
433 |
+
attn_bias = (~mask) * max_neg_value
|
434 |
+
x = x + self.cross_attn(
|
435 |
+
self.norm1(x), context=self.norm_context(context), attn_bias=attn_bias
|
436 |
+
)
|
437 |
+
x = x + self.mlp(self.norm2(x))
|
438 |
+
return x
|
439 |
+
|
440 |
+
class AttnBlock(nn.Module):
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
hidden_size,
|
444 |
+
num_heads,
|
445 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
446 |
+
mlp_ratio=4.0,
|
447 |
+
**block_kwargs
|
448 |
+
):
|
449 |
+
super().__init__()
|
450 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
451 |
+
self.attn = attn_class(
|
452 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs
|
453 |
+
)
|
454 |
+
|
455 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
456 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
457 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
458 |
+
self.mlp = Mlp(
|
459 |
+
in_features=hidden_size,
|
460 |
+
hidden_features=mlp_hidden_dim,
|
461 |
+
act_layer=approx_gelu,
|
462 |
+
drop=0,
|
463 |
+
)
|
464 |
+
|
465 |
+
def forward(self, x, mask=None):
|
466 |
+
attn_bias = mask
|
467 |
+
if mask is not None:
|
468 |
+
mask = (
|
469 |
+
(mask[:, None] * mask[:, :, None])
|
470 |
+
.unsqueeze(1)
|
471 |
+
.expand(-1, self.attn.num_heads, -1, -1)
|
472 |
+
)
|
473 |
+
max_neg_value = -torch.finfo(x.dtype).max
|
474 |
+
attn_bias = (~mask) * max_neg_value
|
475 |
+
x = x + self.attn(self.norm1(x), attn_bias=attn_bias)
|
476 |
+
x = x + self.mlp(self.norm2(x))
|
477 |
+
return x
|
478 |
+
|
479 |
+
class EfficientUpdateFormer(nn.Module):
|
480 |
+
"""
|
481 |
+
Transformer model that updates track estimates.
|
482 |
+
"""
|
483 |
+
|
484 |
+
def __init__(
|
485 |
+
self,
|
486 |
+
space_depth=6,
|
487 |
+
time_depth=6,
|
488 |
+
input_dim=320,
|
489 |
+
hidden_size=384,
|
490 |
+
num_heads=8,
|
491 |
+
output_dim=130,
|
492 |
+
mlp_ratio=4.0,
|
493 |
+
num_virtual_tracks=64,
|
494 |
+
add_space_attn=True,
|
495 |
+
linear_layer_for_vis_conf=False,
|
496 |
+
patch_feat=False,
|
497 |
+
patch_dim=128,
|
498 |
+
):
|
499 |
+
super().__init__()
|
500 |
+
self.out_channels = 2
|
501 |
+
self.num_heads = num_heads
|
502 |
+
self.hidden_size = hidden_size
|
503 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
504 |
+
if linear_layer_for_vis_conf:
|
505 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim - 2, bias=True)
|
506 |
+
self.vis_conf_head = torch.nn.Linear(hidden_size, 2, bias=True)
|
507 |
+
else:
|
508 |
+
self.flow_head = torch.nn.Linear(hidden_size, output_dim, bias=True)
|
509 |
+
|
510 |
+
if patch_feat==False:
|
511 |
+
self.virual_tracks = nn.Parameter(
|
512 |
+
torch.randn(1, num_virtual_tracks, 1, hidden_size)
|
513 |
+
)
|
514 |
+
self.num_virtual_tracks = num_virtual_tracks
|
515 |
+
else:
|
516 |
+
self.patch_proj = nn.Linear(patch_dim, hidden_size, bias=True)
|
517 |
+
|
518 |
+
self.add_space_attn = add_space_attn
|
519 |
+
self.linear_layer_for_vis_conf = linear_layer_for_vis_conf
|
520 |
+
self.time_blocks = nn.ModuleList(
|
521 |
+
[
|
522 |
+
AttnBlock(
|
523 |
+
hidden_size,
|
524 |
+
num_heads,
|
525 |
+
mlp_ratio=mlp_ratio,
|
526 |
+
attn_class=Attention,
|
527 |
+
)
|
528 |
+
for _ in range(time_depth)
|
529 |
+
]
|
530 |
+
)
|
531 |
+
|
532 |
+
if add_space_attn:
|
533 |
+
self.space_virtual_blocks = nn.ModuleList(
|
534 |
+
[
|
535 |
+
AttnBlock(
|
536 |
+
hidden_size,
|
537 |
+
num_heads,
|
538 |
+
mlp_ratio=mlp_ratio,
|
539 |
+
attn_class=Attention,
|
540 |
+
)
|
541 |
+
for _ in range(space_depth)
|
542 |
+
]
|
543 |
+
)
|
544 |
+
self.space_point2virtual_blocks = nn.ModuleList(
|
545 |
+
[
|
546 |
+
CrossAttnBlock(
|
547 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
548 |
+
)
|
549 |
+
for _ in range(space_depth)
|
550 |
+
]
|
551 |
+
)
|
552 |
+
self.space_virtual2point_blocks = nn.ModuleList(
|
553 |
+
[
|
554 |
+
CrossAttnBlock(
|
555 |
+
hidden_size, hidden_size, num_heads, mlp_ratio=mlp_ratio
|
556 |
+
)
|
557 |
+
for _ in range(space_depth)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
assert len(self.time_blocks) >= len(self.space_virtual2point_blocks)
|
561 |
+
self.initialize_weights()
|
562 |
+
|
563 |
+
def initialize_weights(self):
|
564 |
+
def _basic_init(module):
|
565 |
+
if isinstance(module, nn.Linear):
|
566 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
567 |
+
if module.bias is not None:
|
568 |
+
nn.init.constant_(module.bias, 0)
|
569 |
+
torch.nn.init.trunc_normal_(self.flow_head.weight, std=0.001)
|
570 |
+
if self.linear_layer_for_vis_conf:
|
571 |
+
torch.nn.init.trunc_normal_(self.vis_conf_head.weight, std=0.001)
|
572 |
+
|
573 |
+
def _trunc_init(module):
|
574 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
575 |
+
if isinstance(module, nn.Linear):
|
576 |
+
torch.nn.init.trunc_normal_(module.weight, std=0.02)
|
577 |
+
if module.bias is not None:
|
578 |
+
nn.init.zeros_(module.bias)
|
579 |
+
|
580 |
+
self.apply(_basic_init)
|
581 |
+
|
582 |
+
def forward(self, input_tensor, mask=None, add_space_attn=True, patch_feat=None):
|
583 |
+
tokens = self.input_transform(input_tensor)
|
584 |
+
|
585 |
+
B, _, T, _ = tokens.shape
|
586 |
+
if patch_feat is None:
|
587 |
+
virtual_tokens = self.virual_tracks.repeat(B, 1, T, 1)
|
588 |
+
tokens = torch.cat([tokens, virtual_tokens], dim=1)
|
589 |
+
else:
|
590 |
+
patch_feat = self.patch_proj(patch_feat.detach())
|
591 |
+
tokens = torch.cat([tokens, patch_feat], dim=1)
|
592 |
+
self.num_virtual_tracks = patch_feat.shape[1]
|
593 |
+
|
594 |
+
_, N, _, _ = tokens.shape
|
595 |
+
j = 0
|
596 |
+
layers = []
|
597 |
+
for i in range(len(self.time_blocks)):
|
598 |
+
time_tokens = tokens.contiguous().view(B * N, T, -1) # B N T C -> (B N) T C
|
599 |
+
time_tokens = torch.utils.checkpoint.checkpoint(
|
600 |
+
self.time_blocks[i],
|
601 |
+
time_tokens
|
602 |
+
)
|
603 |
+
|
604 |
+
tokens = time_tokens.view(B, N, T, -1) # (B N) T C -> B N T C
|
605 |
+
if (
|
606 |
+
add_space_attn
|
607 |
+
and hasattr(self, "space_virtual_blocks")
|
608 |
+
and (i % (len(self.time_blocks) // len(self.space_virtual_blocks)) == 0)
|
609 |
+
):
|
610 |
+
space_tokens = (
|
611 |
+
tokens.permute(0, 2, 1, 3).contiguous().view(B * T, N, -1)
|
612 |
+
) # B N T C -> (B T) N C
|
613 |
+
|
614 |
+
point_tokens = space_tokens[:, : N - self.num_virtual_tracks]
|
615 |
+
virtual_tokens = space_tokens[:, N - self.num_virtual_tracks :]
|
616 |
+
|
617 |
+
virtual_tokens = torch.utils.checkpoint.checkpoint(
|
618 |
+
self.space_virtual2point_blocks[j],
|
619 |
+
virtual_tokens, point_tokens, mask
|
620 |
+
)
|
621 |
+
|
622 |
+
virtual_tokens = torch.utils.checkpoint.checkpoint(
|
623 |
+
self.space_virtual_blocks[j],
|
624 |
+
virtual_tokens
|
625 |
+
)
|
626 |
+
|
627 |
+
point_tokens = torch.utils.checkpoint.checkpoint(
|
628 |
+
self.space_point2virtual_blocks[j],
|
629 |
+
point_tokens, virtual_tokens, mask
|
630 |
+
)
|
631 |
+
|
632 |
+
space_tokens = torch.cat([point_tokens, virtual_tokens], dim=1)
|
633 |
+
tokens = space_tokens.view(B, T, N, -1).permute(
|
634 |
+
0, 2, 1, 3
|
635 |
+
) # (B T) N C -> B N T C
|
636 |
+
j += 1
|
637 |
+
tokens = tokens[:, : N - self.num_virtual_tracks]
|
638 |
+
|
639 |
+
flow = self.flow_head(tokens)
|
640 |
+
if self.linear_layer_for_vis_conf:
|
641 |
+
vis_conf = self.vis_conf_head(tokens)
|
642 |
+
flow = torch.cat([flow, vis_conf], dim=-1)
|
643 |
+
|
644 |
+
return flow
|
645 |
+
|
646 |
+
def focal_loss(logits, targets, alpha=0.25, gamma=2.0):
|
647 |
+
probs = torch.sigmoid(logits)
|
648 |
+
ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction='none')
|
649 |
+
p_t = probs * targets + (1 - probs) * (1 - targets)
|
650 |
+
loss = alpha * (1 - p_t) ** gamma * ce_loss
|
651 |
+
return loss.mean()
|
652 |
+
|
653 |
+
def balanced_binary_cross_entropy(logits, targets, balance_weight=1.0, eps=1e-6, reduction="mean", pos_bias=0.0, mask=None):
|
654 |
+
"""
|
655 |
+
logits: Tensor of arbitrary shape
|
656 |
+
targets: same shape as logits
|
657 |
+
balance_weight: scaling the loss
|
658 |
+
reduction: 'mean', 'sum', or 'none'
|
659 |
+
"""
|
660 |
+
targets = targets.float()
|
661 |
+
positive = (targets == 1).float().sum()
|
662 |
+
total = targets.numel()
|
663 |
+
positive_ratio = positive / (total + eps)
|
664 |
+
|
665 |
+
pos_weight = (1 - positive_ratio) / (positive_ratio + eps)
|
666 |
+
pos_weight = pos_weight.clamp(min=0.1, max=10.0)
|
667 |
+
loss = F.binary_cross_entropy_with_logits(
|
668 |
+
logits,
|
669 |
+
targets,
|
670 |
+
pos_weight=pos_weight+pos_bias,
|
671 |
+
reduction=reduction
|
672 |
+
)
|
673 |
+
if mask is not None:
|
674 |
+
loss = (loss * mask).sum() / (mask.sum() + eps)
|
675 |
+
return balance_weight * loss
|
676 |
+
|
677 |
+
def sequence_loss(
|
678 |
+
flow_preds,
|
679 |
+
flow_gt,
|
680 |
+
valids,
|
681 |
+
vis=None,
|
682 |
+
gamma=0.8,
|
683 |
+
add_huber_loss=False,
|
684 |
+
loss_only_for_visible=False,
|
685 |
+
depth_sample=None,
|
686 |
+
z_unc=None,
|
687 |
+
mask_traj_gt=None
|
688 |
+
):
|
689 |
+
"""Loss function defined over sequence of flow predictions"""
|
690 |
+
total_flow_loss = 0.0
|
691 |
+
for j in range(len(flow_gt)):
|
692 |
+
B, S, N, D = flow_gt[j].shape
|
693 |
+
B, S2, N = valids[j].shape
|
694 |
+
assert S == S2
|
695 |
+
n_predictions = len(flow_preds[j])
|
696 |
+
flow_loss = 0.0
|
697 |
+
for i in range(n_predictions):
|
698 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
699 |
+
flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
|
700 |
+
if flow_pred.shape[-1] == 3:
|
701 |
+
flow_pred[...,2] = flow_pred[...,2]
|
702 |
+
if add_huber_loss:
|
703 |
+
i_loss = huber_loss(flow_pred, flow_gt[j], delta=6.0)
|
704 |
+
else:
|
705 |
+
if flow_gt[j][...,2].abs().max() != 0:
|
706 |
+
track_z_loss = (flow_pred- flow_gt[j])[...,2].abs().mean()
|
707 |
+
if mask_traj_gt is not None:
|
708 |
+
track_z_loss = ((flow_pred- flow_gt[j])[...,2].abs() * mask_traj_gt.permute(0,2,1)).sum() / (mask_traj_gt.sum(dim=1)+1e-6)
|
709 |
+
else:
|
710 |
+
track_z_loss = 0
|
711 |
+
i_loss = (flow_pred[...,:2] - flow_gt[j][...,:2]).abs() # B, S, N, 2
|
712 |
+
# print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
|
713 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
714 |
+
valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
|
715 |
+
valid_ = valid_ * (flow_gt[j][...,:2].norm(dim=-1) > 0).float()
|
716 |
+
if loss_only_for_visible:
|
717 |
+
valid_ = valid_ * vis[j]
|
718 |
+
# print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
|
719 |
+
flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_) + track_z_loss + 10*reduce_masked_mean(i_loss, valid_* vis[j]))
|
720 |
+
# if flow_loss > 5e2:
|
721 |
+
# import pdb; pdb.set_trace()
|
722 |
+
flow_loss = flow_loss / n_predictions
|
723 |
+
total_flow_loss += flow_loss
|
724 |
+
return total_flow_loss / len(flow_gt)
|
725 |
+
|
726 |
+
def sequence_loss_xyz(
|
727 |
+
flow_preds,
|
728 |
+
flow_gt,
|
729 |
+
valids,
|
730 |
+
intrs,
|
731 |
+
vis=None,
|
732 |
+
gamma=0.8,
|
733 |
+
add_huber_loss=False,
|
734 |
+
loss_only_for_visible=False,
|
735 |
+
mask_traj_gt=None
|
736 |
+
):
|
737 |
+
"""Loss function defined over sequence of flow predictions"""
|
738 |
+
total_flow_loss = 0.0
|
739 |
+
for j in range(len(flow_gt)):
|
740 |
+
B, S, N, D = flow_gt[j].shape
|
741 |
+
B, S2, N = valids[j].shape
|
742 |
+
assert S == S2
|
743 |
+
n_predictions = len(flow_preds[j])
|
744 |
+
flow_loss = 0.0
|
745 |
+
for i in range(n_predictions):
|
746 |
+
i_weight = gamma ** (n_predictions - i - 1)
|
747 |
+
flow_pred = flow_preds[j][i][:,:,:flow_gt[j].shape[2]]
|
748 |
+
flow_gt_ = flow_gt[j]
|
749 |
+
flow_gt_one = torch.cat([flow_gt_[...,:2], torch.ones_like(flow_gt_[:,:,:,:1])], dim=-1)
|
750 |
+
flow_gt_cam = torch.einsum('btsc,btnc->btns', torch.inverse(intrs), flow_gt_one)
|
751 |
+
flow_gt_cam *= flow_gt_[...,2:3].abs()
|
752 |
+
flow_gt_cam[...,2] *= torch.sign(flow_gt_cam[...,2])
|
753 |
+
|
754 |
+
if add_huber_loss:
|
755 |
+
i_loss = huber_loss(flow_pred, flow_gt_cam, delta=6.0)
|
756 |
+
else:
|
757 |
+
i_loss = (flow_pred- flow_gt_cam).norm(dim=-1,keepdim=True) # B, S, N, 2
|
758 |
+
|
759 |
+
# print((flow_pred - flow_gt[j])[...,2].abs()[vis[j].bool()].mean())
|
760 |
+
i_loss = torch.mean(i_loss, dim=3) # B, S, N
|
761 |
+
valid_ = valids[j].clone()[:,:, :flow_gt[j].shape[2]] # Ensure valid_ has the same shape as i_loss
|
762 |
+
if loss_only_for_visible:
|
763 |
+
valid_ = valid_ * vis[j]
|
764 |
+
# print(reduce_masked_mean(i_loss, valid_).item(), track_z_loss.item()/16)
|
765 |
+
flow_loss += i_weight * (reduce_masked_mean(i_loss, valid_)) * 1000
|
766 |
+
# if flow_loss > 5e2:
|
767 |
+
# import pdb; pdb.set_trace()
|
768 |
+
flow_loss = flow_loss / n_predictions
|
769 |
+
total_flow_loss += flow_loss
|
770 |
+
return total_flow_loss / len(flow_gt)
|
771 |
+
|
772 |
+
def huber_loss(x, y, delta=1.0):
|
773 |
+
"""Calculate element-wise Huber loss between x and y"""
|
774 |
+
diff = x - y
|
775 |
+
abs_diff = diff.abs()
|
776 |
+
flag = (abs_diff <= delta).float()
|
777 |
+
return flag * 0.5 * diff**2 + (1 - flag) * delta * (abs_diff - 0.5 * delta)
|
778 |
+
|
779 |
+
|
780 |
+
def sequence_BCE_loss(vis_preds, vis_gts, mask=None):
|
781 |
+
total_bce_loss = 0.0
|
782 |
+
for j in range(len(vis_preds)):
|
783 |
+
n_predictions = len(vis_preds[j])
|
784 |
+
bce_loss = 0.0
|
785 |
+
for i in range(n_predictions):
|
786 |
+
N_gt = vis_gts[j].shape[-1]
|
787 |
+
if mask is not None:
|
788 |
+
vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j], mask=mask[j], reduction="none")
|
789 |
+
else:
|
790 |
+
vis_loss = balanced_binary_cross_entropy(vis_preds[j][i][...,:N_gt], vis_gts[j]) + focal_loss(vis_preds[j][i][...,:N_gt], vis_gts[j])
|
791 |
+
# print(vis_loss, ((torch.sigmoid(vis_preds[j][i][...,:N_gt])>0.5).float() - vis_gts[j]).abs().sum())
|
792 |
+
bce_loss += vis_loss
|
793 |
+
bce_loss = bce_loss / n_predictions
|
794 |
+
total_bce_loss += bce_loss
|
795 |
+
return total_bce_loss / len(vis_preds)
|
796 |
+
|
797 |
+
|
798 |
+
def sequence_prob_loss(
|
799 |
+
tracks: torch.Tensor,
|
800 |
+
confidence: torch.Tensor,
|
801 |
+
target_points: torch.Tensor,
|
802 |
+
visibility: torch.Tensor,
|
803 |
+
expected_dist_thresh: float = 12.0,
|
804 |
+
):
|
805 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
806 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
807 |
+
# them as occluded will actually improve Jaccard metrics and give
|
808 |
+
# qualitatively better results.
|
809 |
+
total_logprob_loss = 0.0
|
810 |
+
for j in range(len(tracks)):
|
811 |
+
n_predictions = len(tracks[j])
|
812 |
+
logprob_loss = 0.0
|
813 |
+
for i in range(n_predictions):
|
814 |
+
N_gt = target_points[j].shape[2]
|
815 |
+
err = torch.sum((tracks[j][i].detach()[:,:,:N_gt,:2] - target_points[j][...,:2]) ** 2, dim=-1)
|
816 |
+
valid = (err <= expected_dist_thresh**2).float()
|
817 |
+
logprob = balanced_binary_cross_entropy(confidence[j][i][...,:N_gt], valid, reduction="none")
|
818 |
+
logprob *= visibility[j]
|
819 |
+
logprob = torch.mean(logprob, dim=[1, 2])
|
820 |
+
logprob_loss += logprob
|
821 |
+
logprob_loss = logprob_loss / n_predictions
|
822 |
+
total_logprob_loss += logprob_loss
|
823 |
+
return total_logprob_loss / len(tracks)
|
824 |
+
|
825 |
+
|
826 |
+
def sequence_dyn_prob_loss(
|
827 |
+
tracks: torch.Tensor,
|
828 |
+
confidence: torch.Tensor,
|
829 |
+
target_points: torch.Tensor,
|
830 |
+
visibility: torch.Tensor,
|
831 |
+
expected_dist_thresh: float = 6.0,
|
832 |
+
):
|
833 |
+
"""Loss for classifying if a point is within pixel threshold of its target."""
|
834 |
+
# Points with an error larger than 12 pixels are likely to be useless; marking
|
835 |
+
# them as occluded will actually improve Jaccard metrics and give
|
836 |
+
# qualitatively better results.
|
837 |
+
total_logprob_loss = 0.0
|
838 |
+
for j in range(len(tracks)):
|
839 |
+
n_predictions = len(tracks[j])
|
840 |
+
logprob_loss = 0.0
|
841 |
+
for i in range(n_predictions):
|
842 |
+
err = torch.sum((tracks[j][i].detach() - target_points[j]) ** 2, dim=-1)
|
843 |
+
valid = (err <= expected_dist_thresh**2).float()
|
844 |
+
valid = (valid.sum(dim=1) > 0).float()
|
845 |
+
logprob = balanced_binary_cross_entropy(confidence[j][i].mean(dim=1), valid, reduction="none")
|
846 |
+
# logprob *= visibility[j]
|
847 |
+
logprob = torch.mean(logprob, dim=[0, 1])
|
848 |
+
logprob_loss += logprob
|
849 |
+
logprob_loss = logprob_loss / n_predictions
|
850 |
+
total_logprob_loss += logprob_loss
|
851 |
+
return total_logprob_loss / len(tracks)
|
852 |
+
|
853 |
+
|
854 |
+
def masked_mean(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
855 |
+
if mask is None:
|
856 |
+
return data.mean(dim=dim, keepdim=True)
|
857 |
+
mask = mask.float()
|
858 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
859 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
860 |
+
mask_sum, min=1.0
|
861 |
+
)
|
862 |
+
return mask_mean
|
863 |
+
|
864 |
+
|
865 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
866 |
+
if mask is None:
|
867 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
868 |
+
mask = mask.float()
|
869 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
870 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
871 |
+
mask_sum, min=1.0
|
872 |
+
)
|
873 |
+
mask_var = torch.sum(
|
874 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
875 |
+
) / torch.clamp(mask_sum, min=1.0)
|
876 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
877 |
+
|
878 |
+
class NeighborTransformer(nn.Module):
|
879 |
+
def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
|
880 |
+
super().__init__()
|
881 |
+
self.dim = dim
|
882 |
+
self.output_token_1 = nn.Parameter(torch.randn(1, dim))
|
883 |
+
self.output_token_2 = nn.Parameter(torch.randn(1, dim))
|
884 |
+
self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
885 |
+
self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
886 |
+
self.aggr1 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
|
887 |
+
self.aggr2 = Attention(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim)
|
888 |
+
|
889 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
890 |
+
from einops import rearrange, repeat
|
891 |
+
import torch.utils.checkpoint as checkpoint
|
892 |
+
|
893 |
+
assert len (x.shape) == 3, "x should be of shape (B, N, D)"
|
894 |
+
assert len (y.shape) == 3, "y should be of shape (B, N, D)"
|
895 |
+
|
896 |
+
# not work so well ...
|
897 |
+
|
898 |
+
def forward_chunk(x, y):
|
899 |
+
new_x = self.xblock1_2(x, y)
|
900 |
+
new_y = self.xblock2_1(y, x)
|
901 |
+
out1 = self.aggr1(repeat(self.output_token_1, 'n d -> b n d', b=x.shape[0]), context=new_x)
|
902 |
+
out2 = self.aggr2(repeat(self.output_token_2, 'n d -> b n d', b=x.shape[0]), context=new_y)
|
903 |
+
return out1 + out2
|
904 |
+
|
905 |
+
return checkpoint.checkpoint(forward_chunk, x, y)
|
906 |
+
|
907 |
+
|
908 |
+
class CorrPointformer(nn.Module):
|
909 |
+
def __init__(self, dim: int, num_heads: int, head_dim: int, mlp_ratio: float):
|
910 |
+
super().__init__()
|
911 |
+
self.dim = dim
|
912 |
+
self.xblock1_2 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
913 |
+
# self.xblock2_1 = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
914 |
+
self.aggr = CrossAttnBlock(dim, context_dim=dim, num_heads=num_heads, dim_head=head_dim, mlp_ratio=mlp_ratio)
|
915 |
+
self.out_proj = nn.Linear(dim, 2*dim)
|
916 |
+
|
917 |
+
def forward(self, query: torch.Tensor, target: torch.Tensor, target_rel_pos: torch.Tensor) -> torch.Tensor:
|
918 |
+
from einops import rearrange, repeat
|
919 |
+
import torch.utils.checkpoint as checkpoint
|
920 |
+
|
921 |
+
def forward_chunk(query, target, target_rel_pos):
|
922 |
+
new_query = self.xblock1_2(query, target).mean(dim=1, keepdim=True)
|
923 |
+
# new_target = self.xblock2_1(target, query).mean(dim=1, keepdim=True)
|
924 |
+
# new_aggr = new_query + new_target
|
925 |
+
out = self.aggr(new_query, target+target_rel_pos) # (potential delta xyz) (target - center)
|
926 |
+
out = self.out_proj(out)
|
927 |
+
return out
|
928 |
+
|
929 |
+
return checkpoint.checkpoint(forward_chunk, query, target, target_rel_pos)
|