sengerchen commited on
Commit
1bb1365
·
verified ·
1 Parent(s): e4c6df8

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +17 -0
  2. .gitignore +48 -0
  3. .gitmodules +3 -0
  4. .pre-commit-config.yaml +22 -0
  5. LICENSE +124 -0
  6. README.md +79 -8
  7. assets/advance/backyard-7_0.jpg +0 -0
  8. assets/advance/backyard-7_1.jpg +0 -0
  9. assets/advance/backyard-7_2.jpg +0 -0
  10. assets/advance/backyard-7_3.jpg +0 -0
  11. assets/advance/backyard-7_4.jpg +0 -0
  12. assets/advance/backyard-7_5.jpg +0 -0
  13. assets/advance/backyard-7_6.jpg +0 -0
  14. assets/advance/blue-car.jpg +3 -0
  15. assets/advance/garden-4_0.jpg +3 -0
  16. assets/advance/garden-4_1.jpg +3 -0
  17. assets/advance/garden-4_2.jpg +3 -0
  18. assets/advance/garden-4_3.jpg +3 -0
  19. assets/advance/telebooth-2_0.jpg +0 -0
  20. assets/advance/telebooth-2_1.jpg +0 -0
  21. assets/advance/vgg-lab-4_0.png +3 -0
  22. assets/advance/vgg-lab-4_1.png +3 -0
  23. assets/advance/vgg-lab-4_2.png +3 -0
  24. assets/advance/vgg-lab-4_3.png +3 -0
  25. assets/basic/blue-car.jpg +3 -0
  26. assets/basic/hilly-countryside.jpg +3 -0
  27. assets/basic/lily-dragon.png +3 -0
  28. assets/basic/llff-room.jpg +0 -0
  29. assets/basic/mountain-lake.jpg +0 -0
  30. assets/basic/vasedeck.jpg +0 -0
  31. assets/basic/vgg-lab-4_0.png +3 -0
  32. benchmark/README.md +156 -0
  33. benchmark/export_reconfusion_example.py +137 -0
  34. demo.py +407 -0
  35. demo_gr.py +1248 -0
  36. docs/CLI_USAGE.md +169 -0
  37. docs/GR_USAGE.md +76 -0
  38. docs/INSTALL.md +39 -0
  39. pyproject.toml +39 -0
  40. seva/__init__.py +0 -0
  41. seva/data_io.py +553 -0
  42. seva/eval.py +1990 -0
  43. seva/geometry.py +811 -0
  44. seva/gui.py +975 -0
  45. seva/model.py +234 -0
  46. seva/modules/__init__.py +0 -0
  47. seva/modules/autoencoder.py +51 -0
  48. seva/modules/conditioner.py +39 -0
  49. seva/modules/layers.py +139 -0
  50. seva/modules/preprocessor.py +116 -0
.gitattributes CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/advance/blue-car.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/advance/garden-4_0.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/advance/garden-4_1.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/advance/garden-4_2.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/advance/garden-4_3.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/advance/vgg-lab-4_0.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/advance/vgg-lab-4_1.png filter=lfs diff=lfs merge=lfs -text
43
+ assets/advance/vgg-lab-4_2.png filter=lfs diff=lfs merge=lfs -text
44
+ assets/advance/vgg-lab-4_3.png filter=lfs diff=lfs merge=lfs -text
45
+ assets/basic/blue-car.jpg filter=lfs diff=lfs merge=lfs -text
46
+ assets/basic/hilly-countryside.jpg filter=lfs diff=lfs merge=lfs -text
47
+ assets/basic/lily-dragon.png filter=lfs diff=lfs merge=lfs -text
48
+ assets/basic/vgg-lab-4_0.png filter=lfs diff=lfs merge=lfs -text
49
+ third_party/dust3r/assets/demo.jpg filter=lfs diff=lfs merge=lfs -text
50
+ third_party/dust3r/assets/matching.jpg filter=lfs diff=lfs merge=lfs -text
51
+ third_party/dust3r/croco/assets/Chateau1.png filter=lfs diff=lfs merge=lfs -text
52
+ third_party/dust3r/croco/assets/Chateau2.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .envrc
2
+ .venv/
3
+ .gradio/
4
+ work_dirs*
5
+
6
+ # Byte-compiled files
7
+ __pycache__/
8
+ *.py[cod]
9
+
10
+ # Virtual environments
11
+ env/
12
+ venv/
13
+ ENV/
14
+ .VENV/
15
+
16
+ # Distribution files
17
+ build/
18
+ dist/
19
+ *.egg-info/
20
+
21
+ # Logs and temporary files
22
+ *.log
23
+ *.tmp
24
+ *.bak
25
+ *.swp
26
+
27
+ # IDE files
28
+ .idea/
29
+ .vscode/
30
+ *.sublime-workspace
31
+ *.sublime-project
32
+
33
+ # OS files
34
+ .DS_Store
35
+ Thumbs.db
36
+
37
+ # Testing and coverage
38
+ htmlcov/
39
+ .coverage
40
+ *.cover
41
+ *.py,cover
42
+ .cache/
43
+
44
+ # Jupyter Notebook checkpoints
45
+ .ipynb_checkpoints/
46
+
47
+ # Pre-commit hooks
48
+ .pre-commit-config.yaml~
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "third_party/dust3r"]
2
+ path = third_party/dust3r
3
+ url = https://github.com/jensenstability/dust3r
.pre-commit-config.yaml ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ default_language_version:
2
+ python: python3
3
+ default_stages: [pre-commit]
4
+ repos:
5
+ - repo: https://github.com/pre-commit/pre-commit-hooks
6
+ rev: v5.0.0
7
+ hooks:
8
+ - id: trailing-whitespace
9
+ - id: end-of-file-fixer
10
+ - repo: https://github.com/charliermarsh/ruff-pre-commit
11
+ rev: v0.8.3
12
+ hooks:
13
+ - id: ruff
14
+ types_or: [python, pyi, jupyter]
15
+ args: [--fix, --extend-ignore=E402]
16
+ - id: ruff-format
17
+ types_or: [python, pyi, jupyter]
18
+ - repo: https://github.com/pre-commit/mirrors-prettier
19
+ rev: v3.1.0
20
+ hooks:
21
+ - id: prettier
22
+ types_or: [markdown]
LICENSE ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Stability AI Non-Commercial License Agreement
2
+ Last Updated: February 20, 2025
3
+
4
+ I. INTRODUCTION
5
+
6
+ This Stability AI Non-Commercial License Agreement (the “Agreement”) applies to any individual person or entity
7
+ (“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or
8
+ Derivative Works thereof for any Research & Non-Commercial use. Capitalized terms not otherwise defined herein
9
+ are defined in Section IV below.
10
+
11
+ This Agreement is intended to allow research and non-commercial uses of the Model free of charge.
12
+
13
+ By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials
14
+ or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement.
15
+
16
+ If You are acting on behalf of a company, organization, or other entity, then “You” includes you and that entity,
17
+ and You agree that You:
18
+ (i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and
19
+ (ii) You agree to the terms of this Agreement on that entity’s behalf.
20
+
21
+ ---
22
+
23
+ II. RESEARCH & NON-COMMERCIAL USE LICENSE
24
+
25
+ Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable,
26
+ non-sublicensable, revocable, and royalty-free limited license under Stability AI’s intellectual property or other
27
+ rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create
28
+ Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose.
29
+
30
+ - **“Research Purpose”** means academic or scientific advancement, and in each case, is not primarily intended
31
+ for commercial advantage or monetary compensation to You or others.
32
+ - **“Non-Commercial Purpose”** means any purpose other than a Research Purpose that is not primarily intended
33
+ for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist)
34
+ or evaluation and testing.
35
+
36
+ ---
37
+
38
+ III. GENERAL TERMS
39
+
40
+ Your Research or Non-Commercial license under this Agreement is subject to the following terms.
41
+
42
+ ### a. Distribution & Attribution
43
+ If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product
44
+ or service that uses any portion of them, You shall:
45
+ 1. Provide a copy of this Agreement to that third party.
46
+ 2. Retain the following attribution notice within a **"Notice"** text file distributed as a part of such copies:
47
+
48
+ **"This Stability AI Model is licensed under the Stability AI Non-Commercial License,
49
+ Copyright © Stability AI Ltd. All Rights Reserved."**
50
+
51
+ 3. Prominently display **“Powered by Stability AI”** on a related website, user interface, blog post,
52
+ about page, or product documentation.
53
+ 4. If You create a Derivative Work, You may add your own attribution notice(s) to the **"Notice"** text file
54
+ included with that Derivative Work, provided that You clearly indicate which attributions apply to the
55
+ Stability AI Materials and state in the **"Notice"** text file that You changed the Stability AI Materials
56
+ and how it was modified.
57
+
58
+ ### b. Use Restrictions
59
+ Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability
60
+ AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control
61
+ Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby
62
+ incorporated by reference.
63
+
64
+ Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the
65
+ Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model
66
+ (excluding the Model or Derivative Works).
67
+
68
+ ### c. Intellectual Property
69
+
70
+ #### (i) Trademark License
71
+ No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials
72
+ or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of
73
+ its Affiliates, except as required under Section IV(a) herein.
74
+
75
+ #### (ii) Ownership of Derivative Works
76
+ As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s
77
+ ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
78
+
79
+ #### (iii) Ownership of Outputs
80
+ As between You and Stability AI, You own any outputs generated from the Model or Derivative Works to the extent
81
+ permitted by applicable law.
82
+
83
+ #### (iv) Disputes
84
+ If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a
85
+ cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works, or
86
+ associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual
87
+ property or other rights owned or licensable by You, then any licenses granted to You under this Agreement
88
+ shall terminate as of the date such litigation or claim is filed or instituted.
89
+
90
+ You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out
91
+ of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of
92
+ this Agreement.
93
+
94
+ #### (v) Feedback
95
+ From time to time, You may provide Stability AI with verbal and/or written suggestions, comments, or other
96
+ feedback related to Stability AI’s existing or prospective technology, products, or services (collectively,
97
+ “Feedback”).
98
+
99
+ You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant
100
+ Stability AI a **perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive,
101
+ worldwide right and license** to exploit the Feedback in any manner without restriction.
102
+
103
+ Your Feedback is provided **“AS IS”** and You make no warranties whatsoever about any Feedback.
104
+
105
+ ---
106
+
107
+ IV. DEFINITIONS
108
+
109
+ - **“Affiliate(s)”** means any entity that directly or indirectly controls, is controlled by, or is under common
110
+ control with the subject entity. For purposes of this definition, “control” means direct or indirect ownership
111
+ or control of more than 50% of the voting interests of the subject entity.
112
+ - **“AUP”** means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may
113
+ be updated from time to time.
114
+ - **"Derivative Work(s)"** means:
115
+ (a) Any derivative work of the Stability AI Materials as recognized by U.S. copyright laws.
116
+ (b) Any modifications to a Model, and any other model created which is based on or derived from the Model or
117
+ the Model’s output, including **fine-tune** and **low-rank adaptation** models derived from a Model or
118
+ a Model’s output, but does not include the output of any Model.
119
+ - **“Model”** means Stability AI’s Stable Virtual Camera model.
120
+ - **"Stability AI" or "we"** means Stability AI Ltd. and its Affiliates.
121
+ - **"Software"** means Stability AI’s proprietary software made available under this Agreement now or in the future.
122
+ - **“Stability AI Materials”** means, collectively, Stability’s proprietary Model, Software, and Documentation
123
+ (and any portion or combination thereof) made available under this Agreement.
124
+ - **“Trade Control Laws”** means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
README.md CHANGED
@@ -1,12 +1,83 @@
1
  ---
2
- title: Stable Virtual Camera
3
- emoji: 💻
4
- colorFrom: green
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 5.22.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: stable-virtual-camera
3
+ app_file: demo_gr.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.20.1
 
 
6
  ---
7
+ # Stable Virtual Camera
8
 
9
+ <a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
10
+ <a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a>
11
+ <a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a>
12
+ <a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
13
+ <a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
14
+ <a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a>
15
+
16
+ `Stable Virtual Camera (Seva)` is a 1.3B generalist diffusion model for Novel View Synthesis (NVS), generating 3D consistent novel views of a scene, given any number of input views and target cameras.
17
+
18
+ # :tada: News
19
+
20
+ - March 2025 - `Stable Virtual Camera` is out everywhere.
21
+
22
+ # :wrench: Installation
23
+
24
+ ```bash
25
+ git clone --recursive https://github.com/Stability-AI/stable-virtual-camera
26
+ cd stable-virtual-camera
27
+ pip install -e .
28
+ ```
29
+
30
+ Please note that you will need `python>=3.10` and `torch>=2.6.0`.
31
+
32
+ Check [INSTALL.md](docs/INSTALL.md) for other dependencies if you want to use our demos or develop from this repo.
33
+ For windows users, please use WSL as flash attention isn't supported on native Windows [yet](https://github.com/pytorch/pytorch/issues/108175).
34
+
35
+ # :open_book: Usage
36
+
37
+ You need to properly authenticate with Hugging Face to download our model weights. Once set up, our code will handle it automatically at your first run. You can authenticate by running
38
+
39
+ ```bash
40
+ # This will prompt you to enter your Hugging Face credentials.
41
+ huggingface-cli login
42
+ ```
43
+
44
+ Once authenticated, go to our model card [here](https://huggingface.co/stabilityai/stable-virtual-camera) and enter your information for access.
45
+
46
+ We provide two demos for you to interative with `Stable Virtual Camera`.
47
+
48
+ ### :rocket: Gradio demo
49
+
50
+ This gradio demo is a GUI interface that requires no expertised knowledge, suitable for general users. Simply run
51
+
52
+ ```bash
53
+ python demo_gr.py
54
+ ```
55
+
56
+ For a more detailed guide, follow [GR_USAGE.md](docs/GR_USAGE.md).
57
+
58
+ ### :computer: CLI demo
59
+
60
+ This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An examplar command line looks as simple as
61
+
62
+ ```bash
63
+ python demo.py --data_path <data_path> [additional arguments]
64
+ ```
65
+
66
+ For a more detailed guide, follow [CLI_USAGE.md](docs/CLI_USAGE.md).
67
+
68
+ For users interested in benchmarking NVS models using command lines, check [`benchmark`](benchmark/) containing the details about scenes, splits, and input/target views we reported in the <a href="http://arxiv.org/abs/2503.14489">paper</a>.
69
+
70
+ # :books: Citing
71
+
72
+ If you find this repository useful, please consider giving a star :star: and citation.
73
+
74
+ ```
75
+ @article{zhou2025stable,
76
+ title={Stable Virtual Camera: Generative View Synthesis with Diffusion Models},
77
+ author={Jensen (Jinghao) Zhou and Hang Gao and Vikram Voleti and Aaryaman Vasishta and Chun-Han Yao and Mark Boss and
78
+ Philip Torr and Christian Rupprecht and Varun Jampani
79
+ },
80
+ journal={arXiv preprint},
81
+ year={2025}
82
+ }
83
+ ```
assets/advance/backyard-7_0.jpg ADDED
assets/advance/backyard-7_1.jpg ADDED
assets/advance/backyard-7_2.jpg ADDED
assets/advance/backyard-7_3.jpg ADDED
assets/advance/backyard-7_4.jpg ADDED
assets/advance/backyard-7_5.jpg ADDED
assets/advance/backyard-7_6.jpg ADDED
assets/advance/blue-car.jpg ADDED

Git LFS Details

  • SHA256: 0cf493d0f738830223949fd24bb3ab0a1c078804fdb744efa95a1fdcfcfb5332
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
assets/advance/garden-4_0.jpg ADDED

Git LFS Details

  • SHA256: 38fbe78f699fc84a1f4268ef8bacef9ddacfd32e9eb8fbcb605e46cfd52b988e
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/advance/garden-4_1.jpg ADDED

Git LFS Details

  • SHA256: 1975effeffc9b2011a28f6eb04d1b0bd2f37f765c194249c95e6b3783d698a42
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/advance/garden-4_2.jpg ADDED

Git LFS Details

  • SHA256: 4112ff5f2ceaa3b469bb402853e7cde10396f858e5a2ceba93b095e1e3d8d335
  • Pointer size: 132 Bytes
  • Size of remote file: 1.04 MB
assets/advance/garden-4_3.jpg ADDED

Git LFS Details

  • SHA256: a750b648c389f78f2f6b26d78f753eace13a41d355f725850c2667f864f709cd
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
assets/advance/telebooth-2_0.jpg ADDED
assets/advance/telebooth-2_1.jpg ADDED
assets/advance/vgg-lab-4_0.png ADDED

Git LFS Details

  • SHA256: d1442eb509af02273cf7168f5212b3221142df4db99991b38395f42f8b239960
  • Pointer size: 131 Bytes
  • Size of remote file: 412 kB
assets/advance/vgg-lab-4_1.png ADDED

Git LFS Details

  • SHA256: c2bb10b9574247ceb0948aa00afea588f001f0271f51908b8132d63587fc43d0
  • Pointer size: 131 Bytes
  • Size of remote file: 443 kB
assets/advance/vgg-lab-4_2.png ADDED

Git LFS Details

  • SHA256: 7fa884bb6d783fd9385bd38042f3461f430bec8311e7b2171474b6a906538030
  • Pointer size: 131 Bytes
  • Size of remote file: 410 kB
assets/advance/vgg-lab-4_3.png ADDED

Git LFS Details

  • SHA256: 99469f816604c92c9c27a7cff119cb3649d3dfa4c41dcef89525b7b3cbd885a4
  • Pointer size: 131 Bytes
  • Size of remote file: 475 kB
assets/basic/blue-car.jpg ADDED

Git LFS Details

  • SHA256: 0cf493d0f738830223949fd24bb3ab0a1c078804fdb744efa95a1fdcfcfb5332
  • Pointer size: 131 Bytes
  • Size of remote file: 106 kB
assets/basic/hilly-countryside.jpg ADDED

Git LFS Details

  • SHA256: 4ae3b8cb5d989b62ceaf4930afea55790048657fa459f383f8bd809b3bdcfca0
  • Pointer size: 131 Bytes
  • Size of remote file: 107 kB
assets/basic/lily-dragon.png ADDED

Git LFS Details

  • SHA256: c545057ee2feeced73566f708311bf758350ef0ded844d7bd438e48fca7f5bd2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.57 MB
assets/basic/llff-room.jpg ADDED
assets/basic/mountain-lake.jpg ADDED
assets/basic/vasedeck.jpg ADDED
assets/basic/vgg-lab-4_0.png ADDED

Git LFS Details

  • SHA256: d1442eb509af02273cf7168f5212b3221142df4db99991b38395f42f8b239960
  • Pointer size: 131 Bytes
  • Size of remote file: 412 kB
benchmark/README.md ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :bar_chart: Benchmark
2
+
3
+ We provide <a href="https://github.com/Stability-AI/stable-virtual-camera/releases/tag/benchmark">in this release</a> (`benchmark.zip`) with the following 17 entries as a benchmark to evaluate NVS models.
4
+ We hope this will help standardize the evaluation of NVS models and facilitate fair comparison between different methods.
5
+
6
+ <table>
7
+ <thead>
8
+ <tr>
9
+ <th align="center">Dataset</th>
10
+ <th align="center">Split</th>
11
+ <th align="center">Path</th>
12
+ <th align="center">Content</th>
13
+ <th align="center">Image Preprocessing</th>
14
+ <th align="center">Image Postprocessing</th>
15
+ </tr>
16
+ </thead>
17
+ <tbody>
18
+ <tr>
19
+ <td align="center">OmniObject3D</td>
20
+ <td align="center"><code>S</code> (SV3D), <code>O</code> (Ours) </td>
21
+ <td align="center"><code>omniobject3d</code></td>
22
+ <td align="center"><code>train_test_split_*.json</code></td>
23
+ <td align="center">center crop to 576</td>
24
+ <td align="center">\</td>
25
+ </tr>
26
+ <tr>
27
+ <td align="center">GSO</td>
28
+ <td align="center"><code>S</code> (SV3D), <code>O</code> (Ours) </td>
29
+ <td align="center"><code>gso</code></td>
30
+ <td align="center"><code>train_test_split_*.json</code></td>
31
+ <td align="center">center crop to 576</td>
32
+ <td align="center">\</td>
33
+ </tr>
34
+ <tr>
35
+ <td align="center" rowspan="4">RealEstate10K</td>
36
+ <td align="center"><code>D</code> (4DiM) </td>
37
+ <td align="center"><code>re10k-4dim</code></td>
38
+ <td align="center"><code>train_test_split_*.json</code></td>
39
+ <td align="center">center crop to 576</td>
40
+ <td align="center">resize to 256</td>
41
+ </tr>
42
+ <tr>
43
+ <td align="center"><code>R</code> (ReconFusion) </td>
44
+ <td align="center"><code>re10k</code></td>
45
+ <td align="center"><code>train_test_split_*.json</code></td>
46
+ <td align="center">center crop to 576</td>
47
+ <td align="center">\</td>
48
+ </tr>
49
+ <tr>
50
+ <td align="center"><code>P</code> (pixelSplat) </td>
51
+ <td align="center"><code>re10k-pixelsplat</code></td>
52
+ <td align="center"><code>train_test_split_*.json</code></td>
53
+ <td align="center">center crop to 576</td>
54
+ <td align="center">resize to 256</td>
55
+ </tr>
56
+ <tr>
57
+ <td align="center"><code>V</code> (ViewCrafter) </td>
58
+ <td align="center"><code>re10k-viewcrafter</code></td>
59
+ <td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
60
+ <td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
61
+ <td align="center">center crop</td>
62
+ </tr>
63
+ <tr>
64
+ <td align="center">LLFF</td>
65
+ <td align="center"><code>R</code> (ReconFusion) </td>
66
+ <td align="center"><code>llff</code></td>
67
+ <td align="center"><code>train_test_split_*.json</code></td>
68
+ <td align="center">center crop to 576</td>
69
+ <td align="center">\</td>
70
+ </tr>
71
+ <tr>
72
+ <td align="center">DTU</td>
73
+ <td align="center"><code>R</code> (ReconFusion) </td>
74
+ <td align="center"><code>dtu</code></td>
75
+ <td align="center"><code>train_test_split_*.json</code></td>
76
+ <td align="center">center crop to 576</td>
77
+ <td align="center">\</td>
78
+ </tr>
79
+ <tr>
80
+ <td align="center" rowspan="2">CO3D</td>
81
+ <td align="center"><code>R</code> (ReconFusion) </td>
82
+ <td align="center"><code>co3d</code></td>
83
+ <td align="center"><code>train_test_split_*.json</code></td>
84
+ <td align="center">center crop to 576</td>
85
+ <td align="center">\</td>
86
+ </tr>
87
+ <tr>
88
+ <td align="center"><code>V</code> (ViewCrafter) </td>
89
+ <td align="center"><code>co3d-viewcrafter</code></td>
90
+ <td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
91
+ <td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
92
+ <td align="center">center crop</td>
93
+ </tr>
94
+ <tr>
95
+ <td align="center" rowspan="2" >WildRGB-D</td>
96
+ <td align="center"><code>Oₑ</code> (Ours, easy) </td>
97
+ <td align="center"><code>wildgbd/easy</code></td>
98
+ <td align="center"><code>train_test_split_*.json</code></td>
99
+ <td align="center">center crop to 576</td>
100
+ <td align="center">\</td>
101
+ </tr>
102
+ <tr>
103
+ <td align="center"><code>Oₕ</code> (Ours, hard) </td>
104
+ <td align="center"><code>wildgbd/hard</code></td>
105
+ <td align="center"><code>train_test_split_*.json</code></td>
106
+ <td align="center">center crop to 576</td>
107
+ <td align="center">\</td>
108
+ </tr>
109
+ <tr>
110
+ <td align="center">Mip-NeRF360</td>
111
+ <td align="center"><code>R</code> (ReconFusion) </td>
112
+ <td align="center"><code>mipnerf360</code></td>
113
+ <td align="center"><code>train_test_split_*.json</code></td>
114
+ <td align="center">center crop to 576</td>
115
+ <td align="center">\</td>
116
+ </tr>
117
+ <tr>
118
+ <td align="center" rowspan="2">DL3DV-140</td>
119
+ <td align="center"><code>O</code> (Ours) </td>
120
+ <td align="center"><code>dl3dv10</code></td>
121
+ <td align="center"><code>train_test_split_*.json</code></td>
122
+ <td align="center">center crop to 576</td>
123
+ <td align="center">\</td>
124
+ </tr>
125
+ <tr>
126
+ <td align="center"><code>L</code> (Long-LRM) </td>
127
+ <td align="center"><code>dl3dv140</code></td>
128
+ <td align="center"><code>train_test_split_*.json</code></td>
129
+ <td align="center">center crop to 576</td>
130
+ <td align="center">\</td>
131
+ </tr>
132
+ <tr>
133
+ <td align="center" rowspan="2">Tanks and Temples</td>
134
+ <td align="center"><code>V</code> (ViewCrafter) </td>
135
+ <td align="center"><code>tnt-viewcrafter</code></td>
136
+ <td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
137
+ <td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
138
+ <td align="center">center crop</td>
139
+ </tr>
140
+ <tr>
141
+ <td align="center"><code>L</code> (Long-LRM) </td>
142
+ <td align="center"><code>tnt-longlrm</code></td>
143
+ <td align="center"><code>train_test_split_*.json</code></td>
144
+ <td align="center">center crop to 576</td>
145
+ <td align="center">\</td>
146
+ </tr>
147
+ </tbody>
148
+ </table>
149
+
150
+ - For entries without `images/*.png` and `transforms.json`, we use the images from the original dataset after converting them into the `reconfusion` format, which is then parsable by `ReconfusionParser` (`seva/data_io.py`).
151
+ Please note that during this conversion, you should sort the images by `sorted(image_paths)`, which is then directly indexable by our train/test ids. We provide in `benchmark/export_reconfusion_example.py` an example script converting an existing academic dataset into the the scene folders.
152
+ - For evaluation and benchmarking, we first conduct operations in the `Image Preprocessing` column to the model input and then operations in the `Image Postprocessing` column to the model output. The final processed samples are used for metric computation.
153
+
154
+ ## Acknowledgment
155
+
156
+ We would like to thank Wangbo Yu, Aleksander Hołyński, Saurabh Saxena, and Ziwen Chen for their kind clarification on experiment settings.
benchmark/export_reconfusion_example.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+ try:
9
+ from sklearn.cluster import KMeans # type: ignore[import]
10
+ except ImportError:
11
+ print("Please install sklearn to use this script.")
12
+ exit(1)
13
+
14
+ # Define the folder containing the image and JSON files
15
+ subfolder = "/path/to/your/dataset"
16
+ output_file = os.path.join(subfolder, "transforms.json")
17
+
18
+ # List to hold the frames
19
+ frames = []
20
+
21
+ # Iterate over the files in the folder
22
+ for file in sorted(os.listdir(subfolder)):
23
+ if file.endswith(".json"):
24
+ # Read the JSON file containing camera extrinsics and intrinsics
25
+ json_path = os.path.join(subfolder, file)
26
+ with open(json_path, "r") as f:
27
+ data = json.load(f)
28
+
29
+ # Read the corresponding image file
30
+ image_file = file.replace(".json", ".png")
31
+ image_path = os.path.join(subfolder, image_file)
32
+ if not os.path.exists(image_path):
33
+ print(f"Image file not found for {file}, skipping...")
34
+ continue
35
+ with Image.open(image_path) as img:
36
+ w, h = img.size
37
+
38
+ # Extract and normalize intrinsic matrix K
39
+ K = data["K"]
40
+ fx = K[0][0] * w
41
+ fy = K[1][1] * h
42
+ cx = K[0][2] * w
43
+ cy = K[1][2] * h
44
+
45
+ # Extract the transformation matrix
46
+ transform_matrix = np.array(data["c2w"])
47
+ # Adjust for OpenGL convention
48
+ transform_matrix[..., [1, 2]] *= -1
49
+
50
+ # Add the frame data
51
+ frames.append(
52
+ {
53
+ "fl_x": fx,
54
+ "fl_y": fy,
55
+ "cx": cx,
56
+ "cy": cy,
57
+ "w": w,
58
+ "h": h,
59
+ "file_path": f"./{os.path.relpath(image_path, subfolder)}",
60
+ "transform_matrix": transform_matrix.tolist(),
61
+ }
62
+ )
63
+
64
+ # Create the output dictionary
65
+ transforms_data = {"orientation_override": "none", "frames": frames}
66
+
67
+ # Write to the transforms.json file
68
+ with open(output_file, "w") as f:
69
+ json.dump(transforms_data, f, indent=4)
70
+
71
+ print(f"transforms.json generated at {output_file}")
72
+
73
+
74
+ # Train-test split function using K-means clustering with stride
75
+ def create_train_test_split(frames, n, output_path, stride):
76
+ # Prepare the data for K-means
77
+ positions = []
78
+ for frame in frames:
79
+ transform_matrix = np.array(frame["transform_matrix"])
80
+ position = transform_matrix[:3, 3] # 3D camera position
81
+ direction = transform_matrix[:3, 2] / np.linalg.norm(
82
+ transform_matrix[:3, 2]
83
+ ) # Normalized 3D direction
84
+ positions.append(np.concatenate([position, direction]))
85
+
86
+ positions = np.array(positions)
87
+
88
+ # Apply K-means clustering
89
+ kmeans = KMeans(n_clusters=n, random_state=42)
90
+ kmeans.fit(positions)
91
+ centers = kmeans.cluster_centers_
92
+
93
+ # Find the index closest to each cluster center
94
+ train_ids = []
95
+ for center in centers:
96
+ distances = np.linalg.norm(positions - center, axis=1)
97
+ train_ids.append(int(np.argmin(distances))) # Convert to Python int
98
+
99
+ # Remaining indices as test_ids, applying stride
100
+ all_indices = set(range(len(frames)))
101
+ remaining_indices = sorted(all_indices - set(train_ids))
102
+ test_ids = [
103
+ int(idx) for idx in remaining_indices[::stride]
104
+ ] # Convert to Python int
105
+
106
+ # Create the split data
107
+ split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids}
108
+
109
+ with open(output_path, "w") as f:
110
+ json.dump(split_data, f, indent=4)
111
+
112
+ print(f"Train-test split file generated at {output_path}")
113
+
114
+
115
+ # Parse arguments
116
+ if __name__ == "__main__":
117
+ parser = argparse.ArgumentParser(
118
+ description="Generate train-test split JSON file using K-means clustering."
119
+ )
120
+ parser.add_argument(
121
+ "--n",
122
+ type=int,
123
+ required=True,
124
+ help="Number of frames to include in the training set.",
125
+ )
126
+ parser.add_argument(
127
+ "--stride",
128
+ type=int,
129
+ default=1,
130
+ help="Stride for selecting test frames (not used with K-means).",
131
+ )
132
+
133
+ args = parser.parse_args()
134
+
135
+ # Create train-test split
136
+ train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json")
137
+ create_train_test_split(frames, args.n, train_test_split_path, args.stride)
demo.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import os.path as osp
4
+
5
+ import fire
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ from seva.data_io import get_parser
13
+ from seva.eval import (
14
+ IS_TORCH_NIGHTLY,
15
+ compute_relative_inds,
16
+ create_transforms_simple,
17
+ infer_prior_inds,
18
+ infer_prior_stats,
19
+ run_one_scene,
20
+ )
21
+ from seva.geometry import (
22
+ generate_interpolated_path,
23
+ generate_spiral_path,
24
+ get_arc_horizontal_w2cs,
25
+ get_default_intrinsics,
26
+ get_lookat,
27
+ get_preset_pose_fov,
28
+ )
29
+ from seva.model import SGMWrapper
30
+ from seva.modules.autoencoder import AutoEncoder
31
+ from seva.modules.conditioner import CLIPConditioner
32
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
33
+ from seva.utils import load_model
34
+
35
+ device = "cuda:0"
36
+
37
+
38
+ # Constants.
39
+ WORK_DIR = "work_dirs/demo"
40
+
41
+ if IS_TORCH_NIGHTLY:
42
+ COMPILE = True
43
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
44
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
45
+ else:
46
+ COMPILE = False
47
+
48
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
49
+ AE = AutoEncoder(chunk_size=1).to(device)
50
+ CONDITIONER = CLIPConditioner().to(device)
51
+ DISCRETIZATION = DDPMDiscretization()
52
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
53
+ VERSION_DICT = {
54
+ "H": 576,
55
+ "W": 576,
56
+ "T": 21,
57
+ "C": 4,
58
+ "f": 8,
59
+ "options": {},
60
+ }
61
+
62
+ if COMPILE:
63
+ MODEL = torch.compile(MODEL, dynamic=False)
64
+ CONDITIONER = torch.compile(CONDITIONER, dynamic=False)
65
+ AE = torch.compile(AE, dynamic=False)
66
+
67
+
68
+ def parse_task(
69
+ task,
70
+ scene,
71
+ num_inputs,
72
+ T,
73
+ version_dict,
74
+ ):
75
+ options = version_dict["options"]
76
+
77
+ anchor_indices = None
78
+ anchor_c2ws = None
79
+ anchor_Ks = None
80
+
81
+ if task == "img2trajvid_s-prob":
82
+ if num_inputs is not None:
83
+ assert (
84
+ num_inputs == 1
85
+ ), "Task `img2trajvid_s-prob` only support 1-view conditioning..."
86
+ else:
87
+ num_inputs = 1
88
+ num_targets = options.get("num_targets", T - 1)
89
+ num_anchors = infer_prior_stats(
90
+ T,
91
+ num_inputs,
92
+ num_total_frames=num_targets,
93
+ version_dict=version_dict,
94
+ )
95
+
96
+ input_indices = [0]
97
+ anchor_indices = np.linspace(1, num_targets, num_anchors).tolist()
98
+
99
+ all_imgs_path = [scene] + [None] * num_targets
100
+
101
+ c2ws, fovs = get_preset_pose_fov(
102
+ option=options.get("traj_prior", "orbit"),
103
+ num_frames=num_targets + 1,
104
+ start_w2c=torch.eye(4),
105
+ look_at=torch.Tensor([0, 0, 10]),
106
+ )
107
+
108
+ with Image.open(scene) as img:
109
+ W, H = img.size
110
+ aspect_ratio = W / H
111
+ Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) # unormalized
112
+ Ks[:, :2] *= (
113
+ torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1)
114
+ ) # normalized
115
+ Ks = Ks.numpy()
116
+
117
+ anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]]
118
+ anchor_Ks = Ks[[round(ind) for ind in anchor_indices]]
119
+
120
+ else:
121
+ parser = get_parser(
122
+ parser_type="reconfusion",
123
+ data_dir=scene,
124
+ normalize=False,
125
+ )
126
+ all_imgs_path = parser.image_paths
127
+ c2ws = parser.camtoworlds
128
+ camera_ids = parser.camera_ids
129
+ Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0)
130
+
131
+ if num_inputs is None:
132
+ assert len(parser.splits_per_num_input_frames.keys()) == 1
133
+ num_inputs = list(parser.splits_per_num_input_frames.keys())[0]
134
+ split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
135
+ elif isinstance(num_inputs, str):
136
+ split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
137
+ num_inputs = int(num_inputs.split("-")[0]) # for example 1_from32
138
+ else:
139
+ split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
140
+
141
+ num_targets = len(split_dict["test_ids"])
142
+
143
+ if task == "img2img":
144
+ # Note in this setting, we should refrain from using all the other camera
145
+ # info except ones from sampled_indices, and most importantly, the order.
146
+ num_anchors = infer_prior_stats(
147
+ T,
148
+ num_inputs,
149
+ num_total_frames=num_targets,
150
+ version_dict=version_dict,
151
+ )
152
+
153
+ sampled_indices = np.sort(
154
+ np.array(split_dict["train_ids"] + split_dict["test_ids"])
155
+ ) # we always sort all indices first
156
+
157
+ traj_prior = options.get("traj_prior", None)
158
+ if traj_prior == "spiral":
159
+ assert parser.bounds is not None
160
+ anchor_c2ws = generate_spiral_path(
161
+ c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]),
162
+ parser.bounds[sampled_indices],
163
+ n_frames=num_anchors + 1,
164
+ n_rots=2,
165
+ zrate=0.5,
166
+ endpoint=False,
167
+ )[1:] @ np.diagflat([1, -1, -1, 1])
168
+ elif traj_prior == "interpolated":
169
+ assert num_inputs > 1
170
+ anchor_c2ws = generate_interpolated_path(
171
+ c2ws[split_dict["train_ids"], :3],
172
+ round((num_anchors + 1) / (num_inputs - 1)),
173
+ endpoint=False,
174
+ )[1 : num_anchors + 1]
175
+ elif traj_prior == "orbit":
176
+ c2ws_th = torch.as_tensor(c2ws)
177
+ lookat = get_lookat(
178
+ c2ws_th[sampled_indices, :3, 3],
179
+ c2ws_th[sampled_indices, :3, 2],
180
+ )
181
+ anchor_c2ws = torch.linalg.inv(
182
+ get_arc_horizontal_w2cs(
183
+ torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]),
184
+ lookat,
185
+ -F.normalize(
186
+ c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0),
187
+ dim=-1,
188
+ ),
189
+ num_frames=num_anchors + 1,
190
+ endpoint=False,
191
+ )
192
+ ).numpy()[1:, :3]
193
+ else:
194
+ anchor_c2ws = None
195
+ # anchor_Ks is default to be the first from target_Ks
196
+
197
+ all_imgs_path = [all_imgs_path[i] for i in sampled_indices]
198
+ c2ws = c2ws[sampled_indices]
199
+ Ks = Ks[sampled_indices]
200
+
201
+ # absolute to relative indices
202
+ input_indices = compute_relative_inds(
203
+ sampled_indices,
204
+ np.array(split_dict["train_ids"]),
205
+ )
206
+ anchor_indices = np.arange(
207
+ sampled_indices.shape[0],
208
+ sampled_indices.shape[0] + num_anchors,
209
+ ).tolist() # the order has no meaning here
210
+
211
+ elif task == "img2vid":
212
+ num_targets = len(all_imgs_path) - num_inputs
213
+ num_anchors = infer_prior_stats(
214
+ T,
215
+ num_inputs,
216
+ num_total_frames=num_targets,
217
+ version_dict=version_dict,
218
+ )
219
+
220
+ input_indices = split_dict["train_ids"]
221
+ anchor_indices = infer_prior_inds(
222
+ c2ws,
223
+ num_prior_frames=num_anchors,
224
+ input_frame_indices=input_indices,
225
+ options=options,
226
+ ).tolist()
227
+ num_anchors = len(anchor_indices)
228
+ anchor_c2ws = c2ws[anchor_indices, :3]
229
+ anchor_Ks = Ks[anchor_indices]
230
+
231
+ elif task == "img2trajvid":
232
+ num_anchors = infer_prior_stats(
233
+ T,
234
+ num_inputs,
235
+ num_total_frames=num_targets,
236
+ version_dict=version_dict,
237
+ )
238
+
239
+ target_c2ws = c2ws[split_dict["test_ids"], :3]
240
+ target_Ks = Ks[split_dict["test_ids"]]
241
+ anchor_c2ws = target_c2ws[
242
+ np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64)
243
+ ]
244
+ anchor_Ks = target_Ks[
245
+ np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64)
246
+ ]
247
+
248
+ sampled_indices = split_dict["train_ids"] + split_dict["test_ids"]
249
+ all_imgs_path = [all_imgs_path[i] for i in sampled_indices]
250
+ c2ws = c2ws[sampled_indices]
251
+ Ks = Ks[sampled_indices]
252
+
253
+ input_indices = np.arange(num_inputs).tolist()
254
+ anchor_indices = np.linspace(
255
+ num_inputs, num_inputs + num_targets - 1, num_anchors
256
+ ).tolist()
257
+
258
+ else:
259
+ raise ValueError(f"Unknown task: {task}")
260
+
261
+ return (
262
+ all_imgs_path,
263
+ num_inputs,
264
+ num_targets,
265
+ input_indices,
266
+ anchor_indices,
267
+ torch.tensor(c2ws[:, :3]).float(),
268
+ torch.tensor(Ks).float(),
269
+ (torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None),
270
+ (torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None),
271
+ )
272
+
273
+
274
+ def main(
275
+ data_path,
276
+ data_items=None,
277
+ task="img2img",
278
+ save_subdir="",
279
+ H=None,
280
+ W=None,
281
+ T=None,
282
+ use_traj_prior=False,
283
+ **overwrite_options,
284
+ ):
285
+ if H is not None:
286
+ VERSION_DICT["H"] = H
287
+ if W is not None:
288
+ VERSION_DICT["W"] = W
289
+ if T is not None:
290
+ VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T
291
+
292
+ options = VERSION_DICT["options"]
293
+ options["chunk_strategy"] = "nearest-gt"
294
+ options["video_save_fps"] = 30.0
295
+ options["beta_linear_start"] = 5e-6
296
+ options["log_snr_shift"] = 2.4
297
+ options["guider_types"] = 1
298
+ options["cfg"] = 2.0
299
+ options["camera_scale"] = 2.0
300
+ options["num_steps"] = 50
301
+ options["cfg_min"] = 1.2
302
+ options["encoding_t"] = 1
303
+ options["decoding_t"] = 1
304
+ options["num_inputs"] = None
305
+ options["seed"] = 23
306
+ options.update(overwrite_options)
307
+
308
+ num_inputs = options["num_inputs"]
309
+ seed = options["seed"]
310
+
311
+ if data_items is not None:
312
+ if not isinstance(data_items, (list, tuple)):
313
+ data_items = data_items.split(",")
314
+ scenes = [os.path.join(data_path, item) for item in data_items]
315
+ else:
316
+ scenes = glob.glob(osp.join(data_path, "*"))
317
+
318
+ for scene in tqdm(scenes):
319
+ save_path_scene = os.path.join(
320
+ WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0]
321
+ )
322
+ if options.get("skip_saved", False) and os.path.exists(
323
+ os.path.join(save_path_scene, "transforms.json")
324
+ ):
325
+ print(f"Skipping {scene} as it is already sampled.")
326
+ continue
327
+
328
+ # parse_task -> infer_prior_stats modifies VERSION_DICT["T"] in-place.
329
+ (
330
+ all_imgs_path,
331
+ num_inputs,
332
+ num_targets,
333
+ input_indices,
334
+ anchor_indices,
335
+ c2ws,
336
+ Ks,
337
+ anchor_c2ws,
338
+ anchor_Ks,
339
+ ) = parse_task(
340
+ task,
341
+ scene,
342
+ num_inputs,
343
+ VERSION_DICT["T"],
344
+ VERSION_DICT,
345
+ )
346
+ assert num_inputs is not None
347
+ # Create image conditioning.
348
+ image_cond = {
349
+ "img": all_imgs_path,
350
+ "input_indices": input_indices,
351
+ "prior_indices": anchor_indices,
352
+ }
353
+ # Create camera conditioning.
354
+ camera_cond = {
355
+ "c2w": c2ws.clone(),
356
+ "K": Ks.clone(),
357
+ "input_indices": list(range(num_inputs + num_targets)),
358
+ }
359
+ # run_one_scene -> transform_img_and_K modifies VERSION_DICT["H"] and VERSION_DICT["W"] in-place.
360
+ video_path_generator = run_one_scene(
361
+ task,
362
+ VERSION_DICT, # H, W maybe updated in run_one_scene
363
+ model=MODEL,
364
+ ae=AE,
365
+ conditioner=CONDITIONER,
366
+ denoiser=DENOISER,
367
+ image_cond=image_cond,
368
+ camera_cond=camera_cond,
369
+ save_path=save_path_scene,
370
+ use_traj_prior=use_traj_prior,
371
+ traj_prior_Ks=anchor_Ks,
372
+ traj_prior_c2ws=anchor_c2ws,
373
+ seed=seed, # to ensure sampled video can be reproduced in regardless of start and i
374
+ )
375
+ for _ in video_path_generator:
376
+ pass
377
+
378
+ # Convert from OpenCV to OpenGL camera format.
379
+ c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float()
380
+ img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png")))
381
+ if len(img_paths) != len(c2ws):
382
+ input_img_paths = sorted(
383
+ glob.glob(osp.join(save_path_scene, "input", "*.png"))
384
+ )
385
+ assert len(img_paths) == num_targets
386
+ assert len(input_img_paths) == num_inputs
387
+ assert c2ws.shape[0] == num_inputs + num_targets
388
+ target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices]
389
+ img_paths = [
390
+ input_img_paths[input_indices.index(i)]
391
+ if i in input_indices
392
+ else img_paths[target_indices.index(i)]
393
+ for i in range(c2ws.shape[0])
394
+ ]
395
+ create_transforms_simple(
396
+ save_path=save_path_scene,
397
+ img_paths=img_paths,
398
+ img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat(
399
+ num_inputs + num_targets, 0
400
+ ),
401
+ c2ws=c2ws,
402
+ Ks=Ks,
403
+ )
404
+
405
+
406
+ if __name__ == "__main__":
407
+ fire.Fire(main)
demo_gr.py ADDED
@@ -0,0 +1,1248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import os.path as osp
5
+ import queue
6
+ import secrets
7
+ import threading
8
+ import time
9
+ from datetime import datetime
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import Literal
13
+
14
+ import gradio as gr
15
+ import httpx
16
+ import imageio.v3 as iio
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn.functional as F
20
+ import tyro
21
+ import viser
22
+ import viser.transforms as vt
23
+ from einops import rearrange
24
+ from gradio import networking
25
+ from gradio.context import LocalContext
26
+ from gradio.tunneling import CERTIFICATE_PATH, Tunnel
27
+
28
+ from seva.eval import (
29
+ IS_TORCH_NIGHTLY,
30
+ chunk_input_and_test,
31
+ create_transforms_simple,
32
+ infer_prior_stats,
33
+ run_one_scene,
34
+ transform_img_and_K,
35
+ )
36
+ from seva.geometry import (
37
+ DEFAULT_FOV_RAD,
38
+ get_default_intrinsics,
39
+ get_preset_pose_fov,
40
+ normalize_scene,
41
+ )
42
+ from seva.gui import define_gui
43
+ from seva.model import SGMWrapper
44
+ from seva.modules.autoencoder import AutoEncoder
45
+ from seva.modules.conditioner import CLIPConditioner
46
+ from seva.modules.preprocessor import Dust3rPipeline
47
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
48
+ from seva.utils import load_model
49
+
50
+ device = "cpu"
51
+
52
+
53
+ # Constants.
54
+ WORK_DIR = "work_dirs/demo_gr"
55
+ MAX_SESSIONS = 1
56
+ ADVANCE_EXAMPLE_MAP = [
57
+ (
58
+ "assets/advance/blue-car.jpg",
59
+ ["assets/advance/blue-car.jpg"],
60
+ ),
61
+ (
62
+ "assets/advance/garden-4_0.jpg",
63
+ [
64
+ "assets/advance/garden-4_0.jpg",
65
+ "assets/advance/garden-4_1.jpg",
66
+ "assets/advance/garden-4_2.jpg",
67
+ "assets/advance/garden-4_3.jpg",
68
+ ],
69
+ ),
70
+ (
71
+ "assets/advance/vgg-lab-4_0.png",
72
+ [
73
+ "assets/advance/vgg-lab-4_0.png",
74
+ "assets/advance/vgg-lab-4_1.png",
75
+ "assets/advance/vgg-lab-4_2.png",
76
+ "assets/advance/vgg-lab-4_3.png",
77
+ ],
78
+ ),
79
+ (
80
+ "assets/advance/telebooth-2_0.jpg",
81
+ [
82
+ "assets/advance/telebooth-2_0.jpg",
83
+ "assets/advance/telebooth-2_1.jpg",
84
+ ],
85
+ ),
86
+ (
87
+ "assets/advance/backyard-7_0.jpg",
88
+ [
89
+ "assets/advance/backyard-7_0.jpg",
90
+ "assets/advance/backyard-7_1.jpg",
91
+ "assets/advance/backyard-7_2.jpg",
92
+ "assets/advance/backyard-7_3.jpg",
93
+ "assets/advance/backyard-7_4.jpg",
94
+ "assets/advance/backyard-7_5.jpg",
95
+ "assets/advance/backyard-7_6.jpg",
96
+ ],
97
+ ),
98
+ ]
99
+
100
+ if IS_TORCH_NIGHTLY:
101
+ COMPILE = True
102
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
103
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
104
+ else:
105
+ COMPILE = False
106
+
107
+ # Shared global variables across sessions.
108
+ DUST3R = Dust3rPipeline(device=device) # type: ignore
109
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
110
+ AE = AutoEncoder(chunk_size=1).to(device)
111
+ CONDITIONER = CLIPConditioner().to(device)
112
+ DISCRETIZATION = DDPMDiscretization()
113
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
114
+ VERSION_DICT = {
115
+ "H": 576,
116
+ "W": 576,
117
+ "T": 21,
118
+ "C": 4,
119
+ "f": 8,
120
+ "options": {},
121
+ }
122
+ SERVERS = {}
123
+ ABORT_EVENTS = {}
124
+
125
+ if COMPILE:
126
+ MODEL = torch.compile(MODEL)
127
+ CONDITIONER = torch.compile(CONDITIONER)
128
+ AE = torch.compile(AE)
129
+
130
+
131
+ class SevaRenderer(object):
132
+ def __init__(self, server: viser.ViserServer):
133
+ self.server = server
134
+ self.gui_state = None
135
+
136
+ def preprocess(
137
+ self, input_img_path_or_tuples: list[tuple[str, None]] | str
138
+ ) -> tuple[dict, dict, dict]:
139
+ # Simply hardcode these such that aspect ratio is always kept and
140
+ # shorter side is resized to 576. This is only to make GUI option fewer
141
+ # though, changing it still works.
142
+ shorter: int = 576
143
+ # Has to be 64 multiple for the network.
144
+ shorter = round(shorter / 64) * 64
145
+
146
+ if isinstance(input_img_path_or_tuples, str):
147
+ # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
148
+ input_imgs = torch.as_tensor(
149
+ iio.imread(input_img_path_or_tuples) / 255.0, dtype=torch.float32
150
+ )[None, ..., :3]
151
+ input_imgs = transform_img_and_K(
152
+ input_imgs.permute(0, 3, 1, 2),
153
+ shorter,
154
+ K=None,
155
+ size_stride=64,
156
+ )[0].permute(0, 2, 3, 1)
157
+ input_Ks = get_default_intrinsics(
158
+ aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
159
+ )
160
+ input_c2ws = torch.eye(4)[None]
161
+ # Simulate a small time interval such that gradio can update
162
+ # propgress properly.
163
+ time.sleep(0.1)
164
+ return (
165
+ {
166
+ "input_imgs": input_imgs,
167
+ "input_Ks": input_Ks,
168
+ "input_c2ws": input_c2ws,
169
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
170
+ "points": [np.zeros((0, 3))],
171
+ "point_colors": [np.zeros((0, 3))],
172
+ "scene_scale": 1.0,
173
+ },
174
+ gr.update(visible=False),
175
+ gr.update(),
176
+ )
177
+ else:
178
+ # Assume `Advance` demo mode: use dust3r to extract camera parameters and points.
179
+ img_paths = [p for (p, _) in input_img_path_or_tuples]
180
+ (
181
+ input_imgs,
182
+ input_Ks,
183
+ input_c2ws,
184
+ points,
185
+ point_colors,
186
+ ) = DUST3R.infer_cameras_and_points(img_paths)
187
+ num_inputs = len(img_paths)
188
+ if num_inputs == 1:
189
+ input_imgs, input_Ks, input_c2ws, points, point_colors = (
190
+ input_imgs[:1],
191
+ input_Ks[:1],
192
+ input_c2ws[:1],
193
+ points[:1],
194
+ point_colors[:1],
195
+ )
196
+ input_imgs = [img[..., :3] for img in input_imgs]
197
+ # Normalize the scene.
198
+ point_chunks = [p.shape[0] for p in points]
199
+ point_indices = np.cumsum(point_chunks)[:-1]
200
+ input_c2ws, points, _ = normalize_scene( # type: ignore
201
+ input_c2ws,
202
+ np.concatenate(points, 0),
203
+ camera_center_method="poses",
204
+ )
205
+ points = np.split(points, point_indices, 0)
206
+ # Scale camera and points for viewport visualization.
207
+ scene_scale = np.median(
208
+ np.ptp(np.concatenate([input_c2ws[:, :3, 3], *points], 0), -1)
209
+ )
210
+ input_c2ws[:, :3, 3] /= scene_scale
211
+ points = [point / scene_scale for point in points]
212
+ input_imgs = [
213
+ torch.as_tensor(img / 255.0, dtype=torch.float32) for img in input_imgs
214
+ ]
215
+ input_Ks = torch.as_tensor(input_Ks)
216
+ input_c2ws = torch.as_tensor(input_c2ws)
217
+ new_input_imgs, new_input_Ks = [], []
218
+ for img, K in zip(input_imgs, input_Ks):
219
+ img = rearrange(img, "h w c -> 1 c h w")
220
+ # If you don't want to keep aspect ratio and want to always center crop, use this:
221
+ # img, K = transform_img_and_K(img, (shorter, shorter), K=K[None])
222
+ img, K = transform_img_and_K(img, shorter, K=K[None], size_stride=64)
223
+ assert isinstance(K, torch.Tensor)
224
+ K = K / K.new_tensor([img.shape[-1], img.shape[-2], 1])[:, None]
225
+ new_input_imgs.append(img)
226
+ new_input_Ks.append(K)
227
+ input_imgs = torch.cat(new_input_imgs, 0)
228
+ input_imgs = rearrange(input_imgs, "b c h w -> b h w c")[..., :3]
229
+ input_Ks = torch.cat(new_input_Ks, 0)
230
+ return (
231
+ {
232
+ "input_imgs": input_imgs,
233
+ "input_Ks": input_Ks,
234
+ "input_c2ws": input_c2ws,
235
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
236
+ "points": points,
237
+ "point_colors": point_colors,
238
+ "scene_scale": scene_scale,
239
+ },
240
+ gr.update(visible=False),
241
+ gr.update()
242
+ if num_inputs <= 10
243
+ else gr.update(choices=["interp"], value="interp"),
244
+ )
245
+
246
+ def visualize_scene(self, preprocessed: dict):
247
+ server = self.server
248
+ server.scene.reset()
249
+ server.gui.reset()
250
+ set_bkgd_color(server)
251
+
252
+ (
253
+ input_imgs,
254
+ input_Ks,
255
+ input_c2ws,
256
+ input_wh,
257
+ points,
258
+ point_colors,
259
+ scene_scale,
260
+ ) = (
261
+ preprocessed["input_imgs"],
262
+ preprocessed["input_Ks"],
263
+ preprocessed["input_c2ws"],
264
+ preprocessed["input_wh"],
265
+ preprocessed["points"],
266
+ preprocessed["point_colors"],
267
+ preprocessed["scene_scale"],
268
+ )
269
+ W, H = input_wh
270
+
271
+ server.scene.set_up_direction(-input_c2ws[..., :3, 1].mean(0).numpy())
272
+
273
+ # Use first image as default fov.
274
+ assert input_imgs[0].shape[:2] == (H, W)
275
+ if H > W:
276
+ init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 0, 0].item()))
277
+ else:
278
+ init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 1, 1].item()))
279
+ init_fov_deg = float(init_fov / np.pi * 180.0)
280
+
281
+ frustum_nodes, pcd_nodes = [], []
282
+ for i in range(len(input_imgs)):
283
+ K = input_Ks[i]
284
+ frustum = server.scene.add_camera_frustum(
285
+ f"/scene_assets/cameras/{i}",
286
+ fov=2 * np.arctan(1 / (2 * K[1, 1].item())),
287
+ aspect=W / H,
288
+ scale=0.1 * scene_scale,
289
+ image=(input_imgs[i].numpy() * 255.0).astype(np.uint8),
290
+ wxyz=vt.SO3.from_matrix(input_c2ws[i, :3, :3].numpy()).wxyz,
291
+ position=input_c2ws[i, :3, 3].numpy(),
292
+ )
293
+
294
+ def get_handler(frustum):
295
+ def handler(event: viser.GuiEvent) -> None:
296
+ assert event.client_id is not None
297
+ client = server.get_clients()[event.client_id]
298
+ with client.atomic():
299
+ client.camera.position = frustum.position
300
+ client.camera.wxyz = frustum.wxyz
301
+ # Set look_at as the projected origin onto the
302
+ # frustum's forward direction.
303
+ look_direction = vt.SO3(frustum.wxyz).as_matrix()[:, 2]
304
+ position_origin = -frustum.position
305
+ client.camera.look_at = (
306
+ frustum.position
307
+ + np.dot(look_direction, position_origin)
308
+ / np.linalg.norm(position_origin)
309
+ * look_direction
310
+ )
311
+
312
+ return handler
313
+
314
+ frustum.on_click(get_handler(frustum)) # type: ignore
315
+ frustum_nodes.append(frustum)
316
+
317
+ pcd = server.scene.add_point_cloud(
318
+ f"/scene_assets/points/{i}",
319
+ points[i],
320
+ point_colors[i],
321
+ point_size=0.01 * scene_scale,
322
+ point_shape="circle",
323
+ )
324
+ pcd_nodes.append(pcd)
325
+
326
+ with server.gui.add_folder("Scene scale", expand_by_default=False, order=200):
327
+ camera_scale_slider = server.gui.add_slider(
328
+ "Log camera scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
329
+ )
330
+
331
+ @camera_scale_slider.on_update
332
+ def _(_) -> None:
333
+ for i in range(len(frustum_nodes)):
334
+ frustum_nodes[i].scale = (
335
+ 0.1 * scene_scale * 10**camera_scale_slider.value
336
+ )
337
+
338
+ point_scale_slider = server.gui.add_slider(
339
+ "Log point scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
340
+ )
341
+
342
+ @point_scale_slider.on_update
343
+ def _(_) -> None:
344
+ for i in range(len(pcd_nodes)):
345
+ pcd_nodes[i].point_size = (
346
+ 0.01 * scene_scale * 10**point_scale_slider.value
347
+ )
348
+
349
+ self.gui_state = define_gui(
350
+ server,
351
+ init_fov=init_fov_deg,
352
+ img_wh=input_wh,
353
+ scene_scale=scene_scale,
354
+ )
355
+
356
+ def get_target_c2ws_and_Ks_from_gui(self, preprocessed: dict):
357
+ input_wh = preprocessed["input_wh"]
358
+ W, H = input_wh
359
+ gui_state = self.gui_state
360
+ assert gui_state is not None and gui_state.camera_traj_list is not None
361
+ target_c2ws, target_Ks = [], []
362
+ for item in gui_state.camera_traj_list:
363
+ target_c2ws.append(item["w2c"])
364
+ assert item["img_wh"] == input_wh
365
+ K = np.array(item["K"]).reshape(3, 3) / np.array([W, H, 1])[:, None]
366
+ target_Ks.append(K)
367
+ target_c2ws = torch.as_tensor(
368
+ np.linalg.inv(np.array(target_c2ws).reshape(-1, 4, 4))
369
+ )
370
+ target_Ks = torch.as_tensor(np.array(target_Ks).reshape(-1, 3, 3))
371
+ return target_c2ws, target_Ks
372
+
373
+ def get_target_c2ws_and_Ks_from_preset(
374
+ self,
375
+ preprocessed: dict,
376
+ preset_traj: Literal[
377
+ "orbit",
378
+ "spiral",
379
+ "lemniscate",
380
+ "zoom-in",
381
+ "zoom-out",
382
+ "dolly zoom-in",
383
+ "dolly zoom-out",
384
+ "move-forward",
385
+ "move-backward",
386
+ "move-up",
387
+ "move-down",
388
+ "move-left",
389
+ "move-right",
390
+ ],
391
+ num_frames: int,
392
+ zoom_factor: float | None,
393
+ ):
394
+ img_wh = preprocessed["input_wh"]
395
+ start_c2w = preprocessed["input_c2ws"][0]
396
+ start_w2c = torch.linalg.inv(start_c2w)
397
+ look_at = torch.tensor([0, 0, 10])
398
+ start_fov = DEFAULT_FOV_RAD
399
+ target_c2ws, target_fovs = get_preset_pose_fov(
400
+ preset_traj,
401
+ num_frames,
402
+ start_w2c,
403
+ look_at,
404
+ -start_c2w[:3, 1],
405
+ start_fov,
406
+ spiral_radii=[1.0, 1.0, 0.5],
407
+ zoom_factor=zoom_factor,
408
+ )
409
+ target_c2ws = torch.as_tensor(target_c2ws)
410
+ target_fovs = torch.as_tensor(target_fovs)
411
+ target_Ks = get_default_intrinsics(
412
+ target_fovs, # type: ignore
413
+ aspect_ratio=img_wh[0] / img_wh[1],
414
+ )
415
+ return target_c2ws, target_Ks
416
+
417
+ def export_output_data(self, preprocessed: dict, output_dir: str):
418
+ input_imgs, input_Ks, input_c2ws, input_wh = (
419
+ preprocessed["input_imgs"],
420
+ preprocessed["input_Ks"],
421
+ preprocessed["input_c2ws"],
422
+ preprocessed["input_wh"],
423
+ )
424
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
425
+
426
+ num_inputs = len(input_imgs)
427
+ num_targets = len(target_c2ws)
428
+
429
+ input_imgs = (input_imgs.cpu().numpy() * 255.0).astype(np.uint8)
430
+ input_c2ws = input_c2ws.cpu().numpy()
431
+ input_Ks = input_Ks.cpu().numpy()
432
+ target_c2ws = target_c2ws.cpu().numpy()
433
+ target_Ks = target_Ks.cpu().numpy()
434
+ img_whs = np.array(input_wh)[None].repeat(len(input_imgs) + len(target_Ks), 0)
435
+
436
+ os.makedirs(output_dir, exist_ok=True)
437
+ img_paths = []
438
+ for i, img in enumerate(input_imgs):
439
+ iio.imwrite(img_path := osp.join(output_dir, f"{i:03d}.png"), img)
440
+ img_paths.append(img_path)
441
+ for i in range(num_targets):
442
+ iio.imwrite(
443
+ img_path := osp.join(output_dir, f"{i + num_inputs:03d}.png"),
444
+ np.zeros((input_wh[1], input_wh[0], 3), dtype=np.uint8),
445
+ )
446
+ img_paths.append(img_path)
447
+
448
+ # Convert from OpenCV to OpenGL camera format.
449
+ all_c2ws = np.concatenate([input_c2ws, target_c2ws])
450
+ all_Ks = np.concatenate([input_Ks, target_Ks])
451
+ all_c2ws = all_c2ws @ np.diag([1, -1, -1, 1])
452
+ create_transforms_simple(output_dir, img_paths, img_whs, all_c2ws, all_Ks)
453
+ split_dict = {
454
+ "train_ids": list(range(num_inputs)),
455
+ "test_ids": list(range(num_inputs, num_inputs + num_targets)),
456
+ }
457
+ with open(
458
+ osp.join(output_dir, f"train_test_split_{num_inputs}.json"), "w"
459
+ ) as f:
460
+ json.dump(split_dict, f, indent=4)
461
+ gr.Info(f"Output data saved to {output_dir}", duration=1)
462
+
463
+ def render(
464
+ self,
465
+ preprocessed: dict,
466
+ session_hash: str,
467
+ seed: int,
468
+ chunk_strategy: str,
469
+ cfg: float,
470
+ preset_traj: Literal[
471
+ "orbit",
472
+ "spiral",
473
+ "lemniscate",
474
+ "zoom-in",
475
+ "zoom-out",
476
+ "dolly zoom-in",
477
+ "dolly zoom-out",
478
+ "move-forward",
479
+ "move-backward",
480
+ "move-up",
481
+ "move-down",
482
+ "move-left",
483
+ "move-right",
484
+ ]
485
+ | None,
486
+ num_frames: int | None,
487
+ zoom_factor: float | None,
488
+ camera_scale: float,
489
+ ):
490
+ render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
491
+ render_dir = osp.join(WORK_DIR, render_name)
492
+
493
+ input_imgs, input_Ks, input_c2ws, (W, H) = (
494
+ preprocessed["input_imgs"],
495
+ preprocessed["input_Ks"],
496
+ preprocessed["input_c2ws"],
497
+ preprocessed["input_wh"],
498
+ )
499
+ num_inputs = len(input_imgs)
500
+ if preset_traj is None:
501
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
502
+ else:
503
+ assert num_frames is not None
504
+ assert num_inputs == 1
505
+ input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
506
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
507
+ preprocessed, preset_traj, num_frames, zoom_factor
508
+ )
509
+ all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
510
+ all_Ks = (
511
+ torch.cat([input_Ks, target_Ks], 0)
512
+ * input_Ks.new_tensor([W, H, 1])[:, None]
513
+ )
514
+ num_targets = len(target_c2ws)
515
+ input_indices = list(range(num_inputs))
516
+ target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
517
+ # Get anchor cameras.
518
+ T = VERSION_DICT["T"]
519
+ version_dict = copy.deepcopy(VERSION_DICT)
520
+ num_anchors = infer_prior_stats(
521
+ T,
522
+ num_inputs,
523
+ num_total_frames=num_targets,
524
+ version_dict=version_dict,
525
+ )
526
+ # infer_prior_stats modifies T in-place.
527
+ T = version_dict["T"]
528
+ assert isinstance(num_anchors, int)
529
+ anchor_indices = np.linspace(
530
+ num_inputs,
531
+ num_inputs + num_targets - 1,
532
+ num_anchors,
533
+ ).tolist()
534
+ anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
535
+ anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
536
+ # Create image conditioning.
537
+ all_imgs_np = (
538
+ F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
539
+ * 255.0
540
+ ).astype(np.uint8)
541
+ image_cond = {
542
+ "img": all_imgs_np,
543
+ "input_indices": input_indices,
544
+ "prior_indices": anchor_indices,
545
+ }
546
+ # Create camera conditioning (K is unnormalized).
547
+ camera_cond = {
548
+ "c2w": all_c2ws,
549
+ "K": all_Ks,
550
+ "input_indices": list(range(num_inputs + num_targets)),
551
+ }
552
+ # Run rendering.
553
+ num_steps = 50
554
+ options_ori = VERSION_DICT["options"]
555
+ options = copy.deepcopy(options_ori)
556
+ options["chunk_strategy"] = chunk_strategy
557
+ options["video_save_fps"] = 30.0
558
+ options["beta_linear_start"] = 5e-6
559
+ options["log_snr_shift"] = 2.4
560
+ options["guider_types"] = [1, 2]
561
+ options["cfg"] = [
562
+ float(cfg),
563
+ 3.0 if num_inputs >= 9 else 2.0,
564
+ ] # We define semi-dense-view regime to have 9 input views.
565
+ options["camera_scale"] = camera_scale
566
+ options["num_steps"] = num_steps
567
+ options["cfg_min"] = 1.2
568
+ options["encoding_t"] = 1
569
+ options["decoding_t"] = 1
570
+ assert session_hash in ABORT_EVENTS
571
+ abort_event = ABORT_EVENTS[session_hash]
572
+ abort_event.clear()
573
+ options["abort_event"] = abort_event
574
+ task = "img2trajvid"
575
+ # Get number of first pass chunks.
576
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
577
+ chunk_strategy_first_pass = options.get(
578
+ "chunk_strategy_first_pass", "gt-nearest"
579
+ )
580
+ num_chunks_0 = len(
581
+ chunk_input_and_test(
582
+ T_first_pass,
583
+ input_c2ws,
584
+ anchor_c2ws,
585
+ input_indices,
586
+ image_cond["prior_indices"],
587
+ options={**options, "sampler_verbose": False},
588
+ task=task,
589
+ chunk_strategy=chunk_strategy_first_pass,
590
+ gt_input_inds=list(range(input_c2ws.shape[0])),
591
+ )[1]
592
+ )
593
+ # Get number of second pass chunks.
594
+ anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
595
+ anchor_indices = np.array(input_indices + anchor_indices)[
596
+ anchor_argsort
597
+ ].tolist()
598
+ gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
599
+ anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
600
+ anchor_argsort
601
+ ]
602
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
603
+ chunk_strategy = options.get("chunk_strategy", "nearest")
604
+ num_chunks_1 = len(
605
+ chunk_input_and_test(
606
+ T_second_pass,
607
+ anchor_c2ws_second_pass,
608
+ target_c2ws,
609
+ anchor_indices,
610
+ target_indices,
611
+ options={**options, "sampler_verbose": False},
612
+ task=task,
613
+ chunk_strategy=chunk_strategy,
614
+ gt_input_inds=gt_input_inds,
615
+ )[1]
616
+ )
617
+ second_pass_pbar = gr.Progress().tqdm(
618
+ iterable=None,
619
+ desc="Second pass sampling",
620
+ total=num_chunks_1 * num_steps,
621
+ )
622
+ first_pass_pbar = gr.Progress().tqdm(
623
+ iterable=None,
624
+ desc="First pass sampling",
625
+ total=num_chunks_0 * num_steps,
626
+ )
627
+ video_path_generator = run_one_scene(
628
+ task=task,
629
+ version_dict={
630
+ "H": H,
631
+ "W": W,
632
+ "T": T,
633
+ "C": VERSION_DICT["C"],
634
+ "f": VERSION_DICT["f"],
635
+ "options": options,
636
+ },
637
+ model=MODEL,
638
+ ae=AE,
639
+ conditioner=CONDITIONER,
640
+ denoiser=DENOISER,
641
+ image_cond=image_cond,
642
+ camera_cond=camera_cond,
643
+ save_path=render_dir,
644
+ use_traj_prior=True,
645
+ traj_prior_c2ws=anchor_c2ws,
646
+ traj_prior_Ks=anchor_Ks,
647
+ seed=seed,
648
+ gradio=True,
649
+ first_pass_pbar=first_pass_pbar,
650
+ second_pass_pbar=second_pass_pbar,
651
+ abort_event=abort_event,
652
+ )
653
+ output_queue = queue.Queue()
654
+
655
+ blocks = LocalContext.blocks.get()
656
+ event_id = LocalContext.event_id.get()
657
+
658
+ def worker():
659
+ # gradio doesn't support threading with progress intentionally, so
660
+ # we need to hack this.
661
+ LocalContext.blocks.set(blocks)
662
+ LocalContext.event_id.set(event_id)
663
+ for i, video_path in enumerate(video_path_generator):
664
+ if i == 0:
665
+ output_queue.put(
666
+ (
667
+ video_path,
668
+ gr.update(),
669
+ gr.update(),
670
+ gr.update(),
671
+ )
672
+ )
673
+ elif i == 1:
674
+ output_queue.put(
675
+ (
676
+ video_path,
677
+ gr.update(visible=True),
678
+ gr.update(visible=False),
679
+ gr.update(visible=False),
680
+ )
681
+ )
682
+ else:
683
+ gr.Error("More than two passes during rendering.")
684
+
685
+ thread = threading.Thread(target=worker, daemon=True)
686
+ thread.start()
687
+
688
+ while thread.is_alive() or not output_queue.empty():
689
+ if abort_event.is_set():
690
+ thread.join()
691
+ abort_event.clear()
692
+ yield (
693
+ gr.update(),
694
+ gr.update(visible=True),
695
+ gr.update(visible=False),
696
+ gr.update(visible=False),
697
+ )
698
+ time.sleep(0.1)
699
+ while not output_queue.empty():
700
+ yield output_queue.get()
701
+
702
+
703
+ # This is basically a copy of the original `networking.setup_tunnel` function,
704
+ # but it also returns the tunnel object for proper cleanup.
705
+ def setup_tunnel(
706
+ local_host: str, local_port: int, share_token: str, share_server_address: str | None
707
+ ) -> tuple[str, Tunnel]:
708
+ share_server_address = (
709
+ networking.GRADIO_SHARE_SERVER_ADDRESS
710
+ if share_server_address is None
711
+ else share_server_address
712
+ )
713
+ if share_server_address is None:
714
+ try:
715
+ response = httpx.get(networking.GRADIO_API_SERVER, timeout=30)
716
+ payload = response.json()[0]
717
+ remote_host, remote_port = payload["host"], int(payload["port"])
718
+ certificate = payload["root_ca"]
719
+ Path(CERTIFICATE_PATH).parent.mkdir(parents=True, exist_ok=True)
720
+ with open(CERTIFICATE_PATH, "w") as f:
721
+ f.write(certificate)
722
+ except Exception as e:
723
+ raise RuntimeError(
724
+ "Could not get share link from Gradio API Server."
725
+ ) from e
726
+ else:
727
+ remote_host, remote_port = share_server_address.split(":")
728
+ remote_port = int(remote_port)
729
+ tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
730
+ address = tunnel.start_tunnel()
731
+ return address, tunnel
732
+
733
+
734
+ def set_bkgd_color(server: viser.ViserServer | viser.ClientHandle):
735
+ server.scene.set_background_image(np.array([[[39, 39, 42]]], dtype=np.uint8))
736
+
737
+
738
+ def start_server_and_abort_event(request: gr.Request):
739
+ server = viser.ViserServer()
740
+
741
+ @server.on_client_connect
742
+ def _(client: viser.ClientHandle):
743
+ # Force dark mode that blends well with gradio's dark theme.
744
+ client.gui.configure_theme(
745
+ dark_mode=True,
746
+ show_share_button=False,
747
+ control_layout="collapsible",
748
+ )
749
+ set_bkgd_color(client)
750
+
751
+ print(f"Starting server {server.get_port()}")
752
+ server_url, tunnel = setup_tunnel(
753
+ local_host=server.get_host(),
754
+ local_port=server.get_port(),
755
+ share_token=secrets.token_urlsafe(32),
756
+ share_server_address=None,
757
+ )
758
+ SERVERS[request.session_hash] = (server, tunnel)
759
+ if server_url is None:
760
+ raise gr.Error(
761
+ "Failed to get a viewport URL. Please check your network connection."
762
+ )
763
+ # Give it enough time to start.
764
+ time.sleep(1)
765
+
766
+ ABORT_EVENTS[request.session_hash] = threading.Event()
767
+
768
+ return (
769
+ SevaRenderer(server),
770
+ gr.HTML(
771
+ f'<iframe src="{server_url}" style="display: block; margin: auto; width: 100%; height: min(60vh, 600px);" frameborder="0"></iframe>',
772
+ container=True,
773
+ ),
774
+ request.session_hash,
775
+ )
776
+
777
+
778
+ def stop_server_and_abort_event(request: gr.Request):
779
+ if request.session_hash in SERVERS:
780
+ print(f"Stopping server {request.session_hash}")
781
+ server, tunnel = SERVERS.pop(request.session_hash)
782
+ server.stop()
783
+ tunnel.kill()
784
+
785
+ if request.session_hash in ABORT_EVENTS:
786
+ print(f"Setting abort event {request.session_hash}")
787
+ ABORT_EVENTS[request.session_hash].set()
788
+ # Give it enough time to abort jobs.
789
+ time.sleep(5)
790
+ ABORT_EVENTS.pop(request.session_hash)
791
+
792
+
793
+ def set_abort_event(request: gr.Request):
794
+ if request.session_hash in ABORT_EVENTS:
795
+ print(f"Setting abort event {request.session_hash}")
796
+ ABORT_EVENTS[request.session_hash].set()
797
+
798
+
799
+ def get_advance_examples(selection: gr.SelectData):
800
+ index = selection.index
801
+ return (
802
+ gr.Gallery(ADVANCE_EXAMPLE_MAP[index][1], visible=True),
803
+ gr.update(visible=True),
804
+ gr.update(visible=True),
805
+ gr.Gallery(visible=False),
806
+ )
807
+
808
+
809
+ def get_preamble():
810
+ gr.Markdown("""
811
+ # Stable Virtual Camera
812
+ <span style="display: flex; flex-wrap: wrap; gap: 5px;">
813
+ <a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
814
+ <a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a>
815
+ <a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a>
816
+ <a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
817
+ <a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
818
+ <a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a>
819
+ </span>
820
+
821
+ Welcome to the demo of <strong>Stable Virtual Camera (Seva)</strong>! Given any number of input views and their cameras, this demo will allow you to generate novel views of a scene at any target camera of interest.
822
+
823
+ We provide two ways to use our demo (selected by the tab below, documented [here](https://github.com/Stability-AI/stable-virtual-camera/blob/main/docs/GR_USAGE.md)):
824
+ 1. **[Basic](https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705)**: Given a single image, you can generate a video following one of our preset camera trajectories.
825
+ 2. **[Advanced](https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba)**: Given any number of input images, you can generate a video following any camera trajectory of your choice by our key-frame-based interface.
826
+
827
+ > This is a research preview and comes with a few [limitations](https://stable-virtual-camera.github.io/#limitations):
828
+ > - Limited quality in certain subjects due to training data, including humans, animals, and dynamic textures.
829
+ > - Limited quality in some highly ambiguous scenes and camera trajectories, including extreme views and collision into objects.
830
+ """)
831
+
832
+
833
+ # Make sure that gradio uses dark theme.
834
+ _APP_JS = """
835
+ function refresh() {
836
+ const url = new URL(window.location);
837
+ if (url.searchParams.get('__theme') !== 'dark') {
838
+ url.searchParams.set('__theme', 'dark');
839
+ }
840
+ }
841
+ """
842
+
843
+
844
+ def main(server_port: int | None = None, share: bool = True):
845
+ with gr.Blocks(js=_APP_JS) as app:
846
+ renderer = gr.State()
847
+ session_hash = gr.State()
848
+ _ = get_preamble()
849
+ with gr.Tabs():
850
+ with gr.Tab("Basic"):
851
+ render_btn = gr.Button("Render video", interactive=False, render=False)
852
+ with gr.Row():
853
+ with gr.Column():
854
+ with gr.Group():
855
+ # Initially disable the Preprocess Images button until an image is selected.
856
+ preprocess_btn = gr.Button("Preprocess images", interactive=False)
857
+ preprocess_progress = gr.Textbox(
858
+ label="",
859
+ visible=False,
860
+ interactive=False,
861
+ )
862
+ with gr.Group():
863
+ input_imgs = gr.Image(
864
+ type="filepath",
865
+ label="Input",
866
+ height=200,
867
+ )
868
+ _ = gr.Examples(
869
+ examples=sorted(glob("assets/basic/*")),
870
+ inputs=[input_imgs],
871
+ label="Example",
872
+ )
873
+ chunk_strategy = gr.Dropdown(
874
+ ["interp", "interp-gt"],
875
+ label="Chunk strategy",
876
+ render=False,
877
+ )
878
+ preprocessed = gr.State()
879
+ # Enable the Preprocess Images button only if an image is selected.
880
+ input_imgs.change(
881
+ lambda img: gr.update(interactive=bool(img)),
882
+ inputs=input_imgs,
883
+ outputs=preprocess_btn,
884
+ )
885
+ preprocess_btn.click(
886
+ lambda r, *args: [
887
+ *r.preprocess(*args),
888
+ gr.update(interactive=True),
889
+ ],
890
+ inputs=[renderer, input_imgs],
891
+ outputs=[
892
+ preprocessed,
893
+ preprocess_progress,
894
+ chunk_strategy,
895
+ render_btn,
896
+ ],
897
+ show_progress_on=[preprocess_progress],
898
+ concurrency_limit=1,
899
+ concurrency_id="gpu_queue",
900
+ )
901
+ preprocess_btn.click(
902
+ lambda: gr.update(visible=True),
903
+ outputs=[preprocess_progress],
904
+ )
905
+ with gr.Row():
906
+ preset_traj = gr.Dropdown(
907
+ choices=[
908
+ "orbit",
909
+ "spiral",
910
+ "lemniscate",
911
+ "zoom-in",
912
+ "zoom-out",
913
+ "dolly zoom-in",
914
+ "dolly zoom-out",
915
+ "move-forward",
916
+ "move-backward",
917
+ "move-up",
918
+ "move-down",
919
+ "move-left",
920
+ "move-right",
921
+ ],
922
+ label="Preset trajectory",
923
+ value="orbit",
924
+ )
925
+ num_frames = gr.Slider(30, 150, 80, label="#Frames")
926
+ zoom_factor = gr.Slider(
927
+ step=0.01, label="Zoom factor", visible=False
928
+ )
929
+ with gr.Row():
930
+ seed = gr.Number(value=23, label="Random seed")
931
+ chunk_strategy.render()
932
+ cfg = gr.Slider(1.0, 7.0, value=4.0, label="CFG value")
933
+ with gr.Row():
934
+ camera_scale = gr.Slider(
935
+ 0.1,
936
+ 15.0,
937
+ value=2.0,
938
+ label="Camera scale",
939
+ )
940
+
941
+ def default_cfg_preset_traj(traj):
942
+ # These are just some hand-tuned values that we
943
+ # found work the best.
944
+ if traj in ["zoom-out", "move-down"]:
945
+ value = 5.0
946
+ elif traj in [
947
+ "orbit",
948
+ "dolly zoom-out",
949
+ "move-backward",
950
+ "move-up",
951
+ "move-left",
952
+ "move-right",
953
+ ]:
954
+ value = 4.0
955
+ else:
956
+ value = 3.0
957
+ return value
958
+
959
+ preset_traj.change(
960
+ default_cfg_preset_traj,
961
+ inputs=[preset_traj],
962
+ outputs=[cfg],
963
+ )
964
+ preset_traj.change(
965
+ lambda traj: gr.update(
966
+ value=(
967
+ 10.0 if "dolly" in traj or "pan" in traj else 2.0
968
+ )
969
+ ),
970
+ inputs=[preset_traj],
971
+ outputs=[camera_scale],
972
+ )
973
+
974
+ def zoom_factor_preset_traj(traj):
975
+ visible = traj in [
976
+ "zoom-in",
977
+ "zoom-out",
978
+ "dolly zoom-in",
979
+ "dolly zoom-out",
980
+ ]
981
+ is_zoomin = traj.endswith("zoom-in")
982
+ if is_zoomin:
983
+ minimum = 0.1
984
+ maximum = 0.5
985
+ value = 0.28
986
+ else:
987
+ minimum = 1.2
988
+ maximum = 3
989
+ value = 1.5
990
+ return gr.update(
991
+ visible=visible,
992
+ minimum=minimum,
993
+ maximum=maximum,
994
+ value=value,
995
+ )
996
+
997
+ preset_traj.change(
998
+ zoom_factor_preset_traj,
999
+ inputs=[preset_traj],
1000
+ outputs=[zoom_factor],
1001
+ )
1002
+ with gr.Column():
1003
+ with gr.Group():
1004
+ abort_btn = gr.Button("Abort rendering", visible=False)
1005
+ render_btn.render()
1006
+ render_progress = gr.Textbox(
1007
+ label="", visible=False, interactive=False
1008
+ )
1009
+ output_video = gr.Video(
1010
+ label="Output", interactive=False, autoplay=True, loop=True
1011
+ )
1012
+ render_btn.click(
1013
+ lambda r, *args: (yield from r.render(*args)),
1014
+ inputs=[
1015
+ renderer,
1016
+ preprocessed,
1017
+ session_hash,
1018
+ seed,
1019
+ chunk_strategy,
1020
+ cfg,
1021
+ preset_traj,
1022
+ num_frames,
1023
+ zoom_factor,
1024
+ camera_scale,
1025
+ ],
1026
+ outputs=[
1027
+ output_video,
1028
+ render_btn,
1029
+ abort_btn,
1030
+ render_progress,
1031
+ ],
1032
+ show_progress_on=[render_progress],
1033
+ concurrency_id="gpu_queue",
1034
+ )
1035
+ render_btn.click(
1036
+ lambda: [
1037
+ gr.update(visible=False),
1038
+ gr.update(visible=True),
1039
+ gr.update(visible=True),
1040
+ ],
1041
+ outputs=[render_btn, abort_btn, render_progress],
1042
+ )
1043
+ abort_btn.click(set_abort_event)
1044
+ with gr.Tab("Advanced"):
1045
+ render_btn = gr.Button("Render video", interactive=False, render=False)
1046
+ viewport = gr.HTML(container=True, render=False)
1047
+ gr.Timer(0.1).tick(
1048
+ lambda renderer: gr.update(
1049
+ interactive=renderer is not None
1050
+ and renderer.gui_state is not None
1051
+ and renderer.gui_state.camera_traj_list is not None
1052
+ ),
1053
+ inputs=[renderer],
1054
+ outputs=[render_btn],
1055
+ )
1056
+ with gr.Row():
1057
+ viewport.render()
1058
+ with gr.Row():
1059
+ with gr.Column():
1060
+ with gr.Group():
1061
+ # Initially disable the Preprocess Images button until images are selected.
1062
+ preprocess_btn = gr.Button("Preprocess images", interactive=False)
1063
+ preprocess_progress = gr.Textbox(
1064
+ label="",
1065
+ visible=False,
1066
+ interactive=False,
1067
+ )
1068
+ with gr.Group():
1069
+ input_imgs = gr.Gallery(
1070
+ interactive=True,
1071
+ label="Input",
1072
+ columns=4,
1073
+ height=200,
1074
+ )
1075
+ # Define example images (gradio doesn't support variable length
1076
+ # examples so we need to hack it).
1077
+ example_imgs = gr.Gallery(
1078
+ [e[0] for e in ADVANCE_EXAMPLE_MAP],
1079
+ allow_preview=False,
1080
+ preview=False,
1081
+ label="Example",
1082
+ columns=20,
1083
+ rows=1,
1084
+ height=115,
1085
+ )
1086
+ example_imgs_expander = gr.Gallery(
1087
+ visible=False,
1088
+ interactive=False,
1089
+ label="Example",
1090
+ preview=True,
1091
+ columns=20,
1092
+ rows=1,
1093
+ )
1094
+ chunk_strategy = gr.Dropdown(
1095
+ ["interp-gt", "interp"],
1096
+ label="Chunk strategy",
1097
+ value="interp-gt",
1098
+ render=False,
1099
+ )
1100
+ with gr.Row():
1101
+ example_imgs_backer = gr.Button(
1102
+ "Go back", visible=False
1103
+ )
1104
+ example_imgs_confirmer = gr.Button(
1105
+ "Confirm", visible=False
1106
+ )
1107
+ example_imgs.select(
1108
+ get_advance_examples,
1109
+ outputs=[
1110
+ example_imgs_expander,
1111
+ example_imgs_confirmer,
1112
+ example_imgs_backer,
1113
+ example_imgs,
1114
+ ],
1115
+ )
1116
+ example_imgs_confirmer.click(
1117
+ lambda x: (
1118
+ x,
1119
+ gr.update(visible=False),
1120
+ gr.update(visible=False),
1121
+ gr.update(visible=False),
1122
+ gr.update(visible=True),
1123
+ gr.update(interactive=bool(x))
1124
+ ),
1125
+ inputs=[example_imgs_expander],
1126
+ outputs=[
1127
+ input_imgs,
1128
+ example_imgs_expander,
1129
+ example_imgs_confirmer,
1130
+ example_imgs_backer,
1131
+ example_imgs,
1132
+ preprocess_btn
1133
+ ],
1134
+ )
1135
+ example_imgs_backer.click(
1136
+ lambda: (
1137
+ gr.update(visible=False),
1138
+ gr.update(visible=False),
1139
+ gr.update(visible=False),
1140
+ gr.update(visible=True),
1141
+ ),
1142
+ outputs=[
1143
+ example_imgs_expander,
1144
+ example_imgs_confirmer,
1145
+ example_imgs_backer,
1146
+ example_imgs,
1147
+ ],
1148
+ )
1149
+ preprocessed = gr.State()
1150
+ preprocess_btn.click(
1151
+ lambda r, *args: r.preprocess(*args),
1152
+ inputs=[renderer, input_imgs],
1153
+ outputs=[
1154
+ preprocessed,
1155
+ preprocess_progress,
1156
+ chunk_strategy,
1157
+ ],
1158
+ show_progress_on=[preprocess_progress],
1159
+ concurrency_id="gpu_queue",
1160
+ )
1161
+ preprocess_btn.click(
1162
+ lambda: gr.update(visible=True),
1163
+ outputs=[preprocess_progress],
1164
+ )
1165
+ preprocessed.change(
1166
+ lambda r, *args: r.visualize_scene(*args),
1167
+ inputs=[renderer, preprocessed],
1168
+ )
1169
+ with gr.Row():
1170
+ seed = gr.Number(value=23, label="Random seed")
1171
+ chunk_strategy.render()
1172
+ cfg = gr.Slider(1.0, 7.0, value=3.0, label="CFG value")
1173
+ with gr.Row():
1174
+ camera_scale = gr.Slider(
1175
+ 0.1,
1176
+ 15.0,
1177
+ value=2.0,
1178
+ label="Camera scale (useful for single-view input)",
1179
+ )
1180
+ with gr.Group():
1181
+ output_data_dir = gr.Textbox(label="Output data directory")
1182
+ output_data_btn = gr.Button("Export output data")
1183
+ output_data_btn.click(
1184
+ lambda r, *args: r.export_output_data(*args),
1185
+ inputs=[renderer, preprocessed, output_data_dir],
1186
+ )
1187
+ with gr.Column():
1188
+ with gr.Group():
1189
+ abort_btn = gr.Button("Abort rendering", visible=False)
1190
+ render_btn.render()
1191
+ render_progress = gr.Textbox(
1192
+ label="", visible=False, interactive=False
1193
+ )
1194
+ output_video = gr.Video(
1195
+ label="Output", interactive=False, autoplay=True, loop=True
1196
+ )
1197
+ render_btn.click(
1198
+ lambda r, *args: (yield from r.render(*args)),
1199
+ inputs=[
1200
+ renderer,
1201
+ preprocessed,
1202
+ session_hash,
1203
+ seed,
1204
+ chunk_strategy,
1205
+ cfg,
1206
+ gr.State(),
1207
+ gr.State(),
1208
+ gr.State(),
1209
+ camera_scale,
1210
+ ],
1211
+ outputs=[
1212
+ output_video,
1213
+ render_btn,
1214
+ abort_btn,
1215
+ render_progress,
1216
+ ],
1217
+ show_progress_on=[render_progress],
1218
+ concurrency_id="gpu_queue",
1219
+ )
1220
+ render_btn.click(
1221
+ lambda: [
1222
+ gr.update(visible=False),
1223
+ gr.update(visible=True),
1224
+ gr.update(visible=True),
1225
+ ],
1226
+ outputs=[render_btn, abort_btn, render_progress],
1227
+ )
1228
+ abort_btn.click(set_abort_event)
1229
+
1230
+ # Register the session initialization and cleanup functions.
1231
+ app.load(
1232
+ start_server_and_abort_event,
1233
+ outputs=[renderer, viewport, session_hash],
1234
+ )
1235
+ app.unload(stop_server_and_abort_event)
1236
+
1237
+ app.queue(max_size=5).launch(
1238
+ share=share,
1239
+ server_port=server_port,
1240
+ show_error=True,
1241
+ allowed_paths=[WORK_DIR],
1242
+ # Badget rendering will be broken otherwise.
1243
+ ssr_mode=False,
1244
+ )
1245
+
1246
+
1247
+ if __name__ == "__main__":
1248
+ tyro.cli(main)
docs/CLI_USAGE.md ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :computer: CLI Demo
2
+
3
+ This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An examplar command line looks as simple as
4
+
5
+ ```bash
6
+ python demo.py --data_path <data_path> [additional arguments]
7
+ ```
8
+
9
+ We discuss here first some key attributes:
10
+
11
+ - `Procedural Two-Pass Sampling`: We recommend enabling procedural sampling by setting `--use_traj_prior True --chunk_strategy <chunk_strategy>` with `<chunk_strategy>` set according to the type of the task.
12
+ - `Resolution and Aspect-Ratio`: Default image preprocessing include center cropping. All input and output are square images of size $576\times 576$. To overwrite, the code support to pass in `--W <W> --H <H>` directly. We recommend passing in `--L_short 576` such that the aspect-ratio of original image is kept while the shortest side will be resized to $576$.
13
+
14
+ ## Task
15
+
16
+ Before diving into the command lines, we introduce `Task` (specified by `--task <task>`) to bucket different usage cases depending on the data constraints in input and output domains (e.g., if the ordering is available).
17
+
18
+ | Task | Type of NVS | Format of `<data_path>` | Target Views Sorted? | Input and Target Views Sorted? | Recommended Usage |
19
+ | :------------------: | :------------: | :--------------------------------------: | :------------------: | :----------------------------: | :----------------------: |
20
+ | `img2img` | set NVS | folder (parsable by `ReconfusionParser`) | :x: | :x: | evaluation, benchmarking |
21
+ | `img2vid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :white_check_mark: | evaluation, benchmarking |
22
+ | `img2trajvid_s-prob` | trajectory NVS | single image | :white_check_mark: | :white_check_mark: | general |
23
+ | `img2trajvid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :x: | general |
24
+
25
+ ### Format of `<data_path>`
26
+
27
+ For `img2trajvid_s-prob` task, we are generating a trajectory video following preset camera motions or effects given only one input image, the data format as simple as
28
+
29
+ ```bash
30
+ <data_path>/
31
+ ├── scene_1.png
32
+ ├── scene_2.png
33
+ └── scene_3.png
34
+ ```
35
+
36
+ For all the other tasks, we use a folder for each scene that is parsable by `ReconfusionParser` (see `seva/data_io.py`). It contains (1) a subdirectory containing all views; (2) `transforms.json` defining the intrinsics and extrinsics (OpenGL convention) for each image; and (3) `train_test_split_*.json` file splitting the input and target views, with `*` indicating the number of the input views.
37
+
38
+ We provide <a href="https://github.com/Stability-AI/stable-virtual-camera/releases/tag/assets_demo_cli">in this release</a> (`assets_demo_cli.zip`) several examplar scenes for you to take reference from. Target views is available if you the data are from academic sources, but in the case where target views is unavailble, we will create dummy black images as placeholders (e.g., the `garden_flythrough` scene). The general data structure follows
39
+
40
+ ```bash
41
+ <data_path>/
42
+ ├── scene_1/
43
+ ├── train_test_split_1.json # for single-view regime
44
+ ├── train_test_split_6.json # for sparse-veiw regime
45
+ ├── train_test_split_32.json # for semi-dense-view regime
46
+ ├── transforms.json
47
+ └── images/
48
+ ├── image_0.png
49
+ ├── image_1.png
50
+ ├── ...
51
+ └── image_1000.png
52
+ ├── scene_2
53
+ └── scene_3
54
+ ```
55
+
56
+ You can specify which scene to run by passing in `--data_items scene_1,scene_2` to run, for example, `scene_1` and `scene_2`.
57
+
58
+ ### Recommended Usage
59
+
60
+ - `img2img` and `img2vid` are recommended to be used for evaluation and benchmarking. These two tasks are used for the quantitative evalution in our <a href="http://arxiv.org/abs/2503.14489">paper</a>. The data is converted from academic datasets so the groundtruth target views are available for metric computation. Check the [`benchmark`](../benchmark/) folder for detailed splits we organize to benchmark different NVS models.
61
+ - `img2vid` requries both the input and target views to be sorted, which is usually not guaranteed in general usage.
62
+ - `img2trajvid_s-prob` is for general usage but only for single-view regime and fixed preset camera control.
63
+ - `img2trajvid` is the task designed for general usage since it does not need the ordering of the input views. This is the task used in the gradio demo.
64
+
65
+ Next we go over all tasks and provide for each task an examplar command line.
66
+
67
+ ## `img2img`
68
+
69
+ ```bash
70
+ python demo.py \
71
+ --data_path <data_path> \
72
+ --num_inputs <P> \
73
+ --video_save_fps 10
74
+ ```
75
+
76
+ - `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
77
+ - The above command works for the dataset without trajectory prior (e.g., DL3DV-140). When the trajectory prior is available given a benchmarking dataset, for example, `orbit` trajectory prior for the CO3D dataset, we use the `nearest-gt` chunking strategy by setting `--use_traj_prior True --traj_prior orbit --chunking_strategy nearest-gt`. We find this leads to more 3D consistent results.
78
+ - For all the single-view conditioning test scenarios: we set `--camera_scale <camera_scale>` with `<camera_scale>` sweeping 20 different camera scales `0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0`.
79
+ - In single-view regime for the RealEstate10K dataset, we find increasing `cfg` is helpful: we additionally set `--cfg 6.0` (`cfg` is `2.0` by default).
80
+ - For the evaluation in semi-dense-view regime (i.e., DL3DV-140 and Tanks and Temples dataset) with `32` input views, we zero-shot extend `T` to fit all input and target views in one forward. Specifically, we set `--T 90` for the DL3DV-140 dataset and `--T 80` for the Tanks and Temples dataset.
81
+ - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
82
+
83
+ For example, you can run the following command on the example `dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557` with 3 input views:
84
+
85
+ ```bash
86
+ python demo.py \
87
+ --data_path /path/to/assets_demo_cli/ \
88
+ --data_items dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557 \
89
+ --num_inputs 3 \
90
+ --video_save_fps 10
91
+ ```
92
+
93
+ ## `img2vid`
94
+
95
+ ```bash
96
+ python demo.py \
97
+ --data_path <data_path> \
98
+ --task img2vid \
99
+ --replace_or_include_input True \
100
+ --num_inputs <P> \
101
+ --use_traj_prior True \
102
+ --chunk_strategy interp \
103
+ ```
104
+
105
+ - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
106
+ - `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
107
+ - We use `interp` chunking strategy by default.
108
+ - For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
109
+
110
+ ## `img2trajvid_s-prob`
111
+
112
+ ```bash
113
+ python demo.py \
114
+ --data_path <data_path> \
115
+ --task img2trajvid_s-prob \
116
+ --replace_or_include_input True \
117
+ --traj_prior orbit \
118
+ --cfg 4.0,2.0 \
119
+ --guider 1,2 \
120
+ --num_targets 111 \
121
+ --L_short 576 \
122
+ --use_traj_prior True \
123
+ --chunk_strategy interp
124
+ ```
125
+
126
+ - `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
127
+ - Default `cfg` should be adusted according to `traj_prior`.
128
+ - Default chunking strategy is `interp`.
129
+ - Default guider is `--guider 1,2` (instead of `1`, `1` still works but `1,2` is slightly better).
130
+ - `camera_scale` (default is `2.0`) can be adjusted according to `traj_prior`. The model has scale ambiguity with single-view input, especially for panning motions. We encourage to tune up `camera_scale` to `10.0` for all panning motions (`--traj_prior pan-*/dolly*`) if you expect a larger camera motion.
131
+
132
+ ## `img2trajvid`
133
+
134
+ ### Sparse-view regime ($P\leq 8$)
135
+
136
+ ```bash
137
+ python demo.py \
138
+ --data_path <data_path> \
139
+ --task img2trajvid \
140
+ --num_inputs <P> \
141
+ --cfg 3.0,2.0 \
142
+ --use_traj_prior True \
143
+ --chunk_strategy interp-gt
144
+ ```
145
+
146
+ - `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
147
+ - Default `cfg` should be set to `3,2` (`3` being `cfg` for the first pass, and `2` being the `cfg` for the second pass). Try to increase the `cfg` for the first pass from `3` to higher values if you observe blurry areas (usually happens for harder scenes with a fair amount of unseen regions).
148
+ - Default chunking strategy should be set to `interp+gt` (instead of `interp`, `interp` can work but usually a bit worse).
149
+ - The `--chunk_strategy_first_pass` is set as `gt-nearest` by default. So it can automatically adapt when $P$ is large (up to a thousand frames).
150
+
151
+ ### Semi-dense-view regime ($P>9$)
152
+
153
+ ```bash
154
+ python demo.py \
155
+ --data_path <data_path> \
156
+ --task img2trajvid \
157
+ --num_inputs <P> \
158
+ --cfg 3.0 \
159
+ --L_short 576 \
160
+ --use_traj_prior True \
161
+ --chunk_strategy interp
162
+ ```
163
+
164
+ - `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
165
+ - Default `cfg` should be set to `3`.
166
+ - Default chunking strategy should be set to `interp` (instead of `interp-gt`, `interp-gt` is also supported but the results do not look good).
167
+ - `T` can be overwritten by `--T <N>,21` (X being extended `T` for the first pass, and `21` being the default `T` for the second pass). `<N>` is dynamically decided now in the code but can also be manually updated. This is useful when you observe that there exist two very dissimilar adjacent anchors which make the interpolation in the second pass impossible. There exist two ways:
168
+ - `--T 96,21`: this overwrites the `T` in the first pass to be exactly `96`.
169
+ - `--num_prior_frames_ratio 1.2`: this enlarges T in the first pass dynamically to be `1.2`$\times$ larger.
docs/GR_USAGE.md ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :rocket: Gradio Demo
2
+
3
+ This gradio demo is the simplest starting point for you play with our project.
4
+
5
+ You can either visit it at our huggingface space [here](https://huggingface.co/spaces/stabilityai/stable-virtual-camera) or run it locally yourself by
6
+
7
+ ```bash
8
+ python demo_gr.py
9
+ ```
10
+
11
+ We provide two ways to use our demo:
12
+
13
+ 1. `Basic` mode, where user can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
14
+ 2. `Advanced` mode, where user can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport (powered by [viser](https://viser.studio/latest)). This is suitable for power users and academic researchers.
15
+
16
+ ### `Basic`
17
+
18
+ This is the default mode when entering our demo (given its simplicity).
19
+
20
+ User can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
21
+
22
+ Here is a video walkthrough:
23
+
24
+ https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705
25
+
26
+ You can choose from 13 preset trajectories that are common for NVS (`move-forward/backward` are omitted for visualization purpose):
27
+
28
+ https://github.com/user-attachments/assets/b2cf8700-3d85-44b9-8d52-248e82f1fb55
29
+
30
+ More formally:
31
+
32
+ - `orbit/spiral/lemniscate` are good for showing the "3D-ness" of the scene.
33
+ - `zoom-in/out` keep the camera position the same while increasing/decreasing the focal length.
34
+ - `dolly zoom-in/out` move camera position backward/forward while increasing/decreasing the focal length.
35
+ - `move-forward/backward/up/down/left/right` move camera position in different directions.
36
+
37
+ Notes:
38
+
39
+ - For a 80 frame video at `786x576` resolution, it takes around 20 seconds for the first pass generation, and around 2 minutes for the second pass generation, tested with a single H100 GPU.
40
+ - Please expect around ~2-3x more times on HF space.
41
+
42
+ ### `Advanced`
43
+
44
+ This is the power mode where you can have very fine-grained control over camera trajectories.
45
+
46
+ User can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport. This is suitable for power users and academic researchers.
47
+
48
+ Here is a video walkthrough
49
+
50
+ https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba
51
+
52
+ Notes:
53
+
54
+ - For a 134 frame video at `576x576` resolution, it takes around 16 seconds for the first pass generation, and around 4 minutes for the second pass generation, tested with a single H100 GPU.
55
+ - Please expect around ~2-3x more times on HF space.
56
+
57
+ ### Pro tips
58
+
59
+ - If the first pass sampling result is bad, click "Abort rendering" button in GUI to avoid stucking at second pass sampling such that you can try something else.
60
+
61
+ ### Performance benchmark
62
+
63
+ We have tested our gradio demo in both a local environment and the HF space environment, across different modes and compilation settings. Here are our results:
64
+ | Total time (s) | `Basic` first pass | `Basic` second pass | `Advanced` first pass | `Advanced` second pass |
65
+ |:------------------------:|:-----------------:|:------------------:|:--------------------:|:---------------------:|
66
+ | HF (L40S, w/o comp.) | 68 | 484 | 48 | 780 |
67
+ | HF (L40S, w/ comp.) | 51 | 362 | 36 | 587 |
68
+ | Local (H100, w/o comp.) | 35 | 204 | 20 | 313 |
69
+ | Local (H100, w/ comp.) | 21 | 144 | 16 | 234 |
70
+
71
+ Notes:
72
+
73
+ - HF space uses L40S GPU, and our local environment uses H100 GPU.
74
+ - We opt-in compilation by `torch.compile`.
75
+ - `Basic` mode is tested by generating 80 frames at `768x576` resolution.
76
+ - `Advanced` mode is tested by generating 134 frames at `576x576` resolution.
docs/INSTALL.md ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :wrench: Installation
2
+
3
+ ### Model Dependencies
4
+
5
+ ```bash
6
+ # Install seva model dependencies.
7
+ pip install -e .
8
+ ```
9
+
10
+ ### Demo Dependencies
11
+
12
+ To use the cli demo (`demo.py`) or the gradio demo (`demo_gr.py`), do the following:
13
+
14
+ ```bash
15
+ # Initialize and update submodules for demo.
16
+ git submodule update --init --recursive
17
+
18
+ # Install pycolmap dependencies for cli and gradio demo (our model is not dependent on it).
19
+ echo "Installing pycolmap (for both cli and gradio demo)..."
20
+ pip install git+https://github.com/jensenz-sai/pycolmap@543266bc316df2fe407b3a33d454b310b1641042
21
+
22
+ # Install dust3r dependencies for gradio demo (our model is not dependent on it).
23
+ echo "Installing dust3r dependencies (only for gradio demo)..."
24
+ pushd third_party/dust3r
25
+ pip install -r requirements.txt
26
+ popd
27
+ ```
28
+
29
+ ### Dev and Speeding Up (Optional)
30
+
31
+ ```bash
32
+ # [OPTIONAL] Install seva dependencies for development.
33
+ pip install -e ".[dev]"
34
+ pre-commit install
35
+
36
+ # [OPTIONAL] Install the torch nightly version for faster JIT via. torch.compile (speed up sampling by 2x in our testing).
37
+ # Please adjust to your own cuda version. For example, if you have cuda 11.8, use the following command.
38
+ pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
39
+ ```
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=65.5.3"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "seva"
7
+ version = "0.0.0"
8
+ requires-python = ">=3.10"
9
+ dependencies = [
10
+ "torch>=2.6.0",
11
+ "roma",
12
+ "viser",
13
+ "tyro",
14
+ "fire",
15
+ "ninja",
16
+ "gradio==5.17.0",
17
+ "einops",
18
+ "colorama",
19
+ "splines",
20
+ "kornia",
21
+ "open-clip-torch",
22
+ "diffusers",
23
+ "numpy==1.24.4",
24
+ "imageio[ffmpeg]",
25
+ "huggingface-hub",
26
+ "opencv-python",
27
+ ]
28
+
29
+ [project.optional-dependencies]
30
+ dev = ["ruff", "ipdb", "pytest", "line_profiler", "pre-commit"]
31
+
32
+ [tool.setuptools.packages.find]
33
+ include = ["seva"]
34
+
35
+ [tool.pyright]
36
+ extraPaths = ["third_party/dust3r"]
37
+
38
+ [tool.ruff]
39
+ lint.ignore = ["E741"]
seva/__init__.py ADDED
File without changes
seva/data_io.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from glob import glob
5
+ from typing import Any, Dict, List, Optional, Tuple
6
+
7
+ import cv2
8
+ import imageio.v3 as iio
9
+ import numpy as np
10
+ import torch
11
+
12
+ from seva.geometry import (
13
+ align_principle_axes,
14
+ similarity_from_cameras,
15
+ transform_cameras,
16
+ transform_points,
17
+ )
18
+
19
+
20
+ def _get_rel_paths(path_dir: str) -> List[str]:
21
+ """Recursively get relative paths of files in a directory."""
22
+ paths = []
23
+ for dp, _, fn in os.walk(path_dir):
24
+ for f in fn:
25
+ paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
26
+ return paths
27
+
28
+
29
+ class BaseParser(object):
30
+ def __init__(
31
+ self,
32
+ data_dir: str,
33
+ factor: int = 1,
34
+ normalize: bool = False,
35
+ test_every: Optional[int] = 8,
36
+ ):
37
+ self.data_dir = data_dir
38
+ self.factor = factor
39
+ self.normalize = normalize
40
+ self.test_every = test_every
41
+
42
+ self.image_names: List[str] = [] # (num_images,)
43
+ self.image_paths: List[str] = [] # (num_images,)
44
+ self.camtoworlds: np.ndarray = np.zeros((0, 4, 4)) # (num_images, 4, 4)
45
+ self.camera_ids: List[int] = [] # (num_images,)
46
+ self.Ks_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> K
47
+ self.params_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> params
48
+ self.imsize_dict: Dict[
49
+ int, Tuple[int, int]
50
+ ] = {} # Dict of camera_id -> (width, height)
51
+ self.points: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
52
+ self.points_err: np.ndarray = np.zeros((0,)) # (num_points,)
53
+ self.points_rgb: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
54
+ self.point_indices: Dict[str, np.ndarray] = {} # Dict of image_name -> (M,)
55
+ self.transform: np.ndarray = np.zeros((4, 4)) # (4, 4)
56
+
57
+ self.mapx_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
58
+ self.mapy_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
59
+ self.roi_undist_dict: Dict[int, Tuple[int, int, int, int]] = (
60
+ dict()
61
+ ) # Dict of camera_id -> (x, y, w, h)
62
+ self.scene_scale: float = 1.0
63
+
64
+
65
+ class DirectParser(BaseParser):
66
+ def __init__(
67
+ self,
68
+ imgs: List[np.ndarray],
69
+ c2ws: np.ndarray,
70
+ Ks: np.ndarray,
71
+ points: Optional[np.ndarray] = None,
72
+ points_rgb: Optional[np.ndarray] = None, # uint8
73
+ mono_disps: Optional[List[np.ndarray]] = None,
74
+ normalize: bool = False,
75
+ test_every: Optional[int] = None,
76
+ ):
77
+ super().__init__("", 1, normalize, test_every)
78
+
79
+ self.image_names = [f"{i:06d}" for i in range(len(imgs))]
80
+ self.image_paths = ["null" for _ in range(len(imgs))]
81
+ self.camtoworlds = c2ws
82
+ self.camera_ids = [i for i in range(len(imgs))]
83
+ self.Ks_dict = {i: K for i, K in enumerate(Ks)}
84
+ self.imsize_dict = {
85
+ i: (img.shape[1], img.shape[0]) for i, img in enumerate(imgs)
86
+ }
87
+ if points is not None:
88
+ self.points = points
89
+ assert points_rgb is not None
90
+ self.points_rgb = points_rgb
91
+ self.points_err = np.zeros((len(points),))
92
+
93
+ self.imgs = imgs
94
+ self.mono_disps = mono_disps
95
+
96
+ # Normalize the world space.
97
+ if normalize:
98
+ T1 = similarity_from_cameras(self.camtoworlds)
99
+ self.camtoworlds = transform_cameras(T1, self.camtoworlds)
100
+
101
+ if points is not None:
102
+ self.points = transform_points(T1, self.points)
103
+ T2 = align_principle_axes(self.points)
104
+ self.camtoworlds = transform_cameras(T2, self.camtoworlds)
105
+ self.points = transform_points(T2, self.points)
106
+ else:
107
+ T2 = np.eye(4)
108
+
109
+ self.transform = T2 @ T1
110
+ else:
111
+ self.transform = np.eye(4)
112
+
113
+ # size of the scene measured by cameras
114
+ camera_locations = self.camtoworlds[:, :3, 3]
115
+ scene_center = np.mean(camera_locations, axis=0)
116
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
117
+ self.scene_scale = np.max(dists)
118
+
119
+
120
+ class COLMAPParser(BaseParser):
121
+ """COLMAP parser."""
122
+
123
+ def __init__(
124
+ self,
125
+ data_dir: str,
126
+ factor: int = 1,
127
+ normalize: bool = False,
128
+ test_every: Optional[int] = 8,
129
+ image_folder: str = "images",
130
+ colmap_folder: str = "sparse/0",
131
+ ):
132
+ super().__init__(data_dir, factor, normalize, test_every)
133
+
134
+ colmap_dir = os.path.join(data_dir, colmap_folder)
135
+ assert os.path.exists(
136
+ colmap_dir
137
+ ), f"COLMAP directory {colmap_dir} does not exist."
138
+
139
+ try:
140
+ from pycolmap import SceneManager
141
+ except ImportError:
142
+ raise ImportError(
143
+ "Please install pycolmap to use the data parsers: "
144
+ " `pip install git+https://github.com/jensenz-sai/pycolmap.git@543266bc316df2fe407b3a33d454b310b1641042`"
145
+ )
146
+
147
+ manager = SceneManager(colmap_dir)
148
+ manager.load_cameras()
149
+ manager.load_images()
150
+ manager.load_points3D()
151
+
152
+ # Extract extrinsic matrices in world-to-camera format.
153
+ imdata = manager.images
154
+ w2c_mats = []
155
+ camera_ids = []
156
+ Ks_dict = dict()
157
+ params_dict = dict()
158
+ imsize_dict = dict() # width, height
159
+ bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
160
+ for k in imdata:
161
+ im = imdata[k]
162
+ rot = im.R()
163
+ trans = im.tvec.reshape(3, 1)
164
+ w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
165
+ w2c_mats.append(w2c)
166
+
167
+ # support different camera intrinsics
168
+ camera_id = im.camera_id
169
+ camera_ids.append(camera_id)
170
+
171
+ # camera intrinsics
172
+ cam = manager.cameras[camera_id]
173
+ fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
174
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
175
+ K[:2, :] /= factor
176
+ Ks_dict[camera_id] = K
177
+
178
+ # Get distortion parameters.
179
+ type_ = cam.camera_type
180
+ if type_ == 0 or type_ == "SIMPLE_PINHOLE":
181
+ params = np.empty(0, dtype=np.float32)
182
+ camtype = "perspective"
183
+ elif type_ == 1 or type_ == "PINHOLE":
184
+ params = np.empty(0, dtype=np.float32)
185
+ camtype = "perspective"
186
+ if type_ == 2 or type_ == "SIMPLE_RADIAL":
187
+ params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
188
+ camtype = "perspective"
189
+ elif type_ == 3 or type_ == "RADIAL":
190
+ params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
191
+ camtype = "perspective"
192
+ elif type_ == 4 or type_ == "OPENCV":
193
+ params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
194
+ camtype = "perspective"
195
+ elif type_ == 5 or type_ == "OPENCV_FISHEYE":
196
+ params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
197
+ camtype = "fisheye"
198
+ assert (
199
+ camtype == "perspective" # type: ignore
200
+ ), f"Only support perspective camera model, got {type_}"
201
+
202
+ params_dict[camera_id] = params # type: ignore
203
+
204
+ # image size
205
+ imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
206
+
207
+ print(
208
+ f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
209
+ )
210
+
211
+ if len(imdata) == 0:
212
+ raise ValueError("No images found in COLMAP.")
213
+ if not (type_ == 0 or type_ == 1): # type: ignore
214
+ print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
215
+
216
+ w2c_mats = np.stack(w2c_mats, axis=0)
217
+
218
+ # Convert extrinsics to camera-to-world.
219
+ camtoworlds = np.linalg.inv(w2c_mats)
220
+
221
+ # Image names from COLMAP. No need for permuting the poses according to
222
+ # image names anymore.
223
+ image_names = [imdata[k].name for k in imdata]
224
+
225
+ # Previous Nerf results were generated with images sorted by filename,
226
+ # ensure metrics are reported on the same test set.
227
+ inds = np.argsort(image_names)
228
+ image_names = [image_names[i] for i in inds]
229
+ camtoworlds = camtoworlds[inds]
230
+ camera_ids = [camera_ids[i] for i in inds]
231
+
232
+ # Load images.
233
+ if factor > 1:
234
+ image_dir_suffix = f"_{factor}"
235
+ else:
236
+ image_dir_suffix = ""
237
+ colmap_image_dir = os.path.join(data_dir, image_folder)
238
+ image_dir = os.path.join(data_dir, image_folder + image_dir_suffix)
239
+ for d in [image_dir, colmap_image_dir]:
240
+ if not os.path.exists(d):
241
+ raise ValueError(f"Image folder {d} does not exist.")
242
+
243
+ # Downsampled images may have different names vs images used for COLMAP,
244
+ # so we need to map between the two sorted lists of files.
245
+ colmap_files = sorted(_get_rel_paths(colmap_image_dir))
246
+ image_files = sorted(_get_rel_paths(image_dir))
247
+ colmap_to_image = dict(zip(colmap_files, image_files))
248
+ image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
249
+
250
+ # 3D points and {image_name -> [point_idx]}
251
+ points = manager.points3D.astype(np.float32) # type: ignore
252
+ points_err = manager.point3D_errors.astype(np.float32) # type: ignore
253
+ points_rgb = manager.point3D_colors.astype(np.uint8) # type: ignore
254
+ point_indices = dict()
255
+
256
+ image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
257
+ for point_id, data in manager.point3D_id_to_images.items():
258
+ for image_id, _ in data:
259
+ image_name = image_id_to_name[image_id]
260
+ point_idx = manager.point3D_id_to_point3D_idx[point_id]
261
+ point_indices.setdefault(image_name, []).append(point_idx)
262
+ point_indices = {
263
+ k: np.array(v).astype(np.int32) for k, v in point_indices.items()
264
+ }
265
+
266
+ # Normalize the world space.
267
+ if normalize:
268
+ T1 = similarity_from_cameras(camtoworlds)
269
+ camtoworlds = transform_cameras(T1, camtoworlds)
270
+ points = transform_points(T1, points)
271
+
272
+ T2 = align_principle_axes(points)
273
+ camtoworlds = transform_cameras(T2, camtoworlds)
274
+ points = transform_points(T2, points)
275
+
276
+ transform = T2 @ T1
277
+ else:
278
+ transform = np.eye(4)
279
+
280
+ self.image_names = image_names # List[str], (num_images,)
281
+ self.image_paths = image_paths # List[str], (num_images,)
282
+ self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
283
+ self.camera_ids = camera_ids # List[int], (num_images,)
284
+ self.Ks_dict = Ks_dict # Dict of camera_id -> K
285
+ self.params_dict = params_dict # Dict of camera_id -> params
286
+ self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
287
+ self.points = points # np.ndarray, (num_points, 3)
288
+ self.points_err = points_err # np.ndarray, (num_points,)
289
+ self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
290
+ self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
291
+ self.transform = transform # np.ndarray, (4, 4)
292
+
293
+ # undistortion
294
+ self.mapx_dict = dict()
295
+ self.mapy_dict = dict()
296
+ self.roi_undist_dict = dict()
297
+ for camera_id in self.params_dict.keys():
298
+ params = self.params_dict[camera_id]
299
+ if len(params) == 0:
300
+ continue # no distortion
301
+ assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
302
+ assert (
303
+ camera_id in self.params_dict
304
+ ), f"Missing params for camera {camera_id}"
305
+ K = self.Ks_dict[camera_id]
306
+ width, height = self.imsize_dict[camera_id]
307
+ K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
308
+ K, params, (width, height), 0
309
+ )
310
+ mapx, mapy = cv2.initUndistortRectifyMap(
311
+ K,
312
+ params,
313
+ None,
314
+ K_undist,
315
+ (width, height),
316
+ cv2.CV_32FC1, # type: ignore
317
+ )
318
+ self.Ks_dict[camera_id] = K_undist
319
+ self.mapx_dict[camera_id] = mapx
320
+ self.mapy_dict[camera_id] = mapy
321
+ self.roi_undist_dict[camera_id] = roi_undist # type: ignore
322
+
323
+ # size of the scene measured by cameras
324
+ camera_locations = camtoworlds[:, :3, 3]
325
+ scene_center = np.mean(camera_locations, axis=0)
326
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
327
+ self.scene_scale = np.max(dists)
328
+
329
+
330
+ class ReconfusionParser(BaseParser):
331
+ def __init__(self, data_dir: str, normalize: bool = False):
332
+ super().__init__(data_dir, 1, normalize, test_every=None)
333
+
334
+ def get_num(p):
335
+ return p.split("_")[-1].removesuffix(".json")
336
+
337
+ splits_per_num_input_frames = {}
338
+ num_input_frames = [
339
+ int(get_num(p)) if get_num(p).isdigit() else get_num(p)
340
+ for p in sorted(glob(osp.join(data_dir, "train_test_split_*.json")))
341
+ ]
342
+ for num_input_frames in num_input_frames:
343
+ with open(
344
+ osp.join(
345
+ data_dir,
346
+ f"train_test_split_{num_input_frames}.json",
347
+ )
348
+ ) as f:
349
+ splits_per_num_input_frames[num_input_frames] = json.load(f)
350
+ self.splits_per_num_input_frames = splits_per_num_input_frames
351
+
352
+ with open(osp.join(data_dir, "transforms.json")) as f:
353
+ metadata = json.load(f)
354
+
355
+ image_names, image_paths, camtoworlds = [], [], []
356
+ for frame in metadata["frames"]:
357
+ if frame["file_path"] is None:
358
+ image_path = image_name = None
359
+ else:
360
+ image_path = osp.join(data_dir, frame["file_path"])
361
+ image_name = osp.basename(image_path)
362
+ image_paths.append(image_path)
363
+ image_names.append(image_name)
364
+ camtoworld = np.array(frame["transform_matrix"])
365
+ if "applied_transform" in metadata:
366
+ applied_transform = np.concatenate(
367
+ [metadata["applied_transform"], [[0, 0, 0, 1]]], axis=0
368
+ )
369
+ camtoworld = applied_transform @ camtoworld
370
+ camtoworlds.append(camtoworld)
371
+ camtoworlds = np.array(camtoworlds)
372
+ camtoworlds[:, :, [1, 2]] *= -1
373
+
374
+ # Normalize the world space.
375
+ if normalize:
376
+ T1 = similarity_from_cameras(camtoworlds)
377
+ camtoworlds = transform_cameras(T1, camtoworlds)
378
+ self.transform = T1
379
+ else:
380
+ self.transform = np.eye(4)
381
+
382
+ self.image_names = image_names
383
+ self.image_paths = image_paths
384
+ self.camtoworlds = camtoworlds
385
+ self.camera_ids = list(range(len(image_paths)))
386
+ self.Ks_dict = {
387
+ i: np.array(
388
+ [
389
+ [
390
+ metadata.get("fl_x", frame.get("fl_x", None)),
391
+ 0.0,
392
+ metadata.get("cx", frame.get("cx", None)),
393
+ ],
394
+ [
395
+ 0.0,
396
+ metadata.get("fl_y", frame.get("fl_y", None)),
397
+ metadata.get("cy", frame.get("cy", None)),
398
+ ],
399
+ [0.0, 0.0, 1.0],
400
+ ]
401
+ )
402
+ for i, frame in enumerate(metadata["frames"])
403
+ }
404
+ self.imsize_dict = {
405
+ i: (
406
+ metadata.get("w", frame.get("w", None)),
407
+ metadata.get("h", frame.get("h", None)),
408
+ )
409
+ for i, frame in enumerate(metadata["frames"])
410
+ }
411
+ # When num_input_frames is None, use all frames for both training and
412
+ # testing.
413
+ # self.splits_per_num_input_frames[None] = {
414
+ # "train_ids": list(range(len(image_paths))),
415
+ # "test_ids": list(range(len(image_paths))),
416
+ # }
417
+
418
+ # size of the scene measured by cameras
419
+ camera_locations = camtoworlds[:, :3, 3]
420
+ scene_center = np.mean(camera_locations, axis=0)
421
+ dists = np.linalg.norm(camera_locations - scene_center, axis=1)
422
+ self.scene_scale = np.max(dists)
423
+
424
+ self.bounds = None
425
+ if osp.exists(osp.join(data_dir, "bounds.npy")):
426
+ self.bounds = np.load(osp.join(data_dir, "bounds.npy"))
427
+ scaling = np.linalg.norm(self.transform[0, :3])
428
+ self.bounds = self.bounds / scaling
429
+
430
+
431
+ class Dataset(torch.utils.data.Dataset):
432
+ """A simple dataset class."""
433
+
434
+ def __init__(
435
+ self,
436
+ parser: BaseParser,
437
+ split: str = "train",
438
+ num_input_frames: Optional[int] = None,
439
+ patch_size: Optional[int] = None,
440
+ load_depths: bool = False,
441
+ load_mono_disps: bool = False,
442
+ ):
443
+ self.parser = parser
444
+ self.split = split
445
+ self.num_input_frames = num_input_frames
446
+ self.patch_size = patch_size
447
+ self.load_depths = load_depths
448
+ self.load_mono_disps = load_mono_disps
449
+ if load_mono_disps:
450
+ assert isinstance(parser, DirectParser)
451
+ assert parser.mono_disps is not None
452
+ if isinstance(parser, ReconfusionParser):
453
+ ids_per_split = parser.splits_per_num_input_frames[num_input_frames]
454
+ self.indices = ids_per_split[
455
+ "train_ids" if split == "train" else "test_ids"
456
+ ]
457
+ else:
458
+ indices = np.arange(len(self.parser.image_names))
459
+ if split == "train":
460
+ self.indices = (
461
+ indices[indices % self.parser.test_every != 0]
462
+ if self.parser.test_every is not None
463
+ else indices
464
+ )
465
+ else:
466
+ self.indices = (
467
+ indices[indices % self.parser.test_every == 0]
468
+ if self.parser.test_every is not None
469
+ else indices
470
+ )
471
+
472
+ def __len__(self):
473
+ return len(self.indices)
474
+
475
+ def __getitem__(self, item: int) -> Dict[str, Any]:
476
+ index = self.indices[item]
477
+ if isinstance(self.parser, DirectParser):
478
+ image = self.parser.imgs[index]
479
+ else:
480
+ image = iio.imread(self.parser.image_paths[index])[..., :3]
481
+ camera_id = self.parser.camera_ids[index]
482
+ K = self.parser.Ks_dict[camera_id].copy() # undistorted K
483
+ params = self.parser.params_dict.get(camera_id, None)
484
+ camtoworlds = self.parser.camtoworlds[index]
485
+
486
+ x, y, w, h = 0, 0, image.shape[1], image.shape[0]
487
+ if params is not None and len(params) > 0:
488
+ # Images are distorted. Undistort them.
489
+ mapx, mapy = (
490
+ self.parser.mapx_dict[camera_id],
491
+ self.parser.mapy_dict[camera_id],
492
+ )
493
+ image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
494
+ x, y, w, h = self.parser.roi_undist_dict[camera_id]
495
+ image = image[y : y + h, x : x + w]
496
+
497
+ if self.patch_size is not None:
498
+ # Random crop.
499
+ h, w = image.shape[:2]
500
+ x = np.random.randint(0, max(w - self.patch_size, 1))
501
+ y = np.random.randint(0, max(h - self.patch_size, 1))
502
+ image = image[y : y + self.patch_size, x : x + self.patch_size]
503
+ K[0, 2] -= x
504
+ K[1, 2] -= y
505
+
506
+ data = {
507
+ "K": torch.from_numpy(K).float(),
508
+ "camtoworld": torch.from_numpy(camtoworlds).float(),
509
+ "image": torch.from_numpy(image).float(),
510
+ "image_id": item, # the index of the image in the dataset
511
+ }
512
+
513
+ if self.load_depths:
514
+ # projected points to image plane to get depths
515
+ worldtocams = np.linalg.inv(camtoworlds)
516
+ image_name = self.parser.image_names[index]
517
+ point_indices = self.parser.point_indices[image_name]
518
+ points_world = self.parser.points[point_indices]
519
+ points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
520
+ points_proj = (K @ points_cam.T).T
521
+ points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
522
+ depths = points_cam[:, 2] # (M,)
523
+ if self.patch_size is not None:
524
+ points[:, 0] -= x
525
+ points[:, 1] -= y
526
+ # filter out points outside the image
527
+ selector = (
528
+ (points[:, 0] >= 0)
529
+ & (points[:, 0] < image.shape[1])
530
+ & (points[:, 1] >= 0)
531
+ & (points[:, 1] < image.shape[0])
532
+ & (depths > 0)
533
+ )
534
+ points = points[selector]
535
+ depths = depths[selector]
536
+ data["points"] = torch.from_numpy(points).float()
537
+ data["depths"] = torch.from_numpy(depths).float()
538
+ if self.load_mono_disps:
539
+ data["mono_disps"] = torch.from_numpy(self.parser.mono_disps[index]).float() # type: ignore
540
+
541
+ return data
542
+
543
+
544
+ def get_parser(parser_type: str, **kwargs) -> BaseParser:
545
+ if parser_type == "colmap":
546
+ parser = COLMAPParser(**kwargs)
547
+ elif parser_type == "direct":
548
+ parser = DirectParser(**kwargs)
549
+ elif parser_type == "reconfusion":
550
+ parser = ReconfusionParser(**kwargs)
551
+ else:
552
+ raise ValueError(f"Unknown parser type: {parser_type}")
553
+ return parser
seva/eval.py ADDED
@@ -0,0 +1,1990 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import json
3
+ import math
4
+ import os
5
+ import re
6
+ import threading
7
+ from typing import List, Literal, Optional, Tuple, Union
8
+
9
+ import gradio as gr
10
+ from colorama import Fore, Style, init
11
+
12
+ init(autoreset=True)
13
+
14
+ import imageio.v3 as iio
15
+ import numpy as np
16
+ import torch
17
+ import torch.nn.functional as F
18
+ import torchvision.transforms.functional as TF
19
+ from einops import repeat
20
+ from PIL import Image
21
+ from tqdm.auto import tqdm
22
+
23
+ from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
24
+ from seva.sampling import (
25
+ EulerEDMSampler,
26
+ MultiviewCFG,
27
+ MultiviewTemporalCFG,
28
+ VanillaCFG,
29
+ )
30
+ from seva.utils import seed_everything
31
+
32
+ try:
33
+ # Check if version string contains 'dev' or 'nightly'
34
+ version = torch.__version__
35
+ IS_TORCH_NIGHTLY = "dev" in version
36
+ if IS_TORCH_NIGHTLY:
37
+ torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment]
38
+ torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment]
39
+ torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment]
40
+ except Exception:
41
+ IS_TORCH_NIGHTLY = False
42
+
43
+
44
+ def pad_indices(
45
+ input_indices: List[int],
46
+ test_indices: List[int],
47
+ T: int,
48
+ padding_mode: Literal["first", "last", "none"] = "last",
49
+ ):
50
+ assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
51
+ if padding_mode == "last":
52
+ padded_indices = [
53
+ i for i in range(T) if i not in (input_indices + test_indices)
54
+ ]
55
+ else:
56
+ padded_indices = []
57
+ input_selects = list(range(len(input_indices)))
58
+ test_selects = list(range(len(test_indices)))
59
+ if max(input_indices) > max(test_indices):
60
+ # last elem from input
61
+ input_selects += [input_selects[-1]] * len(padded_indices)
62
+ input_indices = input_indices + padded_indices
63
+ sorted_inds = np.argsort(input_indices)
64
+ input_indices = [input_indices[ind] for ind in sorted_inds]
65
+ input_selects = [input_selects[ind] for ind in sorted_inds]
66
+ else:
67
+ # last elem from test
68
+ test_selects += [test_selects[-1]] * len(padded_indices)
69
+ test_indices = test_indices + padded_indices
70
+ sorted_inds = np.argsort(test_indices)
71
+ test_indices = [test_indices[ind] for ind in sorted_inds]
72
+ test_selects = [test_selects[ind] for ind in sorted_inds]
73
+
74
+ if padding_mode == "last":
75
+ input_maps = np.array([-1] * T)
76
+ test_maps = np.array([-1] * T)
77
+ else:
78
+ input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
79
+ test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
80
+ input_maps[input_indices] = input_selects
81
+ test_maps[test_indices] = test_selects
82
+ return input_indices, test_indices, input_maps, test_maps
83
+
84
+
85
+ def assemble(
86
+ input,
87
+ test,
88
+ input_maps,
89
+ test_maps,
90
+ ):
91
+ T = len(input_maps)
92
+ assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
93
+ assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
94
+ assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
95
+ assert np.logical_xor(input_maps != -1, test_maps != -1).all()
96
+ return assembled
97
+
98
+
99
+ def get_resizing_factor(
100
+ target_shape: Tuple[int, int], # H, W
101
+ current_shape: Tuple[int, int], # H, W
102
+ cover_target: bool = True,
103
+ # If True, the output shape will fully cover the target shape.
104
+ # If No, the target shape will fully cover the output shape.
105
+ ) -> float:
106
+ r_bound = target_shape[1] / target_shape[0]
107
+ aspect_r = current_shape[1] / current_shape[0]
108
+ if r_bound >= 1.0:
109
+ if cover_target:
110
+ if aspect_r >= r_bound:
111
+ factor = min(target_shape) / min(current_shape)
112
+ elif aspect_r < 1.0:
113
+ factor = max(target_shape) / min(current_shape)
114
+ else:
115
+ factor = max(target_shape) / max(current_shape)
116
+ else:
117
+ if aspect_r >= r_bound:
118
+ factor = max(target_shape) / max(current_shape)
119
+ elif aspect_r < 1.0:
120
+ factor = min(target_shape) / max(current_shape)
121
+ else:
122
+ factor = min(target_shape) / min(current_shape)
123
+ else:
124
+ if cover_target:
125
+ if aspect_r <= r_bound:
126
+ factor = min(target_shape) / min(current_shape)
127
+ elif aspect_r > 1.0:
128
+ factor = max(target_shape) / min(current_shape)
129
+ else:
130
+ factor = max(target_shape) / max(current_shape)
131
+ else:
132
+ if aspect_r <= r_bound:
133
+ factor = max(target_shape) / max(current_shape)
134
+ elif aspect_r > 1.0:
135
+ factor = min(target_shape) / max(current_shape)
136
+ else:
137
+ factor = min(target_shape) / min(current_shape)
138
+ return factor
139
+
140
+
141
+ def get_unique_embedder_keys_from_conditioner(conditioner):
142
+ keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
143
+ keys = [item for sublist in keys for item in sublist] # Flatten list
144
+ return set(keys)
145
+
146
+
147
+ def get_wh_with_fixed_shortest_side(w, h, size):
148
+ # size is smaller or equal to zero, we return original w h
149
+ if size is None or size <= 0:
150
+ return w, h
151
+ if w < h:
152
+ new_w = size
153
+ new_h = int(size * h / w)
154
+ else:
155
+ new_h = size
156
+ new_w = int(size * w / h)
157
+ return new_w, new_h
158
+
159
+
160
+ def load_img_and_K(
161
+ image_path_or_size: Union[str, torch.Size],
162
+ size: Optional[Union[int, Tuple[int, int]]],
163
+ scale: float = 1.0,
164
+ center: Tuple[float, float] = (0.5, 0.5),
165
+ K: torch.Tensor | None = None,
166
+ size_stride: int = 1,
167
+ center_crop: bool = False,
168
+ image_as_tensor: bool = True,
169
+ context_rgb: np.ndarray | None = None,
170
+ device: str = "cuda",
171
+ ):
172
+ if isinstance(image_path_or_size, torch.Size):
173
+ image = Image.new("RGBA", image_path_or_size[::-1])
174
+ else:
175
+ image = Image.open(image_path_or_size).convert("RGBA")
176
+
177
+ w, h = image.size
178
+ if size is None:
179
+ size = (w, h)
180
+
181
+ image = np.array(image).astype(np.float32) / 255
182
+ if image.shape[-1] == 4:
183
+ rgb, alpha = image[:, :, :3], image[:, :, 3:]
184
+ if context_rgb is not None:
185
+ image = rgb * alpha + context_rgb * (1 - alpha)
186
+ else:
187
+ image = rgb * alpha + (1 - alpha)
188
+ image = image.transpose(2, 0, 1)
189
+ image = torch.from_numpy(image).to(dtype=torch.float32)
190
+ image = image.unsqueeze(0)
191
+
192
+ if isinstance(size, (tuple, list)):
193
+ # => if size is a tuple or list, we first rescale to fully cover the `size`
194
+ # area and then crop the `size` area from the rescale image
195
+ W, H = size
196
+ else:
197
+ # => if size is int, we rescale the image to fit the shortest side to size
198
+ # => if size is None, no rescaling is applied
199
+ W, H = get_wh_with_fixed_shortest_side(w, h, size)
200
+ W, H = (
201
+ math.floor(W / size_stride + 0.5) * size_stride,
202
+ math.floor(H / size_stride + 0.5) * size_stride,
203
+ )
204
+
205
+ rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
206
+ resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
207
+ image = torch.nn.functional.interpolate(
208
+ image, resize_size, mode="area", antialias=False
209
+ )
210
+ if scale < 1.0:
211
+ pw = math.ceil((W - resize_size[1]) * 0.5)
212
+ ph = math.ceil((H - resize_size[0]) * 0.5)
213
+ image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
214
+
215
+ cy_center = int(center[1] * image.shape[-2])
216
+ cx_center = int(center[0] * image.shape[-1])
217
+ if center_crop:
218
+ side = min(H, W)
219
+ ct = max(0, cy_center - side // 2)
220
+ cl = max(0, cx_center - side // 2)
221
+ ct = min(ct, image.shape[-2] - side)
222
+ cl = min(cl, image.shape[-1] - side)
223
+ image = TF.crop(image, top=ct, left=cl, height=side, width=side)
224
+ else:
225
+ ct = max(0, cy_center - H // 2)
226
+ cl = max(0, cx_center - W // 2)
227
+ ct = min(ct, image.shape[-2] - H)
228
+ cl = min(cl, image.shape[-1] - W)
229
+ image = TF.crop(image, top=ct, left=cl, height=H, width=W)
230
+
231
+ if K is not None:
232
+ K = K.clone()
233
+ if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
234
+ K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
235
+ else:
236
+ K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
237
+ K[:2, 2] -= K.new_tensor([cl, ct])
238
+
239
+ if image_as_tensor:
240
+ # tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
241
+ image = image.to(device) * 2.0 - 1.0
242
+ else:
243
+ # PIL Image with values ranging from (0, 255)
244
+ image = image.permute(0, 2, 3, 1).numpy()[0]
245
+ image = Image.fromarray((image * 255).astype(np.uint8))
246
+ return image, K
247
+
248
+
249
+ def transform_img_and_K(
250
+ image: torch.Tensor,
251
+ size: Union[int, Tuple[int, int]],
252
+ scale: float = 1.0,
253
+ center: Tuple[float, float] = (0.5, 0.5),
254
+ K: torch.Tensor | None = None,
255
+ size_stride: int = 1,
256
+ mode: str = "crop",
257
+ ):
258
+ assert mode in [
259
+ "crop",
260
+ "pad",
261
+ "stretch",
262
+ ], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
263
+
264
+ h, w = image.shape[-2:]
265
+ if isinstance(size, (tuple, list)):
266
+ # => if size is a tuple or list, we first rescale to fully cover the `size`
267
+ # area and then crop the `size` area from the rescale image
268
+ W, H = size
269
+ else:
270
+ # => if size is int, we rescale the image to fit the shortest side to size
271
+ # => if size is None, no rescaling is applied
272
+ W, H = get_wh_with_fixed_shortest_side(w, h, size)
273
+ W, H = (
274
+ math.floor(W / size_stride + 0.5) * size_stride,
275
+ math.floor(H / size_stride + 0.5) * size_stride,
276
+ )
277
+
278
+ if mode == "stretch":
279
+ rh, rw = H, W
280
+ else:
281
+ rfs = get_resizing_factor(
282
+ (H, W),
283
+ (h, w),
284
+ cover_target=mode != "pad",
285
+ )
286
+ (rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
287
+
288
+ rh, rw = int(rh / scale), int(rw / scale)
289
+ image = torch.nn.functional.interpolate(
290
+ image, (rh, rw), mode="area", antialias=False
291
+ )
292
+
293
+ cy_center = int(center[1] * image.shape[-2])
294
+ cx_center = int(center[0] * image.shape[-1])
295
+ if mode != "pad":
296
+ ct = max(0, cy_center - H // 2)
297
+ cl = max(0, cx_center - W // 2)
298
+ ct = min(ct, image.shape[-2] - H)
299
+ cl = min(cl, image.shape[-1] - W)
300
+ image = TF.crop(image, top=ct, left=cl, height=H, width=W)
301
+ pl, pt = 0, 0
302
+ else:
303
+ pt = max(0, H // 2 - cy_center)
304
+ pl = max(0, W // 2 - cx_center)
305
+ pb = max(0, H - pt - image.shape[-2])
306
+ pr = max(0, W - pl - image.shape[-1])
307
+ image = TF.pad(
308
+ image,
309
+ [pl, pt, pr, pb],
310
+ )
311
+ cl, ct = 0, 0
312
+
313
+ if K is not None:
314
+ K = K.clone()
315
+ # K[:, :2, 2] += K.new_tensor([pl, pt])
316
+ if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
317
+ K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
318
+ else:
319
+ K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
320
+ K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
321
+
322
+ return image, K
323
+
324
+
325
+ lowvram_mode = False
326
+
327
+
328
+ def set_lowvram_mode(mode):
329
+ global lowvram_mode
330
+ lowvram_mode = mode
331
+
332
+
333
+ def load_model(model, device: str = "cuda"):
334
+ model.to(device)
335
+
336
+
337
+ def unload_model(model):
338
+ global lowvram_mode
339
+ if lowvram_mode:
340
+ model.cpu()
341
+ torch.cuda.empty_cache()
342
+
343
+
344
+ def infer_prior_stats(
345
+ T,
346
+ num_input_frames,
347
+ num_total_frames,
348
+ version_dict,
349
+ ):
350
+ options = version_dict["options"]
351
+ chunk_strategy = options.get("chunk_strategy", "nearest")
352
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
353
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
354
+ # get traj_prior_c2ws for 2-pass sampling
355
+ if chunk_strategy.startswith("interp"):
356
+ # Start and end have alreay taken up two slots
357
+ # +1 means we need X + 1 prior frames to bound X times forwards for all test frames
358
+
359
+ # Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
360
+ # generated frames due to insufficient prior frames. This option is effective for
361
+ # complicated trajectory and when `interp` strategy is used (usually semi-dense-view
362
+ # regime). Recommended range is [1.0 (default), 1.5].
363
+ if num_input_frames >= options.get("num_input_semi_dense", 9):
364
+ num_prior_frames = (
365
+ math.ceil(
366
+ num_total_frames
367
+ / (T_second_pass - 2)
368
+ * options.get("num_prior_frames_ratio", 1.0)
369
+ )
370
+ + 1
371
+ )
372
+
373
+ if num_prior_frames + num_input_frames < T_first_pass:
374
+ num_prior_frames = T_first_pass - num_input_frames
375
+
376
+ num_prior_frames = max(
377
+ num_prior_frames,
378
+ options.get("num_prior_frames", 0),
379
+ )
380
+
381
+ T_first_pass = num_prior_frames + num_input_frames
382
+
383
+ if "gt" in chunk_strategy:
384
+ T_second_pass = T_second_pass + num_input_frames
385
+
386
+ # Dynamically update context window length.
387
+ version_dict["T"] = [T_first_pass, T_second_pass]
388
+
389
+ else:
390
+ num_prior_frames = (
391
+ math.ceil(
392
+ num_total_frames
393
+ / (
394
+ T_second_pass
395
+ - 2
396
+ - (num_input_frames if "gt" in chunk_strategy else 0)
397
+ )
398
+ * options.get("num_prior_frames_ratio", 1.0)
399
+ )
400
+ + 1
401
+ )
402
+
403
+ if num_prior_frames + num_input_frames < T_first_pass:
404
+ num_prior_frames = T_first_pass - num_input_frames
405
+
406
+ num_prior_frames = max(
407
+ num_prior_frames,
408
+ options.get("num_prior_frames", 0),
409
+ )
410
+ else:
411
+ num_prior_frames = max(
412
+ T_first_pass - num_input_frames,
413
+ options.get("num_prior_frames", 0),
414
+ )
415
+
416
+ if num_input_frames >= options.get("num_input_semi_dense", 9):
417
+ T_first_pass = num_prior_frames + num_input_frames
418
+
419
+ # Dynamically update context window length.
420
+ version_dict["T"] = [T_first_pass, T_second_pass]
421
+
422
+ return num_prior_frames
423
+
424
+
425
+ def infer_prior_inds(
426
+ c2ws,
427
+ num_prior_frames,
428
+ input_frame_indices,
429
+ options,
430
+ ):
431
+ chunk_strategy = options.get("chunk_strategy", "nearest")
432
+ if chunk_strategy.startswith("interp"):
433
+ prior_frame_indices = np.array(
434
+ [i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
435
+ )
436
+ prior_frame_indices = prior_frame_indices[
437
+ np.ceil(
438
+ np.linspace(
439
+ 0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
440
+ )
441
+ ).astype(int)
442
+ ] # having a ceil here is actually safer for corner case
443
+ else:
444
+ prior_frame_indices = []
445
+ while len(prior_frame_indices) < num_prior_frames:
446
+ closest_distance = np.abs(
447
+ np.arange(c2ws.shape[0])[None]
448
+ - np.concatenate(
449
+ [np.array(input_frame_indices), np.array(prior_frame_indices)]
450
+ )[:, None]
451
+ ).min(0)
452
+ prior_frame_indices.append(np.argsort(closest_distance)[-1])
453
+ return np.sort(prior_frame_indices)
454
+
455
+
456
+ def compute_relative_inds(
457
+ source_inds,
458
+ target_inds,
459
+ ):
460
+ assert len(source_inds) > 2
461
+ # compute relative indices of target_inds within source_inds
462
+ relative_inds = []
463
+ for ind in target_inds:
464
+ if ind in source_inds:
465
+ relative_ind = int(np.where(source_inds == ind)[0][0])
466
+ elif ind < source_inds[0]:
467
+ # extrapolate
468
+ relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
469
+ elif ind > source_inds[-1]:
470
+ # extrapolate
471
+ relative_ind = len(source_inds) + (
472
+ (ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
473
+ )
474
+ else:
475
+ # interpolate
476
+ lower_inds = source_inds[source_inds < ind]
477
+ upper_inds = source_inds[source_inds > ind]
478
+ if len(lower_inds) > 0 and len(upper_inds) > 0:
479
+ lower_ind = lower_inds[-1]
480
+ upper_ind = upper_inds[0]
481
+ relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
482
+ relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
483
+ relative_ind = relative_lower_ind + (ind - lower_ind) / (
484
+ upper_ind - lower_ind
485
+ ) * (relative_upper_ind - relative_lower_ind)
486
+ else:
487
+ # Out of range
488
+ relative_inds.append(float("nan")) # Or some other placeholder
489
+ relative_inds.append(relative_ind)
490
+ return relative_inds
491
+
492
+
493
+ def find_nearest_source_inds(
494
+ source_c2ws,
495
+ target_c2ws,
496
+ nearest_num=1,
497
+ mode="translation",
498
+ ):
499
+ dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
500
+ sorted_inds = np.argsort(dists, axis=0).T
501
+ return sorted_inds[:, :nearest_num]
502
+
503
+
504
+ def chunk_input_and_test(
505
+ T,
506
+ input_c2ws,
507
+ test_c2ws,
508
+ input_ords, # orders
509
+ test_ords, # orders
510
+ options,
511
+ task: str = "img2img",
512
+ chunk_strategy: str = "gt",
513
+ gt_input_inds: list = [],
514
+ ):
515
+ M, N = input_c2ws.shape[0], test_c2ws.shape[0]
516
+
517
+ chunks = []
518
+ if chunk_strategy.startswith("gt"):
519
+ assert len(gt_input_inds) < T, (
520
+ f"Number of gt input frames {len(gt_input_inds)} should be "
521
+ f"less than {T} when `gt` chunking strategy is used."
522
+ )
523
+ assert (
524
+ list(range(M)) == gt_input_inds
525
+ ), "All input_c2ws should be gt when `gt` chunking strategy is used."
526
+
527
+ # LEGACY CHUNKING STRATEGY
528
+ # num_test_per_chunk = T - len(gt_input_inds)
529
+ # test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds]
530
+ # for i in range(0, test_c2ws.shape[0], num_test_per_chunk):
531
+ # chunk = ["NULL"] * T
532
+ # for j, k in enumerate(gt_input_inds):
533
+ # chunk[k] = f"!{j:03d}"
534
+ # for j, k in enumerate(
535
+ # test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]]
536
+ # ):
537
+ # chunk[k] = f">{i + j:03d}"
538
+ # chunks.append(chunk)
539
+
540
+ num_test_seen = 0
541
+ while num_test_seen < N:
542
+ chunk = [f"!{i:03d}" for i in gt_input_inds]
543
+ if chunk_strategy != "gt" and num_test_seen > 0:
544
+ pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
545
+ if (N - num_test_seen) >= math.floor(
546
+ (T - len(gt_input_inds)) * pseudo_num_ratio
547
+ ):
548
+ pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
549
+ else:
550
+ pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
551
+ pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))
552
+
553
+ if "ltr" in chunk_strategy:
554
+ chunk.extend(
555
+ [
556
+ f"!{i + len(gt_input_inds):03d}"
557
+ for i in range(num_test_seen - pseudo_num, num_test_seen)
558
+ ]
559
+ )
560
+ elif "nearest" in chunk_strategy:
561
+ source_inds = np.concatenate(
562
+ [
563
+ find_nearest_source_inds(
564
+ test_c2ws[:num_test_seen],
565
+ test_c2ws[num_test_seen:],
566
+ nearest_num=1, # pseudo_num,
567
+ mode="rotation",
568
+ ),
569
+ find_nearest_source_inds(
570
+ test_c2ws[:num_test_seen],
571
+ test_c2ws[num_test_seen:],
572
+ nearest_num=1, # pseudo_num,
573
+ mode="translation",
574
+ ),
575
+ ],
576
+ axis=1,
577
+ )
578
+ ####### [HACK ALERT] keep running until pseudo num is stablized ########
579
+ temp_pseudo_num = pseudo_num
580
+ while True:
581
+ nearest_source_inds = np.concatenate(
582
+ [
583
+ np.sort(
584
+ [
585
+ ind
586
+ for (ind, _) in collections.Counter(
587
+ [
588
+ item
589
+ for item in source_inds[
590
+ : T
591
+ - len(gt_input_inds)
592
+ - temp_pseudo_num
593
+ ]
594
+ .flatten()
595
+ .tolist()
596
+ if item
597
+ != (
598
+ num_test_seen - 1
599
+ ) # exclude the last one here
600
+ ]
601
+ ).most_common(pseudo_num - 1)
602
+ ],
603
+ ).astype(int),
604
+ [num_test_seen - 1], # always keep the last one
605
+ ]
606
+ )
607
+ if len(nearest_source_inds) >= temp_pseudo_num:
608
+ break # stablized
609
+ else:
610
+ temp_pseudo_num = len(nearest_source_inds)
611
+ pseudo_num = len(nearest_source_inds)
612
+ ########################################################################
613
+ chunk.extend(
614
+ [f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
615
+ )
616
+ else:
617
+ raise NotImplementedError(
618
+ f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
619
+ )
620
+
621
+ chunk.extend(
622
+ [
623
+ f">{i:03d}"
624
+ for i in range(
625
+ num_test_seen,
626
+ min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
627
+ )
628
+ ]
629
+ )
630
+ else:
631
+ chunk.extend(
632
+ [
633
+ f">{i:03d}"
634
+ for i in range(
635
+ num_test_seen,
636
+ min(num_test_seen + T - len(gt_input_inds), N),
637
+ )
638
+ ]
639
+ )
640
+
641
+ num_test_seen += sum([1 for c in chunk if c.startswith(">")])
642
+ if len(chunk) < T:
643
+ chunk.extend(["NULL"] * (T - len(chunk)))
644
+ chunks.append(chunk)
645
+
646
+ elif chunk_strategy.startswith("nearest"):
647
+ input_imgs = np.array([f"!{i:03d}" for i in range(M)])
648
+ test_imgs = np.array([f">{i:03d}" for i in range(N)])
649
+
650
+ match = re.match(r"^nearest-(\d+)$", chunk_strategy)
651
+ if match:
652
+ nearest_num = int(match.group(1))
653
+ assert (
654
+ nearest_num < T
655
+ ), f"Nearest number of {nearest_num} should be less than {T}."
656
+ source_inds = find_nearest_source_inds(
657
+ input_c2ws,
658
+ test_c2ws,
659
+ nearest_num=nearest_num,
660
+ mode="translation", # during the second pass, consider translation only is enough
661
+ )
662
+
663
+ for i in range(0, N, T - nearest_num):
664
+ nearest_source_inds = np.sort(
665
+ [
666
+ ind
667
+ for (ind, _) in collections.Counter(
668
+ source_inds[i : i + T - nearest_num].flatten().tolist()
669
+ ).most_common(nearest_num)
670
+ ]
671
+ )
672
+ chunk = (
673
+ input_imgs[nearest_source_inds].tolist()
674
+ + test_imgs[i : i + T - nearest_num].tolist()
675
+ )
676
+ chunks.append(chunk + ["NULL"] * (T - len(chunk)))
677
+
678
+ else:
679
+ # do not always condition on gt cond frames
680
+ if "gt" not in chunk_strategy:
681
+ gt_input_inds = []
682
+
683
+ source_inds = find_nearest_source_inds(
684
+ input_c2ws,
685
+ test_c2ws,
686
+ nearest_num=1,
687
+ mode="translation", # during the second pass, consider translation only is enough
688
+ )[:, 0]
689
+
690
+ test_inds_per_input = {}
691
+ for test_idx, input_idx in enumerate(source_inds):
692
+ if input_idx not in test_inds_per_input:
693
+ test_inds_per_input[input_idx] = []
694
+ test_inds_per_input[input_idx].append(test_idx)
695
+
696
+ num_test_seen = 0
697
+ chunk = input_imgs[gt_input_inds].tolist()
698
+ candidate_input_inds = sorted(list(test_inds_per_input.keys()))
699
+
700
+ while num_test_seen < N:
701
+ input_idx = candidate_input_inds[0]
702
+ test_inds = test_inds_per_input[input_idx]
703
+ input_is_cond = input_idx in gt_input_inds
704
+ prefix_inds = [] if input_is_cond else [input_idx]
705
+
706
+ if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
707
+ if chunk:
708
+ chunk += ["NULL"] * (T - len(chunk))
709
+ chunks.append(chunk)
710
+ chunk = input_imgs[gt_input_inds].tolist()
711
+ if num_test_seen >= N:
712
+ break
713
+ continue
714
+
715
+ candidate_chunk = (
716
+ input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
717
+ )
718
+
719
+ space_left = T - len(chunk)
720
+ if len(candidate_chunk) <= space_left:
721
+ chunk.extend(candidate_chunk)
722
+ num_test_seen += len(test_inds)
723
+ candidate_input_inds.pop(0)
724
+ else:
725
+ chunk.extend(candidate_chunk[:space_left])
726
+ num_input_idx = 0 if input_is_cond else 1
727
+ num_test_seen += space_left - num_input_idx
728
+ test_inds_per_input[input_idx] = test_inds[
729
+ space_left - num_input_idx :
730
+ ]
731
+
732
+ if len(chunk) == T:
733
+ chunks.append(chunk)
734
+ chunk = input_imgs[gt_input_inds].tolist()
735
+
736
+ if chunk and chunk != input_imgs[gt_input_inds].tolist():
737
+ chunks.append(chunk + ["NULL"] * (T - len(chunk)))
738
+
739
+ elif chunk_strategy.startswith("interp"):
740
+ # `interp` chunk requires ordering info
741
+ assert input_ords is not None and test_ords is not None, (
742
+ "When using `interp` chunking strategy, ordering of input "
743
+ "and test frames should be provided."
744
+ )
745
+
746
+ # if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
747
+ # use input views since their order info within target views is unknown
748
+ if "img2trajvid" in task:
749
+ assert (
750
+ list(range(len(gt_input_inds))) == gt_input_inds
751
+ ), "`img2trajvid` task should put `gt_input_inds` in start."
752
+ input_c2ws = input_c2ws[
753
+ [ind for ind in range(M) if ind not in gt_input_inds]
754
+ ]
755
+ input_ords = [
756
+ input_ords[ind] for ind in range(M) if ind not in gt_input_inds
757
+ ]
758
+ M = input_c2ws.shape[0]
759
+
760
+ input_ords = [0] + input_ords # this is a hack accounting for test views
761
+ # before the first input view
762
+ input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included
763
+ # in the last forward when input_ords[-1] == test_ords[-1]
764
+ input_ords = np.array(input_ords)[:, None]
765
+ input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
766
+ test_ords = np.array(test_ords)[None]
767
+
768
+ in_stop_ranges = np.logical_and(
769
+ np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
770
+ np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
771
+ ) # (M, N)
772
+ assert (in_stop_ranges.sum(1) <= T - 2).all(), (
773
+ "More input frames need to be sampled during the first pass to ensure "
774
+ f"#test frames during each forard in the second pass will not exceed {T - 2}."
775
+ )
776
+ if input_ords[1, 0] <= test_ords[0, 0]:
777
+ assert not in_stop_ranges[0].any()
778
+ if input_ords[-1, 0] >= test_ords[0, -1]:
779
+ assert not in_stop_ranges[-1].any()
780
+
781
+ gt_chunk = (
782
+ [f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
783
+ )
784
+ chunk = gt_chunk + []
785
+ # any test views before the first input views
786
+ if in_stop_ranges[0].any():
787
+ for j, in_range in enumerate(in_stop_ranges[0]):
788
+ if in_range:
789
+ chunk.append(f">{j:03d}")
790
+ in_stop_ranges = in_stop_ranges[1:]
791
+
792
+ i = 0
793
+ base_i = len(gt_input_inds) if "img2trajvid" in task else 0
794
+ chunk.append(f"!{i + base_i:03d}")
795
+ while i < len(in_stop_ranges):
796
+ in_stop_range = in_stop_ranges[i]
797
+ if not in_stop_range.any():
798
+ i += 1
799
+ continue
800
+
801
+ input_left = i + 1 < M
802
+ space_left = T - len(chunk)
803
+ if sum(in_stop_range) + input_left <= space_left:
804
+ for j, in_range in enumerate(in_stop_range):
805
+ if in_range:
806
+ chunk.append(f">{j:03d}")
807
+ i += 1
808
+ if input_left:
809
+ chunk.append(f"!{i + base_i:03d}")
810
+
811
+ else:
812
+ chunk += ["NULL"] * space_left
813
+ chunks.append(chunk)
814
+ chunk = gt_chunk + [f"!{i + base_i:03d}"]
815
+
816
+ if len(chunk) > 1:
817
+ chunk += ["NULL"] * (T - len(chunk))
818
+ chunks.append(chunk)
819
+
820
+ else:
821
+ raise NotImplementedError
822
+
823
+ (
824
+ input_inds_per_chunk,
825
+ input_sels_per_chunk,
826
+ test_inds_per_chunk,
827
+ test_sels_per_chunk,
828
+ ) = (
829
+ [],
830
+ [],
831
+ [],
832
+ [],
833
+ )
834
+ for chunk in chunks:
835
+ input_inds = [
836
+ int(img.removeprefix("!")) for img in chunk if img.startswith("!")
837
+ ]
838
+ input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
839
+ test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
840
+ test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
841
+ input_inds_per_chunk.append(input_inds)
842
+ input_sels_per_chunk.append(input_sels)
843
+ test_inds_per_chunk.append(test_inds)
844
+ test_sels_per_chunk.append(test_sels)
845
+
846
+ if options.get("sampler_verbose", True):
847
+
848
+ def colorize(item):
849
+ if item.startswith("!"):
850
+ return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!'
851
+ elif item.startswith(">"):
852
+ return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>'
853
+ return item # Default color if neither '!' nor '>'
854
+
855
+ print("\nchunks:")
856
+ for chunk in chunks:
857
+ print(", ".join(colorize(item) for item in chunk))
858
+
859
+ return (
860
+ chunks,
861
+ input_inds_per_chunk, # ordering of input in raw sequence
862
+ input_sels_per_chunk, # ordering of input in one-forward sequence of length T
863
+ test_inds_per_chunk, # ordering of test in raw sequence
864
+ test_sels_per_chunk, # oredering of test in one-forward sequence of length T
865
+ )
866
+
867
+
868
+ def is_k_in_dict(d, k):
869
+ return any(map(lambda x: x.startswith(k), d.keys()))
870
+
871
+
872
+ def get_k_from_dict(d, k):
873
+ media_d = {}
874
+ for key, value in d.items():
875
+ if key == k:
876
+ return value
877
+ if key.startswith(k):
878
+ media = key.split("/")[-1]
879
+ if media == "raw":
880
+ return value
881
+ media_d[media] = value
882
+ if len(media_d) == 0:
883
+ return torch.tensor([])
884
+ assert (
885
+ len(media_d) == 1
886
+ ), f"multiple media found in {d} for key {k}: {media_d.keys()}"
887
+ return media_d[media]
888
+
889
+
890
+ def update_kv_for_dict(d, k, v):
891
+ for key in d.keys():
892
+ if key.startswith(k):
893
+ d[key] = v
894
+ return d
895
+
896
+
897
+ def extend_dict(ds, d):
898
+ for key in d.keys():
899
+ if key in ds:
900
+ ds[key] = torch.cat([ds[key], d[key]], 0)
901
+ else:
902
+ ds[key] = d[key]
903
+ return ds
904
+
905
+
906
+ def replace_or_include_input_for_dict(
907
+ samples,
908
+ test_indices,
909
+ imgs,
910
+ c2w,
911
+ K,
912
+ ):
913
+ samples_new = {}
914
+ for sample, value in samples.items():
915
+ if "rgb" in sample:
916
+ imgs[test_indices] = (
917
+ value[test_indices] if value.shape[0] == imgs.shape[0] else value
918
+ ).to(device=imgs.device, dtype=imgs.dtype)
919
+ samples_new[sample] = imgs
920
+ elif "c2w" in sample:
921
+ c2w[test_indices] = (
922
+ value[test_indices] if value.shape[0] == c2w.shape[0] else value
923
+ ).to(device=c2w.device, dtype=c2w.dtype)
924
+ samples_new[sample] = c2w
925
+ elif "intrinsics" in sample:
926
+ K[test_indices] = (
927
+ value[test_indices] if value.shape[0] == K.shape[0] else value
928
+ ).to(device=K.device, dtype=K.dtype)
929
+ samples_new[sample] = K
930
+ else:
931
+ samples_new[sample] = value
932
+ return samples_new
933
+
934
+
935
+ def decode_output(
936
+ samples,
937
+ T,
938
+ indices=None,
939
+ ):
940
+ # decode model output into dict if it is not
941
+ if isinstance(samples, dict):
942
+ # model with postprocessor and outputs dict
943
+ for sample, value in samples.items():
944
+ if isinstance(value, torch.Tensor):
945
+ value = value.detach().cpu()
946
+ elif isinstance(value, np.ndarray):
947
+ value = torch.from_numpy(value)
948
+ else:
949
+ value = torch.tensor(value)
950
+
951
+ if indices is not None and value.shape[0] == T:
952
+ value = value[indices]
953
+ samples[sample] = value
954
+ else:
955
+ # model without postprocessor and outputs tensor (rgb)
956
+ samples = samples.detach().cpu()
957
+
958
+ if indices is not None and samples.shape[0] == T:
959
+ samples = samples[indices]
960
+ samples = {"samples-rgb/image": samples}
961
+
962
+ return samples
963
+
964
+
965
+ def save_output(
966
+ samples,
967
+ save_path,
968
+ video_save_fps=2,
969
+ ):
970
+ os.makedirs(save_path, exist_ok=True)
971
+ for sample in samples:
972
+ media_type = "video"
973
+ if "/" in sample:
974
+ sample_, media_type = sample.split("/")
975
+ else:
976
+ sample_ = sample
977
+
978
+ value = samples[sample]
979
+ if isinstance(value, torch.Tensor):
980
+ value = value.detach().cpu()
981
+ elif isinstance(value, np.ndarray):
982
+ value = torch.from_numpy(value)
983
+ else:
984
+ value = torch.tensor(value)
985
+
986
+ if media_type == "image":
987
+ value = (value.permute(0, 2, 3, 1) + 1) / 2.0
988
+ value = (value * 255).clamp(0, 255).to(torch.uint8)
989
+ iio.imwrite(
990
+ os.path.join(save_path, f"{sample_}.mp4")
991
+ if sample_
992
+ else f"{save_path}.mp4",
993
+ value,
994
+ fps=video_save_fps,
995
+ macro_block_size=1,
996
+ ffmpeg_log_level="error",
997
+ )
998
+ os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
999
+ for i, s in enumerate(value):
1000
+ iio.imwrite(
1001
+ os.path.join(save_path, sample_, f"{i:03d}.png"),
1002
+ s,
1003
+ )
1004
+ elif media_type == "video":
1005
+ value = (value.permute(0, 2, 3, 1) + 1) / 2.0
1006
+ value = (value * 255).clamp(0, 255).to(torch.uint8)
1007
+ iio.imwrite(
1008
+ os.path.join(save_path, f"{sample_}.mp4"),
1009
+ value,
1010
+ fps=video_save_fps,
1011
+ macro_block_size=1,
1012
+ ffmpeg_log_level="error",
1013
+ )
1014
+ elif media_type == "raw":
1015
+ torch.save(
1016
+ value,
1017
+ os.path.join(save_path, f"{sample_}.pt"),
1018
+ )
1019
+ else:
1020
+ pass
1021
+
1022
+
1023
+ def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
1024
+ import os.path as osp
1025
+
1026
+ out_frames = []
1027
+ for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
1028
+ out_frame = {
1029
+ "fl_x": K[0][0].item(),
1030
+ "fl_y": K[1][1].item(),
1031
+ "cx": K[0][2].item(),
1032
+ "cy": K[1][2].item(),
1033
+ "w": img_wh[0].item(),
1034
+ "h": img_wh[1].item(),
1035
+ "file_path": f"./{osp.relpath(img_path, start=save_path)}"
1036
+ if img_path is not None
1037
+ else None,
1038
+ "transform_matrix": c2w.tolist(),
1039
+ }
1040
+ out_frames.append(out_frame)
1041
+ out = {
1042
+ # "camera_model": "PINHOLE",
1043
+ "orientation_override": "none",
1044
+ "frames": out_frames,
1045
+ }
1046
+ with open(osp.join(save_path, "transforms.json"), "w") as of:
1047
+ json.dump(out, of, indent=5)
1048
+
1049
+
1050
+ class GradioTrackedSampler(EulerEDMSampler):
1051
+ """
1052
+ A thin wrapper around the EulerEDMSampler that allows tracking progress and
1053
+ aborting sampling for gradio demo.
1054
+ """
1055
+
1056
+ def __init__(self, abort_event: threading.Event, *args, **kwargs):
1057
+ super().__init__(*args, **kwargs)
1058
+ self.abort_event = abort_event
1059
+
1060
+ def __call__( # type: ignore
1061
+ self,
1062
+ denoiser,
1063
+ x: torch.Tensor,
1064
+ scale: float | torch.Tensor,
1065
+ cond: dict,
1066
+ uc: dict | None = None,
1067
+ num_steps: int | None = None,
1068
+ verbose: bool = True,
1069
+ global_pbar: gr.Progress | None = None,
1070
+ **guider_kwargs,
1071
+ ) -> torch.Tensor | None:
1072
+ uc = cond if uc is None else uc
1073
+ x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
1074
+ x,
1075
+ cond,
1076
+ uc,
1077
+ num_steps,
1078
+ )
1079
+ for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
1080
+ gamma = (
1081
+ min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
1082
+ if self.s_tmin <= sigmas[i] <= self.s_tmax
1083
+ else 0.0
1084
+ )
1085
+ x = self.sampler_step(
1086
+ s_in * sigmas[i],
1087
+ s_in * sigmas[i + 1],
1088
+ denoiser,
1089
+ x,
1090
+ scale,
1091
+ cond,
1092
+ uc,
1093
+ gamma,
1094
+ **guider_kwargs,
1095
+ )
1096
+ # Allow tracking progress in gradio demo.
1097
+ if global_pbar is not None:
1098
+ global_pbar.update()
1099
+ # Allow aborting sampling in gradio demo.
1100
+ if self.abort_event.is_set():
1101
+ return None
1102
+ return x
1103
+
1104
+
1105
+ def create_samplers(
1106
+ guider_types: int | list[int],
1107
+ discretization,
1108
+ num_frames: list[int] | None,
1109
+ num_steps: int,
1110
+ cfg_min: float = 1.0,
1111
+ device: str | torch.device = "cuda",
1112
+ abort_event: threading.Event | None = None,
1113
+ ):
1114
+ guider_mapping = {
1115
+ 0: VanillaCFG,
1116
+ 1: MultiviewCFG,
1117
+ 2: MultiviewTemporalCFG,
1118
+ }
1119
+ samplers = []
1120
+ if not isinstance(guider_types, (list, tuple)):
1121
+ guider_types = [guider_types]
1122
+ for i, guider_type in enumerate(guider_types):
1123
+ if guider_type not in guider_mapping:
1124
+ raise ValueError(
1125
+ f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
1126
+ )
1127
+ guider_cls = guider_mapping[guider_type]
1128
+ guider_args = ()
1129
+ if guider_type > 0:
1130
+ guider_args += (cfg_min,)
1131
+ if guider_type == 2:
1132
+ assert num_frames is not None
1133
+ guider_args = (num_frames[i], cfg_min)
1134
+ guider = guider_cls(*guider_args)
1135
+
1136
+ if abort_event is not None:
1137
+ sampler = GradioTrackedSampler(
1138
+ abort_event,
1139
+ discretization=discretization,
1140
+ guider=guider,
1141
+ num_steps=num_steps,
1142
+ s_churn=0.0,
1143
+ s_tmin=0.0,
1144
+ s_tmax=999.0,
1145
+ s_noise=1.0,
1146
+ verbose=True,
1147
+ device=device,
1148
+ )
1149
+ else:
1150
+ sampler = EulerEDMSampler(
1151
+ discretization=discretization,
1152
+ guider=guider,
1153
+ num_steps=num_steps,
1154
+ s_churn=0.0,
1155
+ s_tmin=0.0,
1156
+ s_tmax=999.0,
1157
+ s_noise=1.0,
1158
+ verbose=True,
1159
+ device=device,
1160
+ )
1161
+ samplers.append(sampler)
1162
+ return samplers
1163
+
1164
+
1165
+ def get_value_dict(
1166
+ curr_imgs,
1167
+ curr_imgs_clip,
1168
+ curr_input_frame_indices,
1169
+ curr_c2ws,
1170
+ curr_Ks,
1171
+ curr_input_camera_indices,
1172
+ all_c2ws,
1173
+ camera_scale,
1174
+ ):
1175
+ assert sorted(curr_input_camera_indices) == sorted(
1176
+ range(len(curr_input_camera_indices))
1177
+ )
1178
+ H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
1179
+
1180
+ value_dict = {}
1181
+ value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
1182
+ value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
1183
+ value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
1184
+ value_dict["cond_frames_mask"][curr_input_frame_indices] = True
1185
+ value_dict["cond_aug"] = 0.0
1186
+
1187
+ c2w = to_hom_pose(curr_c2ws.float())
1188
+ w2c = torch.linalg.inv(c2w)
1189
+
1190
+ # camera centering
1191
+ ref_c2ws = all_c2ws
1192
+ camera_dist_2med = torch.norm(
1193
+ ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
1194
+ dim=-1,
1195
+ )
1196
+ valid_mask = camera_dist_2med <= torch.clamp(
1197
+ torch.quantile(camera_dist_2med, 0.97) * 10,
1198
+ max=1e6,
1199
+ )
1200
+ c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
1201
+ w2c = torch.linalg.inv(c2w)
1202
+
1203
+ # camera normalization
1204
+ camera_dists = c2w[:, :3, 3].clone()
1205
+ translation_scaling_factor = (
1206
+ camera_scale
1207
+ if torch.isclose(
1208
+ torch.norm(camera_dists[0]),
1209
+ torch.zeros(1),
1210
+ atol=1e-5,
1211
+ ).any()
1212
+ else (camera_scale / torch.norm(camera_dists[0]))
1213
+ )
1214
+ w2c[:, :3, 3] *= translation_scaling_factor
1215
+ c2w[:, :3, 3] *= translation_scaling_factor
1216
+ value_dict["plucker_coordinate"], _ = get_plucker_coordinates(
1217
+ extrinsics_src=w2c[0],
1218
+ extrinsics=w2c,
1219
+ intrinsics=curr_Ks.float().clone(),
1220
+ mode="plucker",
1221
+ rel_zero_translation=True,
1222
+ target_size=(H // F, W // F),
1223
+ return_grid_cam=True,
1224
+ )
1225
+
1226
+ value_dict["c2w"] = c2w
1227
+ value_dict["K"] = curr_Ks
1228
+ value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
1229
+ value_dict["camera_mask"][curr_input_camera_indices] = True
1230
+
1231
+ return value_dict
1232
+
1233
+
1234
+ def do_sample(
1235
+ model,
1236
+ ae,
1237
+ conditioner,
1238
+ denoiser,
1239
+ sampler,
1240
+ value_dict,
1241
+ H,
1242
+ W,
1243
+ C,
1244
+ F,
1245
+ T,
1246
+ cfg,
1247
+ encoding_t=1,
1248
+ decoding_t=1,
1249
+ verbose=True,
1250
+ global_pbar=None,
1251
+ **_,
1252
+ ):
1253
+ imgs = value_dict["cond_frames"].to("cuda")
1254
+ input_masks = value_dict["cond_frames_mask"].to("cuda")
1255
+ pluckers = value_dict["plucker_coordinate"].to("cuda")
1256
+
1257
+ num_samples = [1, T]
1258
+ with torch.inference_mode(), torch.autocast("cuda"):
1259
+ load_model(ae)
1260
+ load_model(conditioner)
1261
+ latents = torch.nn.functional.pad(
1262
+ ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
1263
+ )
1264
+ c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
1265
+ uc_crossattn = torch.zeros_like(c_crossattn)
1266
+ c_replace = latents.new_zeros(T, *latents.shape[1:])
1267
+ c_replace[input_masks] = latents
1268
+ uc_replace = torch.zeros_like(c_replace)
1269
+ c_concat = torch.cat(
1270
+ [
1271
+ repeat(
1272
+ input_masks,
1273
+ "n -> n 1 h w",
1274
+ h=pluckers.shape[2],
1275
+ w=pluckers.shape[3],
1276
+ ),
1277
+ pluckers,
1278
+ ],
1279
+ 1,
1280
+ )
1281
+ uc_concat = torch.cat(
1282
+ [pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
1283
+ )
1284
+ c_dense_vector = pluckers
1285
+ uc_dense_vector = c_dense_vector
1286
+ c = {
1287
+ "crossattn": c_crossattn,
1288
+ "replace": c_replace,
1289
+ "concat": c_concat,
1290
+ "dense_vector": c_dense_vector,
1291
+ }
1292
+ uc = {
1293
+ "crossattn": uc_crossattn,
1294
+ "replace": uc_replace,
1295
+ "concat": uc_concat,
1296
+ "dense_vector": uc_dense_vector,
1297
+ }
1298
+ unload_model(ae)
1299
+ unload_model(conditioner)
1300
+
1301
+ additional_model_inputs = {"num_frames": T}
1302
+ additional_sampler_inputs = {
1303
+ "c2w": value_dict["c2w"].to("cuda"),
1304
+ "K": value_dict["K"].to("cuda"),
1305
+ "input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
1306
+ }
1307
+ if global_pbar is not None:
1308
+ additional_sampler_inputs["global_pbar"] = global_pbar
1309
+
1310
+ shape = (math.prod(num_samples), C, H // F, W // F)
1311
+ randn = torch.randn(shape).to("cuda")
1312
+
1313
+ load_model(model)
1314
+ samples_z = sampler(
1315
+ lambda input, sigma, c: denoiser(
1316
+ model,
1317
+ input,
1318
+ sigma,
1319
+ c,
1320
+ **additional_model_inputs,
1321
+ ),
1322
+ randn,
1323
+ scale=cfg,
1324
+ cond=c,
1325
+ uc=uc,
1326
+ verbose=verbose,
1327
+ **additional_sampler_inputs,
1328
+ )
1329
+ if samples_z is None:
1330
+ return
1331
+ unload_model(model)
1332
+
1333
+ load_model(ae)
1334
+ samples = ae.decode(samples_z, decoding_t)
1335
+ unload_model(ae)
1336
+
1337
+ return samples
1338
+
1339
+
1340
+ def run_one_scene(
1341
+ task,
1342
+ version_dict,
1343
+ model,
1344
+ ae,
1345
+ conditioner,
1346
+ denoiser,
1347
+ image_cond,
1348
+ camera_cond,
1349
+ save_path,
1350
+ use_traj_prior,
1351
+ traj_prior_Ks,
1352
+ traj_prior_c2ws,
1353
+ seed=23,
1354
+ gradio=False,
1355
+ abort_event=None,
1356
+ first_pass_pbar=None,
1357
+ second_pass_pbar=None,
1358
+ ):
1359
+ H, W, T, C, F, options = (
1360
+ version_dict["H"],
1361
+ version_dict["W"],
1362
+ version_dict["T"],
1363
+ version_dict["C"],
1364
+ version_dict["f"],
1365
+ version_dict["options"],
1366
+ )
1367
+
1368
+ if isinstance(image_cond, str):
1369
+ image_cond = {"img": [image_cond]}
1370
+ imgs_clip, imgs, img_size = [], [], None
1371
+ for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
1372
+ if isinstance(img, str) or img is None:
1373
+ img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore
1374
+ img_size = img.shape[-2:]
1375
+ if options.get("L_short", -1) == -1:
1376
+ img, K = transform_img_and_K(
1377
+ img,
1378
+ (W, H),
1379
+ K=K[None],
1380
+ mode=(
1381
+ options.get("transform_input", "crop")
1382
+ if i in image_cond["input_indices"]
1383
+ else options.get("transform_target", "crop")
1384
+ ),
1385
+ scale=(
1386
+ 1.0
1387
+ if i in image_cond["input_indices"]
1388
+ else options.get("transform_scale", 1.0)
1389
+ ),
1390
+ )
1391
+ else:
1392
+ downsample = 3
1393
+ assert options["L_short"] % F * 2**downsample == 0, (
1394
+ "Short side of the image should be divisible by "
1395
+ f"F*2**{downsample}={F * 2**downsample}."
1396
+ )
1397
+ img, K = transform_img_and_K(
1398
+ img,
1399
+ options["L_short"],
1400
+ K=K[None],
1401
+ size_stride=F * 2**downsample,
1402
+ mode=(
1403
+ options.get("transform_input", "crop")
1404
+ if i in image_cond["input_indices"]
1405
+ else options.get("transform_target", "crop")
1406
+ ),
1407
+ scale=(
1408
+ 1.0
1409
+ if i in image_cond["input_indices"]
1410
+ else options.get("transform_scale", 1.0)
1411
+ ),
1412
+ )
1413
+ version_dict["W"] = W = img.shape[-1]
1414
+ version_dict["H"] = H = img.shape[-2]
1415
+ K = K[0]
1416
+ K[0] /= W
1417
+ K[1] /= H
1418
+ camera_cond["K"][i] = K
1419
+ img_clip = img
1420
+ elif isinstance(img, np.ndarray):
1421
+ img_size = torch.Size(img.shape[:2])
1422
+ img = torch.as_tensor(img).permute(2, 0, 1)
1423
+ img = img.unsqueeze(0)
1424
+ img = img / 255.0 * 2.0 - 1.0
1425
+ if not gradio:
1426
+ img, K = transform_img_and_K(img, (W, H), K=K[None])
1427
+ assert K is not None
1428
+ K = K[0]
1429
+ K[0] /= W
1430
+ K[1] /= H
1431
+ camera_cond["K"][i] = K
1432
+ img_clip = img
1433
+ else:
1434
+ assert (
1435
+ False
1436
+ ), f"Variable `img` got {type(img)} type which is not supported!!!"
1437
+ imgs_clip.append(img_clip)
1438
+ imgs.append(img)
1439
+ imgs_clip = torch.cat(imgs_clip, dim=0)
1440
+ imgs = torch.cat(imgs, dim=0)
1441
+
1442
+ if traj_prior_Ks is not None:
1443
+ assert img_size is not None
1444
+ for i, prior_k in enumerate(traj_prior_Ks):
1445
+ img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore
1446
+ img, prior_k = transform_img_and_K(
1447
+ img,
1448
+ (W, H),
1449
+ K=prior_k[None],
1450
+ mode=options.get(
1451
+ "transform_target", "crop"
1452
+ ), # mode for prior is always same as target
1453
+ scale=options.get(
1454
+ "transform_scale", 1.0
1455
+ ), # scale for prior is always same as target
1456
+ )
1457
+ prior_k = prior_k[0]
1458
+ prior_k[0] /= W
1459
+ prior_k[1] /= H
1460
+ traj_prior_Ks[i] = prior_k
1461
+
1462
+ options["num_frames"] = T
1463
+ discretization = denoiser.discretization
1464
+ torch.cuda.empty_cache()
1465
+
1466
+ seed_everything(seed)
1467
+
1468
+ # Get Data
1469
+ input_indices = image_cond["input_indices"]
1470
+ input_imgs = imgs[input_indices]
1471
+ input_imgs_clip = imgs_clip[input_indices]
1472
+ input_c2ws = camera_cond["c2w"][input_indices]
1473
+ input_Ks = camera_cond["K"][input_indices]
1474
+
1475
+ test_indices = [i for i in range(len(imgs)) if i not in input_indices]
1476
+ test_imgs = imgs[test_indices]
1477
+ test_imgs_clip = imgs_clip[test_indices]
1478
+ test_c2ws = camera_cond["c2w"][test_indices]
1479
+ test_Ks = camera_cond["K"][test_indices]
1480
+
1481
+ if options.get("save_input", True):
1482
+ save_output(
1483
+ {"/image": input_imgs},
1484
+ save_path=os.path.join(save_path, "input"),
1485
+ video_save_fps=2,
1486
+ )
1487
+
1488
+ if not use_traj_prior:
1489
+ chunk_strategy = options.get("chunk_strategy", "gt")
1490
+
1491
+ (
1492
+ _,
1493
+ input_inds_per_chunk,
1494
+ input_sels_per_chunk,
1495
+ test_inds_per_chunk,
1496
+ test_sels_per_chunk,
1497
+ ) = chunk_input_and_test(
1498
+ T,
1499
+ input_c2ws,
1500
+ test_c2ws,
1501
+ input_indices,
1502
+ test_indices,
1503
+ options=options,
1504
+ task=task,
1505
+ chunk_strategy=chunk_strategy,
1506
+ gt_input_inds=list(range(input_c2ws.shape[0])),
1507
+ )
1508
+ print(
1509
+ f"One pass - chunking with `{chunk_strategy}` strategy: total "
1510
+ f"{len(input_inds_per_chunk)} forward(s) ..."
1511
+ )
1512
+
1513
+ all_samples = {}
1514
+ all_test_inds = []
1515
+ for i, (
1516
+ chunk_input_inds,
1517
+ chunk_input_sels,
1518
+ chunk_test_inds,
1519
+ chunk_test_sels,
1520
+ ) in tqdm(
1521
+ enumerate(
1522
+ zip(
1523
+ input_inds_per_chunk,
1524
+ input_sels_per_chunk,
1525
+ test_inds_per_chunk,
1526
+ test_sels_per_chunk,
1527
+ )
1528
+ ),
1529
+ total=len(input_inds_per_chunk),
1530
+ leave=False,
1531
+ ):
1532
+ (
1533
+ curr_input_sels,
1534
+ curr_test_sels,
1535
+ curr_input_maps,
1536
+ curr_test_maps,
1537
+ ) = pad_indices(
1538
+ chunk_input_sels,
1539
+ chunk_test_sels,
1540
+ T=T,
1541
+ padding_mode=options.get("t_padding_mode", "last"),
1542
+ )
1543
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1544
+ assemble(
1545
+ input=x[chunk_input_inds],
1546
+ test=y[chunk_test_inds],
1547
+ input_maps=curr_input_maps,
1548
+ test_maps=curr_test_maps,
1549
+ )
1550
+ for x, y in zip(
1551
+ [
1552
+ torch.cat(
1553
+ [
1554
+ input_imgs,
1555
+ get_k_from_dict(all_samples, "samples-rgb").to(
1556
+ input_imgs.device
1557
+ ),
1558
+ ],
1559
+ dim=0,
1560
+ ),
1561
+ torch.cat(
1562
+ [
1563
+ input_imgs_clip,
1564
+ get_k_from_dict(all_samples, "samples-rgb").to(
1565
+ input_imgs.device
1566
+ ),
1567
+ ],
1568
+ dim=0,
1569
+ ),
1570
+ torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
1571
+ torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
1572
+ ], # procedually append generated prior views to the input views
1573
+ [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
1574
+ )
1575
+ ]
1576
+ value_dict = get_value_dict(
1577
+ curr_imgs.to("cuda"),
1578
+ curr_imgs_clip.to("cuda"),
1579
+ curr_input_sels
1580
+ + [
1581
+ sel
1582
+ for (ind, sel) in zip(
1583
+ np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
1584
+ curr_test_sels,
1585
+ )
1586
+ if test_indices[ind] in image_cond["input_indices"]
1587
+ ],
1588
+ curr_c2ws,
1589
+ curr_Ks,
1590
+ curr_input_sels
1591
+ + [
1592
+ sel
1593
+ for (ind, sel) in zip(
1594
+ np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
1595
+ curr_test_sels,
1596
+ )
1597
+ if test_indices[ind] in camera_cond["input_indices"]
1598
+ ],
1599
+ all_c2ws=camera_cond["c2w"],
1600
+ camera_scale=options.get("camera_scale", 2.0),
1601
+ )
1602
+ samplers = create_samplers(
1603
+ options["guider_types"],
1604
+ discretization,
1605
+ [len(curr_imgs)],
1606
+ options["num_steps"],
1607
+ options["cfg_min"],
1608
+ abort_event=abort_event,
1609
+ )
1610
+ assert len(samplers) == 1
1611
+ samples = do_sample(
1612
+ model,
1613
+ ae,
1614
+ conditioner,
1615
+ denoiser,
1616
+ samplers[0],
1617
+ value_dict,
1618
+ H,
1619
+ W,
1620
+ C,
1621
+ F,
1622
+ T=len(curr_imgs),
1623
+ cfg=(
1624
+ options["cfg"][0]
1625
+ if isinstance(options["cfg"], (list, tuple))
1626
+ else options["cfg"]
1627
+ ),
1628
+ **{k: options[k] for k in options if k not in ["cfg", "T"]},
1629
+ )
1630
+ samples = decode_output(
1631
+ samples, len(curr_imgs), chunk_test_sels
1632
+ ) # decode into dict
1633
+ if options.get("save_first_pass", False):
1634
+ save_output(
1635
+ replace_or_include_input_for_dict(
1636
+ samples,
1637
+ chunk_test_sels,
1638
+ curr_imgs,
1639
+ curr_c2ws,
1640
+ curr_Ks,
1641
+ ),
1642
+ save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
1643
+ video_save_fps=2,
1644
+ )
1645
+ extend_dict(all_samples, samples)
1646
+ all_test_inds.extend(chunk_test_inds)
1647
+ else:
1648
+ assert traj_prior_c2ws is not None, (
1649
+ "`traj_prior_c2ws` should be set when using 2-pass sampling. One "
1650
+ "potential reason is that the amount of input frames is larger than "
1651
+ "T. Set `num_prior_frames` manually to overwrite the infered stats."
1652
+ )
1653
+ traj_prior_c2ws = torch.as_tensor(
1654
+ traj_prior_c2ws,
1655
+ device=input_c2ws.device,
1656
+ dtype=input_c2ws.dtype,
1657
+ )
1658
+
1659
+ if traj_prior_Ks is None:
1660
+ traj_prior_Ks = test_Ks[:1].repeat_interleave(
1661
+ traj_prior_c2ws.shape[0], dim=0
1662
+ )
1663
+
1664
+ traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
1665
+ traj_prior_imgs_clip = imgs_clip.new_zeros(
1666
+ traj_prior_c2ws.shape[0], *imgs_clip.shape[1:]
1667
+ )
1668
+
1669
+ # ---------------------------------- first pass ----------------------------------
1670
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
1671
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
1672
+ chunk_strategy_first_pass = options.get(
1673
+ "chunk_strategy_first_pass", "gt-nearest"
1674
+ )
1675
+ (
1676
+ _,
1677
+ input_inds_per_chunk,
1678
+ input_sels_per_chunk,
1679
+ prior_inds_per_chunk,
1680
+ prior_sels_per_chunk,
1681
+ ) = chunk_input_and_test(
1682
+ T_first_pass,
1683
+ input_c2ws,
1684
+ traj_prior_c2ws,
1685
+ input_indices,
1686
+ image_cond["prior_indices"],
1687
+ options=options,
1688
+ task=task,
1689
+ chunk_strategy=chunk_strategy_first_pass,
1690
+ gt_input_inds=list(range(input_c2ws.shape[0])),
1691
+ )
1692
+ print(
1693
+ f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
1694
+ f"{len(input_inds_per_chunk)} forward(s) ..."
1695
+ )
1696
+
1697
+ all_samples = {}
1698
+ all_prior_inds = []
1699
+ for i, (
1700
+ chunk_input_inds,
1701
+ chunk_input_sels,
1702
+ chunk_prior_inds,
1703
+ chunk_prior_sels,
1704
+ ) in tqdm(
1705
+ enumerate(
1706
+ zip(
1707
+ input_inds_per_chunk,
1708
+ input_sels_per_chunk,
1709
+ prior_inds_per_chunk,
1710
+ prior_sels_per_chunk,
1711
+ )
1712
+ ),
1713
+ total=len(input_inds_per_chunk),
1714
+ leave=False,
1715
+ ):
1716
+ (
1717
+ curr_input_sels,
1718
+ curr_prior_sels,
1719
+ curr_input_maps,
1720
+ curr_prior_maps,
1721
+ ) = pad_indices(
1722
+ chunk_input_sels,
1723
+ chunk_prior_sels,
1724
+ T=T_first_pass,
1725
+ padding_mode=options.get("t_padding_mode", "last"),
1726
+ )
1727
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1728
+ assemble(
1729
+ input=x[chunk_input_inds],
1730
+ test=y[chunk_prior_inds],
1731
+ input_maps=curr_input_maps,
1732
+ test_maps=curr_prior_maps,
1733
+ )
1734
+ for x, y in zip(
1735
+ [
1736
+ torch.cat(
1737
+ [
1738
+ input_imgs,
1739
+ get_k_from_dict(all_samples, "samples-rgb").to(
1740
+ input_imgs.device
1741
+ ),
1742
+ ],
1743
+ dim=0,
1744
+ ),
1745
+ torch.cat(
1746
+ [
1747
+ input_imgs_clip,
1748
+ get_k_from_dict(all_samples, "samples-rgb").to(
1749
+ input_imgs.device
1750
+ ),
1751
+ ],
1752
+ dim=0,
1753
+ ),
1754
+ torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
1755
+ torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
1756
+ ], # procedually append generated prior views to the input views
1757
+ [
1758
+ traj_prior_imgs,
1759
+ traj_prior_imgs_clip,
1760
+ traj_prior_c2ws,
1761
+ traj_prior_Ks,
1762
+ ],
1763
+ )
1764
+ ]
1765
+ value_dict = get_value_dict(
1766
+ curr_imgs.to("cuda"),
1767
+ curr_imgs_clip.to("cuda"),
1768
+ curr_input_sels,
1769
+ curr_c2ws,
1770
+ curr_Ks,
1771
+ list(range(T_first_pass)),
1772
+ all_c2ws=camera_cond["c2w"],
1773
+ camera_scale=options.get("camera_scale", 2.0),
1774
+ )
1775
+ samplers = create_samplers(
1776
+ options["guider_types"],
1777
+ discretization,
1778
+ [T_first_pass, T_second_pass],
1779
+ options["num_steps"],
1780
+ options["cfg_min"],
1781
+ abort_event=abort_event,
1782
+ )
1783
+ samples = do_sample(
1784
+ model,
1785
+ ae,
1786
+ conditioner,
1787
+ denoiser,
1788
+ (
1789
+ samplers[1]
1790
+ if len(samplers) > 1
1791
+ and options.get("ltr_first_pass", False)
1792
+ and chunk_strategy_first_pass != "gt"
1793
+ and i > 0
1794
+ else samplers[0]
1795
+ ),
1796
+ value_dict,
1797
+ H,
1798
+ W,
1799
+ C,
1800
+ F,
1801
+ cfg=(
1802
+ options["cfg"][0]
1803
+ if isinstance(options["cfg"], (list, tuple))
1804
+ else options["cfg"]
1805
+ ),
1806
+ T=T_first_pass,
1807
+ global_pbar=first_pass_pbar,
1808
+ **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
1809
+ )
1810
+ if samples is None:
1811
+ return
1812
+ samples = decode_output(
1813
+ samples, T_first_pass, chunk_prior_sels
1814
+ ) # decode into dict
1815
+ extend_dict(all_samples, samples)
1816
+ all_prior_inds.extend(chunk_prior_inds)
1817
+
1818
+ if options.get("save_first_pass", True):
1819
+ save_output(
1820
+ all_samples,
1821
+ save_path=os.path.join(save_path, "first-pass"),
1822
+ video_save_fps=5,
1823
+ )
1824
+ video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
1825
+ yield video_path_0
1826
+
1827
+ # ---------------------------------- second pass ----------------------------------
1828
+ prior_indices = image_cond["prior_indices"]
1829
+ assert (
1830
+ prior_indices is not None
1831
+ ), "`prior_frame_indices` needs to be set if using 2-pass sampling."
1832
+ prior_argsort = np.argsort(input_indices + prior_indices).tolist()
1833
+ prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
1834
+ gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]
1835
+
1836
+ traj_prior_imgs = torch.cat(
1837
+ [input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
1838
+ )[prior_argsort]
1839
+ traj_prior_imgs_clip = torch.cat(
1840
+ [
1841
+ input_imgs_clip,
1842
+ get_k_from_dict(all_samples, "samples-rgb"),
1843
+ ],
1844
+ dim=0,
1845
+ )[prior_argsort]
1846
+ traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
1847
+ traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]
1848
+
1849
+ update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
1850
+ update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
1851
+ update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)
1852
+
1853
+ chunk_strategy = options.get("chunk_strategy", "nearest")
1854
+ (
1855
+ _,
1856
+ prior_inds_per_chunk,
1857
+ prior_sels_per_chunk,
1858
+ test_inds_per_chunk,
1859
+ test_sels_per_chunk,
1860
+ ) = chunk_input_and_test(
1861
+ T_second_pass,
1862
+ traj_prior_c2ws,
1863
+ test_c2ws,
1864
+ prior_indices,
1865
+ test_indices,
1866
+ options=options,
1867
+ task=task,
1868
+ chunk_strategy=chunk_strategy,
1869
+ gt_input_inds=gt_input_inds,
1870
+ )
1871
+ print(
1872
+ f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
1873
+ f"{len(prior_inds_per_chunk)} forward(s) ..."
1874
+ )
1875
+
1876
+ all_samples = {}
1877
+ all_test_inds = []
1878
+ for i, (
1879
+ chunk_prior_inds,
1880
+ chunk_prior_sels,
1881
+ chunk_test_inds,
1882
+ chunk_test_sels,
1883
+ ) in tqdm(
1884
+ enumerate(
1885
+ zip(
1886
+ prior_inds_per_chunk,
1887
+ prior_sels_per_chunk,
1888
+ test_inds_per_chunk,
1889
+ test_sels_per_chunk,
1890
+ )
1891
+ ),
1892
+ total=len(prior_inds_per_chunk),
1893
+ leave=False,
1894
+ ):
1895
+ (
1896
+ curr_prior_sels,
1897
+ curr_test_sels,
1898
+ curr_prior_maps,
1899
+ curr_test_maps,
1900
+ ) = pad_indices(
1901
+ chunk_prior_sels,
1902
+ chunk_test_sels,
1903
+ T=T_second_pass,
1904
+ padding_mode="last",
1905
+ )
1906
+ curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
1907
+ assemble(
1908
+ input=x[chunk_prior_inds],
1909
+ test=y[chunk_test_inds],
1910
+ input_maps=curr_prior_maps,
1911
+ test_maps=curr_test_maps,
1912
+ )
1913
+ for x, y in zip(
1914
+ [
1915
+ traj_prior_imgs,
1916
+ traj_prior_imgs_clip,
1917
+ traj_prior_c2ws,
1918
+ traj_prior_Ks,
1919
+ ],
1920
+ [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
1921
+ )
1922
+ ]
1923
+ value_dict = get_value_dict(
1924
+ curr_imgs.to("cuda"),
1925
+ curr_imgs_clip.to("cuda"),
1926
+ curr_prior_sels,
1927
+ curr_c2ws,
1928
+ curr_Ks,
1929
+ list(range(T_second_pass)),
1930
+ all_c2ws=camera_cond["c2w"],
1931
+ camera_scale=options.get("camera_scale", 2.0),
1932
+ )
1933
+ samples = do_sample(
1934
+ model,
1935
+ ae,
1936
+ conditioner,
1937
+ denoiser,
1938
+ samplers[1] if len(samplers) > 1 else samplers[0],
1939
+ value_dict,
1940
+ H,
1941
+ W,
1942
+ C,
1943
+ F,
1944
+ T=T_second_pass,
1945
+ cfg=(
1946
+ options["cfg"][1]
1947
+ if isinstance(options["cfg"], (list, tuple))
1948
+ and len(options["cfg"]) > 1
1949
+ else options["cfg"]
1950
+ ),
1951
+ global_pbar=second_pass_pbar,
1952
+ **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
1953
+ )
1954
+ if samples is None:
1955
+ return
1956
+ samples = decode_output(
1957
+ samples, T_second_pass, chunk_test_sels
1958
+ ) # decode into dict
1959
+ if options.get("save_second_pass", False):
1960
+ save_output(
1961
+ replace_or_include_input_for_dict(
1962
+ samples,
1963
+ chunk_test_sels,
1964
+ curr_imgs,
1965
+ curr_c2ws,
1966
+ curr_Ks,
1967
+ ),
1968
+ save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
1969
+ video_save_fps=2,
1970
+ )
1971
+ extend_dict(all_samples, samples)
1972
+ all_test_inds.extend(chunk_test_inds)
1973
+ all_samples = {
1974
+ key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
1975
+ }
1976
+ save_output(
1977
+ replace_or_include_input_for_dict(
1978
+ all_samples,
1979
+ test_indices,
1980
+ imgs.clone(),
1981
+ camera_cond["c2w"].clone(),
1982
+ camera_cond["K"].clone(),
1983
+ )
1984
+ if options.get("replace_or_include_input", False)
1985
+ else all_samples,
1986
+ save_path=save_path,
1987
+ video_save_fps=options.get("video_save_fps", 2),
1988
+ )
1989
+ video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
1990
+ yield video_path_1
seva/geometry.py ADDED
@@ -0,0 +1,811 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import roma
5
+ import scipy.interpolate
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default
10
+
11
+
12
+ def get_camera_dist(
13
+ source_c2ws: torch.Tensor, # N x 3 x 4
14
+ target_c2ws: torch.Tensor, # M x 3 x 4
15
+ mode: str = "translation",
16
+ ):
17
+ if mode == "rotation":
18
+ dists = torch.acos(
19
+ (
20
+ (
21
+ torch.matmul(
22
+ source_c2ws[:, None, :3, :3],
23
+ target_c2ws[None, :, :3, :3].transpose(-1, -2),
24
+ )
25
+ .diagonal(offset=0, dim1=-2, dim2=-1)
26
+ .sum(-1)
27
+ - 1
28
+ )
29
+ / 2
30
+ ).clamp(-1, 1)
31
+ ) * (180 / torch.pi)
32
+ elif mode == "translation":
33
+ dists = torch.norm(
34
+ source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1
35
+ )
36
+ else:
37
+ raise NotImplementedError(
38
+ f"Mode {mode} is not implemented for finding nearest source indices."
39
+ )
40
+ return dists
41
+
42
+
43
+ def to_hom(X):
44
+ # get homogeneous coordinates of the input
45
+ X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
46
+ return X_hom
47
+
48
+
49
+ def to_hom_pose(pose):
50
+ # get homogeneous coordinates of the input pose
51
+ if pose.shape[-2:] == (3, 4):
52
+ pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1)
53
+ pose_hom[:, :3, :] = pose
54
+ return pose_hom
55
+ return pose
56
+
57
+
58
+ def get_default_intrinsics(
59
+ fov_rad=DEFAULT_FOV_RAD,
60
+ aspect_ratio=1.0,
61
+ ):
62
+ if not isinstance(fov_rad, torch.Tensor):
63
+ fov_rad = torch.tensor(
64
+ [fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad
65
+ )
66
+ if aspect_ratio >= 1.0: # W >= H
67
+ focal_x = 0.5 / torch.tan(0.5 * fov_rad)
68
+ focal_y = focal_x * aspect_ratio
69
+ else: # W < H
70
+ focal_y = 0.5 / torch.tan(0.5 * fov_rad)
71
+ focal_x = focal_y / aspect_ratio
72
+ intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3))
73
+ intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack(
74
+ [focal_x, focal_y, torch.ones_like(focal_x)], dim=-1
75
+ )
76
+ intrinsics[:, :, -1] = torch.tensor(
77
+ [0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype
78
+ )
79
+ return intrinsics
80
+
81
+
82
+ def get_image_grid(img_h, img_w):
83
+ # add 0.5 is VERY important especially when your img_h and img_w
84
+ # is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
85
+ y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5)
86
+ x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5)
87
+ Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W]
88
+ xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2]
89
+ return to_hom(xy_grid) # [HW,3]
90
+
91
+
92
+ def img2cam(X, cam_intr):
93
+ return X @ cam_intr.inverse().transpose(-1, -2)
94
+
95
+
96
+ def cam2world(X, pose):
97
+ X_hom = to_hom(X)
98
+ pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4]
99
+ return X_hom @ pose_inv.transpose(-1, -2)
100
+
101
+
102
+ def get_center_and_ray(
103
+ img_h, img_w, pose, intr, zero_center_for_debugging=False
104
+ ): # [HW,2]
105
+ # given the intrinsic/extrinsic matrices, get the camera center and ray directions]
106
+ # assert(opt.camera.model=="perspective")
107
+
108
+ # compute center and ray
109
+ grid_img = get_image_grid(img_h, img_w) # [HW,3]
110
+ grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3]
111
+ center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3]
112
+
113
+ # transform from camera to world coordinates
114
+ grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3]
115
+ center_3D = cam2world(center_3D_cam, pose) # [B,HW,3]
116
+ ray = grid_3D - center_3D # [B,HW,3]
117
+
118
+ return center_3D_cam if zero_center_for_debugging else center_3D, ray, grid_3D_cam
119
+
120
+
121
+ def get_plucker_coordinates(
122
+ extrinsics_src,
123
+ extrinsics,
124
+ intrinsics=None,
125
+ fov_rad=DEFAULT_FOV_RAD,
126
+ mode="plucker",
127
+ rel_zero_translation=True,
128
+ zero_center_for_debugging=False,
129
+ target_size=[72, 72], # 576-size image
130
+ return_grid_cam=False, # save for later use if want restore
131
+ ):
132
+ if intrinsics is None:
133
+ intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
134
+ else:
135
+ # for some data preprocessed in the early stage (e.g., MVI and CO3D),
136
+ # intrinsics are expressed in raw pixel space (e.g., 576x576) instead
137
+ # of normalized image coordinates
138
+ if not (
139
+ torch.all(intrinsics[:, :2, -1] >= 0)
140
+ and torch.all(intrinsics[:, :2, -1] <= 1)
141
+ ):
142
+ intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8
143
+ # you should ensure the intrisics are expressed in
144
+ # resolution-independent normalized image coordinates just performing a
145
+ # very simple verification here checking if principal points are
146
+ # between 0 and 1
147
+ assert (
148
+ torch.all(intrinsics[:, :2, -1] >= 0)
149
+ and torch.all(intrinsics[:, :2, -1] <= 1)
150
+ ), "Intrinsics should be expressed in resolution-independent normalized image coordinates."
151
+
152
+ c2w_src = torch.linalg.inv(extrinsics_src)
153
+ if not rel_zero_translation:
154
+ c2w_src[:3, 3] = c2w_src[3, :3] = 0.0
155
+ # transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera
156
+ extrinsics_rel = torch.einsum(
157
+ "vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1)
158
+ )
159
+
160
+ intrinsics[:, :2] *= extrinsics.new_tensor(
161
+ [
162
+ target_size[1], # w
163
+ target_size[0], # h
164
+ ]
165
+ ).view(1, -1, 1)
166
+ centers, rays, grid_cam = get_center_and_ray(
167
+ img_h=target_size[0],
168
+ img_w=target_size[1],
169
+ pose=extrinsics_rel[:, :3, :],
170
+ intr=intrinsics,
171
+ zero_center_for_debugging=zero_center_for_debugging,
172
+ )
173
+
174
+ if mode == "plucker" or "v1" in mode:
175
+ rays = torch.nn.functional.normalize(rays, dim=-1)
176
+ plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1)
177
+ else:
178
+ raise ValueError(f"Unknown Plucker coordinate mode: {mode}")
179
+
180
+ plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size)
181
+ if return_grid_cam:
182
+ return plucker, grid_cam.reshape(-1, *target_size, 3)
183
+ return plucker
184
+
185
+
186
+ def rt_to_mat4(
187
+ R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
188
+ ) -> torch.Tensor:
189
+ """
190
+ Args:
191
+ R (torch.Tensor): (..., 3, 3).
192
+ t (torch.Tensor): (..., 3).
193
+ s (torch.Tensor): (...,).
194
+
195
+ Returns:
196
+ torch.Tensor: (..., 4, 4)
197
+ """
198
+ mat34 = torch.cat([R, t[..., None]], dim=-1)
199
+ if s is None:
200
+ bottom = (
201
+ mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
202
+ .reshape((1,) * (mat34.dim() - 2) + (1, 4))
203
+ .expand(mat34.shape[:-2] + (1, 4))
204
+ )
205
+ else:
206
+ bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
207
+ mat4 = torch.cat([mat34, bottom], dim=-2)
208
+ return mat4
209
+
210
+
211
+ def get_preset_pose_fov(
212
+ option: Literal[
213
+ "orbit",
214
+ "spiral",
215
+ "lemniscate",
216
+ "zoom-in",
217
+ "zoom-out",
218
+ "dolly zoom-in",
219
+ "dolly zoom-out",
220
+ "move-forward",
221
+ "move-backward",
222
+ "move-up",
223
+ "move-down",
224
+ "move-left",
225
+ "move-right",
226
+ "roll",
227
+ ],
228
+ num_frames: int,
229
+ start_w2c: torch.Tensor,
230
+ look_at: torch.Tensor,
231
+ up_direction: torch.Tensor | None = None,
232
+ fov: float = DEFAULT_FOV_RAD,
233
+ spiral_radii: list[float] = [0.5, 0.5, 0.2],
234
+ zoom_factor: float | None = None,
235
+ ):
236
+ poses = fovs = None
237
+ if option == "orbit":
238
+ poses = torch.linalg.inv(
239
+ get_arc_horizontal_w2cs(
240
+ start_w2c,
241
+ look_at,
242
+ up_direction,
243
+ num_frames=num_frames,
244
+ endpoint=False,
245
+ )
246
+ ).numpy()
247
+ fovs = np.full((num_frames,), fov)
248
+ elif option == "spiral":
249
+ poses = generate_spiral_path(
250
+ torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]),
251
+ np.array([1, 5]),
252
+ n_frames=num_frames,
253
+ n_rots=2,
254
+ zrate=0.5,
255
+ radii=spiral_radii,
256
+ endpoint=False,
257
+ ) @ np.diagflat([1, -1, -1, 1])
258
+ poses = np.concatenate(
259
+ [
260
+ poses,
261
+ np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0),
262
+ ],
263
+ 1,
264
+ )
265
+ # We want the spiral trajectory to always start from start_w2c. Thus we
266
+ # apply the relative pose to get the final trajectory.
267
+ poses = (
268
+ np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses
269
+ )
270
+ fovs = np.full((num_frames,), fov)
271
+ elif option == "lemniscate":
272
+ poses = torch.linalg.inv(
273
+ get_lemniscate_w2cs(
274
+ start_w2c,
275
+ look_at,
276
+ up_direction,
277
+ num_frames,
278
+ degree=60.0,
279
+ endpoint=False,
280
+ )
281
+ ).numpy()
282
+ fovs = np.full((num_frames,), fov)
283
+ elif option == "roll":
284
+ poses = torch.linalg.inv(
285
+ get_roll_w2cs(
286
+ start_w2c,
287
+ look_at,
288
+ None,
289
+ num_frames,
290
+ degree=360.0,
291
+ endpoint=False,
292
+ )
293
+ ).numpy()
294
+ fovs = np.full((num_frames,), fov)
295
+ elif option in [
296
+ "dolly zoom-in",
297
+ "dolly zoom-out",
298
+ "zoom-in",
299
+ "zoom-out",
300
+ ]:
301
+ if option.startswith("dolly"):
302
+ direction = "backward" if option == "dolly zoom-in" else "forward"
303
+ poses = torch.linalg.inv(
304
+ get_moving_w2cs(
305
+ start_w2c,
306
+ look_at,
307
+ up_direction,
308
+ num_frames,
309
+ endpoint=True,
310
+ direction=direction,
311
+ )
312
+ ).numpy()
313
+ else:
314
+ poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy()
315
+ fov_rad_start = fov
316
+ if zoom_factor is None:
317
+ zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5
318
+ fov_rad_end = zoom_factor * fov
319
+ fovs = (
320
+ np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start)
321
+ + fov_rad_start
322
+ )
323
+ elif option in [
324
+ "move-forward",
325
+ "move-backward",
326
+ "move-up",
327
+ "move-down",
328
+ "move-left",
329
+ "move-right",
330
+ ]:
331
+ poses = torch.linalg.inv(
332
+ get_moving_w2cs(
333
+ start_w2c,
334
+ look_at,
335
+ up_direction,
336
+ num_frames,
337
+ endpoint=True,
338
+ direction=option.removeprefix("move-"),
339
+ )
340
+ ).numpy()
341
+ fovs = np.full((num_frames,), fov)
342
+ else:
343
+ raise ValueError(f"Unknown preset option {option}.")
344
+
345
+ return poses, fovs
346
+
347
+
348
+ def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
349
+ """Triangulate a set of rays to find a single lookat point.
350
+
351
+ Args:
352
+ origins (torch.Tensor): A (N, 3) array of ray origins.
353
+ viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
354
+
355
+ Returns:
356
+ torch.Tensor: A (3,) lookat point.
357
+ """
358
+
359
+ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
360
+ eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
361
+ # Calculate projection matrix I - rr^T
362
+ I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
363
+ # Compute sum of projections
364
+ sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
365
+ # Solve for the intersection point using least squares
366
+ lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
367
+ # Check NaNs.
368
+ assert not torch.any(torch.isnan(lookat))
369
+ return lookat
370
+
371
+
372
+ def get_lookat_w2cs(
373
+ positions: torch.Tensor,
374
+ lookat: torch.Tensor,
375
+ up: torch.Tensor,
376
+ face_off: bool = False,
377
+ ):
378
+ """
379
+ Args:
380
+ positions: (N, 3) tensor of camera positions
381
+ lookat: (3,) tensor of lookat point
382
+ up: (3,) or (N, 3) tensor of up vector
383
+
384
+ Returns:
385
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
386
+ """
387
+ forward_vectors = F.normalize(lookat - positions, dim=-1)
388
+ if face_off:
389
+ forward_vectors = -forward_vectors
390
+ if up.dim() == 1:
391
+ up = up[None]
392
+ right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1)
393
+ down_vectors = F.normalize(
394
+ torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
395
+ )
396
+ Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
397
+ w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
398
+ return w2cs
399
+
400
+
401
+ def get_arc_horizontal_w2cs(
402
+ ref_w2c: torch.Tensor,
403
+ lookat: torch.Tensor,
404
+ up: torch.Tensor | None,
405
+ num_frames: int,
406
+ clockwise: bool = True,
407
+ face_off: bool = False,
408
+ endpoint: bool = False,
409
+ degree: float = 360.0,
410
+ ref_up_shift: float = 0.0,
411
+ ref_radius_scale: float = 1.0,
412
+ **_,
413
+ ) -> torch.Tensor:
414
+ ref_c2w = torch.linalg.inv(ref_w2c)
415
+ ref_position = ref_c2w[:3, 3]
416
+ if up is None:
417
+ up = -ref_c2w[:3, 1]
418
+ assert up is not None
419
+ ref_position += up * ref_up_shift
420
+ ref_position *= ref_radius_scale
421
+ thetas = (
422
+ torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
423
+ if endpoint
424
+ else torch.linspace(
425
+ 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
426
+ )[:-1]
427
+ )
428
+ if not clockwise:
429
+ thetas = -thetas
430
+ positions = (
431
+ torch.einsum(
432
+ "nij,j->ni",
433
+ roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
434
+ ref_position - lookat,
435
+ )
436
+ + lookat
437
+ )
438
+ return get_lookat_w2cs(positions, lookat, up, face_off=face_off)
439
+
440
+
441
+ def get_lemniscate_w2cs(
442
+ ref_w2c: torch.Tensor,
443
+ lookat: torch.Tensor,
444
+ up: torch.Tensor | None,
445
+ num_frames: int,
446
+ degree: float,
447
+ endpoint: bool = False,
448
+ **_,
449
+ ) -> torch.Tensor:
450
+ ref_c2w = torch.linalg.inv(ref_w2c)
451
+ a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
452
+ # Lemniscate curve in camera space. Starting at the origin.
453
+ thetas = (
454
+ torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device)
455
+ if endpoint
456
+ else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
457
+ ) + torch.pi / 2
458
+ positions = torch.stack(
459
+ [
460
+ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
461
+ a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
462
+ torch.zeros(num_frames, device=ref_w2c.device),
463
+ ],
464
+ dim=-1,
465
+ )
466
+ # Transform to world space.
467
+ positions = torch.einsum(
468
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
469
+ )
470
+ if up is None:
471
+ up = -ref_c2w[:3, 1]
472
+ assert up is not None
473
+ return get_lookat_w2cs(positions, lookat, up)
474
+
475
+
476
+ def get_moving_w2cs(
477
+ ref_w2c: torch.Tensor,
478
+ lookat: torch.Tensor,
479
+ up: torch.Tensor | None,
480
+ num_frames: int,
481
+ endpoint: bool = False,
482
+ direction: str = "forward",
483
+ tilt_xy: torch.Tensor = None,
484
+ ):
485
+ """
486
+ Args:
487
+ ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix
488
+ lookat: (3,) tensor of lookat point
489
+ up: (3,) tensor of up vector
490
+
491
+ Returns:
492
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
493
+ """
494
+ ref_c2w = torch.linalg.inv(ref_w2c)
495
+ ref_position = ref_c2w[:3, -1]
496
+ if up is None:
497
+ up = -ref_c2w[:3, 1]
498
+
499
+ direction_vectors = {
500
+ "forward": (lookat - ref_position).clone(),
501
+ "backward": -(lookat - ref_position).clone(),
502
+ "up": up.clone(),
503
+ "down": -up.clone(),
504
+ "right": torch.cross((lookat - ref_position), up, dim=0),
505
+ "left": -torch.cross((lookat - ref_position), up, dim=0),
506
+ }
507
+ if direction not in direction_vectors:
508
+ raise ValueError(
509
+ f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}"
510
+ )
511
+
512
+ positions = ref_position + (
513
+ F.normalize(direction_vectors[direction], dim=0)
514
+ * (
515
+ torch.linspace(0, 0.99, num_frames, device=ref_w2c.device)
516
+ if endpoint
517
+ else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1]
518
+ )[:, None]
519
+ )
520
+
521
+ if tilt_xy is not None:
522
+ positions[:, :2] += tilt_xy
523
+
524
+ return get_lookat_w2cs(positions, lookat, up)
525
+
526
+
527
+ def get_roll_w2cs(
528
+ ref_w2c: torch.Tensor,
529
+ lookat: torch.Tensor,
530
+ up: torch.Tensor | None,
531
+ num_frames: int,
532
+ endpoint: bool = False,
533
+ degree: float = 360.0,
534
+ **_,
535
+ ) -> torch.Tensor:
536
+ ref_c2w = torch.linalg.inv(ref_w2c)
537
+ ref_position = ref_c2w[:3, 3]
538
+ if up is None:
539
+ up = -ref_c2w[:3, 1] # Infer the up vector from the reference.
540
+
541
+ # Create vertical angles
542
+ thetas = (
543
+ torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
544
+ if endpoint
545
+ else torch.linspace(
546
+ 0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
547
+ )[:-1]
548
+ )[:, None]
549
+
550
+ lookat_vector = F.normalize(lookat[None].float(), dim=-1)
551
+ up = up[None]
552
+ up = (
553
+ up * torch.cos(thetas)
554
+ + torch.cross(lookat_vector, up) * torch.sin(thetas)
555
+ + lookat_vector
556
+ * torch.einsum("ij,ij->i", lookat_vector, up)[:, None]
557
+ * (1 - torch.cos(thetas))
558
+ )
559
+
560
+ # Normalize the camera orientation
561
+ return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up)
562
+
563
+
564
+ def normalize(x):
565
+ """Normalization helper function."""
566
+ return x / np.linalg.norm(x)
567
+
568
+
569
+ def viewmatrix(lookdir, up, position, subtract_position=False):
570
+ """Construct lookat view matrix."""
571
+ vec2 = normalize((lookdir - position) if subtract_position else lookdir)
572
+ vec0 = normalize(np.cross(up, vec2))
573
+ vec1 = normalize(np.cross(vec2, vec0))
574
+ m = np.stack([vec0, vec1, vec2, position], axis=1)
575
+ return m
576
+
577
+
578
+ def poses_avg(poses):
579
+ """New pose using average position, z-axis, and up vector of input poses."""
580
+ position = poses[:, :3, 3].mean(0)
581
+ z_axis = poses[:, :3, 2].mean(0)
582
+ up = poses[:, :3, 1].mean(0)
583
+ cam2world = viewmatrix(z_axis, up, position)
584
+ return cam2world
585
+
586
+
587
+ def generate_spiral_path(
588
+ poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None
589
+ ):
590
+ """Calculates a forward facing spiral path for rendering."""
591
+ # Find a reasonable 'focus depth' for this dataset as a weighted average
592
+ # of near and far bounds in disparity space.
593
+ close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0
594
+ dt = 0.75
595
+ focal = 1 / ((1 - dt) / close_depth + dt / inf_depth)
596
+
597
+ # Get radii for spiral path using 90th percentile of camera positions.
598
+ positions = poses[:, :3, 3]
599
+ if radii is None:
600
+ radii = np.percentile(np.abs(positions), 90, 0)
601
+ radii = np.concatenate([radii, [1.0]])
602
+
603
+ # Generate poses for spiral path.
604
+ render_poses = []
605
+ cam2world = poses_avg(poses)
606
+ up = poses[:, :3, 1].mean(0)
607
+ for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint):
608
+ t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
609
+ position = cam2world @ t
610
+ lookat = cam2world @ [0, 0, -focal, 1.0]
611
+ z_axis = position - lookat
612
+ render_poses.append(viewmatrix(z_axis, up, position))
613
+ render_poses = np.stack(render_poses, axis=0)
614
+ return render_poses
615
+
616
+
617
+ def generate_interpolated_path(
618
+ poses: np.ndarray,
619
+ n_interp: int,
620
+ spline_degree: int = 5,
621
+ smoothness: float = 0.03,
622
+ rot_weight: float = 0.1,
623
+ endpoint: bool = False,
624
+ ):
625
+ """Creates a smooth spline path between input keyframe camera poses.
626
+
627
+ Spline is calculated with poses in format (position, lookat-point, up-point).
628
+
629
+ Args:
630
+ poses: (n, 3, 4) array of input pose keyframes.
631
+ n_interp: returned path will have n_interp * (n - 1) total poses.
632
+ spline_degree: polynomial degree of B-spline.
633
+ smoothness: parameter for spline smoothing, 0 forces exact interpolation.
634
+ rot_weight: relative weighting of rotation/translation in spline solve.
635
+
636
+ Returns:
637
+ Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
638
+ """
639
+
640
+ def poses_to_points(poses, dist):
641
+ """Converts from pose matrices to (position, lookat, up) format."""
642
+ pos = poses[:, :3, -1]
643
+ lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
644
+ up = poses[:, :3, -1] + dist * poses[:, :3, 1]
645
+ return np.stack([pos, lookat, up], 1)
646
+
647
+ def points_to_poses(points):
648
+ """Converts from (position, lookat, up) format to pose matrices."""
649
+ return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
650
+
651
+ def interp(points, n, k, s):
652
+ """Runs multidimensional B-spline interpolation on the input points."""
653
+ sh = points.shape
654
+ pts = np.reshape(points, (sh[0], -1))
655
+ k = min(k, sh[0] - 1)
656
+ tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
657
+ u = np.linspace(0, 1, n, endpoint=endpoint)
658
+ new_points = np.array(scipy.interpolate.splev(u, tck))
659
+ new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
660
+ return new_points
661
+
662
+ points = poses_to_points(poses, dist=rot_weight)
663
+ new_points = interp(
664
+ points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
665
+ )
666
+ return points_to_poses(new_points)
667
+
668
+
669
+ def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
670
+ """
671
+ reference: nerf-factory
672
+ Get a similarity transform to normalize dataset
673
+ from c2w (OpenCV convention) cameras
674
+ :param c2w: (N, 4)
675
+ :return T (4,4) , scale (float)
676
+ """
677
+ t = c2w[:, :3, 3]
678
+ R = c2w[:, :3, :3]
679
+
680
+ # (1) Rotate the world so that z+ is the up axis
681
+ # we estimate the up axis by averaging the camera up axes
682
+ ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
683
+ world_up = np.mean(ups, axis=0)
684
+ world_up /= np.linalg.norm(world_up)
685
+
686
+ up_camspace = np.array([0.0, -1.0, 0.0])
687
+ c = (up_camspace * world_up).sum()
688
+ cross = np.cross(world_up, up_camspace)
689
+ skew = np.array(
690
+ [
691
+ [0.0, -cross[2], cross[1]],
692
+ [cross[2], 0.0, -cross[0]],
693
+ [-cross[1], cross[0], 0.0],
694
+ ]
695
+ )
696
+ if c > -1:
697
+ R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
698
+ else:
699
+ # In the unlikely case the original data has y+ up axis,
700
+ # rotate 180-deg about x axis
701
+ R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
702
+
703
+ # R_align = np.eye(3) # DEBUG
704
+ R = R_align @ R
705
+ fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
706
+ t = (R_align @ t[..., None])[..., 0]
707
+
708
+ # (2) Recenter the scene.
709
+ if center_method == "focus":
710
+ # find the closest point to the origin for each camera's center ray
711
+ nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
712
+ translate = -np.median(nearest, axis=0)
713
+ elif center_method == "poses":
714
+ # use center of the camera positions
715
+ translate = -np.median(t, axis=0)
716
+ else:
717
+ raise ValueError(f"Unknown center_method {center_method}")
718
+
719
+ transform = np.eye(4)
720
+ transform[:3, 3] = translate
721
+ transform[:3, :3] = R_align
722
+
723
+ # (3) Rescale the scene using camera distances
724
+ scale_fn = np.max if strict_scaling else np.median
725
+ inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1))
726
+ if inv_scale == 0:
727
+ inv_scale = 1.0
728
+ scale = 1.0 / inv_scale
729
+ transform[:3, :] *= scale
730
+
731
+ return transform
732
+
733
+
734
+ def align_principle_axes(point_cloud):
735
+ # Compute centroid
736
+ centroid = np.median(point_cloud, axis=0)
737
+
738
+ # Translate point cloud to centroid
739
+ translated_point_cloud = point_cloud - centroid
740
+
741
+ # Compute covariance matrix
742
+ covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
743
+
744
+ # Compute eigenvectors and eigenvalues
745
+ eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
746
+
747
+ # Sort eigenvectors by eigenvalues (descending order) so that the z-axis
748
+ # is the principal axis with the smallest eigenvalue.
749
+ sort_indices = eigenvalues.argsort()[::-1]
750
+ eigenvectors = eigenvectors[:, sort_indices]
751
+
752
+ # Check orientation of eigenvectors. If the determinant of the eigenvectors is
753
+ # negative, then we need to flip the sign of one of the eigenvectors.
754
+ if np.linalg.det(eigenvectors) < 0:
755
+ eigenvectors[:, 0] *= -1
756
+
757
+ # Create rotation matrix
758
+ rotation_matrix = eigenvectors.T
759
+
760
+ # Create SE(3) matrix (4x4 transformation matrix)
761
+ transform = np.eye(4)
762
+ transform[:3, :3] = rotation_matrix
763
+ transform[:3, 3] = -rotation_matrix @ centroid
764
+
765
+ return transform
766
+
767
+
768
+ def transform_points(matrix, points):
769
+ """Transform points using a SE(4) matrix.
770
+
771
+ Args:
772
+ matrix: 4x4 SE(4) matrix
773
+ points: Nx3 array of points
774
+
775
+ Returns:
776
+ Nx3 array of transformed points
777
+ """
778
+ assert matrix.shape == (4, 4)
779
+ assert len(points.shape) == 2 and points.shape[1] == 3
780
+ return points @ matrix[:3, :3].T + matrix[:3, 3]
781
+
782
+
783
+ def transform_cameras(matrix, camtoworlds):
784
+ """Transform cameras using a SE(4) matrix.
785
+
786
+ Args:
787
+ matrix: 4x4 SE(4) matrix
788
+ camtoworlds: Nx4x4 array of camera-to-world matrices
789
+
790
+ Returns:
791
+ Nx4x4 array of transformed camera-to-world matrices
792
+ """
793
+ assert matrix.shape == (4, 4)
794
+ assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
795
+ camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
796
+ scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
797
+ camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
798
+ return camtoworlds
799
+
800
+
801
+ def normalize_scene(camtoworlds, points=None, camera_center_method="focus"):
802
+ T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method)
803
+ camtoworlds = transform_cameras(T1, camtoworlds)
804
+ if points is not None:
805
+ points = transform_points(T1, points)
806
+ T2 = align_principle_axes(points)
807
+ camtoworlds = transform_cameras(T2, camtoworlds)
808
+ points = transform_points(T2, points)
809
+ return camtoworlds, points, T2 @ T1
810
+ else:
811
+ return camtoworlds, T1
seva/gui.py ADDED
@@ -0,0 +1,975 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ import dataclasses
3
+ import threading
4
+ import time
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import scipy
9
+ import splines
10
+ import splines.quaternion
11
+ import torch
12
+ import viser
13
+ import viser.transforms as vt
14
+
15
+ from seva.geometry import get_preset_pose_fov
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class Keyframe(object):
20
+ position: np.ndarray
21
+ wxyz: np.ndarray
22
+ override_fov_enabled: bool
23
+ override_fov_rad: float
24
+ aspect: float
25
+ override_transition_enabled: bool
26
+ override_transition_sec: float | None
27
+
28
+ @staticmethod
29
+ def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe":
30
+ return Keyframe(
31
+ camera.position,
32
+ camera.wxyz,
33
+ override_fov_enabled=False,
34
+ override_fov_rad=camera.fov,
35
+ aspect=aspect,
36
+ override_transition_enabled=False,
37
+ override_transition_sec=None,
38
+ )
39
+
40
+ @staticmethod
41
+ def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe":
42
+ return Keyframe(
43
+ se3.translation(),
44
+ se3.rotation().wxyz,
45
+ override_fov_enabled=False,
46
+ override_fov_rad=fov,
47
+ aspect=aspect,
48
+ override_transition_enabled=False,
49
+ override_transition_sec=None,
50
+ )
51
+
52
+
53
+ class CameraTrajectory(object):
54
+ def __init__(
55
+ self,
56
+ server: viser.ViserServer,
57
+ duration_element: viser.GuiInputHandle[float],
58
+ scene_scale: float,
59
+ scene_node_prefix: str = "/",
60
+ ):
61
+ self._server = server
62
+ self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {}
63
+ self._keyframe_counter: int = 0
64
+ self._spline_nodes: list[viser.SceneNodeHandle] = []
65
+ self._camera_edit_panel: viser.Gui3dContainerHandle | None = None
66
+
67
+ self._orientation_spline: splines.quaternion.KochanekBartels | None = None
68
+ self._position_spline: splines.KochanekBartels | None = None
69
+ self._fov_spline: splines.KochanekBartels | None = None
70
+
71
+ self._keyframes_visible: bool = True
72
+
73
+ self._duration_element = duration_element
74
+ self._scene_node_prefix = scene_node_prefix
75
+
76
+ self.scene_scale = scene_scale
77
+ # These parameters should be overridden externally.
78
+ self.loop: bool = False
79
+ self.framerate: float = 30.0
80
+ self.tension: float = 0.0 # Tension / alpha term.
81
+ self.default_fov: float = 0.0
82
+ self.default_transition_sec: float = 0.0
83
+ self.show_spline: bool = True
84
+
85
+ def set_keyframes_visible(self, visible: bool) -> None:
86
+ self._keyframes_visible = visible
87
+ for keyframe in self._keyframes.values():
88
+ keyframe[1].visible = visible
89
+
90
+ def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None:
91
+ """Add a new camera, or replace an old one if `keyframe_index` is passed in."""
92
+ server = self._server
93
+
94
+ # Add a keyframe if we aren't replacing an existing one.
95
+ if keyframe_index is None:
96
+ keyframe_index = self._keyframe_counter
97
+ self._keyframe_counter += 1
98
+
99
+ print(
100
+ f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}"
101
+ )
102
+ frustum_handle = server.scene.add_camera_frustum(
103
+ str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"),
104
+ fov=(
105
+ keyframe.override_fov_rad
106
+ if keyframe.override_fov_enabled
107
+ else self.default_fov
108
+ ),
109
+ aspect=keyframe.aspect,
110
+ scale=0.1 * self.scene_scale,
111
+ color=(200, 10, 30),
112
+ wxyz=keyframe.wxyz,
113
+ position=keyframe.position,
114
+ visible=self._keyframes_visible,
115
+ )
116
+ self._server.scene.add_icosphere(
117
+ str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"),
118
+ radius=0.03,
119
+ color=(200, 10, 30),
120
+ )
121
+
122
+ @frustum_handle.on_click
123
+ def _(_) -> None:
124
+ if self._camera_edit_panel is not None:
125
+ self._camera_edit_panel.remove()
126
+ self._camera_edit_panel = None
127
+
128
+ with server.scene.add_3d_gui_container(
129
+ "/camera_edit_panel",
130
+ position=keyframe.position,
131
+ ) as camera_edit_panel:
132
+ self._camera_edit_panel = camera_edit_panel
133
+ override_fov = server.gui.add_checkbox(
134
+ "Override FOV", initial_value=keyframe.override_fov_enabled
135
+ )
136
+ override_fov_degrees = server.gui.add_slider(
137
+ "Override FOV (degrees)",
138
+ 5.0,
139
+ 175.0,
140
+ step=0.1,
141
+ initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
142
+ disabled=not keyframe.override_fov_enabled,
143
+ )
144
+ delete_button = server.gui.add_button(
145
+ "Delete", color="red", icon=viser.Icon.TRASH
146
+ )
147
+ go_to_button = server.gui.add_button("Go to")
148
+ close_button = server.gui.add_button("Close")
149
+
150
+ @override_fov.on_update
151
+ def _(_) -> None:
152
+ keyframe.override_fov_enabled = override_fov.value
153
+ override_fov_degrees.disabled = not override_fov.value
154
+ self.add_camera(keyframe, keyframe_index)
155
+
156
+ @override_fov_degrees.on_update
157
+ def _(_) -> None:
158
+ keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi
159
+ self.add_camera(keyframe, keyframe_index)
160
+
161
+ @delete_button.on_click
162
+ def _(event: viser.GuiEvent) -> None:
163
+ assert event.client is not None
164
+ with event.client.gui.add_modal("Confirm") as modal:
165
+ event.client.gui.add_markdown("Delete keyframe?")
166
+ confirm_button = event.client.gui.add_button(
167
+ "Yes", color="red", icon=viser.Icon.TRASH
168
+ )
169
+ exit_button = event.client.gui.add_button("Cancel")
170
+
171
+ @confirm_button.on_click
172
+ def _(_) -> None:
173
+ assert camera_edit_panel is not None
174
+
175
+ keyframe_id = None
176
+ for i, keyframe_tuple in self._keyframes.items():
177
+ if keyframe_tuple[1] is frustum_handle:
178
+ keyframe_id = i
179
+ break
180
+ assert keyframe_id is not None
181
+
182
+ self._keyframes.pop(keyframe_id)
183
+ frustum_handle.remove()
184
+ camera_edit_panel.remove()
185
+ self._camera_edit_panel = None
186
+ modal.close()
187
+ self.update_spline()
188
+
189
+ @exit_button.on_click
190
+ def _(_) -> None:
191
+ modal.close()
192
+
193
+ @go_to_button.on_click
194
+ def _(event: viser.GuiEvent) -> None:
195
+ assert event.client is not None
196
+ client = event.client
197
+ T_world_current = vt.SE3.from_rotation_and_translation(
198
+ vt.SO3(client.camera.wxyz), client.camera.position
199
+ )
200
+ T_world_target = vt.SE3.from_rotation_and_translation(
201
+ vt.SO3(keyframe.wxyz), keyframe.position
202
+ ) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
203
+
204
+ T_current_target = T_world_current.inverse() @ T_world_target
205
+
206
+ for j in range(10):
207
+ T_world_set = T_world_current @ vt.SE3.exp(
208
+ T_current_target.log() * j / 9.0
209
+ )
210
+
211
+ # Important bit: we atomically set both the orientation and
212
+ # the position of the camera.
213
+ with client.atomic():
214
+ client.camera.wxyz = T_world_set.rotation().wxyz
215
+ client.camera.position = T_world_set.translation()
216
+ time.sleep(1.0 / 30.0)
217
+
218
+ @close_button.on_click
219
+ def _(_) -> None:
220
+ assert camera_edit_panel is not None
221
+ camera_edit_panel.remove()
222
+ self._camera_edit_panel = None
223
+
224
+ self._keyframes[keyframe_index] = (keyframe, frustum_handle)
225
+
226
+ def update_aspect(self, aspect: float) -> None:
227
+ for keyframe_index, frame in self._keyframes.items():
228
+ frame = dataclasses.replace(frame[0], aspect=aspect)
229
+ self.add_camera(frame, keyframe_index=keyframe_index)
230
+
231
+ def get_aspect(self) -> float:
232
+ """Get W/H aspect ratio, which is shared across all keyframes."""
233
+ assert len(self._keyframes) > 0
234
+ return next(iter(self._keyframes.values()))[0].aspect
235
+
236
+ def reset(self) -> None:
237
+ for frame in self._keyframes.values():
238
+ print(f"removing {frame[1]}")
239
+ frame[1].remove()
240
+ self._keyframes.clear()
241
+ self.update_spline()
242
+ print("camera traj reset")
243
+
244
+ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
245
+ """From a time value in seconds, compute a t value for our geometric
246
+ spline interpolation. An increment of 1 for the latter will move the
247
+ camera forward by one keyframe.
248
+
249
+ We use a PCHIP spline here to guarantee monotonicity.
250
+ """
251
+ transition_times_cumsum = self.compute_transition_times_cumsum()
252
+ spline_indices = np.arange(transition_times_cumsum.shape[0])
253
+
254
+ if self.loop:
255
+ # In the case of a loop, we pad the spline to match the start/end
256
+ # slopes.
257
+ interpolator = scipy.interpolate.PchipInterpolator(
258
+ x=np.concatenate(
259
+ [
260
+ [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
261
+ transition_times_cumsum,
262
+ transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
263
+ ],
264
+ axis=0,
265
+ ),
266
+ y=np.concatenate(
267
+ [[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore
268
+ axis=0,
269
+ ),
270
+ )
271
+ else:
272
+ interpolator = scipy.interpolate.PchipInterpolator(
273
+ x=transition_times_cumsum, y=spline_indices
274
+ )
275
+
276
+ # Clip to account for floating point error.
277
+ return np.clip(interpolator(time), 0, spline_indices[-1])
278
+
279
+ def interpolate_pose_and_fov_rad(
280
+ self, normalized_t: float
281
+ ) -> tuple[vt.SE3, float] | None:
282
+ if len(self._keyframes) < 2:
283
+ return None
284
+
285
+ self._fov_spline = splines.KochanekBartels(
286
+ [
287
+ (
288
+ keyframe[0].override_fov_rad
289
+ if keyframe[0].override_fov_enabled
290
+ else self.default_fov
291
+ )
292
+ for keyframe in self._keyframes.values()
293
+ ],
294
+ tcb=(self.tension, 0.0, 0.0),
295
+ endconditions="closed" if self.loop else "natural",
296
+ )
297
+
298
+ assert self._orientation_spline is not None
299
+ assert self._position_spline is not None
300
+ assert self._fov_spline is not None
301
+
302
+ max_t = self.compute_duration()
303
+ t = max_t * normalized_t
304
+ spline_t = float(self.spline_t_from_t_sec(np.array(t)))
305
+
306
+ quat = self._orientation_spline.evaluate(spline_t)
307
+ assert isinstance(quat, splines.quaternion.UnitQuaternion)
308
+ return (
309
+ vt.SE3.from_rotation_and_translation(
310
+ vt.SO3(np.array([quat.scalar, *quat.vector])),
311
+ self._position_spline.evaluate(spline_t),
312
+ ),
313
+ float(self._fov_spline.evaluate(spline_t)),
314
+ )
315
+
316
+ def update_spline(self) -> None:
317
+ num_frames = int(self.compute_duration() * self.framerate)
318
+ keyframes = list(self._keyframes.values())
319
+
320
+ if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
321
+ for node in self._spline_nodes:
322
+ node.remove()
323
+ self._spline_nodes.clear()
324
+ return
325
+
326
+ transition_times_cumsum = self.compute_transition_times_cumsum()
327
+
328
+ self._orientation_spline = splines.quaternion.KochanekBartels(
329
+ [
330
+ splines.quaternion.UnitQuaternion.from_unit_xyzw(
331
+ np.roll(keyframe[0].wxyz, shift=-1)
332
+ )
333
+ for keyframe in keyframes
334
+ ],
335
+ tcb=(self.tension, 0.0, 0.0),
336
+ endconditions="closed" if self.loop else "natural",
337
+ )
338
+ self._position_spline = splines.KochanekBartels(
339
+ [keyframe[0].position for keyframe in keyframes],
340
+ tcb=(self.tension, 0.0, 0.0),
341
+ endconditions="closed" if self.loop else "natural",
342
+ )
343
+
344
+ # Update visualized spline.
345
+ points_array = self._position_spline.evaluate(
346
+ self.spline_t_from_t_sec(
347
+ np.linspace(0, transition_times_cumsum[-1], num_frames)
348
+ )
349
+ )
350
+ colors_array = np.array(
351
+ [
352
+ colorsys.hls_to_rgb(h, 0.5, 1.0)
353
+ for h in np.linspace(0.0, 1.0, len(points_array))
354
+ ]
355
+ )
356
+
357
+ # Clear prior spline nodes.
358
+ for node in self._spline_nodes:
359
+ node.remove()
360
+ self._spline_nodes.clear()
361
+
362
+ self._spline_nodes.append(
363
+ self._server.scene.add_spline_catmull_rom(
364
+ str(Path(self._scene_node_prefix) / "camera_spline"),
365
+ positions=points_array,
366
+ color=(220, 220, 220),
367
+ closed=self.loop,
368
+ line_width=1.0,
369
+ segments=points_array.shape[0] + 1,
370
+ )
371
+ )
372
+ self._spline_nodes.append(
373
+ self._server.scene.add_point_cloud(
374
+ str(Path(self._scene_node_prefix) / "camera_spline/points"),
375
+ points=points_array,
376
+ colors=colors_array,
377
+ point_size=0.04,
378
+ )
379
+ )
380
+
381
+ def make_transition_handle(i: int) -> None:
382
+ assert self._position_spline is not None
383
+ transition_pos = self._position_spline.evaluate(
384
+ float(
385
+ self.spline_t_from_t_sec(
386
+ (transition_times_cumsum[i] + transition_times_cumsum[i + 1])
387
+ / 2.0,
388
+ )
389
+ )
390
+ )
391
+ transition_sphere = self._server.scene.add_icosphere(
392
+ str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"),
393
+ radius=0.04,
394
+ color=(255, 0, 0),
395
+ position=transition_pos,
396
+ )
397
+ self._spline_nodes.append(transition_sphere)
398
+
399
+ @transition_sphere.on_click
400
+ def _(_) -> None:
401
+ server = self._server
402
+
403
+ if self._camera_edit_panel is not None:
404
+ self._camera_edit_panel.remove()
405
+ self._camera_edit_panel = None
406
+
407
+ keyframe_index = (i + 1) % len(self._keyframes)
408
+ keyframe = keyframes[keyframe_index][0]
409
+
410
+ with server.scene.add_3d_gui_container(
411
+ "/camera_edit_panel",
412
+ position=transition_pos,
413
+ ) as camera_edit_panel:
414
+ self._camera_edit_panel = camera_edit_panel
415
+ override_transition_enabled = server.gui.add_checkbox(
416
+ "Override transition",
417
+ initial_value=keyframe.override_transition_enabled,
418
+ )
419
+ override_transition_sec = server.gui.add_number(
420
+ "Override transition (sec)",
421
+ initial_value=(
422
+ keyframe.override_transition_sec
423
+ if keyframe.override_transition_sec is not None
424
+ else self.default_transition_sec
425
+ ),
426
+ min=0.001,
427
+ max=30.0,
428
+ step=0.001,
429
+ disabled=not override_transition_enabled.value,
430
+ )
431
+ close_button = server.gui.add_button("Close")
432
+
433
+ @override_transition_enabled.on_update
434
+ def _(_) -> None:
435
+ keyframe.override_transition_enabled = (
436
+ override_transition_enabled.value
437
+ )
438
+ override_transition_sec.disabled = (
439
+ not override_transition_enabled.value
440
+ )
441
+ self._duration_element.value = self.compute_duration()
442
+
443
+ @override_transition_sec.on_update
444
+ def _(_) -> None:
445
+ keyframe.override_transition_sec = override_transition_sec.value
446
+ self._duration_element.value = self.compute_duration()
447
+
448
+ @close_button.on_click
449
+ def _(_) -> None:
450
+ assert camera_edit_panel is not None
451
+ camera_edit_panel.remove()
452
+ self._camera_edit_panel = None
453
+
454
+ (num_transitions_plus_1,) = transition_times_cumsum.shape
455
+ for i in range(num_transitions_plus_1 - 1):
456
+ make_transition_handle(i)
457
+
458
+ def compute_duration(self) -> float:
459
+ """Compute the total duration of the trajectory."""
460
+ total = 0.0
461
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
462
+ if i == 0 and not self.loop:
463
+ continue
464
+ del frustum
465
+ total += (
466
+ keyframe.override_transition_sec
467
+ if keyframe.override_transition_enabled
468
+ and keyframe.override_transition_sec is not None
469
+ else self.default_transition_sec
470
+ )
471
+ return total
472
+
473
+ def compute_transition_times_cumsum(self) -> np.ndarray:
474
+ """Compute the total duration of the trajectory."""
475
+ total = 0.0
476
+ out = [0.0]
477
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
478
+ if i == 0:
479
+ continue
480
+ del frustum
481
+ total += (
482
+ keyframe.override_transition_sec
483
+ if keyframe.override_transition_enabled
484
+ and keyframe.override_transition_sec is not None
485
+ else self.default_transition_sec
486
+ )
487
+ out.append(total)
488
+
489
+ if self.loop:
490
+ keyframe = next(iter(self._keyframes.values()))[0]
491
+ total += (
492
+ keyframe.override_transition_sec
493
+ if keyframe.override_transition_enabled
494
+ and keyframe.override_transition_sec is not None
495
+ else self.default_transition_sec
496
+ )
497
+ out.append(total)
498
+
499
+ return np.array(out)
500
+
501
+
502
+ @dataclasses.dataclass
503
+ class GuiState:
504
+ preview_render: bool
505
+ preview_fov: float
506
+ preview_aspect: float
507
+ camera_traj_list: list | None
508
+ active_input_index: int
509
+
510
+
511
+ def define_gui(
512
+ server: viser.ViserServer,
513
+ init_fov: float = 75.0,
514
+ img_wh: tuple[int, int] = (576, 576),
515
+ **kwargs,
516
+ ) -> GuiState:
517
+ gui_state = GuiState(
518
+ preview_render=False,
519
+ preview_fov=0.0,
520
+ preview_aspect=1.0,
521
+ camera_traj_list=None,
522
+ active_input_index=0,
523
+ )
524
+
525
+ with server.gui.add_folder(
526
+ "Preset camera trajectories", order=99, expand_by_default=False
527
+ ):
528
+ preset_traj_dropdown = server.gui.add_dropdown(
529
+ "Options",
530
+ [
531
+ "orbit",
532
+ "spiral",
533
+ "lemniscate",
534
+ "zoom-out",
535
+ "dolly zoom-out",
536
+ ],
537
+ initial_value="orbit",
538
+ hint="Select a preset camera trajectory.",
539
+ )
540
+ preset_duration_num = server.gui.add_number(
541
+ "Duration (sec)",
542
+ min=1.0,
543
+ max=60.0,
544
+ step=0.5,
545
+ initial_value=2.0,
546
+ )
547
+ preset_submit_button = server.gui.add_button(
548
+ "Submit",
549
+ icon=viser.Icon.PICK,
550
+ hint="Add a new keyframe at the current pose.",
551
+ )
552
+
553
+ @preset_submit_button.on_click
554
+ def _(event: viser.GuiEvent) -> None:
555
+ camera_traj.reset()
556
+ gui_state.camera_traj_list = None
557
+
558
+ duration = preset_duration_num.value
559
+ fps = framerate_number.value
560
+ num_frames = int(duration * fps)
561
+ transition_sec = duration / num_frames
562
+ transition_sec_number.value = transition_sec
563
+ assert event.client_id is not None
564
+ transition_sec_number.disabled = True
565
+ loop_checkbox.disabled = True
566
+ add_keyframe_button.disabled = True
567
+
568
+ camera = server.get_clients()[event.client_id].camera
569
+ start_w2c = torch.linalg.inv(
570
+ torch.as_tensor(
571
+ vt.SE3.from_rotation_and_translation(
572
+ vt.SO3(camera.wxyz), camera.position
573
+ ).as_matrix(),
574
+ dtype=torch.float32,
575
+ )
576
+ )
577
+ look_at = torch.as_tensor(camera.look_at, dtype=torch.float32)
578
+ up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32)
579
+ poses, fovs = get_preset_pose_fov(
580
+ option=preset_traj_dropdown.value, # type: ignore
581
+ num_frames=num_frames,
582
+ start_w2c=start_w2c,
583
+ look_at=look_at,
584
+ up_direction=up_direction,
585
+ fov=camera.fov,
586
+ )
587
+ assert poses is not None and fovs is not None
588
+ for pose, fov in zip(poses, fovs):
589
+ camera_traj.add_camera(
590
+ Keyframe.from_se3(
591
+ vt.SE3.from_matrix(pose),
592
+ fov=fov,
593
+ aspect=img_wh[0] / img_wh[1],
594
+ )
595
+ )
596
+
597
+ duration_number.value = camera_traj.compute_duration()
598
+ camera_traj.update_spline()
599
+
600
+ with server.gui.add_folder("Advanced", expand_by_default=False, order=100):
601
+ transition_sec_number = server.gui.add_number(
602
+ "Transition (sec)",
603
+ min=0.001,
604
+ max=30.0,
605
+ step=0.001,
606
+ initial_value=1.5,
607
+ hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
608
+ )
609
+ framerate_number = server.gui.add_number(
610
+ "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
611
+ )
612
+ framerate_buttons = server.gui.add_button_group("", ("24", "30", "60"))
613
+ duration_number = server.gui.add_number(
614
+ "Duration (sec)",
615
+ min=0.0,
616
+ max=1e8,
617
+ step=0.001,
618
+ initial_value=0.0,
619
+ disabled=True,
620
+ )
621
+
622
+ @framerate_buttons.on_click
623
+ def _(_) -> None:
624
+ framerate_number.value = float(framerate_buttons.value)
625
+
626
+ fov_degree_slider = server.gui.add_slider(
627
+ "FOV",
628
+ initial_value=init_fov,
629
+ min=0.1,
630
+ max=175.0,
631
+ step=0.01,
632
+ hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
633
+ )
634
+
635
+ @fov_degree_slider.on_update
636
+ def _(_) -> None:
637
+ fov_radians = fov_degree_slider.value / 180.0 * np.pi
638
+ for client in server.get_clients().values():
639
+ client.camera.fov = fov_radians
640
+ camera_traj.default_fov = fov_radians
641
+
642
+ # Updating the aspect ratio will also re-render the camera frustums.
643
+ # Could rethink this.
644
+ camera_traj.update_aspect(img_wh[0] / img_wh[1])
645
+ compute_and_update_preview_camera_state()
646
+
647
+ scene_node_prefix = "/render_assets"
648
+ base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False)
649
+ add_keyframe_button = server.gui.add_button(
650
+ "Add keyframe",
651
+ icon=viser.Icon.PLUS,
652
+ hint="Add a new keyframe at the current pose.",
653
+ )
654
+
655
+ @add_keyframe_button.on_click
656
+ def _(event: viser.GuiEvent) -> None:
657
+ assert event.client_id is not None
658
+ camera = server.get_clients()[event.client_id].camera
659
+ pose = vt.SE3.from_rotation_and_translation(
660
+ vt.SO3(camera.wxyz), camera.position
661
+ )
662
+ print(f"client {event.client_id} at {camera.position} {camera.wxyz}")
663
+ print(f"camera pose {pose.as_matrix()}")
664
+
665
+ # Add this camera to the trajectory.
666
+ camera_traj.add_camera(
667
+ Keyframe.from_camera(
668
+ camera,
669
+ aspect=img_wh[0] / img_wh[1],
670
+ ),
671
+ )
672
+ duration_number.value = camera_traj.compute_duration()
673
+ camera_traj.update_spline()
674
+
675
+ clear_keyframes_button = server.gui.add_button(
676
+ "Clear keyframes",
677
+ icon=viser.Icon.TRASH,
678
+ hint="Remove all keyframes from the render trajectory.",
679
+ )
680
+
681
+ @clear_keyframes_button.on_click
682
+ def _(event: viser.GuiEvent) -> None:
683
+ assert event.client_id is not None
684
+ client = server.get_clients()[event.client_id]
685
+ with client.atomic(), client.gui.add_modal("Confirm") as modal:
686
+ client.gui.add_markdown("Clear all keyframes?")
687
+ confirm_button = client.gui.add_button(
688
+ "Yes", color="red", icon=viser.Icon.TRASH
689
+ )
690
+ exit_button = client.gui.add_button("Cancel")
691
+
692
+ @confirm_button.on_click
693
+ def _(_) -> None:
694
+ camera_traj.reset()
695
+ modal.close()
696
+
697
+ duration_number.value = camera_traj.compute_duration()
698
+ add_keyframe_button.disabled = False
699
+ transition_sec_number.disabled = False
700
+ transition_sec_number.value = 1.5
701
+ loop_checkbox.disabled = False
702
+
703
+ nonlocal gui_state
704
+ gui_state.camera_traj_list = None
705
+
706
+ @exit_button.on_click
707
+ def _(_) -> None:
708
+ modal.close()
709
+
710
+ play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
711
+ pause_button = server.gui.add_button(
712
+ "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
713
+ )
714
+
715
+ # Poll the play button to see if we should be playing endlessly.
716
+ def play() -> None:
717
+ while True:
718
+ while not play_button.visible:
719
+ max_frame = int(framerate_number.value * duration_number.value)
720
+ if max_frame > 0:
721
+ assert preview_frame_slider is not None
722
+ preview_frame_slider.value = (
723
+ preview_frame_slider.value + 1
724
+ ) % max_frame
725
+ time.sleep(1.0 / framerate_number.value)
726
+ time.sleep(0.1)
727
+
728
+ threading.Thread(target=play).start()
729
+
730
+ # Play the camera trajectory when the play button is pressed.
731
+ @play_button.on_click
732
+ def _(_) -> None:
733
+ play_button.visible = False
734
+ pause_button.visible = True
735
+
736
+ # Play the camera trajectory when the play button is pressed.
737
+ @pause_button.on_click
738
+ def _(_) -> None:
739
+ play_button.visible = True
740
+ pause_button.visible = False
741
+
742
+ preview_render_button = server.gui.add_button(
743
+ "Preview render",
744
+ hint="Show a preview of the render in the viewport.",
745
+ icon=viser.Icon.CAMERA_CHECK,
746
+ )
747
+ preview_render_stop_button = server.gui.add_button(
748
+ "Exit render preview",
749
+ color="red",
750
+ icon=viser.Icon.CAMERA_CANCEL,
751
+ visible=False,
752
+ )
753
+
754
+ @preview_render_button.on_click
755
+ def _(_) -> None:
756
+ gui_state.preview_render = True
757
+ preview_render_button.visible = False
758
+ preview_render_stop_button.visible = True
759
+ play_button.visible = False
760
+ pause_button.visible = True
761
+ preset_submit_button.disabled = True
762
+
763
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
764
+ if maybe_pose_and_fov_rad is None:
765
+ remove_preview_camera()
766
+ return
767
+ pose, fov = maybe_pose_and_fov_rad
768
+ del fov
769
+
770
+ # Hide all render assets when we're previewing the render.
771
+ nonlocal base_scene_node
772
+ base_scene_node.visible = False
773
+
774
+ # Back up and then set camera poses.
775
+ for client in server.get_clients().values():
776
+ camera_pose_backup_from_id[client.client_id] = (
777
+ client.camera.position,
778
+ client.camera.look_at,
779
+ client.camera.up_direction,
780
+ )
781
+ with client.atomic():
782
+ client.camera.wxyz = pose.rotation().wxyz
783
+ client.camera.position = pose.translation()
784
+
785
+ def stop_preview_render() -> None:
786
+ gui_state.preview_render = False
787
+ preview_render_button.visible = True
788
+ preview_render_stop_button.visible = False
789
+ play_button.visible = True
790
+ pause_button.visible = False
791
+ preset_submit_button.disabled = False
792
+
793
+ # Revert camera poses.
794
+ for client in server.get_clients().values():
795
+ if client.client_id not in camera_pose_backup_from_id:
796
+ continue
797
+ cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
798
+ client.client_id
799
+ )
800
+ with client.atomic():
801
+ client.camera.position = cam_position
802
+ client.camera.look_at = cam_look_at
803
+ client.camera.up_direction = cam_up
804
+ client.flush()
805
+
806
+ # Un-hide render assets.
807
+ nonlocal base_scene_node
808
+ base_scene_node.visible = True
809
+ remove_preview_camera()
810
+
811
+ @preview_render_stop_button.on_click
812
+ def _(_) -> None:
813
+ stop_preview_render()
814
+
815
+ def get_max_frame_index() -> int:
816
+ return max(1, int(framerate_number.value * duration_number.value) - 1)
817
+
818
+ def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None:
819
+ """Helper for creating the current frame # slider. This is removed and
820
+ re-added anytime the `max` value changes."""
821
+
822
+ preview_frame_slider = server.gui.add_slider(
823
+ "Preview frame",
824
+ min=0,
825
+ max=get_max_frame_index(),
826
+ step=1,
827
+ initial_value=0,
828
+ order=set_traj_button.order + 0.01,
829
+ disabled=get_max_frame_index() == 1,
830
+ )
831
+ play_button.disabled = preview_frame_slider.disabled
832
+ preview_render_button.disabled = preview_frame_slider.disabled
833
+ set_traj_button.disabled = preview_frame_slider.disabled
834
+
835
+ @preview_frame_slider.on_update
836
+ def _(_) -> None:
837
+ nonlocal preview_camera_handle
838
+ maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
839
+ if maybe_pose_and_fov_rad is None:
840
+ return
841
+ pose, fov_rad = maybe_pose_and_fov_rad
842
+
843
+ preview_camera_handle = server.scene.add_camera_frustum(
844
+ str(Path(scene_node_prefix) / "preview_camera"),
845
+ fov=fov_rad,
846
+ aspect=img_wh[0] / img_wh[1],
847
+ scale=0.35,
848
+ wxyz=pose.rotation().wxyz,
849
+ position=pose.translation(),
850
+ color=(10, 200, 30),
851
+ )
852
+ if gui_state.preview_render:
853
+ for client in server.get_clients().values():
854
+ with client.atomic():
855
+ client.camera.wxyz = pose.rotation().wxyz
856
+ client.camera.position = pose.translation()
857
+
858
+ return preview_frame_slider
859
+
860
+ set_traj_button = server.gui.add_button(
861
+ "Set camera trajectory",
862
+ color="green",
863
+ icon=viser.Icon.CHECK,
864
+ hint="Save the camera trajectory for rendering.",
865
+ )
866
+
867
+ @set_traj_button.on_click
868
+ def _(event: viser.GuiEvent) -> None:
869
+ assert event.client is not None
870
+ num_frames = int(framerate_number.value * duration_number.value)
871
+
872
+ def get_intrinsics(W, H, fov_rad):
873
+ focal = 0.5 * H / np.tan(0.5 * fov_rad)
874
+ return np.array(
875
+ [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]
876
+ )
877
+
878
+ camera_traj_list = []
879
+ for i in range(num_frames):
880
+ maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
881
+ i / num_frames
882
+ )
883
+ if maybe_pose_and_fov_rad is None:
884
+ return
885
+ pose, fov_rad = maybe_pose_and_fov_rad
886
+ H = img_wh[1]
887
+ W = img_wh[0]
888
+ K = get_intrinsics(W, H, fov_rad)
889
+ w2c = pose.inverse().as_matrix()
890
+ camera_traj_list.append(
891
+ {
892
+ "w2c": w2c.flatten().tolist(),
893
+ "K": K.flatten().tolist(),
894
+ "img_wh": (W, H),
895
+ }
896
+ )
897
+ nonlocal gui_state
898
+ gui_state.camera_traj_list = camera_traj_list
899
+ print(f"Get camera_traj_list: {gui_state.camera_traj_list}")
900
+
901
+ stop_preview_render()
902
+
903
+ preview_frame_slider = add_preview_frame_slider()
904
+
905
+ loop_checkbox = server.gui.add_checkbox(
906
+ "Loop", False, hint="Add a segment between the first and last keyframes."
907
+ )
908
+
909
+ @loop_checkbox.on_update
910
+ def _(_) -> None:
911
+ camera_traj.loop = loop_checkbox.value
912
+ duration_number.value = camera_traj.compute_duration()
913
+
914
+ @transition_sec_number.on_update
915
+ def _(_) -> None:
916
+ camera_traj.default_transition_sec = transition_sec_number.value
917
+ duration_number.value = camera_traj.compute_duration()
918
+
919
+ preview_camera_handle: viser.SceneNodeHandle | None = None
920
+
921
+ def remove_preview_camera() -> None:
922
+ nonlocal preview_camera_handle
923
+ if preview_camera_handle is not None:
924
+ preview_camera_handle.remove()
925
+ preview_camera_handle = None
926
+
927
+ def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None:
928
+ """Update the render tab state with the current preview camera pose.
929
+ Returns current camera pose + FOV if available."""
930
+
931
+ if preview_frame_slider is None:
932
+ return None
933
+ maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
934
+ preview_frame_slider.value / get_max_frame_index()
935
+ )
936
+ if maybe_pose_and_fov_rad is None:
937
+ remove_preview_camera()
938
+ return None
939
+ pose, fov_rad = maybe_pose_and_fov_rad
940
+ gui_state.preview_fov = fov_rad
941
+ gui_state.preview_aspect = camera_traj.get_aspect()
942
+ return pose, fov_rad
943
+
944
+ # We back up the camera poses before and after we start previewing renders.
945
+ camera_pose_backup_from_id: dict[int, tuple] = {}
946
+
947
+ # Update the # of frames.
948
+ @duration_number.on_update
949
+ @framerate_number.on_update
950
+ def _(_) -> None:
951
+ remove_preview_camera() # Will be re-added when slider is updated.
952
+
953
+ nonlocal preview_frame_slider
954
+ old = preview_frame_slider
955
+ assert old is not None
956
+
957
+ preview_frame_slider = add_preview_frame_slider()
958
+ if preview_frame_slider is not None:
959
+ old.remove()
960
+ else:
961
+ preview_frame_slider = old
962
+
963
+ camera_traj.framerate = framerate_number.value
964
+ camera_traj.update_spline()
965
+
966
+ camera_traj = CameraTrajectory(
967
+ server,
968
+ duration_number,
969
+ scene_node_prefix=scene_node_prefix,
970
+ **kwargs,
971
+ )
972
+ camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi
973
+ camera_traj.default_transition_sec = transition_sec_number.value
974
+
975
+ return gui_state
seva/model.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from seva.modules.layers import (
7
+ Downsample,
8
+ GroupNorm32,
9
+ ResBlock,
10
+ TimestepEmbedSequential,
11
+ Upsample,
12
+ timestep_embedding,
13
+ )
14
+ from seva.modules.transformer import MultiviewTransformer
15
+
16
+
17
+ @dataclass
18
+ class SevaParams(object):
19
+ in_channels: int = 11
20
+ model_channels: int = 320
21
+ out_channels: int = 4
22
+ num_frames: int = 21
23
+ num_res_blocks: int = 2
24
+ attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
25
+ channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
26
+ num_head_channels: int = 64
27
+ transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
28
+ context_dim: int = 1024
29
+ dense_in_channels: int = 6
30
+ dropout: float = 0.0
31
+ unflatten_names: list[str] = field(
32
+ default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
33
+ )
34
+
35
+ def __post_init__(self):
36
+ assert len(self.channel_mult) == len(self.transformer_depth)
37
+
38
+
39
+ class Seva(nn.Module):
40
+ def __init__(self, params: SevaParams) -> None:
41
+ super().__init__()
42
+ self.params = params
43
+ self.model_channels = params.model_channels
44
+ self.out_channels = params.out_channels
45
+ self.num_head_channels = params.num_head_channels
46
+
47
+ time_embed_dim = params.model_channels * 4
48
+ self.time_embed = nn.Sequential(
49
+ nn.Linear(params.model_channels, time_embed_dim),
50
+ nn.SiLU(),
51
+ nn.Linear(time_embed_dim, time_embed_dim),
52
+ )
53
+
54
+ self.input_blocks = nn.ModuleList(
55
+ [
56
+ TimestepEmbedSequential(
57
+ nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
58
+ )
59
+ ]
60
+ )
61
+ self._feature_size = params.model_channels
62
+ input_block_chans = [params.model_channels]
63
+ ch = params.model_channels
64
+ ds = 1
65
+ for level, mult in enumerate(params.channel_mult):
66
+ for _ in range(params.num_res_blocks):
67
+ input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [
68
+ ResBlock(
69
+ channels=ch,
70
+ emb_channels=time_embed_dim,
71
+ out_channels=mult * params.model_channels,
72
+ dense_in_channels=params.dense_in_channels,
73
+ dropout=params.dropout,
74
+ )
75
+ ]
76
+ ch = mult * params.model_channels
77
+ if ds in params.attention_resolutions:
78
+ num_heads = ch // params.num_head_channels
79
+ dim_head = params.num_head_channels
80
+ input_layers.append(
81
+ MultiviewTransformer(
82
+ ch,
83
+ num_heads,
84
+ dim_head,
85
+ name=f"input_ds{ds}",
86
+ depth=params.transformer_depth[level],
87
+ context_dim=params.context_dim,
88
+ unflatten_names=params.unflatten_names,
89
+ )
90
+ )
91
+ self.input_blocks.append(TimestepEmbedSequential(*input_layers))
92
+ self._feature_size += ch
93
+ input_block_chans.append(ch)
94
+ if level != len(params.channel_mult) - 1:
95
+ ds *= 2
96
+ out_ch = ch
97
+ self.input_blocks.append(
98
+ TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
99
+ )
100
+ ch = out_ch
101
+ input_block_chans.append(ch)
102
+ self._feature_size += ch
103
+
104
+ num_heads = ch // params.num_head_channels
105
+ dim_head = params.num_head_channels
106
+
107
+ self.middle_block = TimestepEmbedSequential(
108
+ ResBlock(
109
+ channels=ch,
110
+ emb_channels=time_embed_dim,
111
+ out_channels=None,
112
+ dense_in_channels=params.dense_in_channels,
113
+ dropout=params.dropout,
114
+ ),
115
+ MultiviewTransformer(
116
+ ch,
117
+ num_heads,
118
+ dim_head,
119
+ name=f"middle_ds{ds}",
120
+ depth=params.transformer_depth[-1],
121
+ context_dim=params.context_dim,
122
+ unflatten_names=params.unflatten_names,
123
+ ),
124
+ ResBlock(
125
+ channels=ch,
126
+ emb_channels=time_embed_dim,
127
+ out_channels=None,
128
+ dense_in_channels=params.dense_in_channels,
129
+ dropout=params.dropout,
130
+ ),
131
+ )
132
+ self._feature_size += ch
133
+
134
+ self.output_blocks = nn.ModuleList([])
135
+ for level, mult in list(enumerate(params.channel_mult))[::-1]:
136
+ for i in range(params.num_res_blocks + 1):
137
+ ich = input_block_chans.pop()
138
+ output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [
139
+ ResBlock(
140
+ channels=ch + ich,
141
+ emb_channels=time_embed_dim,
142
+ out_channels=params.model_channels * mult,
143
+ dense_in_channels=params.dense_in_channels,
144
+ dropout=params.dropout,
145
+ )
146
+ ]
147
+ ch = params.model_channels * mult
148
+ if ds in params.attention_resolutions:
149
+ num_heads = ch // params.num_head_channels
150
+ dim_head = params.num_head_channels
151
+
152
+ output_layers.append(
153
+ MultiviewTransformer(
154
+ ch,
155
+ num_heads,
156
+ dim_head,
157
+ name=f"output_ds{ds}",
158
+ depth=params.transformer_depth[level],
159
+ context_dim=params.context_dim,
160
+ unflatten_names=params.unflatten_names,
161
+ )
162
+ )
163
+ if level and i == params.num_res_blocks:
164
+ out_ch = ch
165
+ ds //= 2
166
+ output_layers.append(Upsample(ch, out_ch))
167
+ self.output_blocks.append(TimestepEmbedSequential(*output_layers))
168
+ self._feature_size += ch
169
+
170
+ self.out = nn.Sequential(
171
+ GroupNorm32(32, ch),
172
+ nn.SiLU(),
173
+ nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
174
+ )
175
+
176
+ def forward(
177
+ self,
178
+ x: torch.Tensor,
179
+ t: torch.Tensor,
180
+ y: torch.Tensor,
181
+ dense_y: torch.Tensor,
182
+ num_frames: int | None = None,
183
+ ) -> torch.Tensor:
184
+ num_frames = num_frames or self.params.num_frames
185
+ t_emb = timestep_embedding(t, self.model_channels)
186
+ t_emb = self.time_embed(t_emb)
187
+
188
+ hs = []
189
+ h = x
190
+ for module in self.input_blocks:
191
+ h = module(
192
+ h,
193
+ emb=t_emb,
194
+ context=y,
195
+ dense_emb=dense_y,
196
+ num_frames=num_frames,
197
+ )
198
+ hs.append(h)
199
+ h = self.middle_block(
200
+ h,
201
+ emb=t_emb,
202
+ context=y,
203
+ dense_emb=dense_y,
204
+ num_frames=num_frames,
205
+ )
206
+ for module in self.output_blocks:
207
+ h = torch.cat([h, hs.pop()], dim=1)
208
+ h = module(
209
+ h,
210
+ emb=t_emb,
211
+ context=y,
212
+ dense_emb=dense_y,
213
+ num_frames=num_frames,
214
+ )
215
+ h = h.type(x.dtype)
216
+ return self.out(h)
217
+
218
+
219
+ class SGMWrapper(nn.Module):
220
+ def __init__(self, module: Seva):
221
+ super().__init__()
222
+ self.module = module
223
+
224
+ def forward(
225
+ self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
226
+ ) -> torch.Tensor:
227
+ x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
228
+ return self.module(
229
+ x,
230
+ t=t,
231
+ y=c["crossattn"],
232
+ dense_y=c["dense_vector"],
233
+ **kwargs,
234
+ )
seva/modules/__init__.py ADDED
File without changes
seva/modules/autoencoder.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models import AutoencoderKL # type: ignore
3
+ from torch import nn
4
+
5
+
6
+ class AutoEncoder(nn.Module):
7
+ scale_factor: float = 0.18215
8
+ downsample: int = 8
9
+
10
+ def __init__(self, chunk_size: int | None = None):
11
+ super().__init__()
12
+ self.module = AutoencoderKL.from_pretrained(
13
+ "stabilityai/stable-diffusion-2-1-base",
14
+ subfolder="vae",
15
+ force_download=False,
16
+ low_cpu_mem_usage=False,
17
+ )
18
+ self.module.eval().requires_grad_(False) # type: ignore
19
+ self.chunk_size = chunk_size
20
+
21
+ def _encode(self, x: torch.Tensor) -> torch.Tensor:
22
+ return (
23
+ self.module.encode(x).latent_dist.mean # type: ignore
24
+ * self.scale_factor
25
+ )
26
+
27
+ def encode(self, x: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
28
+ chunk_size = chunk_size or self.chunk_size
29
+ if chunk_size is not None:
30
+ return torch.cat(
31
+ [self._encode(x_chunk) for x_chunk in x.split(chunk_size)],
32
+ dim=0,
33
+ )
34
+ else:
35
+ return self._encode(x)
36
+
37
+ def _decode(self, z: torch.Tensor) -> torch.Tensor:
38
+ return self.module.decode(z / self.scale_factor).sample # type: ignore
39
+
40
+ def decode(self, z: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
41
+ chunk_size = chunk_size or self.chunk_size
42
+ if chunk_size is not None:
43
+ return torch.cat(
44
+ [self._decode(z_chunk) for z_chunk in z.split(chunk_size)],
45
+ dim=0,
46
+ )
47
+ else:
48
+ return self._decode(z)
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return self.decode(self.encode(x))
seva/modules/conditioner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kornia
2
+ import open_clip
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class CLIPConditioner(nn.Module):
8
+ mean: torch.Tensor
9
+ std: torch.Tensor
10
+
11
+ def __init__(self):
12
+ super().__init__()
13
+ self.module = open_clip.create_model_and_transforms(
14
+ "ViT-H-14", pretrained="laion2b_s32b_b79k"
15
+ )[0]
16
+ self.module.eval().requires_grad_(False) # type: ignore
17
+ self.register_buffer(
18
+ "mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
19
+ )
20
+ self.register_buffer(
21
+ "std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
22
+ )
23
+
24
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
25
+ x = kornia.geometry.resize(
26
+ x,
27
+ (224, 224),
28
+ interpolation="bicubic",
29
+ align_corners=True,
30
+ antialias=True,
31
+ )
32
+ x = (x + 1.0) / 2.0
33
+ x = kornia.enhance.normalize(x, self.mean, self.std)
34
+ return x
35
+
36
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
37
+ x = self.preprocess(x)
38
+ x = self.module.encode_image(x)
39
+ return x
seva/modules/layers.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from einops import repeat
6
+ from torch import nn
7
+
8
+ from .transformer import MultiviewTransformer
9
+
10
+
11
+ def timestep_embedding(
12
+ timesteps: torch.Tensor,
13
+ dim: int,
14
+ max_period: int = 10000,
15
+ repeat_only: bool = False,
16
+ ) -> torch.Tensor:
17
+ if not repeat_only:
18
+ half = dim // 2
19
+ freqs = torch.exp(
20
+ -math.log(max_period)
21
+ * torch.arange(start=0, end=half, dtype=torch.float32)
22
+ / half
23
+ ).to(device=timesteps.device)
24
+ args = timesteps[:, None].float() * freqs[None]
25
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
26
+ if dim % 2:
27
+ embedding = torch.cat(
28
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
29
+ )
30
+ else:
31
+ embedding = repeat(timesteps, "b -> b d", d=dim)
32
+ return embedding
33
+
34
+
35
+ class Upsample(nn.Module):
36
+ def __init__(self, channels: int, out_channels: int | None = None):
37
+ super().__init__()
38
+ self.channels = channels
39
+ self.out_channels = out_channels or channels
40
+ self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1)
41
+
42
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
43
+ assert x.shape[1] == self.channels
44
+ x = F.interpolate(x, scale_factor=2, mode="nearest")
45
+ x = self.conv(x)
46
+ return x
47
+
48
+
49
+ class Downsample(nn.Module):
50
+ def __init__(self, channels: int, out_channels: int | None = None):
51
+ super().__init__()
52
+ self.channels = channels
53
+ self.out_channels = out_channels or channels
54
+ self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1)
55
+
56
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
57
+ assert x.shape[1] == self.channels
58
+ return self.op(x)
59
+
60
+
61
+ class GroupNorm32(nn.GroupNorm):
62
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
63
+ return super().forward(input.float()).type(input.dtype)
64
+
65
+
66
+ class TimestepEmbedSequential(nn.Sequential):
67
+ def forward( # type: ignore[override]
68
+ self,
69
+ x: torch.Tensor,
70
+ emb: torch.Tensor,
71
+ context: torch.Tensor,
72
+ dense_emb: torch.Tensor,
73
+ num_frames: int,
74
+ ) -> torch.Tensor:
75
+ for layer in self:
76
+ if isinstance(layer, MultiviewTransformer):
77
+ assert num_frames is not None
78
+ x = layer(x, context, num_frames)
79
+ elif isinstance(layer, ResBlock):
80
+ x = layer(x, emb, dense_emb)
81
+ else:
82
+ x = layer(x)
83
+ return x
84
+
85
+
86
+ class ResBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ channels: int,
90
+ emb_channels: int,
91
+ out_channels: int | None,
92
+ dense_in_channels: int,
93
+ dropout: float,
94
+ ):
95
+ super().__init__()
96
+ out_channels = out_channels or channels
97
+
98
+ self.in_layers = nn.Sequential(
99
+ GroupNorm32(32, channels),
100
+ nn.SiLU(),
101
+ nn.Conv2d(channels, out_channels, 3, 1, 1),
102
+ )
103
+ self.emb_layers = nn.Sequential(
104
+ nn.SiLU(), nn.Linear(emb_channels, out_channels)
105
+ )
106
+ self.dense_emb_layers = nn.Sequential(
107
+ nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0)
108
+ )
109
+ self.out_layers = nn.Sequential(
110
+ GroupNorm32(32, out_channels),
111
+ nn.SiLU(),
112
+ nn.Dropout(dropout),
113
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
114
+ )
115
+ if out_channels == channels:
116
+ self.skip_connection = nn.Identity()
117
+ else:
118
+ self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0)
119
+
120
+ def forward(
121
+ self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor
122
+ ) -> torch.Tensor:
123
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
124
+ h = in_rest(x)
125
+ dense = self.dense_emb_layers(
126
+ F.interpolate(
127
+ dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True
128
+ )
129
+ ).type(h.dtype)
130
+ dense_scale, dense_shift = torch.chunk(dense, 2, dim=1)
131
+ h = h * (1 + dense_scale) + dense_shift
132
+ h = in_conv(h)
133
+ emb_out = self.emb_layers(emb).type(h.dtype)
134
+ while len(emb_out.shape) < len(h.shape):
135
+ emb_out = emb_out[..., None]
136
+ h = h + emb_out
137
+ h = self.out_layers(h)
138
+ h = self.skip_connection(x) + h
139
+ return h
seva/modules/preprocessor.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import os
3
+ import os.path as osp
4
+ import sys
5
+ from typing import cast
6
+
7
+ import imageio.v3 as iio
8
+ import numpy as np
9
+ import torch
10
+
11
+
12
+ class Dust3rPipeline(object):
13
+ def __init__(self, device: str | torch.device = "cuda"):
14
+ submodule_path = osp.realpath(
15
+ osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
16
+ )
17
+ if submodule_path not in sys.path:
18
+ sys.path.insert(0, submodule_path)
19
+ try:
20
+ with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
21
+ from dust3r.cloud_opt import ( # type: ignore[import]
22
+ GlobalAlignerMode,
23
+ global_aligner,
24
+ )
25
+ from dust3r.image_pairs import make_pairs # type: ignore[import]
26
+ from dust3r.inference import inference # type: ignore[import]
27
+ from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import]
28
+ from dust3r.utils.image import load_images # type: ignore[import]
29
+ except ImportError:
30
+ raise ImportError(
31
+ "Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
32
+ "To initialize them, run the following command in the project root:\n"
33
+ " git submodule update --init --recursive"
34
+ )
35
+
36
+ self.device = torch.device(device)
37
+ self.model = AsymmetricCroCo3DStereo.from_pretrained(
38
+ "naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
39
+ ).to(self.device)
40
+
41
+ self._GlobalAlignerMode = GlobalAlignerMode
42
+ self._global_aligner = global_aligner
43
+ self._make_pairs = make_pairs
44
+ self._inference = inference
45
+ self._load_images = load_images
46
+
47
+ def infer_cameras_and_points(
48
+ self,
49
+ img_paths: list[str],
50
+ Ks: list[list] = None,
51
+ c2ws: list[list] = None,
52
+ batch_size: int = 16,
53
+ schedule: str = "cosine",
54
+ lr: float = 0.01,
55
+ niter: int = 500,
56
+ min_conf_thr: int = 3,
57
+ ) -> tuple[
58
+ list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
59
+ ]:
60
+ num_img = len(img_paths)
61
+ if num_img == 1:
62
+ print("Only one image found, duplicating it to create a stereo pair.")
63
+ img_paths = img_paths * 2
64
+
65
+ images = self._load_images(img_paths, size=512)
66
+ pairs = self._make_pairs(
67
+ images,
68
+ scene_graph="complete",
69
+ prefilter=None,
70
+ symmetrize=True,
71
+ )
72
+ output = self._inference(pairs, self.model, self.device, batch_size=batch_size)
73
+
74
+ ori_imgs = [iio.imread(p) for p in img_paths]
75
+ ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
76
+ img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)
77
+
78
+ scene = self._global_aligner(
79
+ output,
80
+ device=self.device,
81
+ mode=self._GlobalAlignerMode.PointCloudOptimizer,
82
+ same_focals=True,
83
+ optimize_pp=False, # True,
84
+ min_conf_thr=min_conf_thr,
85
+ )
86
+
87
+ # if Ks is not None:
88
+ # scene.preset_focal(
89
+ # torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
90
+ # )
91
+
92
+ if c2ws is not None:
93
+ scene.preset_pose(c2ws)
94
+
95
+ _ = scene.compute_global_alignment(
96
+ init="msp", niter=niter, schedule=schedule, lr=lr
97
+ )
98
+
99
+ imgs = cast(list, scene.imgs)
100
+ Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
101
+ c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore
102
+ pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore
103
+ if num_img > 1:
104
+ masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
105
+ points = [p[m] for p, m in zip(pts3d, masks)]
106
+ point_colors = [img[m] for img, m in zip(imgs, masks)]
107
+ else:
108
+ points = [p.reshape(-1, 3) for p in pts3d]
109
+ point_colors = [img.reshape(-1, 3) for img in imgs]
110
+
111
+ # Convert back to the original image size.
112
+ imgs = ori_imgs
113
+ Ks[:, :2, -1] *= ori_img_whs / img_whs
114
+ Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]
115
+
116
+ return imgs, Ks, c2ws, points, point_colors