Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +17 -0
- .gitignore +48 -0
- .gitmodules +3 -0
- .pre-commit-config.yaml +22 -0
- LICENSE +124 -0
- README.md +79 -8
- assets/advance/backyard-7_0.jpg +0 -0
- assets/advance/backyard-7_1.jpg +0 -0
- assets/advance/backyard-7_2.jpg +0 -0
- assets/advance/backyard-7_3.jpg +0 -0
- assets/advance/backyard-7_4.jpg +0 -0
- assets/advance/backyard-7_5.jpg +0 -0
- assets/advance/backyard-7_6.jpg +0 -0
- assets/advance/blue-car.jpg +3 -0
- assets/advance/garden-4_0.jpg +3 -0
- assets/advance/garden-4_1.jpg +3 -0
- assets/advance/garden-4_2.jpg +3 -0
- assets/advance/garden-4_3.jpg +3 -0
- assets/advance/telebooth-2_0.jpg +0 -0
- assets/advance/telebooth-2_1.jpg +0 -0
- assets/advance/vgg-lab-4_0.png +3 -0
- assets/advance/vgg-lab-4_1.png +3 -0
- assets/advance/vgg-lab-4_2.png +3 -0
- assets/advance/vgg-lab-4_3.png +3 -0
- assets/basic/blue-car.jpg +3 -0
- assets/basic/hilly-countryside.jpg +3 -0
- assets/basic/lily-dragon.png +3 -0
- assets/basic/llff-room.jpg +0 -0
- assets/basic/mountain-lake.jpg +0 -0
- assets/basic/vasedeck.jpg +0 -0
- assets/basic/vgg-lab-4_0.png +3 -0
- benchmark/README.md +156 -0
- benchmark/export_reconfusion_example.py +137 -0
- demo.py +407 -0
- demo_gr.py +1248 -0
- docs/CLI_USAGE.md +169 -0
- docs/GR_USAGE.md +76 -0
- docs/INSTALL.md +39 -0
- pyproject.toml +39 -0
- seva/__init__.py +0 -0
- seva/data_io.py +553 -0
- seva/eval.py +1990 -0
- seva/geometry.py +811 -0
- seva/gui.py +975 -0
- seva/model.py +234 -0
- seva/modules/__init__.py +0 -0
- seva/modules/autoencoder.py +51 -0
- seva/modules/conditioner.py +39 -0
- seva/modules/layers.py +139 -0
- seva/modules/preprocessor.py +116 -0
.gitattributes
CHANGED
@@ -33,3 +33,20 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/advance/blue-car.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/advance/garden-4_0.jpg filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/advance/garden-4_1.jpg filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/advance/garden-4_2.jpg filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/advance/garden-4_3.jpg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/advance/vgg-lab-4_0.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/advance/vgg-lab-4_1.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
assets/advance/vgg-lab-4_2.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
assets/advance/vgg-lab-4_3.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
assets/basic/blue-car.jpg filter=lfs diff=lfs merge=lfs -text
|
46 |
+
assets/basic/hilly-countryside.jpg filter=lfs diff=lfs merge=lfs -text
|
47 |
+
assets/basic/lily-dragon.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
assets/basic/vgg-lab-4_0.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
third_party/dust3r/assets/demo.jpg filter=lfs diff=lfs merge=lfs -text
|
50 |
+
third_party/dust3r/assets/matching.jpg filter=lfs diff=lfs merge=lfs -text
|
51 |
+
third_party/dust3r/croco/assets/Chateau1.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
third_party/dust3r/croco/assets/Chateau2.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.envrc
|
2 |
+
.venv/
|
3 |
+
.gradio/
|
4 |
+
work_dirs*
|
5 |
+
|
6 |
+
# Byte-compiled files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
|
10 |
+
# Virtual environments
|
11 |
+
env/
|
12 |
+
venv/
|
13 |
+
ENV/
|
14 |
+
.VENV/
|
15 |
+
|
16 |
+
# Distribution files
|
17 |
+
build/
|
18 |
+
dist/
|
19 |
+
*.egg-info/
|
20 |
+
|
21 |
+
# Logs and temporary files
|
22 |
+
*.log
|
23 |
+
*.tmp
|
24 |
+
*.bak
|
25 |
+
*.swp
|
26 |
+
|
27 |
+
# IDE files
|
28 |
+
.idea/
|
29 |
+
.vscode/
|
30 |
+
*.sublime-workspace
|
31 |
+
*.sublime-project
|
32 |
+
|
33 |
+
# OS files
|
34 |
+
.DS_Store
|
35 |
+
Thumbs.db
|
36 |
+
|
37 |
+
# Testing and coverage
|
38 |
+
htmlcov/
|
39 |
+
.coverage
|
40 |
+
*.cover
|
41 |
+
*.py,cover
|
42 |
+
.cache/
|
43 |
+
|
44 |
+
# Jupyter Notebook checkpoints
|
45 |
+
.ipynb_checkpoints/
|
46 |
+
|
47 |
+
# Pre-commit hooks
|
48 |
+
.pre-commit-config.yaml~
|
.gitmodules
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[submodule "third_party/dust3r"]
|
2 |
+
path = third_party/dust3r
|
3 |
+
url = https://github.com/jensenstability/dust3r
|
.pre-commit-config.yaml
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
default_language_version:
|
2 |
+
python: python3
|
3 |
+
default_stages: [pre-commit]
|
4 |
+
repos:
|
5 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
6 |
+
rev: v5.0.0
|
7 |
+
hooks:
|
8 |
+
- id: trailing-whitespace
|
9 |
+
- id: end-of-file-fixer
|
10 |
+
- repo: https://github.com/charliermarsh/ruff-pre-commit
|
11 |
+
rev: v0.8.3
|
12 |
+
hooks:
|
13 |
+
- id: ruff
|
14 |
+
types_or: [python, pyi, jupyter]
|
15 |
+
args: [--fix, --extend-ignore=E402]
|
16 |
+
- id: ruff-format
|
17 |
+
types_or: [python, pyi, jupyter]
|
18 |
+
- repo: https://github.com/pre-commit/mirrors-prettier
|
19 |
+
rev: v3.1.0
|
20 |
+
hooks:
|
21 |
+
- id: prettier
|
22 |
+
types_or: [markdown]
|
LICENSE
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Stability AI Non-Commercial License Agreement
|
2 |
+
Last Updated: February 20, 2025
|
3 |
+
|
4 |
+
I. INTRODUCTION
|
5 |
+
|
6 |
+
This Stability AI Non-Commercial License Agreement (the “Agreement”) applies to any individual person or entity
|
7 |
+
(“You”, “Your” or “Licensee”) that uses or distributes any portion or element of the Stability AI Materials or
|
8 |
+
Derivative Works thereof for any Research & Non-Commercial use. Capitalized terms not otherwise defined herein
|
9 |
+
are defined in Section IV below.
|
10 |
+
|
11 |
+
This Agreement is intended to allow research and non-commercial uses of the Model free of charge.
|
12 |
+
|
13 |
+
By clicking “I Accept” or by using or distributing or using any portion or element of the Stability Materials
|
14 |
+
or Derivative Works, You agree that You have read, understood and are bound by the terms of this Agreement.
|
15 |
+
|
16 |
+
If You are acting on behalf of a company, organization, or other entity, then “You” includes you and that entity,
|
17 |
+
and You agree that You:
|
18 |
+
(i) are an authorized representative of such entity with the authority to bind such entity to this Agreement, and
|
19 |
+
(ii) You agree to the terms of this Agreement on that entity’s behalf.
|
20 |
+
|
21 |
+
---
|
22 |
+
|
23 |
+
II. RESEARCH & NON-COMMERCIAL USE LICENSE
|
24 |
+
|
25 |
+
Subject to the terms of this Agreement, Stability AI grants You a non-exclusive, worldwide, non-transferable,
|
26 |
+
non-sublicensable, revocable, and royalty-free limited license under Stability AI’s intellectual property or other
|
27 |
+
rights owned by Stability AI embodied in the Stability AI Materials to use, reproduce, distribute, and create
|
28 |
+
Derivative Works of, and make modifications to, the Stability AI Materials for any Research or Non-Commercial Purpose.
|
29 |
+
|
30 |
+
- **“Research Purpose”** means academic or scientific advancement, and in each case, is not primarily intended
|
31 |
+
for commercial advantage or monetary compensation to You or others.
|
32 |
+
- **“Non-Commercial Purpose”** means any purpose other than a Research Purpose that is not primarily intended
|
33 |
+
for commercial advantage or monetary compensation to You or others, such as personal use (i.e., hobbyist)
|
34 |
+
or evaluation and testing.
|
35 |
+
|
36 |
+
---
|
37 |
+
|
38 |
+
III. GENERAL TERMS
|
39 |
+
|
40 |
+
Your Research or Non-Commercial license under this Agreement is subject to the following terms.
|
41 |
+
|
42 |
+
### a. Distribution & Attribution
|
43 |
+
If You distribute or make available the Stability AI Materials or a Derivative Work to a third party, or a product
|
44 |
+
or service that uses any portion of them, You shall:
|
45 |
+
1. Provide a copy of this Agreement to that third party.
|
46 |
+
2. Retain the following attribution notice within a **"Notice"** text file distributed as a part of such copies:
|
47 |
+
|
48 |
+
**"This Stability AI Model is licensed under the Stability AI Non-Commercial License,
|
49 |
+
Copyright © Stability AI Ltd. All Rights Reserved."**
|
50 |
+
|
51 |
+
3. Prominently display **“Powered by Stability AI”** on a related website, user interface, blog post,
|
52 |
+
about page, or product documentation.
|
53 |
+
4. If You create a Derivative Work, You may add your own attribution notice(s) to the **"Notice"** text file
|
54 |
+
included with that Derivative Work, provided that You clearly indicate which attributions apply to the
|
55 |
+
Stability AI Materials and state in the **"Notice"** text file that You changed the Stability AI Materials
|
56 |
+
and how it was modified.
|
57 |
+
|
58 |
+
### b. Use Restrictions
|
59 |
+
Your use of the Stability AI Materials and Derivative Works, including any output or results of the Stability
|
60 |
+
AI Materials or Derivative Works, must comply with applicable laws and regulations (including Trade Control
|
61 |
+
Laws and equivalent regulations) and adhere to the Documentation and Stability AI’s AUP, which is hereby
|
62 |
+
incorporated by reference.
|
63 |
+
|
64 |
+
Furthermore, You will not use the Stability AI Materials or Derivative Works, or any output or results of the
|
65 |
+
Stability AI Materials or Derivative Works, to create or improve any foundational generative AI model
|
66 |
+
(excluding the Model or Derivative Works).
|
67 |
+
|
68 |
+
### c. Intellectual Property
|
69 |
+
|
70 |
+
#### (i) Trademark License
|
71 |
+
No trademark licenses are granted under this Agreement, and in connection with the Stability AI Materials
|
72 |
+
or Derivative Works, You may not use any name or mark owned by or associated with Stability AI or any of
|
73 |
+
its Affiliates, except as required under Section IV(a) herein.
|
74 |
+
|
75 |
+
#### (ii) Ownership of Derivative Works
|
76 |
+
As between You and Stability AI, You are the owner of Derivative Works You create, subject to Stability AI’s
|
77 |
+
ownership of the Stability AI Materials and any Derivative Works made by or for Stability AI.
|
78 |
+
|
79 |
+
#### (iii) Ownership of Outputs
|
80 |
+
As between You and Stability AI, You own any outputs generated from the Model or Derivative Works to the extent
|
81 |
+
permitted by applicable law.
|
82 |
+
|
83 |
+
#### (iv) Disputes
|
84 |
+
If You or Your Affiliate(s) institute litigation or other proceedings against Stability AI (including a
|
85 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Stability AI Materials, Derivative Works, or
|
86 |
+
associated outputs or results, or any portion of any of the foregoing, constitutes infringement of intellectual
|
87 |
+
property or other rights owned or licensable by You, then any licenses granted to You under this Agreement
|
88 |
+
shall terminate as of the date such litigation or claim is filed or instituted.
|
89 |
+
|
90 |
+
You will indemnify and hold harmless Stability AI from and against any claim by any third party arising out
|
91 |
+
of or related to Your use or distribution of the Stability AI Materials or Derivative Works in violation of
|
92 |
+
this Agreement.
|
93 |
+
|
94 |
+
#### (v) Feedback
|
95 |
+
From time to time, You may provide Stability AI with verbal and/or written suggestions, comments, or other
|
96 |
+
feedback related to Stability AI’s existing or prospective technology, products, or services (collectively,
|
97 |
+
“Feedback”).
|
98 |
+
|
99 |
+
You are not obligated to provide Stability AI with Feedback, but to the extent that You do, You hereby grant
|
100 |
+
Stability AI a **perpetual, irrevocable, royalty-free, fully-paid, sub-licensable, transferable, non-exclusive,
|
101 |
+
worldwide right and license** to exploit the Feedback in any manner without restriction.
|
102 |
+
|
103 |
+
Your Feedback is provided **“AS IS”** and You make no warranties whatsoever about any Feedback.
|
104 |
+
|
105 |
+
---
|
106 |
+
|
107 |
+
IV. DEFINITIONS
|
108 |
+
|
109 |
+
- **“Affiliate(s)”** means any entity that directly or indirectly controls, is controlled by, or is under common
|
110 |
+
control with the subject entity. For purposes of this definition, “control” means direct or indirect ownership
|
111 |
+
or control of more than 50% of the voting interests of the subject entity.
|
112 |
+
- **“AUP”** means the Stability AI Acceptable Use Policy available at https://stability.ai/use-policy, as may
|
113 |
+
be updated from time to time.
|
114 |
+
- **"Derivative Work(s)"** means:
|
115 |
+
(a) Any derivative work of the Stability AI Materials as recognized by U.S. copyright laws.
|
116 |
+
(b) Any modifications to a Model, and any other model created which is based on or derived from the Model or
|
117 |
+
the Model’s output, including **fine-tune** and **low-rank adaptation** models derived from a Model or
|
118 |
+
a Model’s output, but does not include the output of any Model.
|
119 |
+
- **“Model”** means Stability AI’s Stable Virtual Camera model.
|
120 |
+
- **"Stability AI" or "we"** means Stability AI Ltd. and its Affiliates.
|
121 |
+
- **"Software"** means Stability AI’s proprietary software made available under this Agreement now or in the future.
|
122 |
+
- **“Stability AI Materials”** means, collectively, Stability’s proprietary Model, Software, and Documentation
|
123 |
+
(and any portion or combination thereof) made available under this Agreement.
|
124 |
+
- **“Trade Control Laws”** means any applicable U.S. and non-U.S. export control and trade sanctions laws and regulations.
|
README.md
CHANGED
@@ -1,12 +1,83 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: green
|
5 |
-
colorTo: indigo
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: stable-virtual-camera
|
3 |
+
app_file: demo_gr.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 5.20.1
|
|
|
|
|
6 |
---
|
7 |
+
# Stable Virtual Camera
|
8 |
|
9 |
+
<a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
|
10 |
+
<a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a>
|
11 |
+
<a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a>
|
12 |
+
<a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
|
13 |
+
<a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
|
14 |
+
<a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a>
|
15 |
+
|
16 |
+
`Stable Virtual Camera (Seva)` is a 1.3B generalist diffusion model for Novel View Synthesis (NVS), generating 3D consistent novel views of a scene, given any number of input views and target cameras.
|
17 |
+
|
18 |
+
# :tada: News
|
19 |
+
|
20 |
+
- March 2025 - `Stable Virtual Camera` is out everywhere.
|
21 |
+
|
22 |
+
# :wrench: Installation
|
23 |
+
|
24 |
+
```bash
|
25 |
+
git clone --recursive https://github.com/Stability-AI/stable-virtual-camera
|
26 |
+
cd stable-virtual-camera
|
27 |
+
pip install -e .
|
28 |
+
```
|
29 |
+
|
30 |
+
Please note that you will need `python>=3.10` and `torch>=2.6.0`.
|
31 |
+
|
32 |
+
Check [INSTALL.md](docs/INSTALL.md) for other dependencies if you want to use our demos or develop from this repo.
|
33 |
+
For windows users, please use WSL as flash attention isn't supported on native Windows [yet](https://github.com/pytorch/pytorch/issues/108175).
|
34 |
+
|
35 |
+
# :open_book: Usage
|
36 |
+
|
37 |
+
You need to properly authenticate with Hugging Face to download our model weights. Once set up, our code will handle it automatically at your first run. You can authenticate by running
|
38 |
+
|
39 |
+
```bash
|
40 |
+
# This will prompt you to enter your Hugging Face credentials.
|
41 |
+
huggingface-cli login
|
42 |
+
```
|
43 |
+
|
44 |
+
Once authenticated, go to our model card [here](https://huggingface.co/stabilityai/stable-virtual-camera) and enter your information for access.
|
45 |
+
|
46 |
+
We provide two demos for you to interative with `Stable Virtual Camera`.
|
47 |
+
|
48 |
+
### :rocket: Gradio demo
|
49 |
+
|
50 |
+
This gradio demo is a GUI interface that requires no expertised knowledge, suitable for general users. Simply run
|
51 |
+
|
52 |
+
```bash
|
53 |
+
python demo_gr.py
|
54 |
+
```
|
55 |
+
|
56 |
+
For a more detailed guide, follow [GR_USAGE.md](docs/GR_USAGE.md).
|
57 |
+
|
58 |
+
### :computer: CLI demo
|
59 |
+
|
60 |
+
This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An examplar command line looks as simple as
|
61 |
+
|
62 |
+
```bash
|
63 |
+
python demo.py --data_path <data_path> [additional arguments]
|
64 |
+
```
|
65 |
+
|
66 |
+
For a more detailed guide, follow [CLI_USAGE.md](docs/CLI_USAGE.md).
|
67 |
+
|
68 |
+
For users interested in benchmarking NVS models using command lines, check [`benchmark`](benchmark/) containing the details about scenes, splits, and input/target views we reported in the <a href="http://arxiv.org/abs/2503.14489">paper</a>.
|
69 |
+
|
70 |
+
# :books: Citing
|
71 |
+
|
72 |
+
If you find this repository useful, please consider giving a star :star: and citation.
|
73 |
+
|
74 |
+
```
|
75 |
+
@article{zhou2025stable,
|
76 |
+
title={Stable Virtual Camera: Generative View Synthesis with Diffusion Models},
|
77 |
+
author={Jensen (Jinghao) Zhou and Hang Gao and Vikram Voleti and Aaryaman Vasishta and Chun-Han Yao and Mark Boss and
|
78 |
+
Philip Torr and Christian Rupprecht and Varun Jampani
|
79 |
+
},
|
80 |
+
journal={arXiv preprint},
|
81 |
+
year={2025}
|
82 |
+
}
|
83 |
+
```
|
assets/advance/backyard-7_0.jpg
ADDED
![]() |
assets/advance/backyard-7_1.jpg
ADDED
![]() |
assets/advance/backyard-7_2.jpg
ADDED
![]() |
assets/advance/backyard-7_3.jpg
ADDED
![]() |
assets/advance/backyard-7_4.jpg
ADDED
![]() |
assets/advance/backyard-7_5.jpg
ADDED
![]() |
assets/advance/backyard-7_6.jpg
ADDED
![]() |
assets/advance/blue-car.jpg
ADDED
![]() |
Git LFS Details
|
assets/advance/garden-4_0.jpg
ADDED
![]() |
Git LFS Details
|
assets/advance/garden-4_1.jpg
ADDED
![]() |
Git LFS Details
|
assets/advance/garden-4_2.jpg
ADDED
![]() |
Git LFS Details
|
assets/advance/garden-4_3.jpg
ADDED
![]() |
Git LFS Details
|
assets/advance/telebooth-2_0.jpg
ADDED
![]() |
assets/advance/telebooth-2_1.jpg
ADDED
![]() |
assets/advance/vgg-lab-4_0.png
ADDED
![]() |
Git LFS Details
|
assets/advance/vgg-lab-4_1.png
ADDED
![]() |
Git LFS Details
|
assets/advance/vgg-lab-4_2.png
ADDED
![]() |
Git LFS Details
|
assets/advance/vgg-lab-4_3.png
ADDED
![]() |
Git LFS Details
|
assets/basic/blue-car.jpg
ADDED
![]() |
Git LFS Details
|
assets/basic/hilly-countryside.jpg
ADDED
![]() |
Git LFS Details
|
assets/basic/lily-dragon.png
ADDED
![]() |
Git LFS Details
|
assets/basic/llff-room.jpg
ADDED
![]() |
assets/basic/mountain-lake.jpg
ADDED
![]() |
assets/basic/vasedeck.jpg
ADDED
![]() |
assets/basic/vgg-lab-4_0.png
ADDED
![]() |
Git LFS Details
|
benchmark/README.md
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# :bar_chart: Benchmark
|
2 |
+
|
3 |
+
We provide <a href="https://github.com/Stability-AI/stable-virtual-camera/releases/tag/benchmark">in this release</a> (`benchmark.zip`) with the following 17 entries as a benchmark to evaluate NVS models.
|
4 |
+
We hope this will help standardize the evaluation of NVS models and facilitate fair comparison between different methods.
|
5 |
+
|
6 |
+
<table>
|
7 |
+
<thead>
|
8 |
+
<tr>
|
9 |
+
<th align="center">Dataset</th>
|
10 |
+
<th align="center">Split</th>
|
11 |
+
<th align="center">Path</th>
|
12 |
+
<th align="center">Content</th>
|
13 |
+
<th align="center">Image Preprocessing</th>
|
14 |
+
<th align="center">Image Postprocessing</th>
|
15 |
+
</tr>
|
16 |
+
</thead>
|
17 |
+
<tbody>
|
18 |
+
<tr>
|
19 |
+
<td align="center">OmniObject3D</td>
|
20 |
+
<td align="center"><code>S</code> (SV3D), <code>O</code> (Ours) </td>
|
21 |
+
<td align="center"><code>omniobject3d</code></td>
|
22 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
23 |
+
<td align="center">center crop to 576</td>
|
24 |
+
<td align="center">\</td>
|
25 |
+
</tr>
|
26 |
+
<tr>
|
27 |
+
<td align="center">GSO</td>
|
28 |
+
<td align="center"><code>S</code> (SV3D), <code>O</code> (Ours) </td>
|
29 |
+
<td align="center"><code>gso</code></td>
|
30 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
31 |
+
<td align="center">center crop to 576</td>
|
32 |
+
<td align="center">\</td>
|
33 |
+
</tr>
|
34 |
+
<tr>
|
35 |
+
<td align="center" rowspan="4">RealEstate10K</td>
|
36 |
+
<td align="center"><code>D</code> (4DiM) </td>
|
37 |
+
<td align="center"><code>re10k-4dim</code></td>
|
38 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
39 |
+
<td align="center">center crop to 576</td>
|
40 |
+
<td align="center">resize to 256</td>
|
41 |
+
</tr>
|
42 |
+
<tr>
|
43 |
+
<td align="center"><code>R</code> (ReconFusion) </td>
|
44 |
+
<td align="center"><code>re10k</code></td>
|
45 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
46 |
+
<td align="center">center crop to 576</td>
|
47 |
+
<td align="center">\</td>
|
48 |
+
</tr>
|
49 |
+
<tr>
|
50 |
+
<td align="center"><code>P</code> (pixelSplat) </td>
|
51 |
+
<td align="center"><code>re10k-pixelsplat</code></td>
|
52 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
53 |
+
<td align="center">center crop to 576</td>
|
54 |
+
<td align="center">resize to 256</td>
|
55 |
+
</tr>
|
56 |
+
<tr>
|
57 |
+
<td align="center"><code>V</code> (ViewCrafter) </td>
|
58 |
+
<td align="center"><code>re10k-viewcrafter</code></td>
|
59 |
+
<td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
|
60 |
+
<td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
|
61 |
+
<td align="center">center crop</td>
|
62 |
+
</tr>
|
63 |
+
<tr>
|
64 |
+
<td align="center">LLFF</td>
|
65 |
+
<td align="center"><code>R</code> (ReconFusion) </td>
|
66 |
+
<td align="center"><code>llff</code></td>
|
67 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
68 |
+
<td align="center">center crop to 576</td>
|
69 |
+
<td align="center">\</td>
|
70 |
+
</tr>
|
71 |
+
<tr>
|
72 |
+
<td align="center">DTU</td>
|
73 |
+
<td align="center"><code>R</code> (ReconFusion) </td>
|
74 |
+
<td align="center"><code>dtu</code></td>
|
75 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
76 |
+
<td align="center">center crop to 576</td>
|
77 |
+
<td align="center">\</td>
|
78 |
+
</tr>
|
79 |
+
<tr>
|
80 |
+
<td align="center" rowspan="2">CO3D</td>
|
81 |
+
<td align="center"><code>R</code> (ReconFusion) </td>
|
82 |
+
<td align="center"><code>co3d</code></td>
|
83 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
84 |
+
<td align="center">center crop to 576</td>
|
85 |
+
<td align="center">\</td>
|
86 |
+
</tr>
|
87 |
+
<tr>
|
88 |
+
<td align="center"><code>V</code> (ViewCrafter) </td>
|
89 |
+
<td align="center"><code>co3d-viewcrafter</code></td>
|
90 |
+
<td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
|
91 |
+
<td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
|
92 |
+
<td align="center">center crop</td>
|
93 |
+
</tr>
|
94 |
+
<tr>
|
95 |
+
<td align="center" rowspan="2" >WildRGB-D</td>
|
96 |
+
<td align="center"><code>Oₑ</code> (Ours, easy) </td>
|
97 |
+
<td align="center"><code>wildgbd/easy</code></td>
|
98 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
99 |
+
<td align="center">center crop to 576</td>
|
100 |
+
<td align="center">\</td>
|
101 |
+
</tr>
|
102 |
+
<tr>
|
103 |
+
<td align="center"><code>Oₕ</code> (Ours, hard) </td>
|
104 |
+
<td align="center"><code>wildgbd/hard</code></td>
|
105 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
106 |
+
<td align="center">center crop to 576</td>
|
107 |
+
<td align="center">\</td>
|
108 |
+
</tr>
|
109 |
+
<tr>
|
110 |
+
<td align="center">Mip-NeRF360</td>
|
111 |
+
<td align="center"><code>R</code> (ReconFusion) </td>
|
112 |
+
<td align="center"><code>mipnerf360</code></td>
|
113 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
114 |
+
<td align="center">center crop to 576</td>
|
115 |
+
<td align="center">\</td>
|
116 |
+
</tr>
|
117 |
+
<tr>
|
118 |
+
<td align="center" rowspan="2">DL3DV-140</td>
|
119 |
+
<td align="center"><code>O</code> (Ours) </td>
|
120 |
+
<td align="center"><code>dl3dv10</code></td>
|
121 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
122 |
+
<td align="center">center crop to 576</td>
|
123 |
+
<td align="center">\</td>
|
124 |
+
</tr>
|
125 |
+
<tr>
|
126 |
+
<td align="center"><code>L</code> (Long-LRM) </td>
|
127 |
+
<td align="center"><code>dl3dv140</code></td>
|
128 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
129 |
+
<td align="center">center crop to 576</td>
|
130 |
+
<td align="center">\</td>
|
131 |
+
</tr>
|
132 |
+
<tr>
|
133 |
+
<td align="center" rowspan="2">Tanks and Temples</td>
|
134 |
+
<td align="center"><code>V</code> (ViewCrafter) </td>
|
135 |
+
<td align="center"><code>tnt-viewcrafter</code></td>
|
136 |
+
<td align="center"><code>images/*.png</code>,<code>transforms.json</code>,<code>train_test_split_*.json</code></td>
|
137 |
+
<td align="center">resize the shortest side to 576 (<code>--L_short 576</code>)</td>
|
138 |
+
<td align="center">center crop</td>
|
139 |
+
</tr>
|
140 |
+
<tr>
|
141 |
+
<td align="center"><code>L</code> (Long-LRM) </td>
|
142 |
+
<td align="center"><code>tnt-longlrm</code></td>
|
143 |
+
<td align="center"><code>train_test_split_*.json</code></td>
|
144 |
+
<td align="center">center crop to 576</td>
|
145 |
+
<td align="center">\</td>
|
146 |
+
</tr>
|
147 |
+
</tbody>
|
148 |
+
</table>
|
149 |
+
|
150 |
+
- For entries without `images/*.png` and `transforms.json`, we use the images from the original dataset after converting them into the `reconfusion` format, which is then parsable by `ReconfusionParser` (`seva/data_io.py`).
|
151 |
+
Please note that during this conversion, you should sort the images by `sorted(image_paths)`, which is then directly indexable by our train/test ids. We provide in `benchmark/export_reconfusion_example.py` an example script converting an existing academic dataset into the the scene folders.
|
152 |
+
- For evaluation and benchmarking, we first conduct operations in the `Image Preprocessing` column to the model input and then operations in the `Image Postprocessing` column to the model output. The final processed samples are used for metric computation.
|
153 |
+
|
154 |
+
## Acknowledgment
|
155 |
+
|
156 |
+
We would like to thank Wangbo Yu, Aleksander Hołyński, Saurabh Saxena, and Ziwen Chen for their kind clarification on experiment settings.
|
benchmark/export_reconfusion_example.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
try:
|
9 |
+
from sklearn.cluster import KMeans # type: ignore[import]
|
10 |
+
except ImportError:
|
11 |
+
print("Please install sklearn to use this script.")
|
12 |
+
exit(1)
|
13 |
+
|
14 |
+
# Define the folder containing the image and JSON files
|
15 |
+
subfolder = "/path/to/your/dataset"
|
16 |
+
output_file = os.path.join(subfolder, "transforms.json")
|
17 |
+
|
18 |
+
# List to hold the frames
|
19 |
+
frames = []
|
20 |
+
|
21 |
+
# Iterate over the files in the folder
|
22 |
+
for file in sorted(os.listdir(subfolder)):
|
23 |
+
if file.endswith(".json"):
|
24 |
+
# Read the JSON file containing camera extrinsics and intrinsics
|
25 |
+
json_path = os.path.join(subfolder, file)
|
26 |
+
with open(json_path, "r") as f:
|
27 |
+
data = json.load(f)
|
28 |
+
|
29 |
+
# Read the corresponding image file
|
30 |
+
image_file = file.replace(".json", ".png")
|
31 |
+
image_path = os.path.join(subfolder, image_file)
|
32 |
+
if not os.path.exists(image_path):
|
33 |
+
print(f"Image file not found for {file}, skipping...")
|
34 |
+
continue
|
35 |
+
with Image.open(image_path) as img:
|
36 |
+
w, h = img.size
|
37 |
+
|
38 |
+
# Extract and normalize intrinsic matrix K
|
39 |
+
K = data["K"]
|
40 |
+
fx = K[0][0] * w
|
41 |
+
fy = K[1][1] * h
|
42 |
+
cx = K[0][2] * w
|
43 |
+
cy = K[1][2] * h
|
44 |
+
|
45 |
+
# Extract the transformation matrix
|
46 |
+
transform_matrix = np.array(data["c2w"])
|
47 |
+
# Adjust for OpenGL convention
|
48 |
+
transform_matrix[..., [1, 2]] *= -1
|
49 |
+
|
50 |
+
# Add the frame data
|
51 |
+
frames.append(
|
52 |
+
{
|
53 |
+
"fl_x": fx,
|
54 |
+
"fl_y": fy,
|
55 |
+
"cx": cx,
|
56 |
+
"cy": cy,
|
57 |
+
"w": w,
|
58 |
+
"h": h,
|
59 |
+
"file_path": f"./{os.path.relpath(image_path, subfolder)}",
|
60 |
+
"transform_matrix": transform_matrix.tolist(),
|
61 |
+
}
|
62 |
+
)
|
63 |
+
|
64 |
+
# Create the output dictionary
|
65 |
+
transforms_data = {"orientation_override": "none", "frames": frames}
|
66 |
+
|
67 |
+
# Write to the transforms.json file
|
68 |
+
with open(output_file, "w") as f:
|
69 |
+
json.dump(transforms_data, f, indent=4)
|
70 |
+
|
71 |
+
print(f"transforms.json generated at {output_file}")
|
72 |
+
|
73 |
+
|
74 |
+
# Train-test split function using K-means clustering with stride
|
75 |
+
def create_train_test_split(frames, n, output_path, stride):
|
76 |
+
# Prepare the data for K-means
|
77 |
+
positions = []
|
78 |
+
for frame in frames:
|
79 |
+
transform_matrix = np.array(frame["transform_matrix"])
|
80 |
+
position = transform_matrix[:3, 3] # 3D camera position
|
81 |
+
direction = transform_matrix[:3, 2] / np.linalg.norm(
|
82 |
+
transform_matrix[:3, 2]
|
83 |
+
) # Normalized 3D direction
|
84 |
+
positions.append(np.concatenate([position, direction]))
|
85 |
+
|
86 |
+
positions = np.array(positions)
|
87 |
+
|
88 |
+
# Apply K-means clustering
|
89 |
+
kmeans = KMeans(n_clusters=n, random_state=42)
|
90 |
+
kmeans.fit(positions)
|
91 |
+
centers = kmeans.cluster_centers_
|
92 |
+
|
93 |
+
# Find the index closest to each cluster center
|
94 |
+
train_ids = []
|
95 |
+
for center in centers:
|
96 |
+
distances = np.linalg.norm(positions - center, axis=1)
|
97 |
+
train_ids.append(int(np.argmin(distances))) # Convert to Python int
|
98 |
+
|
99 |
+
# Remaining indices as test_ids, applying stride
|
100 |
+
all_indices = set(range(len(frames)))
|
101 |
+
remaining_indices = sorted(all_indices - set(train_ids))
|
102 |
+
test_ids = [
|
103 |
+
int(idx) for idx in remaining_indices[::stride]
|
104 |
+
] # Convert to Python int
|
105 |
+
|
106 |
+
# Create the split data
|
107 |
+
split_data = {"train_ids": sorted(train_ids), "test_ids": test_ids}
|
108 |
+
|
109 |
+
with open(output_path, "w") as f:
|
110 |
+
json.dump(split_data, f, indent=4)
|
111 |
+
|
112 |
+
print(f"Train-test split file generated at {output_path}")
|
113 |
+
|
114 |
+
|
115 |
+
# Parse arguments
|
116 |
+
if __name__ == "__main__":
|
117 |
+
parser = argparse.ArgumentParser(
|
118 |
+
description="Generate train-test split JSON file using K-means clustering."
|
119 |
+
)
|
120 |
+
parser.add_argument(
|
121 |
+
"--n",
|
122 |
+
type=int,
|
123 |
+
required=True,
|
124 |
+
help="Number of frames to include in the training set.",
|
125 |
+
)
|
126 |
+
parser.add_argument(
|
127 |
+
"--stride",
|
128 |
+
type=int,
|
129 |
+
default=1,
|
130 |
+
help="Stride for selecting test frames (not used with K-means).",
|
131 |
+
)
|
132 |
+
|
133 |
+
args = parser.parse_args()
|
134 |
+
|
135 |
+
# Create train-test split
|
136 |
+
train_test_split_path = os.path.join(subfolder, f"train_test_split_{args.n}.json")
|
137 |
+
create_train_test_split(frames, args.n, train_test_split_path, args.stride)
|
demo.py
ADDED
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
|
5 |
+
import fire
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from seva.data_io import get_parser
|
13 |
+
from seva.eval import (
|
14 |
+
IS_TORCH_NIGHTLY,
|
15 |
+
compute_relative_inds,
|
16 |
+
create_transforms_simple,
|
17 |
+
infer_prior_inds,
|
18 |
+
infer_prior_stats,
|
19 |
+
run_one_scene,
|
20 |
+
)
|
21 |
+
from seva.geometry import (
|
22 |
+
generate_interpolated_path,
|
23 |
+
generate_spiral_path,
|
24 |
+
get_arc_horizontal_w2cs,
|
25 |
+
get_default_intrinsics,
|
26 |
+
get_lookat,
|
27 |
+
get_preset_pose_fov,
|
28 |
+
)
|
29 |
+
from seva.model import SGMWrapper
|
30 |
+
from seva.modules.autoencoder import AutoEncoder
|
31 |
+
from seva.modules.conditioner import CLIPConditioner
|
32 |
+
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
|
33 |
+
from seva.utils import load_model
|
34 |
+
|
35 |
+
device = "cuda:0"
|
36 |
+
|
37 |
+
|
38 |
+
# Constants.
|
39 |
+
WORK_DIR = "work_dirs/demo"
|
40 |
+
|
41 |
+
if IS_TORCH_NIGHTLY:
|
42 |
+
COMPILE = True
|
43 |
+
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
44 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
45 |
+
else:
|
46 |
+
COMPILE = False
|
47 |
+
|
48 |
+
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
|
49 |
+
AE = AutoEncoder(chunk_size=1).to(device)
|
50 |
+
CONDITIONER = CLIPConditioner().to(device)
|
51 |
+
DISCRETIZATION = DDPMDiscretization()
|
52 |
+
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
|
53 |
+
VERSION_DICT = {
|
54 |
+
"H": 576,
|
55 |
+
"W": 576,
|
56 |
+
"T": 21,
|
57 |
+
"C": 4,
|
58 |
+
"f": 8,
|
59 |
+
"options": {},
|
60 |
+
}
|
61 |
+
|
62 |
+
if COMPILE:
|
63 |
+
MODEL = torch.compile(MODEL, dynamic=False)
|
64 |
+
CONDITIONER = torch.compile(CONDITIONER, dynamic=False)
|
65 |
+
AE = torch.compile(AE, dynamic=False)
|
66 |
+
|
67 |
+
|
68 |
+
def parse_task(
|
69 |
+
task,
|
70 |
+
scene,
|
71 |
+
num_inputs,
|
72 |
+
T,
|
73 |
+
version_dict,
|
74 |
+
):
|
75 |
+
options = version_dict["options"]
|
76 |
+
|
77 |
+
anchor_indices = None
|
78 |
+
anchor_c2ws = None
|
79 |
+
anchor_Ks = None
|
80 |
+
|
81 |
+
if task == "img2trajvid_s-prob":
|
82 |
+
if num_inputs is not None:
|
83 |
+
assert (
|
84 |
+
num_inputs == 1
|
85 |
+
), "Task `img2trajvid_s-prob` only support 1-view conditioning..."
|
86 |
+
else:
|
87 |
+
num_inputs = 1
|
88 |
+
num_targets = options.get("num_targets", T - 1)
|
89 |
+
num_anchors = infer_prior_stats(
|
90 |
+
T,
|
91 |
+
num_inputs,
|
92 |
+
num_total_frames=num_targets,
|
93 |
+
version_dict=version_dict,
|
94 |
+
)
|
95 |
+
|
96 |
+
input_indices = [0]
|
97 |
+
anchor_indices = np.linspace(1, num_targets, num_anchors).tolist()
|
98 |
+
|
99 |
+
all_imgs_path = [scene] + [None] * num_targets
|
100 |
+
|
101 |
+
c2ws, fovs = get_preset_pose_fov(
|
102 |
+
option=options.get("traj_prior", "orbit"),
|
103 |
+
num_frames=num_targets + 1,
|
104 |
+
start_w2c=torch.eye(4),
|
105 |
+
look_at=torch.Tensor([0, 0, 10]),
|
106 |
+
)
|
107 |
+
|
108 |
+
with Image.open(scene) as img:
|
109 |
+
W, H = img.size
|
110 |
+
aspect_ratio = W / H
|
111 |
+
Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) # unormalized
|
112 |
+
Ks[:, :2] *= (
|
113 |
+
torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1)
|
114 |
+
) # normalized
|
115 |
+
Ks = Ks.numpy()
|
116 |
+
|
117 |
+
anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]]
|
118 |
+
anchor_Ks = Ks[[round(ind) for ind in anchor_indices]]
|
119 |
+
|
120 |
+
else:
|
121 |
+
parser = get_parser(
|
122 |
+
parser_type="reconfusion",
|
123 |
+
data_dir=scene,
|
124 |
+
normalize=False,
|
125 |
+
)
|
126 |
+
all_imgs_path = parser.image_paths
|
127 |
+
c2ws = parser.camtoworlds
|
128 |
+
camera_ids = parser.camera_ids
|
129 |
+
Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0)
|
130 |
+
|
131 |
+
if num_inputs is None:
|
132 |
+
assert len(parser.splits_per_num_input_frames.keys()) == 1
|
133 |
+
num_inputs = list(parser.splits_per_num_input_frames.keys())[0]
|
134 |
+
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
|
135 |
+
elif isinstance(num_inputs, str):
|
136 |
+
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
|
137 |
+
num_inputs = int(num_inputs.split("-")[0]) # for example 1_from32
|
138 |
+
else:
|
139 |
+
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore
|
140 |
+
|
141 |
+
num_targets = len(split_dict["test_ids"])
|
142 |
+
|
143 |
+
if task == "img2img":
|
144 |
+
# Note in this setting, we should refrain from using all the other camera
|
145 |
+
# info except ones from sampled_indices, and most importantly, the order.
|
146 |
+
num_anchors = infer_prior_stats(
|
147 |
+
T,
|
148 |
+
num_inputs,
|
149 |
+
num_total_frames=num_targets,
|
150 |
+
version_dict=version_dict,
|
151 |
+
)
|
152 |
+
|
153 |
+
sampled_indices = np.sort(
|
154 |
+
np.array(split_dict["train_ids"] + split_dict["test_ids"])
|
155 |
+
) # we always sort all indices first
|
156 |
+
|
157 |
+
traj_prior = options.get("traj_prior", None)
|
158 |
+
if traj_prior == "spiral":
|
159 |
+
assert parser.bounds is not None
|
160 |
+
anchor_c2ws = generate_spiral_path(
|
161 |
+
c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]),
|
162 |
+
parser.bounds[sampled_indices],
|
163 |
+
n_frames=num_anchors + 1,
|
164 |
+
n_rots=2,
|
165 |
+
zrate=0.5,
|
166 |
+
endpoint=False,
|
167 |
+
)[1:] @ np.diagflat([1, -1, -1, 1])
|
168 |
+
elif traj_prior == "interpolated":
|
169 |
+
assert num_inputs > 1
|
170 |
+
anchor_c2ws = generate_interpolated_path(
|
171 |
+
c2ws[split_dict["train_ids"], :3],
|
172 |
+
round((num_anchors + 1) / (num_inputs - 1)),
|
173 |
+
endpoint=False,
|
174 |
+
)[1 : num_anchors + 1]
|
175 |
+
elif traj_prior == "orbit":
|
176 |
+
c2ws_th = torch.as_tensor(c2ws)
|
177 |
+
lookat = get_lookat(
|
178 |
+
c2ws_th[sampled_indices, :3, 3],
|
179 |
+
c2ws_th[sampled_indices, :3, 2],
|
180 |
+
)
|
181 |
+
anchor_c2ws = torch.linalg.inv(
|
182 |
+
get_arc_horizontal_w2cs(
|
183 |
+
torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]),
|
184 |
+
lookat,
|
185 |
+
-F.normalize(
|
186 |
+
c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0),
|
187 |
+
dim=-1,
|
188 |
+
),
|
189 |
+
num_frames=num_anchors + 1,
|
190 |
+
endpoint=False,
|
191 |
+
)
|
192 |
+
).numpy()[1:, :3]
|
193 |
+
else:
|
194 |
+
anchor_c2ws = None
|
195 |
+
# anchor_Ks is default to be the first from target_Ks
|
196 |
+
|
197 |
+
all_imgs_path = [all_imgs_path[i] for i in sampled_indices]
|
198 |
+
c2ws = c2ws[sampled_indices]
|
199 |
+
Ks = Ks[sampled_indices]
|
200 |
+
|
201 |
+
# absolute to relative indices
|
202 |
+
input_indices = compute_relative_inds(
|
203 |
+
sampled_indices,
|
204 |
+
np.array(split_dict["train_ids"]),
|
205 |
+
)
|
206 |
+
anchor_indices = np.arange(
|
207 |
+
sampled_indices.shape[0],
|
208 |
+
sampled_indices.shape[0] + num_anchors,
|
209 |
+
).tolist() # the order has no meaning here
|
210 |
+
|
211 |
+
elif task == "img2vid":
|
212 |
+
num_targets = len(all_imgs_path) - num_inputs
|
213 |
+
num_anchors = infer_prior_stats(
|
214 |
+
T,
|
215 |
+
num_inputs,
|
216 |
+
num_total_frames=num_targets,
|
217 |
+
version_dict=version_dict,
|
218 |
+
)
|
219 |
+
|
220 |
+
input_indices = split_dict["train_ids"]
|
221 |
+
anchor_indices = infer_prior_inds(
|
222 |
+
c2ws,
|
223 |
+
num_prior_frames=num_anchors,
|
224 |
+
input_frame_indices=input_indices,
|
225 |
+
options=options,
|
226 |
+
).tolist()
|
227 |
+
num_anchors = len(anchor_indices)
|
228 |
+
anchor_c2ws = c2ws[anchor_indices, :3]
|
229 |
+
anchor_Ks = Ks[anchor_indices]
|
230 |
+
|
231 |
+
elif task == "img2trajvid":
|
232 |
+
num_anchors = infer_prior_stats(
|
233 |
+
T,
|
234 |
+
num_inputs,
|
235 |
+
num_total_frames=num_targets,
|
236 |
+
version_dict=version_dict,
|
237 |
+
)
|
238 |
+
|
239 |
+
target_c2ws = c2ws[split_dict["test_ids"], :3]
|
240 |
+
target_Ks = Ks[split_dict["test_ids"]]
|
241 |
+
anchor_c2ws = target_c2ws[
|
242 |
+
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64)
|
243 |
+
]
|
244 |
+
anchor_Ks = target_Ks[
|
245 |
+
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64)
|
246 |
+
]
|
247 |
+
|
248 |
+
sampled_indices = split_dict["train_ids"] + split_dict["test_ids"]
|
249 |
+
all_imgs_path = [all_imgs_path[i] for i in sampled_indices]
|
250 |
+
c2ws = c2ws[sampled_indices]
|
251 |
+
Ks = Ks[sampled_indices]
|
252 |
+
|
253 |
+
input_indices = np.arange(num_inputs).tolist()
|
254 |
+
anchor_indices = np.linspace(
|
255 |
+
num_inputs, num_inputs + num_targets - 1, num_anchors
|
256 |
+
).tolist()
|
257 |
+
|
258 |
+
else:
|
259 |
+
raise ValueError(f"Unknown task: {task}")
|
260 |
+
|
261 |
+
return (
|
262 |
+
all_imgs_path,
|
263 |
+
num_inputs,
|
264 |
+
num_targets,
|
265 |
+
input_indices,
|
266 |
+
anchor_indices,
|
267 |
+
torch.tensor(c2ws[:, :3]).float(),
|
268 |
+
torch.tensor(Ks).float(),
|
269 |
+
(torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None),
|
270 |
+
(torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None),
|
271 |
+
)
|
272 |
+
|
273 |
+
|
274 |
+
def main(
|
275 |
+
data_path,
|
276 |
+
data_items=None,
|
277 |
+
task="img2img",
|
278 |
+
save_subdir="",
|
279 |
+
H=None,
|
280 |
+
W=None,
|
281 |
+
T=None,
|
282 |
+
use_traj_prior=False,
|
283 |
+
**overwrite_options,
|
284 |
+
):
|
285 |
+
if H is not None:
|
286 |
+
VERSION_DICT["H"] = H
|
287 |
+
if W is not None:
|
288 |
+
VERSION_DICT["W"] = W
|
289 |
+
if T is not None:
|
290 |
+
VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T
|
291 |
+
|
292 |
+
options = VERSION_DICT["options"]
|
293 |
+
options["chunk_strategy"] = "nearest-gt"
|
294 |
+
options["video_save_fps"] = 30.0
|
295 |
+
options["beta_linear_start"] = 5e-6
|
296 |
+
options["log_snr_shift"] = 2.4
|
297 |
+
options["guider_types"] = 1
|
298 |
+
options["cfg"] = 2.0
|
299 |
+
options["camera_scale"] = 2.0
|
300 |
+
options["num_steps"] = 50
|
301 |
+
options["cfg_min"] = 1.2
|
302 |
+
options["encoding_t"] = 1
|
303 |
+
options["decoding_t"] = 1
|
304 |
+
options["num_inputs"] = None
|
305 |
+
options["seed"] = 23
|
306 |
+
options.update(overwrite_options)
|
307 |
+
|
308 |
+
num_inputs = options["num_inputs"]
|
309 |
+
seed = options["seed"]
|
310 |
+
|
311 |
+
if data_items is not None:
|
312 |
+
if not isinstance(data_items, (list, tuple)):
|
313 |
+
data_items = data_items.split(",")
|
314 |
+
scenes = [os.path.join(data_path, item) for item in data_items]
|
315 |
+
else:
|
316 |
+
scenes = glob.glob(osp.join(data_path, "*"))
|
317 |
+
|
318 |
+
for scene in tqdm(scenes):
|
319 |
+
save_path_scene = os.path.join(
|
320 |
+
WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0]
|
321 |
+
)
|
322 |
+
if options.get("skip_saved", False) and os.path.exists(
|
323 |
+
os.path.join(save_path_scene, "transforms.json")
|
324 |
+
):
|
325 |
+
print(f"Skipping {scene} as it is already sampled.")
|
326 |
+
continue
|
327 |
+
|
328 |
+
# parse_task -> infer_prior_stats modifies VERSION_DICT["T"] in-place.
|
329 |
+
(
|
330 |
+
all_imgs_path,
|
331 |
+
num_inputs,
|
332 |
+
num_targets,
|
333 |
+
input_indices,
|
334 |
+
anchor_indices,
|
335 |
+
c2ws,
|
336 |
+
Ks,
|
337 |
+
anchor_c2ws,
|
338 |
+
anchor_Ks,
|
339 |
+
) = parse_task(
|
340 |
+
task,
|
341 |
+
scene,
|
342 |
+
num_inputs,
|
343 |
+
VERSION_DICT["T"],
|
344 |
+
VERSION_DICT,
|
345 |
+
)
|
346 |
+
assert num_inputs is not None
|
347 |
+
# Create image conditioning.
|
348 |
+
image_cond = {
|
349 |
+
"img": all_imgs_path,
|
350 |
+
"input_indices": input_indices,
|
351 |
+
"prior_indices": anchor_indices,
|
352 |
+
}
|
353 |
+
# Create camera conditioning.
|
354 |
+
camera_cond = {
|
355 |
+
"c2w": c2ws.clone(),
|
356 |
+
"K": Ks.clone(),
|
357 |
+
"input_indices": list(range(num_inputs + num_targets)),
|
358 |
+
}
|
359 |
+
# run_one_scene -> transform_img_and_K modifies VERSION_DICT["H"] and VERSION_DICT["W"] in-place.
|
360 |
+
video_path_generator = run_one_scene(
|
361 |
+
task,
|
362 |
+
VERSION_DICT, # H, W maybe updated in run_one_scene
|
363 |
+
model=MODEL,
|
364 |
+
ae=AE,
|
365 |
+
conditioner=CONDITIONER,
|
366 |
+
denoiser=DENOISER,
|
367 |
+
image_cond=image_cond,
|
368 |
+
camera_cond=camera_cond,
|
369 |
+
save_path=save_path_scene,
|
370 |
+
use_traj_prior=use_traj_prior,
|
371 |
+
traj_prior_Ks=anchor_Ks,
|
372 |
+
traj_prior_c2ws=anchor_c2ws,
|
373 |
+
seed=seed, # to ensure sampled video can be reproduced in regardless of start and i
|
374 |
+
)
|
375 |
+
for _ in video_path_generator:
|
376 |
+
pass
|
377 |
+
|
378 |
+
# Convert from OpenCV to OpenGL camera format.
|
379 |
+
c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float()
|
380 |
+
img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png")))
|
381 |
+
if len(img_paths) != len(c2ws):
|
382 |
+
input_img_paths = sorted(
|
383 |
+
glob.glob(osp.join(save_path_scene, "input", "*.png"))
|
384 |
+
)
|
385 |
+
assert len(img_paths) == num_targets
|
386 |
+
assert len(input_img_paths) == num_inputs
|
387 |
+
assert c2ws.shape[0] == num_inputs + num_targets
|
388 |
+
target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices]
|
389 |
+
img_paths = [
|
390 |
+
input_img_paths[input_indices.index(i)]
|
391 |
+
if i in input_indices
|
392 |
+
else img_paths[target_indices.index(i)]
|
393 |
+
for i in range(c2ws.shape[0])
|
394 |
+
]
|
395 |
+
create_transforms_simple(
|
396 |
+
save_path=save_path_scene,
|
397 |
+
img_paths=img_paths,
|
398 |
+
img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat(
|
399 |
+
num_inputs + num_targets, 0
|
400 |
+
),
|
401 |
+
c2ws=c2ws,
|
402 |
+
Ks=Ks,
|
403 |
+
)
|
404 |
+
|
405 |
+
|
406 |
+
if __name__ == "__main__":
|
407 |
+
fire.Fire(main)
|
demo_gr.py
ADDED
@@ -0,0 +1,1248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import os.path as osp
|
5 |
+
import queue
|
6 |
+
import secrets
|
7 |
+
import threading
|
8 |
+
import time
|
9 |
+
from datetime import datetime
|
10 |
+
from glob import glob
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Literal
|
13 |
+
|
14 |
+
import gradio as gr
|
15 |
+
import httpx
|
16 |
+
import imageio.v3 as iio
|
17 |
+
import numpy as np
|
18 |
+
import torch
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import tyro
|
21 |
+
import viser
|
22 |
+
import viser.transforms as vt
|
23 |
+
from einops import rearrange
|
24 |
+
from gradio import networking
|
25 |
+
from gradio.context import LocalContext
|
26 |
+
from gradio.tunneling import CERTIFICATE_PATH, Tunnel
|
27 |
+
|
28 |
+
from seva.eval import (
|
29 |
+
IS_TORCH_NIGHTLY,
|
30 |
+
chunk_input_and_test,
|
31 |
+
create_transforms_simple,
|
32 |
+
infer_prior_stats,
|
33 |
+
run_one_scene,
|
34 |
+
transform_img_and_K,
|
35 |
+
)
|
36 |
+
from seva.geometry import (
|
37 |
+
DEFAULT_FOV_RAD,
|
38 |
+
get_default_intrinsics,
|
39 |
+
get_preset_pose_fov,
|
40 |
+
normalize_scene,
|
41 |
+
)
|
42 |
+
from seva.gui import define_gui
|
43 |
+
from seva.model import SGMWrapper
|
44 |
+
from seva.modules.autoencoder import AutoEncoder
|
45 |
+
from seva.modules.conditioner import CLIPConditioner
|
46 |
+
from seva.modules.preprocessor import Dust3rPipeline
|
47 |
+
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
|
48 |
+
from seva.utils import load_model
|
49 |
+
|
50 |
+
device = "cpu"
|
51 |
+
|
52 |
+
|
53 |
+
# Constants.
|
54 |
+
WORK_DIR = "work_dirs/demo_gr"
|
55 |
+
MAX_SESSIONS = 1
|
56 |
+
ADVANCE_EXAMPLE_MAP = [
|
57 |
+
(
|
58 |
+
"assets/advance/blue-car.jpg",
|
59 |
+
["assets/advance/blue-car.jpg"],
|
60 |
+
),
|
61 |
+
(
|
62 |
+
"assets/advance/garden-4_0.jpg",
|
63 |
+
[
|
64 |
+
"assets/advance/garden-4_0.jpg",
|
65 |
+
"assets/advance/garden-4_1.jpg",
|
66 |
+
"assets/advance/garden-4_2.jpg",
|
67 |
+
"assets/advance/garden-4_3.jpg",
|
68 |
+
],
|
69 |
+
),
|
70 |
+
(
|
71 |
+
"assets/advance/vgg-lab-4_0.png",
|
72 |
+
[
|
73 |
+
"assets/advance/vgg-lab-4_0.png",
|
74 |
+
"assets/advance/vgg-lab-4_1.png",
|
75 |
+
"assets/advance/vgg-lab-4_2.png",
|
76 |
+
"assets/advance/vgg-lab-4_3.png",
|
77 |
+
],
|
78 |
+
),
|
79 |
+
(
|
80 |
+
"assets/advance/telebooth-2_0.jpg",
|
81 |
+
[
|
82 |
+
"assets/advance/telebooth-2_0.jpg",
|
83 |
+
"assets/advance/telebooth-2_1.jpg",
|
84 |
+
],
|
85 |
+
),
|
86 |
+
(
|
87 |
+
"assets/advance/backyard-7_0.jpg",
|
88 |
+
[
|
89 |
+
"assets/advance/backyard-7_0.jpg",
|
90 |
+
"assets/advance/backyard-7_1.jpg",
|
91 |
+
"assets/advance/backyard-7_2.jpg",
|
92 |
+
"assets/advance/backyard-7_3.jpg",
|
93 |
+
"assets/advance/backyard-7_4.jpg",
|
94 |
+
"assets/advance/backyard-7_5.jpg",
|
95 |
+
"assets/advance/backyard-7_6.jpg",
|
96 |
+
],
|
97 |
+
),
|
98 |
+
]
|
99 |
+
|
100 |
+
if IS_TORCH_NIGHTLY:
|
101 |
+
COMPILE = True
|
102 |
+
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
103 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
104 |
+
else:
|
105 |
+
COMPILE = False
|
106 |
+
|
107 |
+
# Shared global variables across sessions.
|
108 |
+
DUST3R = Dust3rPipeline(device=device) # type: ignore
|
109 |
+
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
|
110 |
+
AE = AutoEncoder(chunk_size=1).to(device)
|
111 |
+
CONDITIONER = CLIPConditioner().to(device)
|
112 |
+
DISCRETIZATION = DDPMDiscretization()
|
113 |
+
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
|
114 |
+
VERSION_DICT = {
|
115 |
+
"H": 576,
|
116 |
+
"W": 576,
|
117 |
+
"T": 21,
|
118 |
+
"C": 4,
|
119 |
+
"f": 8,
|
120 |
+
"options": {},
|
121 |
+
}
|
122 |
+
SERVERS = {}
|
123 |
+
ABORT_EVENTS = {}
|
124 |
+
|
125 |
+
if COMPILE:
|
126 |
+
MODEL = torch.compile(MODEL)
|
127 |
+
CONDITIONER = torch.compile(CONDITIONER)
|
128 |
+
AE = torch.compile(AE)
|
129 |
+
|
130 |
+
|
131 |
+
class SevaRenderer(object):
|
132 |
+
def __init__(self, server: viser.ViserServer):
|
133 |
+
self.server = server
|
134 |
+
self.gui_state = None
|
135 |
+
|
136 |
+
def preprocess(
|
137 |
+
self, input_img_path_or_tuples: list[tuple[str, None]] | str
|
138 |
+
) -> tuple[dict, dict, dict]:
|
139 |
+
# Simply hardcode these such that aspect ratio is always kept and
|
140 |
+
# shorter side is resized to 576. This is only to make GUI option fewer
|
141 |
+
# though, changing it still works.
|
142 |
+
shorter: int = 576
|
143 |
+
# Has to be 64 multiple for the network.
|
144 |
+
shorter = round(shorter / 64) * 64
|
145 |
+
|
146 |
+
if isinstance(input_img_path_or_tuples, str):
|
147 |
+
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
|
148 |
+
input_imgs = torch.as_tensor(
|
149 |
+
iio.imread(input_img_path_or_tuples) / 255.0, dtype=torch.float32
|
150 |
+
)[None, ..., :3]
|
151 |
+
input_imgs = transform_img_and_K(
|
152 |
+
input_imgs.permute(0, 3, 1, 2),
|
153 |
+
shorter,
|
154 |
+
K=None,
|
155 |
+
size_stride=64,
|
156 |
+
)[0].permute(0, 2, 3, 1)
|
157 |
+
input_Ks = get_default_intrinsics(
|
158 |
+
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
|
159 |
+
)
|
160 |
+
input_c2ws = torch.eye(4)[None]
|
161 |
+
# Simulate a small time interval such that gradio can update
|
162 |
+
# propgress properly.
|
163 |
+
time.sleep(0.1)
|
164 |
+
return (
|
165 |
+
{
|
166 |
+
"input_imgs": input_imgs,
|
167 |
+
"input_Ks": input_Ks,
|
168 |
+
"input_c2ws": input_c2ws,
|
169 |
+
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
|
170 |
+
"points": [np.zeros((0, 3))],
|
171 |
+
"point_colors": [np.zeros((0, 3))],
|
172 |
+
"scene_scale": 1.0,
|
173 |
+
},
|
174 |
+
gr.update(visible=False),
|
175 |
+
gr.update(),
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
# Assume `Advance` demo mode: use dust3r to extract camera parameters and points.
|
179 |
+
img_paths = [p for (p, _) in input_img_path_or_tuples]
|
180 |
+
(
|
181 |
+
input_imgs,
|
182 |
+
input_Ks,
|
183 |
+
input_c2ws,
|
184 |
+
points,
|
185 |
+
point_colors,
|
186 |
+
) = DUST3R.infer_cameras_and_points(img_paths)
|
187 |
+
num_inputs = len(img_paths)
|
188 |
+
if num_inputs == 1:
|
189 |
+
input_imgs, input_Ks, input_c2ws, points, point_colors = (
|
190 |
+
input_imgs[:1],
|
191 |
+
input_Ks[:1],
|
192 |
+
input_c2ws[:1],
|
193 |
+
points[:1],
|
194 |
+
point_colors[:1],
|
195 |
+
)
|
196 |
+
input_imgs = [img[..., :3] for img in input_imgs]
|
197 |
+
# Normalize the scene.
|
198 |
+
point_chunks = [p.shape[0] for p in points]
|
199 |
+
point_indices = np.cumsum(point_chunks)[:-1]
|
200 |
+
input_c2ws, points, _ = normalize_scene( # type: ignore
|
201 |
+
input_c2ws,
|
202 |
+
np.concatenate(points, 0),
|
203 |
+
camera_center_method="poses",
|
204 |
+
)
|
205 |
+
points = np.split(points, point_indices, 0)
|
206 |
+
# Scale camera and points for viewport visualization.
|
207 |
+
scene_scale = np.median(
|
208 |
+
np.ptp(np.concatenate([input_c2ws[:, :3, 3], *points], 0), -1)
|
209 |
+
)
|
210 |
+
input_c2ws[:, :3, 3] /= scene_scale
|
211 |
+
points = [point / scene_scale for point in points]
|
212 |
+
input_imgs = [
|
213 |
+
torch.as_tensor(img / 255.0, dtype=torch.float32) for img in input_imgs
|
214 |
+
]
|
215 |
+
input_Ks = torch.as_tensor(input_Ks)
|
216 |
+
input_c2ws = torch.as_tensor(input_c2ws)
|
217 |
+
new_input_imgs, new_input_Ks = [], []
|
218 |
+
for img, K in zip(input_imgs, input_Ks):
|
219 |
+
img = rearrange(img, "h w c -> 1 c h w")
|
220 |
+
# If you don't want to keep aspect ratio and want to always center crop, use this:
|
221 |
+
# img, K = transform_img_and_K(img, (shorter, shorter), K=K[None])
|
222 |
+
img, K = transform_img_and_K(img, shorter, K=K[None], size_stride=64)
|
223 |
+
assert isinstance(K, torch.Tensor)
|
224 |
+
K = K / K.new_tensor([img.shape[-1], img.shape[-2], 1])[:, None]
|
225 |
+
new_input_imgs.append(img)
|
226 |
+
new_input_Ks.append(K)
|
227 |
+
input_imgs = torch.cat(new_input_imgs, 0)
|
228 |
+
input_imgs = rearrange(input_imgs, "b c h w -> b h w c")[..., :3]
|
229 |
+
input_Ks = torch.cat(new_input_Ks, 0)
|
230 |
+
return (
|
231 |
+
{
|
232 |
+
"input_imgs": input_imgs,
|
233 |
+
"input_Ks": input_Ks,
|
234 |
+
"input_c2ws": input_c2ws,
|
235 |
+
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
|
236 |
+
"points": points,
|
237 |
+
"point_colors": point_colors,
|
238 |
+
"scene_scale": scene_scale,
|
239 |
+
},
|
240 |
+
gr.update(visible=False),
|
241 |
+
gr.update()
|
242 |
+
if num_inputs <= 10
|
243 |
+
else gr.update(choices=["interp"], value="interp"),
|
244 |
+
)
|
245 |
+
|
246 |
+
def visualize_scene(self, preprocessed: dict):
|
247 |
+
server = self.server
|
248 |
+
server.scene.reset()
|
249 |
+
server.gui.reset()
|
250 |
+
set_bkgd_color(server)
|
251 |
+
|
252 |
+
(
|
253 |
+
input_imgs,
|
254 |
+
input_Ks,
|
255 |
+
input_c2ws,
|
256 |
+
input_wh,
|
257 |
+
points,
|
258 |
+
point_colors,
|
259 |
+
scene_scale,
|
260 |
+
) = (
|
261 |
+
preprocessed["input_imgs"],
|
262 |
+
preprocessed["input_Ks"],
|
263 |
+
preprocessed["input_c2ws"],
|
264 |
+
preprocessed["input_wh"],
|
265 |
+
preprocessed["points"],
|
266 |
+
preprocessed["point_colors"],
|
267 |
+
preprocessed["scene_scale"],
|
268 |
+
)
|
269 |
+
W, H = input_wh
|
270 |
+
|
271 |
+
server.scene.set_up_direction(-input_c2ws[..., :3, 1].mean(0).numpy())
|
272 |
+
|
273 |
+
# Use first image as default fov.
|
274 |
+
assert input_imgs[0].shape[:2] == (H, W)
|
275 |
+
if H > W:
|
276 |
+
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 0, 0].item()))
|
277 |
+
else:
|
278 |
+
init_fov = 2 * np.arctan(1 / (2 * input_Ks[0, 1, 1].item()))
|
279 |
+
init_fov_deg = float(init_fov / np.pi * 180.0)
|
280 |
+
|
281 |
+
frustum_nodes, pcd_nodes = [], []
|
282 |
+
for i in range(len(input_imgs)):
|
283 |
+
K = input_Ks[i]
|
284 |
+
frustum = server.scene.add_camera_frustum(
|
285 |
+
f"/scene_assets/cameras/{i}",
|
286 |
+
fov=2 * np.arctan(1 / (2 * K[1, 1].item())),
|
287 |
+
aspect=W / H,
|
288 |
+
scale=0.1 * scene_scale,
|
289 |
+
image=(input_imgs[i].numpy() * 255.0).astype(np.uint8),
|
290 |
+
wxyz=vt.SO3.from_matrix(input_c2ws[i, :3, :3].numpy()).wxyz,
|
291 |
+
position=input_c2ws[i, :3, 3].numpy(),
|
292 |
+
)
|
293 |
+
|
294 |
+
def get_handler(frustum):
|
295 |
+
def handler(event: viser.GuiEvent) -> None:
|
296 |
+
assert event.client_id is not None
|
297 |
+
client = server.get_clients()[event.client_id]
|
298 |
+
with client.atomic():
|
299 |
+
client.camera.position = frustum.position
|
300 |
+
client.camera.wxyz = frustum.wxyz
|
301 |
+
# Set look_at as the projected origin onto the
|
302 |
+
# frustum's forward direction.
|
303 |
+
look_direction = vt.SO3(frustum.wxyz).as_matrix()[:, 2]
|
304 |
+
position_origin = -frustum.position
|
305 |
+
client.camera.look_at = (
|
306 |
+
frustum.position
|
307 |
+
+ np.dot(look_direction, position_origin)
|
308 |
+
/ np.linalg.norm(position_origin)
|
309 |
+
* look_direction
|
310 |
+
)
|
311 |
+
|
312 |
+
return handler
|
313 |
+
|
314 |
+
frustum.on_click(get_handler(frustum)) # type: ignore
|
315 |
+
frustum_nodes.append(frustum)
|
316 |
+
|
317 |
+
pcd = server.scene.add_point_cloud(
|
318 |
+
f"/scene_assets/points/{i}",
|
319 |
+
points[i],
|
320 |
+
point_colors[i],
|
321 |
+
point_size=0.01 * scene_scale,
|
322 |
+
point_shape="circle",
|
323 |
+
)
|
324 |
+
pcd_nodes.append(pcd)
|
325 |
+
|
326 |
+
with server.gui.add_folder("Scene scale", expand_by_default=False, order=200):
|
327 |
+
camera_scale_slider = server.gui.add_slider(
|
328 |
+
"Log camera scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
|
329 |
+
)
|
330 |
+
|
331 |
+
@camera_scale_slider.on_update
|
332 |
+
def _(_) -> None:
|
333 |
+
for i in range(len(frustum_nodes)):
|
334 |
+
frustum_nodes[i].scale = (
|
335 |
+
0.1 * scene_scale * 10**camera_scale_slider.value
|
336 |
+
)
|
337 |
+
|
338 |
+
point_scale_slider = server.gui.add_slider(
|
339 |
+
"Log point scale", initial_value=0.0, min=-2.0, max=2.0, step=0.1
|
340 |
+
)
|
341 |
+
|
342 |
+
@point_scale_slider.on_update
|
343 |
+
def _(_) -> None:
|
344 |
+
for i in range(len(pcd_nodes)):
|
345 |
+
pcd_nodes[i].point_size = (
|
346 |
+
0.01 * scene_scale * 10**point_scale_slider.value
|
347 |
+
)
|
348 |
+
|
349 |
+
self.gui_state = define_gui(
|
350 |
+
server,
|
351 |
+
init_fov=init_fov_deg,
|
352 |
+
img_wh=input_wh,
|
353 |
+
scene_scale=scene_scale,
|
354 |
+
)
|
355 |
+
|
356 |
+
def get_target_c2ws_and_Ks_from_gui(self, preprocessed: dict):
|
357 |
+
input_wh = preprocessed["input_wh"]
|
358 |
+
W, H = input_wh
|
359 |
+
gui_state = self.gui_state
|
360 |
+
assert gui_state is not None and gui_state.camera_traj_list is not None
|
361 |
+
target_c2ws, target_Ks = [], []
|
362 |
+
for item in gui_state.camera_traj_list:
|
363 |
+
target_c2ws.append(item["w2c"])
|
364 |
+
assert item["img_wh"] == input_wh
|
365 |
+
K = np.array(item["K"]).reshape(3, 3) / np.array([W, H, 1])[:, None]
|
366 |
+
target_Ks.append(K)
|
367 |
+
target_c2ws = torch.as_tensor(
|
368 |
+
np.linalg.inv(np.array(target_c2ws).reshape(-1, 4, 4))
|
369 |
+
)
|
370 |
+
target_Ks = torch.as_tensor(np.array(target_Ks).reshape(-1, 3, 3))
|
371 |
+
return target_c2ws, target_Ks
|
372 |
+
|
373 |
+
def get_target_c2ws_and_Ks_from_preset(
|
374 |
+
self,
|
375 |
+
preprocessed: dict,
|
376 |
+
preset_traj: Literal[
|
377 |
+
"orbit",
|
378 |
+
"spiral",
|
379 |
+
"lemniscate",
|
380 |
+
"zoom-in",
|
381 |
+
"zoom-out",
|
382 |
+
"dolly zoom-in",
|
383 |
+
"dolly zoom-out",
|
384 |
+
"move-forward",
|
385 |
+
"move-backward",
|
386 |
+
"move-up",
|
387 |
+
"move-down",
|
388 |
+
"move-left",
|
389 |
+
"move-right",
|
390 |
+
],
|
391 |
+
num_frames: int,
|
392 |
+
zoom_factor: float | None,
|
393 |
+
):
|
394 |
+
img_wh = preprocessed["input_wh"]
|
395 |
+
start_c2w = preprocessed["input_c2ws"][0]
|
396 |
+
start_w2c = torch.linalg.inv(start_c2w)
|
397 |
+
look_at = torch.tensor([0, 0, 10])
|
398 |
+
start_fov = DEFAULT_FOV_RAD
|
399 |
+
target_c2ws, target_fovs = get_preset_pose_fov(
|
400 |
+
preset_traj,
|
401 |
+
num_frames,
|
402 |
+
start_w2c,
|
403 |
+
look_at,
|
404 |
+
-start_c2w[:3, 1],
|
405 |
+
start_fov,
|
406 |
+
spiral_radii=[1.0, 1.0, 0.5],
|
407 |
+
zoom_factor=zoom_factor,
|
408 |
+
)
|
409 |
+
target_c2ws = torch.as_tensor(target_c2ws)
|
410 |
+
target_fovs = torch.as_tensor(target_fovs)
|
411 |
+
target_Ks = get_default_intrinsics(
|
412 |
+
target_fovs, # type: ignore
|
413 |
+
aspect_ratio=img_wh[0] / img_wh[1],
|
414 |
+
)
|
415 |
+
return target_c2ws, target_Ks
|
416 |
+
|
417 |
+
def export_output_data(self, preprocessed: dict, output_dir: str):
|
418 |
+
input_imgs, input_Ks, input_c2ws, input_wh = (
|
419 |
+
preprocessed["input_imgs"],
|
420 |
+
preprocessed["input_Ks"],
|
421 |
+
preprocessed["input_c2ws"],
|
422 |
+
preprocessed["input_wh"],
|
423 |
+
)
|
424 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
|
425 |
+
|
426 |
+
num_inputs = len(input_imgs)
|
427 |
+
num_targets = len(target_c2ws)
|
428 |
+
|
429 |
+
input_imgs = (input_imgs.cpu().numpy() * 255.0).astype(np.uint8)
|
430 |
+
input_c2ws = input_c2ws.cpu().numpy()
|
431 |
+
input_Ks = input_Ks.cpu().numpy()
|
432 |
+
target_c2ws = target_c2ws.cpu().numpy()
|
433 |
+
target_Ks = target_Ks.cpu().numpy()
|
434 |
+
img_whs = np.array(input_wh)[None].repeat(len(input_imgs) + len(target_Ks), 0)
|
435 |
+
|
436 |
+
os.makedirs(output_dir, exist_ok=True)
|
437 |
+
img_paths = []
|
438 |
+
for i, img in enumerate(input_imgs):
|
439 |
+
iio.imwrite(img_path := osp.join(output_dir, f"{i:03d}.png"), img)
|
440 |
+
img_paths.append(img_path)
|
441 |
+
for i in range(num_targets):
|
442 |
+
iio.imwrite(
|
443 |
+
img_path := osp.join(output_dir, f"{i + num_inputs:03d}.png"),
|
444 |
+
np.zeros((input_wh[1], input_wh[0], 3), dtype=np.uint8),
|
445 |
+
)
|
446 |
+
img_paths.append(img_path)
|
447 |
+
|
448 |
+
# Convert from OpenCV to OpenGL camera format.
|
449 |
+
all_c2ws = np.concatenate([input_c2ws, target_c2ws])
|
450 |
+
all_Ks = np.concatenate([input_Ks, target_Ks])
|
451 |
+
all_c2ws = all_c2ws @ np.diag([1, -1, -1, 1])
|
452 |
+
create_transforms_simple(output_dir, img_paths, img_whs, all_c2ws, all_Ks)
|
453 |
+
split_dict = {
|
454 |
+
"train_ids": list(range(num_inputs)),
|
455 |
+
"test_ids": list(range(num_inputs, num_inputs + num_targets)),
|
456 |
+
}
|
457 |
+
with open(
|
458 |
+
osp.join(output_dir, f"train_test_split_{num_inputs}.json"), "w"
|
459 |
+
) as f:
|
460 |
+
json.dump(split_dict, f, indent=4)
|
461 |
+
gr.Info(f"Output data saved to {output_dir}", duration=1)
|
462 |
+
|
463 |
+
def render(
|
464 |
+
self,
|
465 |
+
preprocessed: dict,
|
466 |
+
session_hash: str,
|
467 |
+
seed: int,
|
468 |
+
chunk_strategy: str,
|
469 |
+
cfg: float,
|
470 |
+
preset_traj: Literal[
|
471 |
+
"orbit",
|
472 |
+
"spiral",
|
473 |
+
"lemniscate",
|
474 |
+
"zoom-in",
|
475 |
+
"zoom-out",
|
476 |
+
"dolly zoom-in",
|
477 |
+
"dolly zoom-out",
|
478 |
+
"move-forward",
|
479 |
+
"move-backward",
|
480 |
+
"move-up",
|
481 |
+
"move-down",
|
482 |
+
"move-left",
|
483 |
+
"move-right",
|
484 |
+
]
|
485 |
+
| None,
|
486 |
+
num_frames: int | None,
|
487 |
+
zoom_factor: float | None,
|
488 |
+
camera_scale: float,
|
489 |
+
):
|
490 |
+
render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
491 |
+
render_dir = osp.join(WORK_DIR, render_name)
|
492 |
+
|
493 |
+
input_imgs, input_Ks, input_c2ws, (W, H) = (
|
494 |
+
preprocessed["input_imgs"],
|
495 |
+
preprocessed["input_Ks"],
|
496 |
+
preprocessed["input_c2ws"],
|
497 |
+
preprocessed["input_wh"],
|
498 |
+
)
|
499 |
+
num_inputs = len(input_imgs)
|
500 |
+
if preset_traj is None:
|
501 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_gui(preprocessed)
|
502 |
+
else:
|
503 |
+
assert num_frames is not None
|
504 |
+
assert num_inputs == 1
|
505 |
+
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
|
506 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
|
507 |
+
preprocessed, preset_traj, num_frames, zoom_factor
|
508 |
+
)
|
509 |
+
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
|
510 |
+
all_Ks = (
|
511 |
+
torch.cat([input_Ks, target_Ks], 0)
|
512 |
+
* input_Ks.new_tensor([W, H, 1])[:, None]
|
513 |
+
)
|
514 |
+
num_targets = len(target_c2ws)
|
515 |
+
input_indices = list(range(num_inputs))
|
516 |
+
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
|
517 |
+
# Get anchor cameras.
|
518 |
+
T = VERSION_DICT["T"]
|
519 |
+
version_dict = copy.deepcopy(VERSION_DICT)
|
520 |
+
num_anchors = infer_prior_stats(
|
521 |
+
T,
|
522 |
+
num_inputs,
|
523 |
+
num_total_frames=num_targets,
|
524 |
+
version_dict=version_dict,
|
525 |
+
)
|
526 |
+
# infer_prior_stats modifies T in-place.
|
527 |
+
T = version_dict["T"]
|
528 |
+
assert isinstance(num_anchors, int)
|
529 |
+
anchor_indices = np.linspace(
|
530 |
+
num_inputs,
|
531 |
+
num_inputs + num_targets - 1,
|
532 |
+
num_anchors,
|
533 |
+
).tolist()
|
534 |
+
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
|
535 |
+
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
|
536 |
+
# Create image conditioning.
|
537 |
+
all_imgs_np = (
|
538 |
+
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
|
539 |
+
* 255.0
|
540 |
+
).astype(np.uint8)
|
541 |
+
image_cond = {
|
542 |
+
"img": all_imgs_np,
|
543 |
+
"input_indices": input_indices,
|
544 |
+
"prior_indices": anchor_indices,
|
545 |
+
}
|
546 |
+
# Create camera conditioning (K is unnormalized).
|
547 |
+
camera_cond = {
|
548 |
+
"c2w": all_c2ws,
|
549 |
+
"K": all_Ks,
|
550 |
+
"input_indices": list(range(num_inputs + num_targets)),
|
551 |
+
}
|
552 |
+
# Run rendering.
|
553 |
+
num_steps = 50
|
554 |
+
options_ori = VERSION_DICT["options"]
|
555 |
+
options = copy.deepcopy(options_ori)
|
556 |
+
options["chunk_strategy"] = chunk_strategy
|
557 |
+
options["video_save_fps"] = 30.0
|
558 |
+
options["beta_linear_start"] = 5e-6
|
559 |
+
options["log_snr_shift"] = 2.4
|
560 |
+
options["guider_types"] = [1, 2]
|
561 |
+
options["cfg"] = [
|
562 |
+
float(cfg),
|
563 |
+
3.0 if num_inputs >= 9 else 2.0,
|
564 |
+
] # We define semi-dense-view regime to have 9 input views.
|
565 |
+
options["camera_scale"] = camera_scale
|
566 |
+
options["num_steps"] = num_steps
|
567 |
+
options["cfg_min"] = 1.2
|
568 |
+
options["encoding_t"] = 1
|
569 |
+
options["decoding_t"] = 1
|
570 |
+
assert session_hash in ABORT_EVENTS
|
571 |
+
abort_event = ABORT_EVENTS[session_hash]
|
572 |
+
abort_event.clear()
|
573 |
+
options["abort_event"] = abort_event
|
574 |
+
task = "img2trajvid"
|
575 |
+
# Get number of first pass chunks.
|
576 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
577 |
+
chunk_strategy_first_pass = options.get(
|
578 |
+
"chunk_strategy_first_pass", "gt-nearest"
|
579 |
+
)
|
580 |
+
num_chunks_0 = len(
|
581 |
+
chunk_input_and_test(
|
582 |
+
T_first_pass,
|
583 |
+
input_c2ws,
|
584 |
+
anchor_c2ws,
|
585 |
+
input_indices,
|
586 |
+
image_cond["prior_indices"],
|
587 |
+
options={**options, "sampler_verbose": False},
|
588 |
+
task=task,
|
589 |
+
chunk_strategy=chunk_strategy_first_pass,
|
590 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
591 |
+
)[1]
|
592 |
+
)
|
593 |
+
# Get number of second pass chunks.
|
594 |
+
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
|
595 |
+
anchor_indices = np.array(input_indices + anchor_indices)[
|
596 |
+
anchor_argsort
|
597 |
+
].tolist()
|
598 |
+
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
|
599 |
+
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
|
600 |
+
anchor_argsort
|
601 |
+
]
|
602 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
603 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
604 |
+
num_chunks_1 = len(
|
605 |
+
chunk_input_and_test(
|
606 |
+
T_second_pass,
|
607 |
+
anchor_c2ws_second_pass,
|
608 |
+
target_c2ws,
|
609 |
+
anchor_indices,
|
610 |
+
target_indices,
|
611 |
+
options={**options, "sampler_verbose": False},
|
612 |
+
task=task,
|
613 |
+
chunk_strategy=chunk_strategy,
|
614 |
+
gt_input_inds=gt_input_inds,
|
615 |
+
)[1]
|
616 |
+
)
|
617 |
+
second_pass_pbar = gr.Progress().tqdm(
|
618 |
+
iterable=None,
|
619 |
+
desc="Second pass sampling",
|
620 |
+
total=num_chunks_1 * num_steps,
|
621 |
+
)
|
622 |
+
first_pass_pbar = gr.Progress().tqdm(
|
623 |
+
iterable=None,
|
624 |
+
desc="First pass sampling",
|
625 |
+
total=num_chunks_0 * num_steps,
|
626 |
+
)
|
627 |
+
video_path_generator = run_one_scene(
|
628 |
+
task=task,
|
629 |
+
version_dict={
|
630 |
+
"H": H,
|
631 |
+
"W": W,
|
632 |
+
"T": T,
|
633 |
+
"C": VERSION_DICT["C"],
|
634 |
+
"f": VERSION_DICT["f"],
|
635 |
+
"options": options,
|
636 |
+
},
|
637 |
+
model=MODEL,
|
638 |
+
ae=AE,
|
639 |
+
conditioner=CONDITIONER,
|
640 |
+
denoiser=DENOISER,
|
641 |
+
image_cond=image_cond,
|
642 |
+
camera_cond=camera_cond,
|
643 |
+
save_path=render_dir,
|
644 |
+
use_traj_prior=True,
|
645 |
+
traj_prior_c2ws=anchor_c2ws,
|
646 |
+
traj_prior_Ks=anchor_Ks,
|
647 |
+
seed=seed,
|
648 |
+
gradio=True,
|
649 |
+
first_pass_pbar=first_pass_pbar,
|
650 |
+
second_pass_pbar=second_pass_pbar,
|
651 |
+
abort_event=abort_event,
|
652 |
+
)
|
653 |
+
output_queue = queue.Queue()
|
654 |
+
|
655 |
+
blocks = LocalContext.blocks.get()
|
656 |
+
event_id = LocalContext.event_id.get()
|
657 |
+
|
658 |
+
def worker():
|
659 |
+
# gradio doesn't support threading with progress intentionally, so
|
660 |
+
# we need to hack this.
|
661 |
+
LocalContext.blocks.set(blocks)
|
662 |
+
LocalContext.event_id.set(event_id)
|
663 |
+
for i, video_path in enumerate(video_path_generator):
|
664 |
+
if i == 0:
|
665 |
+
output_queue.put(
|
666 |
+
(
|
667 |
+
video_path,
|
668 |
+
gr.update(),
|
669 |
+
gr.update(),
|
670 |
+
gr.update(),
|
671 |
+
)
|
672 |
+
)
|
673 |
+
elif i == 1:
|
674 |
+
output_queue.put(
|
675 |
+
(
|
676 |
+
video_path,
|
677 |
+
gr.update(visible=True),
|
678 |
+
gr.update(visible=False),
|
679 |
+
gr.update(visible=False),
|
680 |
+
)
|
681 |
+
)
|
682 |
+
else:
|
683 |
+
gr.Error("More than two passes during rendering.")
|
684 |
+
|
685 |
+
thread = threading.Thread(target=worker, daemon=True)
|
686 |
+
thread.start()
|
687 |
+
|
688 |
+
while thread.is_alive() or not output_queue.empty():
|
689 |
+
if abort_event.is_set():
|
690 |
+
thread.join()
|
691 |
+
abort_event.clear()
|
692 |
+
yield (
|
693 |
+
gr.update(),
|
694 |
+
gr.update(visible=True),
|
695 |
+
gr.update(visible=False),
|
696 |
+
gr.update(visible=False),
|
697 |
+
)
|
698 |
+
time.sleep(0.1)
|
699 |
+
while not output_queue.empty():
|
700 |
+
yield output_queue.get()
|
701 |
+
|
702 |
+
|
703 |
+
# This is basically a copy of the original `networking.setup_tunnel` function,
|
704 |
+
# but it also returns the tunnel object for proper cleanup.
|
705 |
+
def setup_tunnel(
|
706 |
+
local_host: str, local_port: int, share_token: str, share_server_address: str | None
|
707 |
+
) -> tuple[str, Tunnel]:
|
708 |
+
share_server_address = (
|
709 |
+
networking.GRADIO_SHARE_SERVER_ADDRESS
|
710 |
+
if share_server_address is None
|
711 |
+
else share_server_address
|
712 |
+
)
|
713 |
+
if share_server_address is None:
|
714 |
+
try:
|
715 |
+
response = httpx.get(networking.GRADIO_API_SERVER, timeout=30)
|
716 |
+
payload = response.json()[0]
|
717 |
+
remote_host, remote_port = payload["host"], int(payload["port"])
|
718 |
+
certificate = payload["root_ca"]
|
719 |
+
Path(CERTIFICATE_PATH).parent.mkdir(parents=True, exist_ok=True)
|
720 |
+
with open(CERTIFICATE_PATH, "w") as f:
|
721 |
+
f.write(certificate)
|
722 |
+
except Exception as e:
|
723 |
+
raise RuntimeError(
|
724 |
+
"Could not get share link from Gradio API Server."
|
725 |
+
) from e
|
726 |
+
else:
|
727 |
+
remote_host, remote_port = share_server_address.split(":")
|
728 |
+
remote_port = int(remote_port)
|
729 |
+
tunnel = Tunnel(remote_host, remote_port, local_host, local_port, share_token)
|
730 |
+
address = tunnel.start_tunnel()
|
731 |
+
return address, tunnel
|
732 |
+
|
733 |
+
|
734 |
+
def set_bkgd_color(server: viser.ViserServer | viser.ClientHandle):
|
735 |
+
server.scene.set_background_image(np.array([[[39, 39, 42]]], dtype=np.uint8))
|
736 |
+
|
737 |
+
|
738 |
+
def start_server_and_abort_event(request: gr.Request):
|
739 |
+
server = viser.ViserServer()
|
740 |
+
|
741 |
+
@server.on_client_connect
|
742 |
+
def _(client: viser.ClientHandle):
|
743 |
+
# Force dark mode that blends well with gradio's dark theme.
|
744 |
+
client.gui.configure_theme(
|
745 |
+
dark_mode=True,
|
746 |
+
show_share_button=False,
|
747 |
+
control_layout="collapsible",
|
748 |
+
)
|
749 |
+
set_bkgd_color(client)
|
750 |
+
|
751 |
+
print(f"Starting server {server.get_port()}")
|
752 |
+
server_url, tunnel = setup_tunnel(
|
753 |
+
local_host=server.get_host(),
|
754 |
+
local_port=server.get_port(),
|
755 |
+
share_token=secrets.token_urlsafe(32),
|
756 |
+
share_server_address=None,
|
757 |
+
)
|
758 |
+
SERVERS[request.session_hash] = (server, tunnel)
|
759 |
+
if server_url is None:
|
760 |
+
raise gr.Error(
|
761 |
+
"Failed to get a viewport URL. Please check your network connection."
|
762 |
+
)
|
763 |
+
# Give it enough time to start.
|
764 |
+
time.sleep(1)
|
765 |
+
|
766 |
+
ABORT_EVENTS[request.session_hash] = threading.Event()
|
767 |
+
|
768 |
+
return (
|
769 |
+
SevaRenderer(server),
|
770 |
+
gr.HTML(
|
771 |
+
f'<iframe src="{server_url}" style="display: block; margin: auto; width: 100%; height: min(60vh, 600px);" frameborder="0"></iframe>',
|
772 |
+
container=True,
|
773 |
+
),
|
774 |
+
request.session_hash,
|
775 |
+
)
|
776 |
+
|
777 |
+
|
778 |
+
def stop_server_and_abort_event(request: gr.Request):
|
779 |
+
if request.session_hash in SERVERS:
|
780 |
+
print(f"Stopping server {request.session_hash}")
|
781 |
+
server, tunnel = SERVERS.pop(request.session_hash)
|
782 |
+
server.stop()
|
783 |
+
tunnel.kill()
|
784 |
+
|
785 |
+
if request.session_hash in ABORT_EVENTS:
|
786 |
+
print(f"Setting abort event {request.session_hash}")
|
787 |
+
ABORT_EVENTS[request.session_hash].set()
|
788 |
+
# Give it enough time to abort jobs.
|
789 |
+
time.sleep(5)
|
790 |
+
ABORT_EVENTS.pop(request.session_hash)
|
791 |
+
|
792 |
+
|
793 |
+
def set_abort_event(request: gr.Request):
|
794 |
+
if request.session_hash in ABORT_EVENTS:
|
795 |
+
print(f"Setting abort event {request.session_hash}")
|
796 |
+
ABORT_EVENTS[request.session_hash].set()
|
797 |
+
|
798 |
+
|
799 |
+
def get_advance_examples(selection: gr.SelectData):
|
800 |
+
index = selection.index
|
801 |
+
return (
|
802 |
+
gr.Gallery(ADVANCE_EXAMPLE_MAP[index][1], visible=True),
|
803 |
+
gr.update(visible=True),
|
804 |
+
gr.update(visible=True),
|
805 |
+
gr.Gallery(visible=False),
|
806 |
+
)
|
807 |
+
|
808 |
+
|
809 |
+
def get_preamble():
|
810 |
+
gr.Markdown("""
|
811 |
+
# Stable Virtual Camera
|
812 |
+
<span style="display: flex; flex-wrap: wrap; gap: 5px;">
|
813 |
+
<a href="https://stable-virtual-camera.github.io"><img src="https://img.shields.io/badge/%F0%9F%8F%A0%20Project%20Page-gray.svg"></a>
|
814 |
+
<a href="http://arxiv.org/abs/2503.14489"><img src="https://img.shields.io/badge/%F0%9F%93%84%20arXiv-2503.14489-B31B1B.svg"></a>
|
815 |
+
<a href="https://stability.ai/news/introducing-stable-virtual-camera-multi-view-video-generation-with-3d-camera-control"><img src="https://img.shields.io/badge/%F0%9F%93%83%20Blog-Stability%20AI-orange.svg"></a>
|
816 |
+
<a href="https://huggingface.co/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%A4%97%20Model_Card-Huggingface-orange"></a>
|
817 |
+
<a href="https://huggingface.co/spaces/stabilityai/stable-virtual-camera"><img src="https://img.shields.io/badge/%F0%9F%9A%80%20Gradio%20Demo-Huggingface-orange"></a>
|
818 |
+
<a href="https://www.youtube.com/channel/UCLLlVDcS7nNenT_zzO3OPxQ"><img src="https://img.shields.io/badge/%F0%9F%8E%AC%20Video-YouTube-orange"></a>
|
819 |
+
</span>
|
820 |
+
|
821 |
+
Welcome to the demo of <strong>Stable Virtual Camera (Seva)</strong>! Given any number of input views and their cameras, this demo will allow you to generate novel views of a scene at any target camera of interest.
|
822 |
+
|
823 |
+
We provide two ways to use our demo (selected by the tab below, documented [here](https://github.com/Stability-AI/stable-virtual-camera/blob/main/docs/GR_USAGE.md)):
|
824 |
+
1. **[Basic](https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705)**: Given a single image, you can generate a video following one of our preset camera trajectories.
|
825 |
+
2. **[Advanced](https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba)**: Given any number of input images, you can generate a video following any camera trajectory of your choice by our key-frame-based interface.
|
826 |
+
|
827 |
+
> This is a research preview and comes with a few [limitations](https://stable-virtual-camera.github.io/#limitations):
|
828 |
+
> - Limited quality in certain subjects due to training data, including humans, animals, and dynamic textures.
|
829 |
+
> - Limited quality in some highly ambiguous scenes and camera trajectories, including extreme views and collision into objects.
|
830 |
+
""")
|
831 |
+
|
832 |
+
|
833 |
+
# Make sure that gradio uses dark theme.
|
834 |
+
_APP_JS = """
|
835 |
+
function refresh() {
|
836 |
+
const url = new URL(window.location);
|
837 |
+
if (url.searchParams.get('__theme') !== 'dark') {
|
838 |
+
url.searchParams.set('__theme', 'dark');
|
839 |
+
}
|
840 |
+
}
|
841 |
+
"""
|
842 |
+
|
843 |
+
|
844 |
+
def main(server_port: int | None = None, share: bool = True):
|
845 |
+
with gr.Blocks(js=_APP_JS) as app:
|
846 |
+
renderer = gr.State()
|
847 |
+
session_hash = gr.State()
|
848 |
+
_ = get_preamble()
|
849 |
+
with gr.Tabs():
|
850 |
+
with gr.Tab("Basic"):
|
851 |
+
render_btn = gr.Button("Render video", interactive=False, render=False)
|
852 |
+
with gr.Row():
|
853 |
+
with gr.Column():
|
854 |
+
with gr.Group():
|
855 |
+
# Initially disable the Preprocess Images button until an image is selected.
|
856 |
+
preprocess_btn = gr.Button("Preprocess images", interactive=False)
|
857 |
+
preprocess_progress = gr.Textbox(
|
858 |
+
label="",
|
859 |
+
visible=False,
|
860 |
+
interactive=False,
|
861 |
+
)
|
862 |
+
with gr.Group():
|
863 |
+
input_imgs = gr.Image(
|
864 |
+
type="filepath",
|
865 |
+
label="Input",
|
866 |
+
height=200,
|
867 |
+
)
|
868 |
+
_ = gr.Examples(
|
869 |
+
examples=sorted(glob("assets/basic/*")),
|
870 |
+
inputs=[input_imgs],
|
871 |
+
label="Example",
|
872 |
+
)
|
873 |
+
chunk_strategy = gr.Dropdown(
|
874 |
+
["interp", "interp-gt"],
|
875 |
+
label="Chunk strategy",
|
876 |
+
render=False,
|
877 |
+
)
|
878 |
+
preprocessed = gr.State()
|
879 |
+
# Enable the Preprocess Images button only if an image is selected.
|
880 |
+
input_imgs.change(
|
881 |
+
lambda img: gr.update(interactive=bool(img)),
|
882 |
+
inputs=input_imgs,
|
883 |
+
outputs=preprocess_btn,
|
884 |
+
)
|
885 |
+
preprocess_btn.click(
|
886 |
+
lambda r, *args: [
|
887 |
+
*r.preprocess(*args),
|
888 |
+
gr.update(interactive=True),
|
889 |
+
],
|
890 |
+
inputs=[renderer, input_imgs],
|
891 |
+
outputs=[
|
892 |
+
preprocessed,
|
893 |
+
preprocess_progress,
|
894 |
+
chunk_strategy,
|
895 |
+
render_btn,
|
896 |
+
],
|
897 |
+
show_progress_on=[preprocess_progress],
|
898 |
+
concurrency_limit=1,
|
899 |
+
concurrency_id="gpu_queue",
|
900 |
+
)
|
901 |
+
preprocess_btn.click(
|
902 |
+
lambda: gr.update(visible=True),
|
903 |
+
outputs=[preprocess_progress],
|
904 |
+
)
|
905 |
+
with gr.Row():
|
906 |
+
preset_traj = gr.Dropdown(
|
907 |
+
choices=[
|
908 |
+
"orbit",
|
909 |
+
"spiral",
|
910 |
+
"lemniscate",
|
911 |
+
"zoom-in",
|
912 |
+
"zoom-out",
|
913 |
+
"dolly zoom-in",
|
914 |
+
"dolly zoom-out",
|
915 |
+
"move-forward",
|
916 |
+
"move-backward",
|
917 |
+
"move-up",
|
918 |
+
"move-down",
|
919 |
+
"move-left",
|
920 |
+
"move-right",
|
921 |
+
],
|
922 |
+
label="Preset trajectory",
|
923 |
+
value="orbit",
|
924 |
+
)
|
925 |
+
num_frames = gr.Slider(30, 150, 80, label="#Frames")
|
926 |
+
zoom_factor = gr.Slider(
|
927 |
+
step=0.01, label="Zoom factor", visible=False
|
928 |
+
)
|
929 |
+
with gr.Row():
|
930 |
+
seed = gr.Number(value=23, label="Random seed")
|
931 |
+
chunk_strategy.render()
|
932 |
+
cfg = gr.Slider(1.0, 7.0, value=4.0, label="CFG value")
|
933 |
+
with gr.Row():
|
934 |
+
camera_scale = gr.Slider(
|
935 |
+
0.1,
|
936 |
+
15.0,
|
937 |
+
value=2.0,
|
938 |
+
label="Camera scale",
|
939 |
+
)
|
940 |
+
|
941 |
+
def default_cfg_preset_traj(traj):
|
942 |
+
# These are just some hand-tuned values that we
|
943 |
+
# found work the best.
|
944 |
+
if traj in ["zoom-out", "move-down"]:
|
945 |
+
value = 5.0
|
946 |
+
elif traj in [
|
947 |
+
"orbit",
|
948 |
+
"dolly zoom-out",
|
949 |
+
"move-backward",
|
950 |
+
"move-up",
|
951 |
+
"move-left",
|
952 |
+
"move-right",
|
953 |
+
]:
|
954 |
+
value = 4.0
|
955 |
+
else:
|
956 |
+
value = 3.0
|
957 |
+
return value
|
958 |
+
|
959 |
+
preset_traj.change(
|
960 |
+
default_cfg_preset_traj,
|
961 |
+
inputs=[preset_traj],
|
962 |
+
outputs=[cfg],
|
963 |
+
)
|
964 |
+
preset_traj.change(
|
965 |
+
lambda traj: gr.update(
|
966 |
+
value=(
|
967 |
+
10.0 if "dolly" in traj or "pan" in traj else 2.0
|
968 |
+
)
|
969 |
+
),
|
970 |
+
inputs=[preset_traj],
|
971 |
+
outputs=[camera_scale],
|
972 |
+
)
|
973 |
+
|
974 |
+
def zoom_factor_preset_traj(traj):
|
975 |
+
visible = traj in [
|
976 |
+
"zoom-in",
|
977 |
+
"zoom-out",
|
978 |
+
"dolly zoom-in",
|
979 |
+
"dolly zoom-out",
|
980 |
+
]
|
981 |
+
is_zoomin = traj.endswith("zoom-in")
|
982 |
+
if is_zoomin:
|
983 |
+
minimum = 0.1
|
984 |
+
maximum = 0.5
|
985 |
+
value = 0.28
|
986 |
+
else:
|
987 |
+
minimum = 1.2
|
988 |
+
maximum = 3
|
989 |
+
value = 1.5
|
990 |
+
return gr.update(
|
991 |
+
visible=visible,
|
992 |
+
minimum=minimum,
|
993 |
+
maximum=maximum,
|
994 |
+
value=value,
|
995 |
+
)
|
996 |
+
|
997 |
+
preset_traj.change(
|
998 |
+
zoom_factor_preset_traj,
|
999 |
+
inputs=[preset_traj],
|
1000 |
+
outputs=[zoom_factor],
|
1001 |
+
)
|
1002 |
+
with gr.Column():
|
1003 |
+
with gr.Group():
|
1004 |
+
abort_btn = gr.Button("Abort rendering", visible=False)
|
1005 |
+
render_btn.render()
|
1006 |
+
render_progress = gr.Textbox(
|
1007 |
+
label="", visible=False, interactive=False
|
1008 |
+
)
|
1009 |
+
output_video = gr.Video(
|
1010 |
+
label="Output", interactive=False, autoplay=True, loop=True
|
1011 |
+
)
|
1012 |
+
render_btn.click(
|
1013 |
+
lambda r, *args: (yield from r.render(*args)),
|
1014 |
+
inputs=[
|
1015 |
+
renderer,
|
1016 |
+
preprocessed,
|
1017 |
+
session_hash,
|
1018 |
+
seed,
|
1019 |
+
chunk_strategy,
|
1020 |
+
cfg,
|
1021 |
+
preset_traj,
|
1022 |
+
num_frames,
|
1023 |
+
zoom_factor,
|
1024 |
+
camera_scale,
|
1025 |
+
],
|
1026 |
+
outputs=[
|
1027 |
+
output_video,
|
1028 |
+
render_btn,
|
1029 |
+
abort_btn,
|
1030 |
+
render_progress,
|
1031 |
+
],
|
1032 |
+
show_progress_on=[render_progress],
|
1033 |
+
concurrency_id="gpu_queue",
|
1034 |
+
)
|
1035 |
+
render_btn.click(
|
1036 |
+
lambda: [
|
1037 |
+
gr.update(visible=False),
|
1038 |
+
gr.update(visible=True),
|
1039 |
+
gr.update(visible=True),
|
1040 |
+
],
|
1041 |
+
outputs=[render_btn, abort_btn, render_progress],
|
1042 |
+
)
|
1043 |
+
abort_btn.click(set_abort_event)
|
1044 |
+
with gr.Tab("Advanced"):
|
1045 |
+
render_btn = gr.Button("Render video", interactive=False, render=False)
|
1046 |
+
viewport = gr.HTML(container=True, render=False)
|
1047 |
+
gr.Timer(0.1).tick(
|
1048 |
+
lambda renderer: gr.update(
|
1049 |
+
interactive=renderer is not None
|
1050 |
+
and renderer.gui_state is not None
|
1051 |
+
and renderer.gui_state.camera_traj_list is not None
|
1052 |
+
),
|
1053 |
+
inputs=[renderer],
|
1054 |
+
outputs=[render_btn],
|
1055 |
+
)
|
1056 |
+
with gr.Row():
|
1057 |
+
viewport.render()
|
1058 |
+
with gr.Row():
|
1059 |
+
with gr.Column():
|
1060 |
+
with gr.Group():
|
1061 |
+
# Initially disable the Preprocess Images button until images are selected.
|
1062 |
+
preprocess_btn = gr.Button("Preprocess images", interactive=False)
|
1063 |
+
preprocess_progress = gr.Textbox(
|
1064 |
+
label="",
|
1065 |
+
visible=False,
|
1066 |
+
interactive=False,
|
1067 |
+
)
|
1068 |
+
with gr.Group():
|
1069 |
+
input_imgs = gr.Gallery(
|
1070 |
+
interactive=True,
|
1071 |
+
label="Input",
|
1072 |
+
columns=4,
|
1073 |
+
height=200,
|
1074 |
+
)
|
1075 |
+
# Define example images (gradio doesn't support variable length
|
1076 |
+
# examples so we need to hack it).
|
1077 |
+
example_imgs = gr.Gallery(
|
1078 |
+
[e[0] for e in ADVANCE_EXAMPLE_MAP],
|
1079 |
+
allow_preview=False,
|
1080 |
+
preview=False,
|
1081 |
+
label="Example",
|
1082 |
+
columns=20,
|
1083 |
+
rows=1,
|
1084 |
+
height=115,
|
1085 |
+
)
|
1086 |
+
example_imgs_expander = gr.Gallery(
|
1087 |
+
visible=False,
|
1088 |
+
interactive=False,
|
1089 |
+
label="Example",
|
1090 |
+
preview=True,
|
1091 |
+
columns=20,
|
1092 |
+
rows=1,
|
1093 |
+
)
|
1094 |
+
chunk_strategy = gr.Dropdown(
|
1095 |
+
["interp-gt", "interp"],
|
1096 |
+
label="Chunk strategy",
|
1097 |
+
value="interp-gt",
|
1098 |
+
render=False,
|
1099 |
+
)
|
1100 |
+
with gr.Row():
|
1101 |
+
example_imgs_backer = gr.Button(
|
1102 |
+
"Go back", visible=False
|
1103 |
+
)
|
1104 |
+
example_imgs_confirmer = gr.Button(
|
1105 |
+
"Confirm", visible=False
|
1106 |
+
)
|
1107 |
+
example_imgs.select(
|
1108 |
+
get_advance_examples,
|
1109 |
+
outputs=[
|
1110 |
+
example_imgs_expander,
|
1111 |
+
example_imgs_confirmer,
|
1112 |
+
example_imgs_backer,
|
1113 |
+
example_imgs,
|
1114 |
+
],
|
1115 |
+
)
|
1116 |
+
example_imgs_confirmer.click(
|
1117 |
+
lambda x: (
|
1118 |
+
x,
|
1119 |
+
gr.update(visible=False),
|
1120 |
+
gr.update(visible=False),
|
1121 |
+
gr.update(visible=False),
|
1122 |
+
gr.update(visible=True),
|
1123 |
+
gr.update(interactive=bool(x))
|
1124 |
+
),
|
1125 |
+
inputs=[example_imgs_expander],
|
1126 |
+
outputs=[
|
1127 |
+
input_imgs,
|
1128 |
+
example_imgs_expander,
|
1129 |
+
example_imgs_confirmer,
|
1130 |
+
example_imgs_backer,
|
1131 |
+
example_imgs,
|
1132 |
+
preprocess_btn
|
1133 |
+
],
|
1134 |
+
)
|
1135 |
+
example_imgs_backer.click(
|
1136 |
+
lambda: (
|
1137 |
+
gr.update(visible=False),
|
1138 |
+
gr.update(visible=False),
|
1139 |
+
gr.update(visible=False),
|
1140 |
+
gr.update(visible=True),
|
1141 |
+
),
|
1142 |
+
outputs=[
|
1143 |
+
example_imgs_expander,
|
1144 |
+
example_imgs_confirmer,
|
1145 |
+
example_imgs_backer,
|
1146 |
+
example_imgs,
|
1147 |
+
],
|
1148 |
+
)
|
1149 |
+
preprocessed = gr.State()
|
1150 |
+
preprocess_btn.click(
|
1151 |
+
lambda r, *args: r.preprocess(*args),
|
1152 |
+
inputs=[renderer, input_imgs],
|
1153 |
+
outputs=[
|
1154 |
+
preprocessed,
|
1155 |
+
preprocess_progress,
|
1156 |
+
chunk_strategy,
|
1157 |
+
],
|
1158 |
+
show_progress_on=[preprocess_progress],
|
1159 |
+
concurrency_id="gpu_queue",
|
1160 |
+
)
|
1161 |
+
preprocess_btn.click(
|
1162 |
+
lambda: gr.update(visible=True),
|
1163 |
+
outputs=[preprocess_progress],
|
1164 |
+
)
|
1165 |
+
preprocessed.change(
|
1166 |
+
lambda r, *args: r.visualize_scene(*args),
|
1167 |
+
inputs=[renderer, preprocessed],
|
1168 |
+
)
|
1169 |
+
with gr.Row():
|
1170 |
+
seed = gr.Number(value=23, label="Random seed")
|
1171 |
+
chunk_strategy.render()
|
1172 |
+
cfg = gr.Slider(1.0, 7.0, value=3.0, label="CFG value")
|
1173 |
+
with gr.Row():
|
1174 |
+
camera_scale = gr.Slider(
|
1175 |
+
0.1,
|
1176 |
+
15.0,
|
1177 |
+
value=2.0,
|
1178 |
+
label="Camera scale (useful for single-view input)",
|
1179 |
+
)
|
1180 |
+
with gr.Group():
|
1181 |
+
output_data_dir = gr.Textbox(label="Output data directory")
|
1182 |
+
output_data_btn = gr.Button("Export output data")
|
1183 |
+
output_data_btn.click(
|
1184 |
+
lambda r, *args: r.export_output_data(*args),
|
1185 |
+
inputs=[renderer, preprocessed, output_data_dir],
|
1186 |
+
)
|
1187 |
+
with gr.Column():
|
1188 |
+
with gr.Group():
|
1189 |
+
abort_btn = gr.Button("Abort rendering", visible=False)
|
1190 |
+
render_btn.render()
|
1191 |
+
render_progress = gr.Textbox(
|
1192 |
+
label="", visible=False, interactive=False
|
1193 |
+
)
|
1194 |
+
output_video = gr.Video(
|
1195 |
+
label="Output", interactive=False, autoplay=True, loop=True
|
1196 |
+
)
|
1197 |
+
render_btn.click(
|
1198 |
+
lambda r, *args: (yield from r.render(*args)),
|
1199 |
+
inputs=[
|
1200 |
+
renderer,
|
1201 |
+
preprocessed,
|
1202 |
+
session_hash,
|
1203 |
+
seed,
|
1204 |
+
chunk_strategy,
|
1205 |
+
cfg,
|
1206 |
+
gr.State(),
|
1207 |
+
gr.State(),
|
1208 |
+
gr.State(),
|
1209 |
+
camera_scale,
|
1210 |
+
],
|
1211 |
+
outputs=[
|
1212 |
+
output_video,
|
1213 |
+
render_btn,
|
1214 |
+
abort_btn,
|
1215 |
+
render_progress,
|
1216 |
+
],
|
1217 |
+
show_progress_on=[render_progress],
|
1218 |
+
concurrency_id="gpu_queue",
|
1219 |
+
)
|
1220 |
+
render_btn.click(
|
1221 |
+
lambda: [
|
1222 |
+
gr.update(visible=False),
|
1223 |
+
gr.update(visible=True),
|
1224 |
+
gr.update(visible=True),
|
1225 |
+
],
|
1226 |
+
outputs=[render_btn, abort_btn, render_progress],
|
1227 |
+
)
|
1228 |
+
abort_btn.click(set_abort_event)
|
1229 |
+
|
1230 |
+
# Register the session initialization and cleanup functions.
|
1231 |
+
app.load(
|
1232 |
+
start_server_and_abort_event,
|
1233 |
+
outputs=[renderer, viewport, session_hash],
|
1234 |
+
)
|
1235 |
+
app.unload(stop_server_and_abort_event)
|
1236 |
+
|
1237 |
+
app.queue(max_size=5).launch(
|
1238 |
+
share=share,
|
1239 |
+
server_port=server_port,
|
1240 |
+
show_error=True,
|
1241 |
+
allowed_paths=[WORK_DIR],
|
1242 |
+
# Badget rendering will be broken otherwise.
|
1243 |
+
ssr_mode=False,
|
1244 |
+
)
|
1245 |
+
|
1246 |
+
|
1247 |
+
if __name__ == "__main__":
|
1248 |
+
tyro.cli(main)
|
docs/CLI_USAGE.md
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# :computer: CLI Demo
|
2 |
+
|
3 |
+
This cli demo allows you to pass in more options and control the model in a fine-grained way, suitable for power users and academic researchers. An examplar command line looks as simple as
|
4 |
+
|
5 |
+
```bash
|
6 |
+
python demo.py --data_path <data_path> [additional arguments]
|
7 |
+
```
|
8 |
+
|
9 |
+
We discuss here first some key attributes:
|
10 |
+
|
11 |
+
- `Procedural Two-Pass Sampling`: We recommend enabling procedural sampling by setting `--use_traj_prior True --chunk_strategy <chunk_strategy>` with `<chunk_strategy>` set according to the type of the task.
|
12 |
+
- `Resolution and Aspect-Ratio`: Default image preprocessing include center cropping. All input and output are square images of size $576\times 576$. To overwrite, the code support to pass in `--W <W> --H <H>` directly. We recommend passing in `--L_short 576` such that the aspect-ratio of original image is kept while the shortest side will be resized to $576$.
|
13 |
+
|
14 |
+
## Task
|
15 |
+
|
16 |
+
Before diving into the command lines, we introduce `Task` (specified by `--task <task>`) to bucket different usage cases depending on the data constraints in input and output domains (e.g., if the ordering is available).
|
17 |
+
|
18 |
+
| Task | Type of NVS | Format of `<data_path>` | Target Views Sorted? | Input and Target Views Sorted? | Recommended Usage |
|
19 |
+
| :------------------: | :------------: | :--------------------------------------: | :------------------: | :----------------------------: | :----------------------: |
|
20 |
+
| `img2img` | set NVS | folder (parsable by `ReconfusionParser`) | :x: | :x: | evaluation, benchmarking |
|
21 |
+
| `img2vid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :white_check_mark: | evaluation, benchmarking |
|
22 |
+
| `img2trajvid_s-prob` | trajectory NVS | single image | :white_check_mark: | :white_check_mark: | general |
|
23 |
+
| `img2trajvid` | trajectory NVS | folder (parsable by `ReconfusionParser`) | :white_check_mark: | :x: | general |
|
24 |
+
|
25 |
+
### Format of `<data_path>`
|
26 |
+
|
27 |
+
For `img2trajvid_s-prob` task, we are generating a trajectory video following preset camera motions or effects given only one input image, the data format as simple as
|
28 |
+
|
29 |
+
```bash
|
30 |
+
<data_path>/
|
31 |
+
├── scene_1.png
|
32 |
+
├── scene_2.png
|
33 |
+
└── scene_3.png
|
34 |
+
```
|
35 |
+
|
36 |
+
For all the other tasks, we use a folder for each scene that is parsable by `ReconfusionParser` (see `seva/data_io.py`). It contains (1) a subdirectory containing all views; (2) `transforms.json` defining the intrinsics and extrinsics (OpenGL convention) for each image; and (3) `train_test_split_*.json` file splitting the input and target views, with `*` indicating the number of the input views.
|
37 |
+
|
38 |
+
We provide <a href="https://github.com/Stability-AI/stable-virtual-camera/releases/tag/assets_demo_cli">in this release</a> (`assets_demo_cli.zip`) several examplar scenes for you to take reference from. Target views is available if you the data are from academic sources, but in the case where target views is unavailble, we will create dummy black images as placeholders (e.g., the `garden_flythrough` scene). The general data structure follows
|
39 |
+
|
40 |
+
```bash
|
41 |
+
<data_path>/
|
42 |
+
├── scene_1/
|
43 |
+
├── train_test_split_1.json # for single-view regime
|
44 |
+
├── train_test_split_6.json # for sparse-veiw regime
|
45 |
+
├── train_test_split_32.json # for semi-dense-view regime
|
46 |
+
├── transforms.json
|
47 |
+
└── images/
|
48 |
+
├── image_0.png
|
49 |
+
├── image_1.png
|
50 |
+
├── ...
|
51 |
+
└── image_1000.png
|
52 |
+
├── scene_2
|
53 |
+
└── scene_3
|
54 |
+
```
|
55 |
+
|
56 |
+
You can specify which scene to run by passing in `--data_items scene_1,scene_2` to run, for example, `scene_1` and `scene_2`.
|
57 |
+
|
58 |
+
### Recommended Usage
|
59 |
+
|
60 |
+
- `img2img` and `img2vid` are recommended to be used for evaluation and benchmarking. These two tasks are used for the quantitative evalution in our <a href="http://arxiv.org/abs/2503.14489">paper</a>. The data is converted from academic datasets so the groundtruth target views are available for metric computation. Check the [`benchmark`](../benchmark/) folder for detailed splits we organize to benchmark different NVS models.
|
61 |
+
- `img2vid` requries both the input and target views to be sorted, which is usually not guaranteed in general usage.
|
62 |
+
- `img2trajvid_s-prob` is for general usage but only for single-view regime and fixed preset camera control.
|
63 |
+
- `img2trajvid` is the task designed for general usage since it does not need the ordering of the input views. This is the task used in the gradio demo.
|
64 |
+
|
65 |
+
Next we go over all tasks and provide for each task an examplar command line.
|
66 |
+
|
67 |
+
## `img2img`
|
68 |
+
|
69 |
+
```bash
|
70 |
+
python demo.py \
|
71 |
+
--data_path <data_path> \
|
72 |
+
--num_inputs <P> \
|
73 |
+
--video_save_fps 10
|
74 |
+
```
|
75 |
+
|
76 |
+
- `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
|
77 |
+
- The above command works for the dataset without trajectory prior (e.g., DL3DV-140). When the trajectory prior is available given a benchmarking dataset, for example, `orbit` trajectory prior for the CO3D dataset, we use the `nearest-gt` chunking strategy by setting `--use_traj_prior True --traj_prior orbit --chunking_strategy nearest-gt`. We find this leads to more 3D consistent results.
|
78 |
+
- For all the single-view conditioning test scenarios: we set `--camera_scale <camera_scale>` with `<camera_scale>` sweeping 20 different camera scales `0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0 1.1 1.2 1.3 1.4 1.5 1.6 1.7 1.8 1.9 2.0`.
|
79 |
+
- In single-view regime for the RealEstate10K dataset, we find increasing `cfg` is helpful: we additionally set `--cfg 6.0` (`cfg` is `2.0` by default).
|
80 |
+
- For the evaluation in semi-dense-view regime (i.e., DL3DV-140 and Tanks and Temples dataset) with `32` input views, we zero-shot extend `T` to fit all input and target views in one forward. Specifically, we set `--T 90` for the DL3DV-140 dataset and `--T 80` for the Tanks and Temples dataset.
|
81 |
+
- For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
|
82 |
+
|
83 |
+
For example, you can run the following command on the example `dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557` with 3 input views:
|
84 |
+
|
85 |
+
```bash
|
86 |
+
python demo.py \
|
87 |
+
--data_path /path/to/assets_demo_cli/ \
|
88 |
+
--data_items dl3d140-165f5af8bfe32f70595a1c9393a6e442acf7af019998275144f605b89a306557 \
|
89 |
+
--num_inputs 3 \
|
90 |
+
--video_save_fps 10
|
91 |
+
```
|
92 |
+
|
93 |
+
## `img2vid`
|
94 |
+
|
95 |
+
```bash
|
96 |
+
python demo.py \
|
97 |
+
--data_path <data_path> \
|
98 |
+
--task img2vid \
|
99 |
+
--replace_or_include_input True \
|
100 |
+
--num_inputs <P> \
|
101 |
+
--use_traj_prior True \
|
102 |
+
--chunk_strategy interp \
|
103 |
+
```
|
104 |
+
|
105 |
+
- `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
|
106 |
+
- `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
|
107 |
+
- We use `interp` chunking strategy by default.
|
108 |
+
- For the evaluation on ViewCrafter split (including the RealEastate10K, CO3D, and Tanks and Temples dataset), we find zero-shot extending `T` to `25` to fit all input and target views in one forward is better. Also, the V split uses the original image resolutions: we therefore set `--T 25 --L_short 576`.
|
109 |
+
|
110 |
+
## `img2trajvid_s-prob`
|
111 |
+
|
112 |
+
```bash
|
113 |
+
python demo.py \
|
114 |
+
--data_path <data_path> \
|
115 |
+
--task img2trajvid_s-prob \
|
116 |
+
--replace_or_include_input True \
|
117 |
+
--traj_prior orbit \
|
118 |
+
--cfg 4.0,2.0 \
|
119 |
+
--guider 1,2 \
|
120 |
+
--num_targets 111 \
|
121 |
+
--L_short 576 \
|
122 |
+
--use_traj_prior True \
|
123 |
+
--chunk_strategy interp
|
124 |
+
```
|
125 |
+
|
126 |
+
- `--replace_or_include_input True` is necessary here since input views and target views are mutually exclusive, forming a trajectory together in this task, so we need to append back the input views to the generated target views.
|
127 |
+
- Default `cfg` should be adusted according to `traj_prior`.
|
128 |
+
- Default chunking strategy is `interp`.
|
129 |
+
- Default guider is `--guider 1,2` (instead of `1`, `1` still works but `1,2` is slightly better).
|
130 |
+
- `camera_scale` (default is `2.0`) can be adjusted according to `traj_prior`. The model has scale ambiguity with single-view input, especially for panning motions. We encourage to tune up `camera_scale` to `10.0` for all panning motions (`--traj_prior pan-*/dolly*`) if you expect a larger camera motion.
|
131 |
+
|
132 |
+
## `img2trajvid`
|
133 |
+
|
134 |
+
### Sparse-view regime ($P\leq 8$)
|
135 |
+
|
136 |
+
```bash
|
137 |
+
python demo.py \
|
138 |
+
--data_path <data_path> \
|
139 |
+
--task img2trajvid \
|
140 |
+
--num_inputs <P> \
|
141 |
+
--cfg 3.0,2.0 \
|
142 |
+
--use_traj_prior True \
|
143 |
+
--chunk_strategy interp-gt
|
144 |
+
```
|
145 |
+
|
146 |
+
- `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
|
147 |
+
- Default `cfg` should be set to `3,2` (`3` being `cfg` for the first pass, and `2` being the `cfg` for the second pass). Try to increase the `cfg` for the first pass from `3` to higher values if you observe blurry areas (usually happens for harder scenes with a fair amount of unseen regions).
|
148 |
+
- Default chunking strategy should be set to `interp+gt` (instead of `interp`, `interp` can work but usually a bit worse).
|
149 |
+
- The `--chunk_strategy_first_pass` is set as `gt-nearest` by default. So it can automatically adapt when $P$ is large (up to a thousand frames).
|
150 |
+
|
151 |
+
### Semi-dense-view regime ($P>9$)
|
152 |
+
|
153 |
+
```bash
|
154 |
+
python demo.py \
|
155 |
+
--data_path <data_path> \
|
156 |
+
--task img2trajvid \
|
157 |
+
--num_inputs <P> \
|
158 |
+
--cfg 3.0 \
|
159 |
+
--L_short 576 \
|
160 |
+
--use_traj_prior True \
|
161 |
+
--chunk_strategy interp
|
162 |
+
```
|
163 |
+
|
164 |
+
- `--num_inputs <P>` is only necessary if there are multiple `train_test_split_*.json` files in the scene folder.
|
165 |
+
- Default `cfg` should be set to `3`.
|
166 |
+
- Default chunking strategy should be set to `interp` (instead of `interp-gt`, `interp-gt` is also supported but the results do not look good).
|
167 |
+
- `T` can be overwritten by `--T <N>,21` (X being extended `T` for the first pass, and `21` being the default `T` for the second pass). `<N>` is dynamically decided now in the code but can also be manually updated. This is useful when you observe that there exist two very dissimilar adjacent anchors which make the interpolation in the second pass impossible. There exist two ways:
|
168 |
+
- `--T 96,21`: this overwrites the `T` in the first pass to be exactly `96`.
|
169 |
+
- `--num_prior_frames_ratio 1.2`: this enlarges T in the first pass dynamically to be `1.2`$\times$ larger.
|
docs/GR_USAGE.md
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# :rocket: Gradio Demo
|
2 |
+
|
3 |
+
This gradio demo is the simplest starting point for you play with our project.
|
4 |
+
|
5 |
+
You can either visit it at our huggingface space [here](https://huggingface.co/spaces/stabilityai/stable-virtual-camera) or run it locally yourself by
|
6 |
+
|
7 |
+
```bash
|
8 |
+
python demo_gr.py
|
9 |
+
```
|
10 |
+
|
11 |
+
We provide two ways to use our demo:
|
12 |
+
|
13 |
+
1. `Basic` mode, where user can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
|
14 |
+
2. `Advanced` mode, where user can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport (powered by [viser](https://viser.studio/latest)). This is suitable for power users and academic researchers.
|
15 |
+
|
16 |
+
### `Basic`
|
17 |
+
|
18 |
+
This is the default mode when entering our demo (given its simplicity).
|
19 |
+
|
20 |
+
User can upload a single image, and set a target camera trajectory from our preset options. This is the most straightforward way to use our model, and is suitable for most users.
|
21 |
+
|
22 |
+
Here is a video walkthrough:
|
23 |
+
|
24 |
+
https://github.com/user-attachments/assets/4d965fa6-d8eb-452c-b773-6e09c88ca705
|
25 |
+
|
26 |
+
You can choose from 13 preset trajectories that are common for NVS (`move-forward/backward` are omitted for visualization purpose):
|
27 |
+
|
28 |
+
https://github.com/user-attachments/assets/b2cf8700-3d85-44b9-8d52-248e82f1fb55
|
29 |
+
|
30 |
+
More formally:
|
31 |
+
|
32 |
+
- `orbit/spiral/lemniscate` are good for showing the "3D-ness" of the scene.
|
33 |
+
- `zoom-in/out` keep the camera position the same while increasing/decreasing the focal length.
|
34 |
+
- `dolly zoom-in/out` move camera position backward/forward while increasing/decreasing the focal length.
|
35 |
+
- `move-forward/backward/up/down/left/right` move camera position in different directions.
|
36 |
+
|
37 |
+
Notes:
|
38 |
+
|
39 |
+
- For a 80 frame video at `786x576` resolution, it takes around 20 seconds for the first pass generation, and around 2 minutes for the second pass generation, tested with a single H100 GPU.
|
40 |
+
- Please expect around ~2-3x more times on HF space.
|
41 |
+
|
42 |
+
### `Advanced`
|
43 |
+
|
44 |
+
This is the power mode where you can have very fine-grained control over camera trajectories.
|
45 |
+
|
46 |
+
User can upload one or multiple images, and set a target camera trajectory by interacting with a 3D viewport. This is suitable for power users and academic researchers.
|
47 |
+
|
48 |
+
Here is a video walkthrough
|
49 |
+
|
50 |
+
https://github.com/user-attachments/assets/dcec1be0-bd10-441e-879c-d1c2b63091ba
|
51 |
+
|
52 |
+
Notes:
|
53 |
+
|
54 |
+
- For a 134 frame video at `576x576` resolution, it takes around 16 seconds for the first pass generation, and around 4 minutes for the second pass generation, tested with a single H100 GPU.
|
55 |
+
- Please expect around ~2-3x more times on HF space.
|
56 |
+
|
57 |
+
### Pro tips
|
58 |
+
|
59 |
+
- If the first pass sampling result is bad, click "Abort rendering" button in GUI to avoid stucking at second pass sampling such that you can try something else.
|
60 |
+
|
61 |
+
### Performance benchmark
|
62 |
+
|
63 |
+
We have tested our gradio demo in both a local environment and the HF space environment, across different modes and compilation settings. Here are our results:
|
64 |
+
| Total time (s) | `Basic` first pass | `Basic` second pass | `Advanced` first pass | `Advanced` second pass |
|
65 |
+
|:------------------------:|:-----------------:|:------------------:|:--------------------:|:---------------------:|
|
66 |
+
| HF (L40S, w/o comp.) | 68 | 484 | 48 | 780 |
|
67 |
+
| HF (L40S, w/ comp.) | 51 | 362 | 36 | 587 |
|
68 |
+
| Local (H100, w/o comp.) | 35 | 204 | 20 | 313 |
|
69 |
+
| Local (H100, w/ comp.) | 21 | 144 | 16 | 234 |
|
70 |
+
|
71 |
+
Notes:
|
72 |
+
|
73 |
+
- HF space uses L40S GPU, and our local environment uses H100 GPU.
|
74 |
+
- We opt-in compilation by `torch.compile`.
|
75 |
+
- `Basic` mode is tested by generating 80 frames at `768x576` resolution.
|
76 |
+
- `Advanced` mode is tested by generating 134 frames at `576x576` resolution.
|
docs/INSTALL.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# :wrench: Installation
|
2 |
+
|
3 |
+
### Model Dependencies
|
4 |
+
|
5 |
+
```bash
|
6 |
+
# Install seva model dependencies.
|
7 |
+
pip install -e .
|
8 |
+
```
|
9 |
+
|
10 |
+
### Demo Dependencies
|
11 |
+
|
12 |
+
To use the cli demo (`demo.py`) or the gradio demo (`demo_gr.py`), do the following:
|
13 |
+
|
14 |
+
```bash
|
15 |
+
# Initialize and update submodules for demo.
|
16 |
+
git submodule update --init --recursive
|
17 |
+
|
18 |
+
# Install pycolmap dependencies for cli and gradio demo (our model is not dependent on it).
|
19 |
+
echo "Installing pycolmap (for both cli and gradio demo)..."
|
20 |
+
pip install git+https://github.com/jensenz-sai/pycolmap@543266bc316df2fe407b3a33d454b310b1641042
|
21 |
+
|
22 |
+
# Install dust3r dependencies for gradio demo (our model is not dependent on it).
|
23 |
+
echo "Installing dust3r dependencies (only for gradio demo)..."
|
24 |
+
pushd third_party/dust3r
|
25 |
+
pip install -r requirements.txt
|
26 |
+
popd
|
27 |
+
```
|
28 |
+
|
29 |
+
### Dev and Speeding Up (Optional)
|
30 |
+
|
31 |
+
```bash
|
32 |
+
# [OPTIONAL] Install seva dependencies for development.
|
33 |
+
pip install -e ".[dev]"
|
34 |
+
pre-commit install
|
35 |
+
|
36 |
+
# [OPTIONAL] Install the torch nightly version for faster JIT via. torch.compile (speed up sampling by 2x in our testing).
|
37 |
+
# Please adjust to your own cuda version. For example, if you have cuda 11.8, use the following command.
|
38 |
+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu118
|
39 |
+
```
|
pyproject.toml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=65.5.3"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "seva"
|
7 |
+
version = "0.0.0"
|
8 |
+
requires-python = ">=3.10"
|
9 |
+
dependencies = [
|
10 |
+
"torch>=2.6.0",
|
11 |
+
"roma",
|
12 |
+
"viser",
|
13 |
+
"tyro",
|
14 |
+
"fire",
|
15 |
+
"ninja",
|
16 |
+
"gradio==5.17.0",
|
17 |
+
"einops",
|
18 |
+
"colorama",
|
19 |
+
"splines",
|
20 |
+
"kornia",
|
21 |
+
"open-clip-torch",
|
22 |
+
"diffusers",
|
23 |
+
"numpy==1.24.4",
|
24 |
+
"imageio[ffmpeg]",
|
25 |
+
"huggingface-hub",
|
26 |
+
"opencv-python",
|
27 |
+
]
|
28 |
+
|
29 |
+
[project.optional-dependencies]
|
30 |
+
dev = ["ruff", "ipdb", "pytest", "line_profiler", "pre-commit"]
|
31 |
+
|
32 |
+
[tool.setuptools.packages.find]
|
33 |
+
include = ["seva"]
|
34 |
+
|
35 |
+
[tool.pyright]
|
36 |
+
extraPaths = ["third_party/dust3r"]
|
37 |
+
|
38 |
+
[tool.ruff]
|
39 |
+
lint.ignore = ["E741"]
|
seva/__init__.py
ADDED
File without changes
|
seva/data_io.py
ADDED
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
from glob import glob
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
import imageio.v3 as iio
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from seva.geometry import (
|
13 |
+
align_principle_axes,
|
14 |
+
similarity_from_cameras,
|
15 |
+
transform_cameras,
|
16 |
+
transform_points,
|
17 |
+
)
|
18 |
+
|
19 |
+
|
20 |
+
def _get_rel_paths(path_dir: str) -> List[str]:
|
21 |
+
"""Recursively get relative paths of files in a directory."""
|
22 |
+
paths = []
|
23 |
+
for dp, _, fn in os.walk(path_dir):
|
24 |
+
for f in fn:
|
25 |
+
paths.append(os.path.relpath(os.path.join(dp, f), path_dir))
|
26 |
+
return paths
|
27 |
+
|
28 |
+
|
29 |
+
class BaseParser(object):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
data_dir: str,
|
33 |
+
factor: int = 1,
|
34 |
+
normalize: bool = False,
|
35 |
+
test_every: Optional[int] = 8,
|
36 |
+
):
|
37 |
+
self.data_dir = data_dir
|
38 |
+
self.factor = factor
|
39 |
+
self.normalize = normalize
|
40 |
+
self.test_every = test_every
|
41 |
+
|
42 |
+
self.image_names: List[str] = [] # (num_images,)
|
43 |
+
self.image_paths: List[str] = [] # (num_images,)
|
44 |
+
self.camtoworlds: np.ndarray = np.zeros((0, 4, 4)) # (num_images, 4, 4)
|
45 |
+
self.camera_ids: List[int] = [] # (num_images,)
|
46 |
+
self.Ks_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> K
|
47 |
+
self.params_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> params
|
48 |
+
self.imsize_dict: Dict[
|
49 |
+
int, Tuple[int, int]
|
50 |
+
] = {} # Dict of camera_id -> (width, height)
|
51 |
+
self.points: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
|
52 |
+
self.points_err: np.ndarray = np.zeros((0,)) # (num_points,)
|
53 |
+
self.points_rgb: np.ndarray = np.zeros((0, 3)) # (num_points, 3)
|
54 |
+
self.point_indices: Dict[str, np.ndarray] = {} # Dict of image_name -> (M,)
|
55 |
+
self.transform: np.ndarray = np.zeros((4, 4)) # (4, 4)
|
56 |
+
|
57 |
+
self.mapx_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
|
58 |
+
self.mapy_dict: Dict[int, np.ndarray] = {} # Dict of camera_id -> (H, W)
|
59 |
+
self.roi_undist_dict: Dict[int, Tuple[int, int, int, int]] = (
|
60 |
+
dict()
|
61 |
+
) # Dict of camera_id -> (x, y, w, h)
|
62 |
+
self.scene_scale: float = 1.0
|
63 |
+
|
64 |
+
|
65 |
+
class DirectParser(BaseParser):
|
66 |
+
def __init__(
|
67 |
+
self,
|
68 |
+
imgs: List[np.ndarray],
|
69 |
+
c2ws: np.ndarray,
|
70 |
+
Ks: np.ndarray,
|
71 |
+
points: Optional[np.ndarray] = None,
|
72 |
+
points_rgb: Optional[np.ndarray] = None, # uint8
|
73 |
+
mono_disps: Optional[List[np.ndarray]] = None,
|
74 |
+
normalize: bool = False,
|
75 |
+
test_every: Optional[int] = None,
|
76 |
+
):
|
77 |
+
super().__init__("", 1, normalize, test_every)
|
78 |
+
|
79 |
+
self.image_names = [f"{i:06d}" for i in range(len(imgs))]
|
80 |
+
self.image_paths = ["null" for _ in range(len(imgs))]
|
81 |
+
self.camtoworlds = c2ws
|
82 |
+
self.camera_ids = [i for i in range(len(imgs))]
|
83 |
+
self.Ks_dict = {i: K for i, K in enumerate(Ks)}
|
84 |
+
self.imsize_dict = {
|
85 |
+
i: (img.shape[1], img.shape[0]) for i, img in enumerate(imgs)
|
86 |
+
}
|
87 |
+
if points is not None:
|
88 |
+
self.points = points
|
89 |
+
assert points_rgb is not None
|
90 |
+
self.points_rgb = points_rgb
|
91 |
+
self.points_err = np.zeros((len(points),))
|
92 |
+
|
93 |
+
self.imgs = imgs
|
94 |
+
self.mono_disps = mono_disps
|
95 |
+
|
96 |
+
# Normalize the world space.
|
97 |
+
if normalize:
|
98 |
+
T1 = similarity_from_cameras(self.camtoworlds)
|
99 |
+
self.camtoworlds = transform_cameras(T1, self.camtoworlds)
|
100 |
+
|
101 |
+
if points is not None:
|
102 |
+
self.points = transform_points(T1, self.points)
|
103 |
+
T2 = align_principle_axes(self.points)
|
104 |
+
self.camtoworlds = transform_cameras(T2, self.camtoworlds)
|
105 |
+
self.points = transform_points(T2, self.points)
|
106 |
+
else:
|
107 |
+
T2 = np.eye(4)
|
108 |
+
|
109 |
+
self.transform = T2 @ T1
|
110 |
+
else:
|
111 |
+
self.transform = np.eye(4)
|
112 |
+
|
113 |
+
# size of the scene measured by cameras
|
114 |
+
camera_locations = self.camtoworlds[:, :3, 3]
|
115 |
+
scene_center = np.mean(camera_locations, axis=0)
|
116 |
+
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
|
117 |
+
self.scene_scale = np.max(dists)
|
118 |
+
|
119 |
+
|
120 |
+
class COLMAPParser(BaseParser):
|
121 |
+
"""COLMAP parser."""
|
122 |
+
|
123 |
+
def __init__(
|
124 |
+
self,
|
125 |
+
data_dir: str,
|
126 |
+
factor: int = 1,
|
127 |
+
normalize: bool = False,
|
128 |
+
test_every: Optional[int] = 8,
|
129 |
+
image_folder: str = "images",
|
130 |
+
colmap_folder: str = "sparse/0",
|
131 |
+
):
|
132 |
+
super().__init__(data_dir, factor, normalize, test_every)
|
133 |
+
|
134 |
+
colmap_dir = os.path.join(data_dir, colmap_folder)
|
135 |
+
assert os.path.exists(
|
136 |
+
colmap_dir
|
137 |
+
), f"COLMAP directory {colmap_dir} does not exist."
|
138 |
+
|
139 |
+
try:
|
140 |
+
from pycolmap import SceneManager
|
141 |
+
except ImportError:
|
142 |
+
raise ImportError(
|
143 |
+
"Please install pycolmap to use the data parsers: "
|
144 |
+
" `pip install git+https://github.com/jensenz-sai/pycolmap.git@543266bc316df2fe407b3a33d454b310b1641042`"
|
145 |
+
)
|
146 |
+
|
147 |
+
manager = SceneManager(colmap_dir)
|
148 |
+
manager.load_cameras()
|
149 |
+
manager.load_images()
|
150 |
+
manager.load_points3D()
|
151 |
+
|
152 |
+
# Extract extrinsic matrices in world-to-camera format.
|
153 |
+
imdata = manager.images
|
154 |
+
w2c_mats = []
|
155 |
+
camera_ids = []
|
156 |
+
Ks_dict = dict()
|
157 |
+
params_dict = dict()
|
158 |
+
imsize_dict = dict() # width, height
|
159 |
+
bottom = np.array([0, 0, 0, 1]).reshape(1, 4)
|
160 |
+
for k in imdata:
|
161 |
+
im = imdata[k]
|
162 |
+
rot = im.R()
|
163 |
+
trans = im.tvec.reshape(3, 1)
|
164 |
+
w2c = np.concatenate([np.concatenate([rot, trans], 1), bottom], axis=0)
|
165 |
+
w2c_mats.append(w2c)
|
166 |
+
|
167 |
+
# support different camera intrinsics
|
168 |
+
camera_id = im.camera_id
|
169 |
+
camera_ids.append(camera_id)
|
170 |
+
|
171 |
+
# camera intrinsics
|
172 |
+
cam = manager.cameras[camera_id]
|
173 |
+
fx, fy, cx, cy = cam.fx, cam.fy, cam.cx, cam.cy
|
174 |
+
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
175 |
+
K[:2, :] /= factor
|
176 |
+
Ks_dict[camera_id] = K
|
177 |
+
|
178 |
+
# Get distortion parameters.
|
179 |
+
type_ = cam.camera_type
|
180 |
+
if type_ == 0 or type_ == "SIMPLE_PINHOLE":
|
181 |
+
params = np.empty(0, dtype=np.float32)
|
182 |
+
camtype = "perspective"
|
183 |
+
elif type_ == 1 or type_ == "PINHOLE":
|
184 |
+
params = np.empty(0, dtype=np.float32)
|
185 |
+
camtype = "perspective"
|
186 |
+
if type_ == 2 or type_ == "SIMPLE_RADIAL":
|
187 |
+
params = np.array([cam.k1, 0.0, 0.0, 0.0], dtype=np.float32)
|
188 |
+
camtype = "perspective"
|
189 |
+
elif type_ == 3 or type_ == "RADIAL":
|
190 |
+
params = np.array([cam.k1, cam.k2, 0.0, 0.0], dtype=np.float32)
|
191 |
+
camtype = "perspective"
|
192 |
+
elif type_ == 4 or type_ == "OPENCV":
|
193 |
+
params = np.array([cam.k1, cam.k2, cam.p1, cam.p2], dtype=np.float32)
|
194 |
+
camtype = "perspective"
|
195 |
+
elif type_ == 5 or type_ == "OPENCV_FISHEYE":
|
196 |
+
params = np.array([cam.k1, cam.k2, cam.k3, cam.k4], dtype=np.float32)
|
197 |
+
camtype = "fisheye"
|
198 |
+
assert (
|
199 |
+
camtype == "perspective" # type: ignore
|
200 |
+
), f"Only support perspective camera model, got {type_}"
|
201 |
+
|
202 |
+
params_dict[camera_id] = params # type: ignore
|
203 |
+
|
204 |
+
# image size
|
205 |
+
imsize_dict[camera_id] = (cam.width // factor, cam.height // factor)
|
206 |
+
|
207 |
+
print(
|
208 |
+
f"[Parser] {len(imdata)} images, taken by {len(set(camera_ids))} cameras."
|
209 |
+
)
|
210 |
+
|
211 |
+
if len(imdata) == 0:
|
212 |
+
raise ValueError("No images found in COLMAP.")
|
213 |
+
if not (type_ == 0 or type_ == 1): # type: ignore
|
214 |
+
print("Warning: COLMAP Camera is not PINHOLE. Images have distortion.")
|
215 |
+
|
216 |
+
w2c_mats = np.stack(w2c_mats, axis=0)
|
217 |
+
|
218 |
+
# Convert extrinsics to camera-to-world.
|
219 |
+
camtoworlds = np.linalg.inv(w2c_mats)
|
220 |
+
|
221 |
+
# Image names from COLMAP. No need for permuting the poses according to
|
222 |
+
# image names anymore.
|
223 |
+
image_names = [imdata[k].name for k in imdata]
|
224 |
+
|
225 |
+
# Previous Nerf results were generated with images sorted by filename,
|
226 |
+
# ensure metrics are reported on the same test set.
|
227 |
+
inds = np.argsort(image_names)
|
228 |
+
image_names = [image_names[i] for i in inds]
|
229 |
+
camtoworlds = camtoworlds[inds]
|
230 |
+
camera_ids = [camera_ids[i] for i in inds]
|
231 |
+
|
232 |
+
# Load images.
|
233 |
+
if factor > 1:
|
234 |
+
image_dir_suffix = f"_{factor}"
|
235 |
+
else:
|
236 |
+
image_dir_suffix = ""
|
237 |
+
colmap_image_dir = os.path.join(data_dir, image_folder)
|
238 |
+
image_dir = os.path.join(data_dir, image_folder + image_dir_suffix)
|
239 |
+
for d in [image_dir, colmap_image_dir]:
|
240 |
+
if not os.path.exists(d):
|
241 |
+
raise ValueError(f"Image folder {d} does not exist.")
|
242 |
+
|
243 |
+
# Downsampled images may have different names vs images used for COLMAP,
|
244 |
+
# so we need to map between the two sorted lists of files.
|
245 |
+
colmap_files = sorted(_get_rel_paths(colmap_image_dir))
|
246 |
+
image_files = sorted(_get_rel_paths(image_dir))
|
247 |
+
colmap_to_image = dict(zip(colmap_files, image_files))
|
248 |
+
image_paths = [os.path.join(image_dir, colmap_to_image[f]) for f in image_names]
|
249 |
+
|
250 |
+
# 3D points and {image_name -> [point_idx]}
|
251 |
+
points = manager.points3D.astype(np.float32) # type: ignore
|
252 |
+
points_err = manager.point3D_errors.astype(np.float32) # type: ignore
|
253 |
+
points_rgb = manager.point3D_colors.astype(np.uint8) # type: ignore
|
254 |
+
point_indices = dict()
|
255 |
+
|
256 |
+
image_id_to_name = {v: k for k, v in manager.name_to_image_id.items()}
|
257 |
+
for point_id, data in manager.point3D_id_to_images.items():
|
258 |
+
for image_id, _ in data:
|
259 |
+
image_name = image_id_to_name[image_id]
|
260 |
+
point_idx = manager.point3D_id_to_point3D_idx[point_id]
|
261 |
+
point_indices.setdefault(image_name, []).append(point_idx)
|
262 |
+
point_indices = {
|
263 |
+
k: np.array(v).astype(np.int32) for k, v in point_indices.items()
|
264 |
+
}
|
265 |
+
|
266 |
+
# Normalize the world space.
|
267 |
+
if normalize:
|
268 |
+
T1 = similarity_from_cameras(camtoworlds)
|
269 |
+
camtoworlds = transform_cameras(T1, camtoworlds)
|
270 |
+
points = transform_points(T1, points)
|
271 |
+
|
272 |
+
T2 = align_principle_axes(points)
|
273 |
+
camtoworlds = transform_cameras(T2, camtoworlds)
|
274 |
+
points = transform_points(T2, points)
|
275 |
+
|
276 |
+
transform = T2 @ T1
|
277 |
+
else:
|
278 |
+
transform = np.eye(4)
|
279 |
+
|
280 |
+
self.image_names = image_names # List[str], (num_images,)
|
281 |
+
self.image_paths = image_paths # List[str], (num_images,)
|
282 |
+
self.camtoworlds = camtoworlds # np.ndarray, (num_images, 4, 4)
|
283 |
+
self.camera_ids = camera_ids # List[int], (num_images,)
|
284 |
+
self.Ks_dict = Ks_dict # Dict of camera_id -> K
|
285 |
+
self.params_dict = params_dict # Dict of camera_id -> params
|
286 |
+
self.imsize_dict = imsize_dict # Dict of camera_id -> (width, height)
|
287 |
+
self.points = points # np.ndarray, (num_points, 3)
|
288 |
+
self.points_err = points_err # np.ndarray, (num_points,)
|
289 |
+
self.points_rgb = points_rgb # np.ndarray, (num_points, 3)
|
290 |
+
self.point_indices = point_indices # Dict[str, np.ndarray], image_name -> [M,]
|
291 |
+
self.transform = transform # np.ndarray, (4, 4)
|
292 |
+
|
293 |
+
# undistortion
|
294 |
+
self.mapx_dict = dict()
|
295 |
+
self.mapy_dict = dict()
|
296 |
+
self.roi_undist_dict = dict()
|
297 |
+
for camera_id in self.params_dict.keys():
|
298 |
+
params = self.params_dict[camera_id]
|
299 |
+
if len(params) == 0:
|
300 |
+
continue # no distortion
|
301 |
+
assert camera_id in self.Ks_dict, f"Missing K for camera {camera_id}"
|
302 |
+
assert (
|
303 |
+
camera_id in self.params_dict
|
304 |
+
), f"Missing params for camera {camera_id}"
|
305 |
+
K = self.Ks_dict[camera_id]
|
306 |
+
width, height = self.imsize_dict[camera_id]
|
307 |
+
K_undist, roi_undist = cv2.getOptimalNewCameraMatrix(
|
308 |
+
K, params, (width, height), 0
|
309 |
+
)
|
310 |
+
mapx, mapy = cv2.initUndistortRectifyMap(
|
311 |
+
K,
|
312 |
+
params,
|
313 |
+
None,
|
314 |
+
K_undist,
|
315 |
+
(width, height),
|
316 |
+
cv2.CV_32FC1, # type: ignore
|
317 |
+
)
|
318 |
+
self.Ks_dict[camera_id] = K_undist
|
319 |
+
self.mapx_dict[camera_id] = mapx
|
320 |
+
self.mapy_dict[camera_id] = mapy
|
321 |
+
self.roi_undist_dict[camera_id] = roi_undist # type: ignore
|
322 |
+
|
323 |
+
# size of the scene measured by cameras
|
324 |
+
camera_locations = camtoworlds[:, :3, 3]
|
325 |
+
scene_center = np.mean(camera_locations, axis=0)
|
326 |
+
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
|
327 |
+
self.scene_scale = np.max(dists)
|
328 |
+
|
329 |
+
|
330 |
+
class ReconfusionParser(BaseParser):
|
331 |
+
def __init__(self, data_dir: str, normalize: bool = False):
|
332 |
+
super().__init__(data_dir, 1, normalize, test_every=None)
|
333 |
+
|
334 |
+
def get_num(p):
|
335 |
+
return p.split("_")[-1].removesuffix(".json")
|
336 |
+
|
337 |
+
splits_per_num_input_frames = {}
|
338 |
+
num_input_frames = [
|
339 |
+
int(get_num(p)) if get_num(p).isdigit() else get_num(p)
|
340 |
+
for p in sorted(glob(osp.join(data_dir, "train_test_split_*.json")))
|
341 |
+
]
|
342 |
+
for num_input_frames in num_input_frames:
|
343 |
+
with open(
|
344 |
+
osp.join(
|
345 |
+
data_dir,
|
346 |
+
f"train_test_split_{num_input_frames}.json",
|
347 |
+
)
|
348 |
+
) as f:
|
349 |
+
splits_per_num_input_frames[num_input_frames] = json.load(f)
|
350 |
+
self.splits_per_num_input_frames = splits_per_num_input_frames
|
351 |
+
|
352 |
+
with open(osp.join(data_dir, "transforms.json")) as f:
|
353 |
+
metadata = json.load(f)
|
354 |
+
|
355 |
+
image_names, image_paths, camtoworlds = [], [], []
|
356 |
+
for frame in metadata["frames"]:
|
357 |
+
if frame["file_path"] is None:
|
358 |
+
image_path = image_name = None
|
359 |
+
else:
|
360 |
+
image_path = osp.join(data_dir, frame["file_path"])
|
361 |
+
image_name = osp.basename(image_path)
|
362 |
+
image_paths.append(image_path)
|
363 |
+
image_names.append(image_name)
|
364 |
+
camtoworld = np.array(frame["transform_matrix"])
|
365 |
+
if "applied_transform" in metadata:
|
366 |
+
applied_transform = np.concatenate(
|
367 |
+
[metadata["applied_transform"], [[0, 0, 0, 1]]], axis=0
|
368 |
+
)
|
369 |
+
camtoworld = applied_transform @ camtoworld
|
370 |
+
camtoworlds.append(camtoworld)
|
371 |
+
camtoworlds = np.array(camtoworlds)
|
372 |
+
camtoworlds[:, :, [1, 2]] *= -1
|
373 |
+
|
374 |
+
# Normalize the world space.
|
375 |
+
if normalize:
|
376 |
+
T1 = similarity_from_cameras(camtoworlds)
|
377 |
+
camtoworlds = transform_cameras(T1, camtoworlds)
|
378 |
+
self.transform = T1
|
379 |
+
else:
|
380 |
+
self.transform = np.eye(4)
|
381 |
+
|
382 |
+
self.image_names = image_names
|
383 |
+
self.image_paths = image_paths
|
384 |
+
self.camtoworlds = camtoworlds
|
385 |
+
self.camera_ids = list(range(len(image_paths)))
|
386 |
+
self.Ks_dict = {
|
387 |
+
i: np.array(
|
388 |
+
[
|
389 |
+
[
|
390 |
+
metadata.get("fl_x", frame.get("fl_x", None)),
|
391 |
+
0.0,
|
392 |
+
metadata.get("cx", frame.get("cx", None)),
|
393 |
+
],
|
394 |
+
[
|
395 |
+
0.0,
|
396 |
+
metadata.get("fl_y", frame.get("fl_y", None)),
|
397 |
+
metadata.get("cy", frame.get("cy", None)),
|
398 |
+
],
|
399 |
+
[0.0, 0.0, 1.0],
|
400 |
+
]
|
401 |
+
)
|
402 |
+
for i, frame in enumerate(metadata["frames"])
|
403 |
+
}
|
404 |
+
self.imsize_dict = {
|
405 |
+
i: (
|
406 |
+
metadata.get("w", frame.get("w", None)),
|
407 |
+
metadata.get("h", frame.get("h", None)),
|
408 |
+
)
|
409 |
+
for i, frame in enumerate(metadata["frames"])
|
410 |
+
}
|
411 |
+
# When num_input_frames is None, use all frames for both training and
|
412 |
+
# testing.
|
413 |
+
# self.splits_per_num_input_frames[None] = {
|
414 |
+
# "train_ids": list(range(len(image_paths))),
|
415 |
+
# "test_ids": list(range(len(image_paths))),
|
416 |
+
# }
|
417 |
+
|
418 |
+
# size of the scene measured by cameras
|
419 |
+
camera_locations = camtoworlds[:, :3, 3]
|
420 |
+
scene_center = np.mean(camera_locations, axis=0)
|
421 |
+
dists = np.linalg.norm(camera_locations - scene_center, axis=1)
|
422 |
+
self.scene_scale = np.max(dists)
|
423 |
+
|
424 |
+
self.bounds = None
|
425 |
+
if osp.exists(osp.join(data_dir, "bounds.npy")):
|
426 |
+
self.bounds = np.load(osp.join(data_dir, "bounds.npy"))
|
427 |
+
scaling = np.linalg.norm(self.transform[0, :3])
|
428 |
+
self.bounds = self.bounds / scaling
|
429 |
+
|
430 |
+
|
431 |
+
class Dataset(torch.utils.data.Dataset):
|
432 |
+
"""A simple dataset class."""
|
433 |
+
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
parser: BaseParser,
|
437 |
+
split: str = "train",
|
438 |
+
num_input_frames: Optional[int] = None,
|
439 |
+
patch_size: Optional[int] = None,
|
440 |
+
load_depths: bool = False,
|
441 |
+
load_mono_disps: bool = False,
|
442 |
+
):
|
443 |
+
self.parser = parser
|
444 |
+
self.split = split
|
445 |
+
self.num_input_frames = num_input_frames
|
446 |
+
self.patch_size = patch_size
|
447 |
+
self.load_depths = load_depths
|
448 |
+
self.load_mono_disps = load_mono_disps
|
449 |
+
if load_mono_disps:
|
450 |
+
assert isinstance(parser, DirectParser)
|
451 |
+
assert parser.mono_disps is not None
|
452 |
+
if isinstance(parser, ReconfusionParser):
|
453 |
+
ids_per_split = parser.splits_per_num_input_frames[num_input_frames]
|
454 |
+
self.indices = ids_per_split[
|
455 |
+
"train_ids" if split == "train" else "test_ids"
|
456 |
+
]
|
457 |
+
else:
|
458 |
+
indices = np.arange(len(self.parser.image_names))
|
459 |
+
if split == "train":
|
460 |
+
self.indices = (
|
461 |
+
indices[indices % self.parser.test_every != 0]
|
462 |
+
if self.parser.test_every is not None
|
463 |
+
else indices
|
464 |
+
)
|
465 |
+
else:
|
466 |
+
self.indices = (
|
467 |
+
indices[indices % self.parser.test_every == 0]
|
468 |
+
if self.parser.test_every is not None
|
469 |
+
else indices
|
470 |
+
)
|
471 |
+
|
472 |
+
def __len__(self):
|
473 |
+
return len(self.indices)
|
474 |
+
|
475 |
+
def __getitem__(self, item: int) -> Dict[str, Any]:
|
476 |
+
index = self.indices[item]
|
477 |
+
if isinstance(self.parser, DirectParser):
|
478 |
+
image = self.parser.imgs[index]
|
479 |
+
else:
|
480 |
+
image = iio.imread(self.parser.image_paths[index])[..., :3]
|
481 |
+
camera_id = self.parser.camera_ids[index]
|
482 |
+
K = self.parser.Ks_dict[camera_id].copy() # undistorted K
|
483 |
+
params = self.parser.params_dict.get(camera_id, None)
|
484 |
+
camtoworlds = self.parser.camtoworlds[index]
|
485 |
+
|
486 |
+
x, y, w, h = 0, 0, image.shape[1], image.shape[0]
|
487 |
+
if params is not None and len(params) > 0:
|
488 |
+
# Images are distorted. Undistort them.
|
489 |
+
mapx, mapy = (
|
490 |
+
self.parser.mapx_dict[camera_id],
|
491 |
+
self.parser.mapy_dict[camera_id],
|
492 |
+
)
|
493 |
+
image = cv2.remap(image, mapx, mapy, cv2.INTER_LINEAR)
|
494 |
+
x, y, w, h = self.parser.roi_undist_dict[camera_id]
|
495 |
+
image = image[y : y + h, x : x + w]
|
496 |
+
|
497 |
+
if self.patch_size is not None:
|
498 |
+
# Random crop.
|
499 |
+
h, w = image.shape[:2]
|
500 |
+
x = np.random.randint(0, max(w - self.patch_size, 1))
|
501 |
+
y = np.random.randint(0, max(h - self.patch_size, 1))
|
502 |
+
image = image[y : y + self.patch_size, x : x + self.patch_size]
|
503 |
+
K[0, 2] -= x
|
504 |
+
K[1, 2] -= y
|
505 |
+
|
506 |
+
data = {
|
507 |
+
"K": torch.from_numpy(K).float(),
|
508 |
+
"camtoworld": torch.from_numpy(camtoworlds).float(),
|
509 |
+
"image": torch.from_numpy(image).float(),
|
510 |
+
"image_id": item, # the index of the image in the dataset
|
511 |
+
}
|
512 |
+
|
513 |
+
if self.load_depths:
|
514 |
+
# projected points to image plane to get depths
|
515 |
+
worldtocams = np.linalg.inv(camtoworlds)
|
516 |
+
image_name = self.parser.image_names[index]
|
517 |
+
point_indices = self.parser.point_indices[image_name]
|
518 |
+
points_world = self.parser.points[point_indices]
|
519 |
+
points_cam = (worldtocams[:3, :3] @ points_world.T + worldtocams[:3, 3:4]).T
|
520 |
+
points_proj = (K @ points_cam.T).T
|
521 |
+
points = points_proj[:, :2] / points_proj[:, 2:3] # (M, 2)
|
522 |
+
depths = points_cam[:, 2] # (M,)
|
523 |
+
if self.patch_size is not None:
|
524 |
+
points[:, 0] -= x
|
525 |
+
points[:, 1] -= y
|
526 |
+
# filter out points outside the image
|
527 |
+
selector = (
|
528 |
+
(points[:, 0] >= 0)
|
529 |
+
& (points[:, 0] < image.shape[1])
|
530 |
+
& (points[:, 1] >= 0)
|
531 |
+
& (points[:, 1] < image.shape[0])
|
532 |
+
& (depths > 0)
|
533 |
+
)
|
534 |
+
points = points[selector]
|
535 |
+
depths = depths[selector]
|
536 |
+
data["points"] = torch.from_numpy(points).float()
|
537 |
+
data["depths"] = torch.from_numpy(depths).float()
|
538 |
+
if self.load_mono_disps:
|
539 |
+
data["mono_disps"] = torch.from_numpy(self.parser.mono_disps[index]).float() # type: ignore
|
540 |
+
|
541 |
+
return data
|
542 |
+
|
543 |
+
|
544 |
+
def get_parser(parser_type: str, **kwargs) -> BaseParser:
|
545 |
+
if parser_type == "colmap":
|
546 |
+
parser = COLMAPParser(**kwargs)
|
547 |
+
elif parser_type == "direct":
|
548 |
+
parser = DirectParser(**kwargs)
|
549 |
+
elif parser_type == "reconfusion":
|
550 |
+
parser = ReconfusionParser(**kwargs)
|
551 |
+
else:
|
552 |
+
raise ValueError(f"Unknown parser type: {parser_type}")
|
553 |
+
return parser
|
seva/eval.py
ADDED
@@ -0,0 +1,1990 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import collections
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import threading
|
7 |
+
from typing import List, Literal, Optional, Tuple, Union
|
8 |
+
|
9 |
+
import gradio as gr
|
10 |
+
from colorama import Fore, Style, init
|
11 |
+
|
12 |
+
init(autoreset=True)
|
13 |
+
|
14 |
+
import imageio.v3 as iio
|
15 |
+
import numpy as np
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
import torchvision.transforms.functional as TF
|
19 |
+
from einops import repeat
|
20 |
+
from PIL import Image
|
21 |
+
from tqdm.auto import tqdm
|
22 |
+
|
23 |
+
from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
|
24 |
+
from seva.sampling import (
|
25 |
+
EulerEDMSampler,
|
26 |
+
MultiviewCFG,
|
27 |
+
MultiviewTemporalCFG,
|
28 |
+
VanillaCFG,
|
29 |
+
)
|
30 |
+
from seva.utils import seed_everything
|
31 |
+
|
32 |
+
try:
|
33 |
+
# Check if version string contains 'dev' or 'nightly'
|
34 |
+
version = torch.__version__
|
35 |
+
IS_TORCH_NIGHTLY = "dev" in version
|
36 |
+
if IS_TORCH_NIGHTLY:
|
37 |
+
torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment]
|
38 |
+
torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment]
|
39 |
+
torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment]
|
40 |
+
except Exception:
|
41 |
+
IS_TORCH_NIGHTLY = False
|
42 |
+
|
43 |
+
|
44 |
+
def pad_indices(
|
45 |
+
input_indices: List[int],
|
46 |
+
test_indices: List[int],
|
47 |
+
T: int,
|
48 |
+
padding_mode: Literal["first", "last", "none"] = "last",
|
49 |
+
):
|
50 |
+
assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
|
51 |
+
if padding_mode == "last":
|
52 |
+
padded_indices = [
|
53 |
+
i for i in range(T) if i not in (input_indices + test_indices)
|
54 |
+
]
|
55 |
+
else:
|
56 |
+
padded_indices = []
|
57 |
+
input_selects = list(range(len(input_indices)))
|
58 |
+
test_selects = list(range(len(test_indices)))
|
59 |
+
if max(input_indices) > max(test_indices):
|
60 |
+
# last elem from input
|
61 |
+
input_selects += [input_selects[-1]] * len(padded_indices)
|
62 |
+
input_indices = input_indices + padded_indices
|
63 |
+
sorted_inds = np.argsort(input_indices)
|
64 |
+
input_indices = [input_indices[ind] for ind in sorted_inds]
|
65 |
+
input_selects = [input_selects[ind] for ind in sorted_inds]
|
66 |
+
else:
|
67 |
+
# last elem from test
|
68 |
+
test_selects += [test_selects[-1]] * len(padded_indices)
|
69 |
+
test_indices = test_indices + padded_indices
|
70 |
+
sorted_inds = np.argsort(test_indices)
|
71 |
+
test_indices = [test_indices[ind] for ind in sorted_inds]
|
72 |
+
test_selects = [test_selects[ind] for ind in sorted_inds]
|
73 |
+
|
74 |
+
if padding_mode == "last":
|
75 |
+
input_maps = np.array([-1] * T)
|
76 |
+
test_maps = np.array([-1] * T)
|
77 |
+
else:
|
78 |
+
input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
|
79 |
+
test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
|
80 |
+
input_maps[input_indices] = input_selects
|
81 |
+
test_maps[test_indices] = test_selects
|
82 |
+
return input_indices, test_indices, input_maps, test_maps
|
83 |
+
|
84 |
+
|
85 |
+
def assemble(
|
86 |
+
input,
|
87 |
+
test,
|
88 |
+
input_maps,
|
89 |
+
test_maps,
|
90 |
+
):
|
91 |
+
T = len(input_maps)
|
92 |
+
assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
|
93 |
+
assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
|
94 |
+
assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
|
95 |
+
assert np.logical_xor(input_maps != -1, test_maps != -1).all()
|
96 |
+
return assembled
|
97 |
+
|
98 |
+
|
99 |
+
def get_resizing_factor(
|
100 |
+
target_shape: Tuple[int, int], # H, W
|
101 |
+
current_shape: Tuple[int, int], # H, W
|
102 |
+
cover_target: bool = True,
|
103 |
+
# If True, the output shape will fully cover the target shape.
|
104 |
+
# If No, the target shape will fully cover the output shape.
|
105 |
+
) -> float:
|
106 |
+
r_bound = target_shape[1] / target_shape[0]
|
107 |
+
aspect_r = current_shape[1] / current_shape[0]
|
108 |
+
if r_bound >= 1.0:
|
109 |
+
if cover_target:
|
110 |
+
if aspect_r >= r_bound:
|
111 |
+
factor = min(target_shape) / min(current_shape)
|
112 |
+
elif aspect_r < 1.0:
|
113 |
+
factor = max(target_shape) / min(current_shape)
|
114 |
+
else:
|
115 |
+
factor = max(target_shape) / max(current_shape)
|
116 |
+
else:
|
117 |
+
if aspect_r >= r_bound:
|
118 |
+
factor = max(target_shape) / max(current_shape)
|
119 |
+
elif aspect_r < 1.0:
|
120 |
+
factor = min(target_shape) / max(current_shape)
|
121 |
+
else:
|
122 |
+
factor = min(target_shape) / min(current_shape)
|
123 |
+
else:
|
124 |
+
if cover_target:
|
125 |
+
if aspect_r <= r_bound:
|
126 |
+
factor = min(target_shape) / min(current_shape)
|
127 |
+
elif aspect_r > 1.0:
|
128 |
+
factor = max(target_shape) / min(current_shape)
|
129 |
+
else:
|
130 |
+
factor = max(target_shape) / max(current_shape)
|
131 |
+
else:
|
132 |
+
if aspect_r <= r_bound:
|
133 |
+
factor = max(target_shape) / max(current_shape)
|
134 |
+
elif aspect_r > 1.0:
|
135 |
+
factor = min(target_shape) / max(current_shape)
|
136 |
+
else:
|
137 |
+
factor = min(target_shape) / min(current_shape)
|
138 |
+
return factor
|
139 |
+
|
140 |
+
|
141 |
+
def get_unique_embedder_keys_from_conditioner(conditioner):
|
142 |
+
keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
|
143 |
+
keys = [item for sublist in keys for item in sublist] # Flatten list
|
144 |
+
return set(keys)
|
145 |
+
|
146 |
+
|
147 |
+
def get_wh_with_fixed_shortest_side(w, h, size):
|
148 |
+
# size is smaller or equal to zero, we return original w h
|
149 |
+
if size is None or size <= 0:
|
150 |
+
return w, h
|
151 |
+
if w < h:
|
152 |
+
new_w = size
|
153 |
+
new_h = int(size * h / w)
|
154 |
+
else:
|
155 |
+
new_h = size
|
156 |
+
new_w = int(size * w / h)
|
157 |
+
return new_w, new_h
|
158 |
+
|
159 |
+
|
160 |
+
def load_img_and_K(
|
161 |
+
image_path_or_size: Union[str, torch.Size],
|
162 |
+
size: Optional[Union[int, Tuple[int, int]]],
|
163 |
+
scale: float = 1.0,
|
164 |
+
center: Tuple[float, float] = (0.5, 0.5),
|
165 |
+
K: torch.Tensor | None = None,
|
166 |
+
size_stride: int = 1,
|
167 |
+
center_crop: bool = False,
|
168 |
+
image_as_tensor: bool = True,
|
169 |
+
context_rgb: np.ndarray | None = None,
|
170 |
+
device: str = "cuda",
|
171 |
+
):
|
172 |
+
if isinstance(image_path_or_size, torch.Size):
|
173 |
+
image = Image.new("RGBA", image_path_or_size[::-1])
|
174 |
+
else:
|
175 |
+
image = Image.open(image_path_or_size).convert("RGBA")
|
176 |
+
|
177 |
+
w, h = image.size
|
178 |
+
if size is None:
|
179 |
+
size = (w, h)
|
180 |
+
|
181 |
+
image = np.array(image).astype(np.float32) / 255
|
182 |
+
if image.shape[-1] == 4:
|
183 |
+
rgb, alpha = image[:, :, :3], image[:, :, 3:]
|
184 |
+
if context_rgb is not None:
|
185 |
+
image = rgb * alpha + context_rgb * (1 - alpha)
|
186 |
+
else:
|
187 |
+
image = rgb * alpha + (1 - alpha)
|
188 |
+
image = image.transpose(2, 0, 1)
|
189 |
+
image = torch.from_numpy(image).to(dtype=torch.float32)
|
190 |
+
image = image.unsqueeze(0)
|
191 |
+
|
192 |
+
if isinstance(size, (tuple, list)):
|
193 |
+
# => if size is a tuple or list, we first rescale to fully cover the `size`
|
194 |
+
# area and then crop the `size` area from the rescale image
|
195 |
+
W, H = size
|
196 |
+
else:
|
197 |
+
# => if size is int, we rescale the image to fit the shortest side to size
|
198 |
+
# => if size is None, no rescaling is applied
|
199 |
+
W, H = get_wh_with_fixed_shortest_side(w, h, size)
|
200 |
+
W, H = (
|
201 |
+
math.floor(W / size_stride + 0.5) * size_stride,
|
202 |
+
math.floor(H / size_stride + 0.5) * size_stride,
|
203 |
+
)
|
204 |
+
|
205 |
+
rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
|
206 |
+
resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
|
207 |
+
image = torch.nn.functional.interpolate(
|
208 |
+
image, resize_size, mode="area", antialias=False
|
209 |
+
)
|
210 |
+
if scale < 1.0:
|
211 |
+
pw = math.ceil((W - resize_size[1]) * 0.5)
|
212 |
+
ph = math.ceil((H - resize_size[0]) * 0.5)
|
213 |
+
image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
|
214 |
+
|
215 |
+
cy_center = int(center[1] * image.shape[-2])
|
216 |
+
cx_center = int(center[0] * image.shape[-1])
|
217 |
+
if center_crop:
|
218 |
+
side = min(H, W)
|
219 |
+
ct = max(0, cy_center - side // 2)
|
220 |
+
cl = max(0, cx_center - side // 2)
|
221 |
+
ct = min(ct, image.shape[-2] - side)
|
222 |
+
cl = min(cl, image.shape[-1] - side)
|
223 |
+
image = TF.crop(image, top=ct, left=cl, height=side, width=side)
|
224 |
+
else:
|
225 |
+
ct = max(0, cy_center - H // 2)
|
226 |
+
cl = max(0, cx_center - W // 2)
|
227 |
+
ct = min(ct, image.shape[-2] - H)
|
228 |
+
cl = min(cl, image.shape[-1] - W)
|
229 |
+
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
|
230 |
+
|
231 |
+
if K is not None:
|
232 |
+
K = K.clone()
|
233 |
+
if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
|
234 |
+
K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
|
235 |
+
else:
|
236 |
+
K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
|
237 |
+
K[:2, 2] -= K.new_tensor([cl, ct])
|
238 |
+
|
239 |
+
if image_as_tensor:
|
240 |
+
# tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
|
241 |
+
image = image.to(device) * 2.0 - 1.0
|
242 |
+
else:
|
243 |
+
# PIL Image with values ranging from (0, 255)
|
244 |
+
image = image.permute(0, 2, 3, 1).numpy()[0]
|
245 |
+
image = Image.fromarray((image * 255).astype(np.uint8))
|
246 |
+
return image, K
|
247 |
+
|
248 |
+
|
249 |
+
def transform_img_and_K(
|
250 |
+
image: torch.Tensor,
|
251 |
+
size: Union[int, Tuple[int, int]],
|
252 |
+
scale: float = 1.0,
|
253 |
+
center: Tuple[float, float] = (0.5, 0.5),
|
254 |
+
K: torch.Tensor | None = None,
|
255 |
+
size_stride: int = 1,
|
256 |
+
mode: str = "crop",
|
257 |
+
):
|
258 |
+
assert mode in [
|
259 |
+
"crop",
|
260 |
+
"pad",
|
261 |
+
"stretch",
|
262 |
+
], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
|
263 |
+
|
264 |
+
h, w = image.shape[-2:]
|
265 |
+
if isinstance(size, (tuple, list)):
|
266 |
+
# => if size is a tuple or list, we first rescale to fully cover the `size`
|
267 |
+
# area and then crop the `size` area from the rescale image
|
268 |
+
W, H = size
|
269 |
+
else:
|
270 |
+
# => if size is int, we rescale the image to fit the shortest side to size
|
271 |
+
# => if size is None, no rescaling is applied
|
272 |
+
W, H = get_wh_with_fixed_shortest_side(w, h, size)
|
273 |
+
W, H = (
|
274 |
+
math.floor(W / size_stride + 0.5) * size_stride,
|
275 |
+
math.floor(H / size_stride + 0.5) * size_stride,
|
276 |
+
)
|
277 |
+
|
278 |
+
if mode == "stretch":
|
279 |
+
rh, rw = H, W
|
280 |
+
else:
|
281 |
+
rfs = get_resizing_factor(
|
282 |
+
(H, W),
|
283 |
+
(h, w),
|
284 |
+
cover_target=mode != "pad",
|
285 |
+
)
|
286 |
+
(rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
|
287 |
+
|
288 |
+
rh, rw = int(rh / scale), int(rw / scale)
|
289 |
+
image = torch.nn.functional.interpolate(
|
290 |
+
image, (rh, rw), mode="area", antialias=False
|
291 |
+
)
|
292 |
+
|
293 |
+
cy_center = int(center[1] * image.shape[-2])
|
294 |
+
cx_center = int(center[0] * image.shape[-1])
|
295 |
+
if mode != "pad":
|
296 |
+
ct = max(0, cy_center - H // 2)
|
297 |
+
cl = max(0, cx_center - W // 2)
|
298 |
+
ct = min(ct, image.shape[-2] - H)
|
299 |
+
cl = min(cl, image.shape[-1] - W)
|
300 |
+
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
|
301 |
+
pl, pt = 0, 0
|
302 |
+
else:
|
303 |
+
pt = max(0, H // 2 - cy_center)
|
304 |
+
pl = max(0, W // 2 - cx_center)
|
305 |
+
pb = max(0, H - pt - image.shape[-2])
|
306 |
+
pr = max(0, W - pl - image.shape[-1])
|
307 |
+
image = TF.pad(
|
308 |
+
image,
|
309 |
+
[pl, pt, pr, pb],
|
310 |
+
)
|
311 |
+
cl, ct = 0, 0
|
312 |
+
|
313 |
+
if K is not None:
|
314 |
+
K = K.clone()
|
315 |
+
# K[:, :2, 2] += K.new_tensor([pl, pt])
|
316 |
+
if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
|
317 |
+
K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
|
318 |
+
else:
|
319 |
+
K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
|
320 |
+
K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
|
321 |
+
|
322 |
+
return image, K
|
323 |
+
|
324 |
+
|
325 |
+
lowvram_mode = False
|
326 |
+
|
327 |
+
|
328 |
+
def set_lowvram_mode(mode):
|
329 |
+
global lowvram_mode
|
330 |
+
lowvram_mode = mode
|
331 |
+
|
332 |
+
|
333 |
+
def load_model(model, device: str = "cuda"):
|
334 |
+
model.to(device)
|
335 |
+
|
336 |
+
|
337 |
+
def unload_model(model):
|
338 |
+
global lowvram_mode
|
339 |
+
if lowvram_mode:
|
340 |
+
model.cpu()
|
341 |
+
torch.cuda.empty_cache()
|
342 |
+
|
343 |
+
|
344 |
+
def infer_prior_stats(
|
345 |
+
T,
|
346 |
+
num_input_frames,
|
347 |
+
num_total_frames,
|
348 |
+
version_dict,
|
349 |
+
):
|
350 |
+
options = version_dict["options"]
|
351 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
352 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
353 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
354 |
+
# get traj_prior_c2ws for 2-pass sampling
|
355 |
+
if chunk_strategy.startswith("interp"):
|
356 |
+
# Start and end have alreay taken up two slots
|
357 |
+
# +1 means we need X + 1 prior frames to bound X times forwards for all test frames
|
358 |
+
|
359 |
+
# Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
|
360 |
+
# generated frames due to insufficient prior frames. This option is effective for
|
361 |
+
# complicated trajectory and when `interp` strategy is used (usually semi-dense-view
|
362 |
+
# regime). Recommended range is [1.0 (default), 1.5].
|
363 |
+
if num_input_frames >= options.get("num_input_semi_dense", 9):
|
364 |
+
num_prior_frames = (
|
365 |
+
math.ceil(
|
366 |
+
num_total_frames
|
367 |
+
/ (T_second_pass - 2)
|
368 |
+
* options.get("num_prior_frames_ratio", 1.0)
|
369 |
+
)
|
370 |
+
+ 1
|
371 |
+
)
|
372 |
+
|
373 |
+
if num_prior_frames + num_input_frames < T_first_pass:
|
374 |
+
num_prior_frames = T_first_pass - num_input_frames
|
375 |
+
|
376 |
+
num_prior_frames = max(
|
377 |
+
num_prior_frames,
|
378 |
+
options.get("num_prior_frames", 0),
|
379 |
+
)
|
380 |
+
|
381 |
+
T_first_pass = num_prior_frames + num_input_frames
|
382 |
+
|
383 |
+
if "gt" in chunk_strategy:
|
384 |
+
T_second_pass = T_second_pass + num_input_frames
|
385 |
+
|
386 |
+
# Dynamically update context window length.
|
387 |
+
version_dict["T"] = [T_first_pass, T_second_pass]
|
388 |
+
|
389 |
+
else:
|
390 |
+
num_prior_frames = (
|
391 |
+
math.ceil(
|
392 |
+
num_total_frames
|
393 |
+
/ (
|
394 |
+
T_second_pass
|
395 |
+
- 2
|
396 |
+
- (num_input_frames if "gt" in chunk_strategy else 0)
|
397 |
+
)
|
398 |
+
* options.get("num_prior_frames_ratio", 1.0)
|
399 |
+
)
|
400 |
+
+ 1
|
401 |
+
)
|
402 |
+
|
403 |
+
if num_prior_frames + num_input_frames < T_first_pass:
|
404 |
+
num_prior_frames = T_first_pass - num_input_frames
|
405 |
+
|
406 |
+
num_prior_frames = max(
|
407 |
+
num_prior_frames,
|
408 |
+
options.get("num_prior_frames", 0),
|
409 |
+
)
|
410 |
+
else:
|
411 |
+
num_prior_frames = max(
|
412 |
+
T_first_pass - num_input_frames,
|
413 |
+
options.get("num_prior_frames", 0),
|
414 |
+
)
|
415 |
+
|
416 |
+
if num_input_frames >= options.get("num_input_semi_dense", 9):
|
417 |
+
T_first_pass = num_prior_frames + num_input_frames
|
418 |
+
|
419 |
+
# Dynamically update context window length.
|
420 |
+
version_dict["T"] = [T_first_pass, T_second_pass]
|
421 |
+
|
422 |
+
return num_prior_frames
|
423 |
+
|
424 |
+
|
425 |
+
def infer_prior_inds(
|
426 |
+
c2ws,
|
427 |
+
num_prior_frames,
|
428 |
+
input_frame_indices,
|
429 |
+
options,
|
430 |
+
):
|
431 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
432 |
+
if chunk_strategy.startswith("interp"):
|
433 |
+
prior_frame_indices = np.array(
|
434 |
+
[i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
|
435 |
+
)
|
436 |
+
prior_frame_indices = prior_frame_indices[
|
437 |
+
np.ceil(
|
438 |
+
np.linspace(
|
439 |
+
0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
|
440 |
+
)
|
441 |
+
).astype(int)
|
442 |
+
] # having a ceil here is actually safer for corner case
|
443 |
+
else:
|
444 |
+
prior_frame_indices = []
|
445 |
+
while len(prior_frame_indices) < num_prior_frames:
|
446 |
+
closest_distance = np.abs(
|
447 |
+
np.arange(c2ws.shape[0])[None]
|
448 |
+
- np.concatenate(
|
449 |
+
[np.array(input_frame_indices), np.array(prior_frame_indices)]
|
450 |
+
)[:, None]
|
451 |
+
).min(0)
|
452 |
+
prior_frame_indices.append(np.argsort(closest_distance)[-1])
|
453 |
+
return np.sort(prior_frame_indices)
|
454 |
+
|
455 |
+
|
456 |
+
def compute_relative_inds(
|
457 |
+
source_inds,
|
458 |
+
target_inds,
|
459 |
+
):
|
460 |
+
assert len(source_inds) > 2
|
461 |
+
# compute relative indices of target_inds within source_inds
|
462 |
+
relative_inds = []
|
463 |
+
for ind in target_inds:
|
464 |
+
if ind in source_inds:
|
465 |
+
relative_ind = int(np.where(source_inds == ind)[0][0])
|
466 |
+
elif ind < source_inds[0]:
|
467 |
+
# extrapolate
|
468 |
+
relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
|
469 |
+
elif ind > source_inds[-1]:
|
470 |
+
# extrapolate
|
471 |
+
relative_ind = len(source_inds) + (
|
472 |
+
(ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
|
473 |
+
)
|
474 |
+
else:
|
475 |
+
# interpolate
|
476 |
+
lower_inds = source_inds[source_inds < ind]
|
477 |
+
upper_inds = source_inds[source_inds > ind]
|
478 |
+
if len(lower_inds) > 0 and len(upper_inds) > 0:
|
479 |
+
lower_ind = lower_inds[-1]
|
480 |
+
upper_ind = upper_inds[0]
|
481 |
+
relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
|
482 |
+
relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
|
483 |
+
relative_ind = relative_lower_ind + (ind - lower_ind) / (
|
484 |
+
upper_ind - lower_ind
|
485 |
+
) * (relative_upper_ind - relative_lower_ind)
|
486 |
+
else:
|
487 |
+
# Out of range
|
488 |
+
relative_inds.append(float("nan")) # Or some other placeholder
|
489 |
+
relative_inds.append(relative_ind)
|
490 |
+
return relative_inds
|
491 |
+
|
492 |
+
|
493 |
+
def find_nearest_source_inds(
|
494 |
+
source_c2ws,
|
495 |
+
target_c2ws,
|
496 |
+
nearest_num=1,
|
497 |
+
mode="translation",
|
498 |
+
):
|
499 |
+
dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
|
500 |
+
sorted_inds = np.argsort(dists, axis=0).T
|
501 |
+
return sorted_inds[:, :nearest_num]
|
502 |
+
|
503 |
+
|
504 |
+
def chunk_input_and_test(
|
505 |
+
T,
|
506 |
+
input_c2ws,
|
507 |
+
test_c2ws,
|
508 |
+
input_ords, # orders
|
509 |
+
test_ords, # orders
|
510 |
+
options,
|
511 |
+
task: str = "img2img",
|
512 |
+
chunk_strategy: str = "gt",
|
513 |
+
gt_input_inds: list = [],
|
514 |
+
):
|
515 |
+
M, N = input_c2ws.shape[0], test_c2ws.shape[0]
|
516 |
+
|
517 |
+
chunks = []
|
518 |
+
if chunk_strategy.startswith("gt"):
|
519 |
+
assert len(gt_input_inds) < T, (
|
520 |
+
f"Number of gt input frames {len(gt_input_inds)} should be "
|
521 |
+
f"less than {T} when `gt` chunking strategy is used."
|
522 |
+
)
|
523 |
+
assert (
|
524 |
+
list(range(M)) == gt_input_inds
|
525 |
+
), "All input_c2ws should be gt when `gt` chunking strategy is used."
|
526 |
+
|
527 |
+
# LEGACY CHUNKING STRATEGY
|
528 |
+
# num_test_per_chunk = T - len(gt_input_inds)
|
529 |
+
# test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds]
|
530 |
+
# for i in range(0, test_c2ws.shape[0], num_test_per_chunk):
|
531 |
+
# chunk = ["NULL"] * T
|
532 |
+
# for j, k in enumerate(gt_input_inds):
|
533 |
+
# chunk[k] = f"!{j:03d}"
|
534 |
+
# for j, k in enumerate(
|
535 |
+
# test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]]
|
536 |
+
# ):
|
537 |
+
# chunk[k] = f">{i + j:03d}"
|
538 |
+
# chunks.append(chunk)
|
539 |
+
|
540 |
+
num_test_seen = 0
|
541 |
+
while num_test_seen < N:
|
542 |
+
chunk = [f"!{i:03d}" for i in gt_input_inds]
|
543 |
+
if chunk_strategy != "gt" and num_test_seen > 0:
|
544 |
+
pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
|
545 |
+
if (N - num_test_seen) >= math.floor(
|
546 |
+
(T - len(gt_input_inds)) * pseudo_num_ratio
|
547 |
+
):
|
548 |
+
pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
|
549 |
+
else:
|
550 |
+
pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
|
551 |
+
pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))
|
552 |
+
|
553 |
+
if "ltr" in chunk_strategy:
|
554 |
+
chunk.extend(
|
555 |
+
[
|
556 |
+
f"!{i + len(gt_input_inds):03d}"
|
557 |
+
for i in range(num_test_seen - pseudo_num, num_test_seen)
|
558 |
+
]
|
559 |
+
)
|
560 |
+
elif "nearest" in chunk_strategy:
|
561 |
+
source_inds = np.concatenate(
|
562 |
+
[
|
563 |
+
find_nearest_source_inds(
|
564 |
+
test_c2ws[:num_test_seen],
|
565 |
+
test_c2ws[num_test_seen:],
|
566 |
+
nearest_num=1, # pseudo_num,
|
567 |
+
mode="rotation",
|
568 |
+
),
|
569 |
+
find_nearest_source_inds(
|
570 |
+
test_c2ws[:num_test_seen],
|
571 |
+
test_c2ws[num_test_seen:],
|
572 |
+
nearest_num=1, # pseudo_num,
|
573 |
+
mode="translation",
|
574 |
+
),
|
575 |
+
],
|
576 |
+
axis=1,
|
577 |
+
)
|
578 |
+
####### [HACK ALERT] keep running until pseudo num is stablized ########
|
579 |
+
temp_pseudo_num = pseudo_num
|
580 |
+
while True:
|
581 |
+
nearest_source_inds = np.concatenate(
|
582 |
+
[
|
583 |
+
np.sort(
|
584 |
+
[
|
585 |
+
ind
|
586 |
+
for (ind, _) in collections.Counter(
|
587 |
+
[
|
588 |
+
item
|
589 |
+
for item in source_inds[
|
590 |
+
: T
|
591 |
+
- len(gt_input_inds)
|
592 |
+
- temp_pseudo_num
|
593 |
+
]
|
594 |
+
.flatten()
|
595 |
+
.tolist()
|
596 |
+
if item
|
597 |
+
!= (
|
598 |
+
num_test_seen - 1
|
599 |
+
) # exclude the last one here
|
600 |
+
]
|
601 |
+
).most_common(pseudo_num - 1)
|
602 |
+
],
|
603 |
+
).astype(int),
|
604 |
+
[num_test_seen - 1], # always keep the last one
|
605 |
+
]
|
606 |
+
)
|
607 |
+
if len(nearest_source_inds) >= temp_pseudo_num:
|
608 |
+
break # stablized
|
609 |
+
else:
|
610 |
+
temp_pseudo_num = len(nearest_source_inds)
|
611 |
+
pseudo_num = len(nearest_source_inds)
|
612 |
+
########################################################################
|
613 |
+
chunk.extend(
|
614 |
+
[f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
|
615 |
+
)
|
616 |
+
else:
|
617 |
+
raise NotImplementedError(
|
618 |
+
f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
|
619 |
+
)
|
620 |
+
|
621 |
+
chunk.extend(
|
622 |
+
[
|
623 |
+
f">{i:03d}"
|
624 |
+
for i in range(
|
625 |
+
num_test_seen,
|
626 |
+
min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
|
627 |
+
)
|
628 |
+
]
|
629 |
+
)
|
630 |
+
else:
|
631 |
+
chunk.extend(
|
632 |
+
[
|
633 |
+
f">{i:03d}"
|
634 |
+
for i in range(
|
635 |
+
num_test_seen,
|
636 |
+
min(num_test_seen + T - len(gt_input_inds), N),
|
637 |
+
)
|
638 |
+
]
|
639 |
+
)
|
640 |
+
|
641 |
+
num_test_seen += sum([1 for c in chunk if c.startswith(">")])
|
642 |
+
if len(chunk) < T:
|
643 |
+
chunk.extend(["NULL"] * (T - len(chunk)))
|
644 |
+
chunks.append(chunk)
|
645 |
+
|
646 |
+
elif chunk_strategy.startswith("nearest"):
|
647 |
+
input_imgs = np.array([f"!{i:03d}" for i in range(M)])
|
648 |
+
test_imgs = np.array([f">{i:03d}" for i in range(N)])
|
649 |
+
|
650 |
+
match = re.match(r"^nearest-(\d+)$", chunk_strategy)
|
651 |
+
if match:
|
652 |
+
nearest_num = int(match.group(1))
|
653 |
+
assert (
|
654 |
+
nearest_num < T
|
655 |
+
), f"Nearest number of {nearest_num} should be less than {T}."
|
656 |
+
source_inds = find_nearest_source_inds(
|
657 |
+
input_c2ws,
|
658 |
+
test_c2ws,
|
659 |
+
nearest_num=nearest_num,
|
660 |
+
mode="translation", # during the second pass, consider translation only is enough
|
661 |
+
)
|
662 |
+
|
663 |
+
for i in range(0, N, T - nearest_num):
|
664 |
+
nearest_source_inds = np.sort(
|
665 |
+
[
|
666 |
+
ind
|
667 |
+
for (ind, _) in collections.Counter(
|
668 |
+
source_inds[i : i + T - nearest_num].flatten().tolist()
|
669 |
+
).most_common(nearest_num)
|
670 |
+
]
|
671 |
+
)
|
672 |
+
chunk = (
|
673 |
+
input_imgs[nearest_source_inds].tolist()
|
674 |
+
+ test_imgs[i : i + T - nearest_num].tolist()
|
675 |
+
)
|
676 |
+
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
|
677 |
+
|
678 |
+
else:
|
679 |
+
# do not always condition on gt cond frames
|
680 |
+
if "gt" not in chunk_strategy:
|
681 |
+
gt_input_inds = []
|
682 |
+
|
683 |
+
source_inds = find_nearest_source_inds(
|
684 |
+
input_c2ws,
|
685 |
+
test_c2ws,
|
686 |
+
nearest_num=1,
|
687 |
+
mode="translation", # during the second pass, consider translation only is enough
|
688 |
+
)[:, 0]
|
689 |
+
|
690 |
+
test_inds_per_input = {}
|
691 |
+
for test_idx, input_idx in enumerate(source_inds):
|
692 |
+
if input_idx not in test_inds_per_input:
|
693 |
+
test_inds_per_input[input_idx] = []
|
694 |
+
test_inds_per_input[input_idx].append(test_idx)
|
695 |
+
|
696 |
+
num_test_seen = 0
|
697 |
+
chunk = input_imgs[gt_input_inds].tolist()
|
698 |
+
candidate_input_inds = sorted(list(test_inds_per_input.keys()))
|
699 |
+
|
700 |
+
while num_test_seen < N:
|
701 |
+
input_idx = candidate_input_inds[0]
|
702 |
+
test_inds = test_inds_per_input[input_idx]
|
703 |
+
input_is_cond = input_idx in gt_input_inds
|
704 |
+
prefix_inds = [] if input_is_cond else [input_idx]
|
705 |
+
|
706 |
+
if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
|
707 |
+
if chunk:
|
708 |
+
chunk += ["NULL"] * (T - len(chunk))
|
709 |
+
chunks.append(chunk)
|
710 |
+
chunk = input_imgs[gt_input_inds].tolist()
|
711 |
+
if num_test_seen >= N:
|
712 |
+
break
|
713 |
+
continue
|
714 |
+
|
715 |
+
candidate_chunk = (
|
716 |
+
input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
|
717 |
+
)
|
718 |
+
|
719 |
+
space_left = T - len(chunk)
|
720 |
+
if len(candidate_chunk) <= space_left:
|
721 |
+
chunk.extend(candidate_chunk)
|
722 |
+
num_test_seen += len(test_inds)
|
723 |
+
candidate_input_inds.pop(0)
|
724 |
+
else:
|
725 |
+
chunk.extend(candidate_chunk[:space_left])
|
726 |
+
num_input_idx = 0 if input_is_cond else 1
|
727 |
+
num_test_seen += space_left - num_input_idx
|
728 |
+
test_inds_per_input[input_idx] = test_inds[
|
729 |
+
space_left - num_input_idx :
|
730 |
+
]
|
731 |
+
|
732 |
+
if len(chunk) == T:
|
733 |
+
chunks.append(chunk)
|
734 |
+
chunk = input_imgs[gt_input_inds].tolist()
|
735 |
+
|
736 |
+
if chunk and chunk != input_imgs[gt_input_inds].tolist():
|
737 |
+
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
|
738 |
+
|
739 |
+
elif chunk_strategy.startswith("interp"):
|
740 |
+
# `interp` chunk requires ordering info
|
741 |
+
assert input_ords is not None and test_ords is not None, (
|
742 |
+
"When using `interp` chunking strategy, ordering of input "
|
743 |
+
"and test frames should be provided."
|
744 |
+
)
|
745 |
+
|
746 |
+
# if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
|
747 |
+
# use input views since their order info within target views is unknown
|
748 |
+
if "img2trajvid" in task:
|
749 |
+
assert (
|
750 |
+
list(range(len(gt_input_inds))) == gt_input_inds
|
751 |
+
), "`img2trajvid` task should put `gt_input_inds` in start."
|
752 |
+
input_c2ws = input_c2ws[
|
753 |
+
[ind for ind in range(M) if ind not in gt_input_inds]
|
754 |
+
]
|
755 |
+
input_ords = [
|
756 |
+
input_ords[ind] for ind in range(M) if ind not in gt_input_inds
|
757 |
+
]
|
758 |
+
M = input_c2ws.shape[0]
|
759 |
+
|
760 |
+
input_ords = [0] + input_ords # this is a hack accounting for test views
|
761 |
+
# before the first input view
|
762 |
+
input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included
|
763 |
+
# in the last forward when input_ords[-1] == test_ords[-1]
|
764 |
+
input_ords = np.array(input_ords)[:, None]
|
765 |
+
input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
|
766 |
+
test_ords = np.array(test_ords)[None]
|
767 |
+
|
768 |
+
in_stop_ranges = np.logical_and(
|
769 |
+
np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
|
770 |
+
np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
|
771 |
+
) # (M, N)
|
772 |
+
assert (in_stop_ranges.sum(1) <= T - 2).all(), (
|
773 |
+
"More input frames need to be sampled during the first pass to ensure "
|
774 |
+
f"#test frames during each forard in the second pass will not exceed {T - 2}."
|
775 |
+
)
|
776 |
+
if input_ords[1, 0] <= test_ords[0, 0]:
|
777 |
+
assert not in_stop_ranges[0].any()
|
778 |
+
if input_ords[-1, 0] >= test_ords[0, -1]:
|
779 |
+
assert not in_stop_ranges[-1].any()
|
780 |
+
|
781 |
+
gt_chunk = (
|
782 |
+
[f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
|
783 |
+
)
|
784 |
+
chunk = gt_chunk + []
|
785 |
+
# any test views before the first input views
|
786 |
+
if in_stop_ranges[0].any():
|
787 |
+
for j, in_range in enumerate(in_stop_ranges[0]):
|
788 |
+
if in_range:
|
789 |
+
chunk.append(f">{j:03d}")
|
790 |
+
in_stop_ranges = in_stop_ranges[1:]
|
791 |
+
|
792 |
+
i = 0
|
793 |
+
base_i = len(gt_input_inds) if "img2trajvid" in task else 0
|
794 |
+
chunk.append(f"!{i + base_i:03d}")
|
795 |
+
while i < len(in_stop_ranges):
|
796 |
+
in_stop_range = in_stop_ranges[i]
|
797 |
+
if not in_stop_range.any():
|
798 |
+
i += 1
|
799 |
+
continue
|
800 |
+
|
801 |
+
input_left = i + 1 < M
|
802 |
+
space_left = T - len(chunk)
|
803 |
+
if sum(in_stop_range) + input_left <= space_left:
|
804 |
+
for j, in_range in enumerate(in_stop_range):
|
805 |
+
if in_range:
|
806 |
+
chunk.append(f">{j:03d}")
|
807 |
+
i += 1
|
808 |
+
if input_left:
|
809 |
+
chunk.append(f"!{i + base_i:03d}")
|
810 |
+
|
811 |
+
else:
|
812 |
+
chunk += ["NULL"] * space_left
|
813 |
+
chunks.append(chunk)
|
814 |
+
chunk = gt_chunk + [f"!{i + base_i:03d}"]
|
815 |
+
|
816 |
+
if len(chunk) > 1:
|
817 |
+
chunk += ["NULL"] * (T - len(chunk))
|
818 |
+
chunks.append(chunk)
|
819 |
+
|
820 |
+
else:
|
821 |
+
raise NotImplementedError
|
822 |
+
|
823 |
+
(
|
824 |
+
input_inds_per_chunk,
|
825 |
+
input_sels_per_chunk,
|
826 |
+
test_inds_per_chunk,
|
827 |
+
test_sels_per_chunk,
|
828 |
+
) = (
|
829 |
+
[],
|
830 |
+
[],
|
831 |
+
[],
|
832 |
+
[],
|
833 |
+
)
|
834 |
+
for chunk in chunks:
|
835 |
+
input_inds = [
|
836 |
+
int(img.removeprefix("!")) for img in chunk if img.startswith("!")
|
837 |
+
]
|
838 |
+
input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
|
839 |
+
test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
|
840 |
+
test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
|
841 |
+
input_inds_per_chunk.append(input_inds)
|
842 |
+
input_sels_per_chunk.append(input_sels)
|
843 |
+
test_inds_per_chunk.append(test_inds)
|
844 |
+
test_sels_per_chunk.append(test_sels)
|
845 |
+
|
846 |
+
if options.get("sampler_verbose", True):
|
847 |
+
|
848 |
+
def colorize(item):
|
849 |
+
if item.startswith("!"):
|
850 |
+
return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!'
|
851 |
+
elif item.startswith(">"):
|
852 |
+
return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>'
|
853 |
+
return item # Default color if neither '!' nor '>'
|
854 |
+
|
855 |
+
print("\nchunks:")
|
856 |
+
for chunk in chunks:
|
857 |
+
print(", ".join(colorize(item) for item in chunk))
|
858 |
+
|
859 |
+
return (
|
860 |
+
chunks,
|
861 |
+
input_inds_per_chunk, # ordering of input in raw sequence
|
862 |
+
input_sels_per_chunk, # ordering of input in one-forward sequence of length T
|
863 |
+
test_inds_per_chunk, # ordering of test in raw sequence
|
864 |
+
test_sels_per_chunk, # oredering of test in one-forward sequence of length T
|
865 |
+
)
|
866 |
+
|
867 |
+
|
868 |
+
def is_k_in_dict(d, k):
|
869 |
+
return any(map(lambda x: x.startswith(k), d.keys()))
|
870 |
+
|
871 |
+
|
872 |
+
def get_k_from_dict(d, k):
|
873 |
+
media_d = {}
|
874 |
+
for key, value in d.items():
|
875 |
+
if key == k:
|
876 |
+
return value
|
877 |
+
if key.startswith(k):
|
878 |
+
media = key.split("/")[-1]
|
879 |
+
if media == "raw":
|
880 |
+
return value
|
881 |
+
media_d[media] = value
|
882 |
+
if len(media_d) == 0:
|
883 |
+
return torch.tensor([])
|
884 |
+
assert (
|
885 |
+
len(media_d) == 1
|
886 |
+
), f"multiple media found in {d} for key {k}: {media_d.keys()}"
|
887 |
+
return media_d[media]
|
888 |
+
|
889 |
+
|
890 |
+
def update_kv_for_dict(d, k, v):
|
891 |
+
for key in d.keys():
|
892 |
+
if key.startswith(k):
|
893 |
+
d[key] = v
|
894 |
+
return d
|
895 |
+
|
896 |
+
|
897 |
+
def extend_dict(ds, d):
|
898 |
+
for key in d.keys():
|
899 |
+
if key in ds:
|
900 |
+
ds[key] = torch.cat([ds[key], d[key]], 0)
|
901 |
+
else:
|
902 |
+
ds[key] = d[key]
|
903 |
+
return ds
|
904 |
+
|
905 |
+
|
906 |
+
def replace_or_include_input_for_dict(
|
907 |
+
samples,
|
908 |
+
test_indices,
|
909 |
+
imgs,
|
910 |
+
c2w,
|
911 |
+
K,
|
912 |
+
):
|
913 |
+
samples_new = {}
|
914 |
+
for sample, value in samples.items():
|
915 |
+
if "rgb" in sample:
|
916 |
+
imgs[test_indices] = (
|
917 |
+
value[test_indices] if value.shape[0] == imgs.shape[0] else value
|
918 |
+
).to(device=imgs.device, dtype=imgs.dtype)
|
919 |
+
samples_new[sample] = imgs
|
920 |
+
elif "c2w" in sample:
|
921 |
+
c2w[test_indices] = (
|
922 |
+
value[test_indices] if value.shape[0] == c2w.shape[0] else value
|
923 |
+
).to(device=c2w.device, dtype=c2w.dtype)
|
924 |
+
samples_new[sample] = c2w
|
925 |
+
elif "intrinsics" in sample:
|
926 |
+
K[test_indices] = (
|
927 |
+
value[test_indices] if value.shape[0] == K.shape[0] else value
|
928 |
+
).to(device=K.device, dtype=K.dtype)
|
929 |
+
samples_new[sample] = K
|
930 |
+
else:
|
931 |
+
samples_new[sample] = value
|
932 |
+
return samples_new
|
933 |
+
|
934 |
+
|
935 |
+
def decode_output(
|
936 |
+
samples,
|
937 |
+
T,
|
938 |
+
indices=None,
|
939 |
+
):
|
940 |
+
# decode model output into dict if it is not
|
941 |
+
if isinstance(samples, dict):
|
942 |
+
# model with postprocessor and outputs dict
|
943 |
+
for sample, value in samples.items():
|
944 |
+
if isinstance(value, torch.Tensor):
|
945 |
+
value = value.detach().cpu()
|
946 |
+
elif isinstance(value, np.ndarray):
|
947 |
+
value = torch.from_numpy(value)
|
948 |
+
else:
|
949 |
+
value = torch.tensor(value)
|
950 |
+
|
951 |
+
if indices is not None and value.shape[0] == T:
|
952 |
+
value = value[indices]
|
953 |
+
samples[sample] = value
|
954 |
+
else:
|
955 |
+
# model without postprocessor and outputs tensor (rgb)
|
956 |
+
samples = samples.detach().cpu()
|
957 |
+
|
958 |
+
if indices is not None and samples.shape[0] == T:
|
959 |
+
samples = samples[indices]
|
960 |
+
samples = {"samples-rgb/image": samples}
|
961 |
+
|
962 |
+
return samples
|
963 |
+
|
964 |
+
|
965 |
+
def save_output(
|
966 |
+
samples,
|
967 |
+
save_path,
|
968 |
+
video_save_fps=2,
|
969 |
+
):
|
970 |
+
os.makedirs(save_path, exist_ok=True)
|
971 |
+
for sample in samples:
|
972 |
+
media_type = "video"
|
973 |
+
if "/" in sample:
|
974 |
+
sample_, media_type = sample.split("/")
|
975 |
+
else:
|
976 |
+
sample_ = sample
|
977 |
+
|
978 |
+
value = samples[sample]
|
979 |
+
if isinstance(value, torch.Tensor):
|
980 |
+
value = value.detach().cpu()
|
981 |
+
elif isinstance(value, np.ndarray):
|
982 |
+
value = torch.from_numpy(value)
|
983 |
+
else:
|
984 |
+
value = torch.tensor(value)
|
985 |
+
|
986 |
+
if media_type == "image":
|
987 |
+
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
|
988 |
+
value = (value * 255).clamp(0, 255).to(torch.uint8)
|
989 |
+
iio.imwrite(
|
990 |
+
os.path.join(save_path, f"{sample_}.mp4")
|
991 |
+
if sample_
|
992 |
+
else f"{save_path}.mp4",
|
993 |
+
value,
|
994 |
+
fps=video_save_fps,
|
995 |
+
macro_block_size=1,
|
996 |
+
ffmpeg_log_level="error",
|
997 |
+
)
|
998 |
+
os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
|
999 |
+
for i, s in enumerate(value):
|
1000 |
+
iio.imwrite(
|
1001 |
+
os.path.join(save_path, sample_, f"{i:03d}.png"),
|
1002 |
+
s,
|
1003 |
+
)
|
1004 |
+
elif media_type == "video":
|
1005 |
+
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
|
1006 |
+
value = (value * 255).clamp(0, 255).to(torch.uint8)
|
1007 |
+
iio.imwrite(
|
1008 |
+
os.path.join(save_path, f"{sample_}.mp4"),
|
1009 |
+
value,
|
1010 |
+
fps=video_save_fps,
|
1011 |
+
macro_block_size=1,
|
1012 |
+
ffmpeg_log_level="error",
|
1013 |
+
)
|
1014 |
+
elif media_type == "raw":
|
1015 |
+
torch.save(
|
1016 |
+
value,
|
1017 |
+
os.path.join(save_path, f"{sample_}.pt"),
|
1018 |
+
)
|
1019 |
+
else:
|
1020 |
+
pass
|
1021 |
+
|
1022 |
+
|
1023 |
+
def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
|
1024 |
+
import os.path as osp
|
1025 |
+
|
1026 |
+
out_frames = []
|
1027 |
+
for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
|
1028 |
+
out_frame = {
|
1029 |
+
"fl_x": K[0][0].item(),
|
1030 |
+
"fl_y": K[1][1].item(),
|
1031 |
+
"cx": K[0][2].item(),
|
1032 |
+
"cy": K[1][2].item(),
|
1033 |
+
"w": img_wh[0].item(),
|
1034 |
+
"h": img_wh[1].item(),
|
1035 |
+
"file_path": f"./{osp.relpath(img_path, start=save_path)}"
|
1036 |
+
if img_path is not None
|
1037 |
+
else None,
|
1038 |
+
"transform_matrix": c2w.tolist(),
|
1039 |
+
}
|
1040 |
+
out_frames.append(out_frame)
|
1041 |
+
out = {
|
1042 |
+
# "camera_model": "PINHOLE",
|
1043 |
+
"orientation_override": "none",
|
1044 |
+
"frames": out_frames,
|
1045 |
+
}
|
1046 |
+
with open(osp.join(save_path, "transforms.json"), "w") as of:
|
1047 |
+
json.dump(out, of, indent=5)
|
1048 |
+
|
1049 |
+
|
1050 |
+
class GradioTrackedSampler(EulerEDMSampler):
|
1051 |
+
"""
|
1052 |
+
A thin wrapper around the EulerEDMSampler that allows tracking progress and
|
1053 |
+
aborting sampling for gradio demo.
|
1054 |
+
"""
|
1055 |
+
|
1056 |
+
def __init__(self, abort_event: threading.Event, *args, **kwargs):
|
1057 |
+
super().__init__(*args, **kwargs)
|
1058 |
+
self.abort_event = abort_event
|
1059 |
+
|
1060 |
+
def __call__( # type: ignore
|
1061 |
+
self,
|
1062 |
+
denoiser,
|
1063 |
+
x: torch.Tensor,
|
1064 |
+
scale: float | torch.Tensor,
|
1065 |
+
cond: dict,
|
1066 |
+
uc: dict | None = None,
|
1067 |
+
num_steps: int | None = None,
|
1068 |
+
verbose: bool = True,
|
1069 |
+
global_pbar: gr.Progress | None = None,
|
1070 |
+
**guider_kwargs,
|
1071 |
+
) -> torch.Tensor | None:
|
1072 |
+
uc = cond if uc is None else uc
|
1073 |
+
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
|
1074 |
+
x,
|
1075 |
+
cond,
|
1076 |
+
uc,
|
1077 |
+
num_steps,
|
1078 |
+
)
|
1079 |
+
for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
|
1080 |
+
gamma = (
|
1081 |
+
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
|
1082 |
+
if self.s_tmin <= sigmas[i] <= self.s_tmax
|
1083 |
+
else 0.0
|
1084 |
+
)
|
1085 |
+
x = self.sampler_step(
|
1086 |
+
s_in * sigmas[i],
|
1087 |
+
s_in * sigmas[i + 1],
|
1088 |
+
denoiser,
|
1089 |
+
x,
|
1090 |
+
scale,
|
1091 |
+
cond,
|
1092 |
+
uc,
|
1093 |
+
gamma,
|
1094 |
+
**guider_kwargs,
|
1095 |
+
)
|
1096 |
+
# Allow tracking progress in gradio demo.
|
1097 |
+
if global_pbar is not None:
|
1098 |
+
global_pbar.update()
|
1099 |
+
# Allow aborting sampling in gradio demo.
|
1100 |
+
if self.abort_event.is_set():
|
1101 |
+
return None
|
1102 |
+
return x
|
1103 |
+
|
1104 |
+
|
1105 |
+
def create_samplers(
|
1106 |
+
guider_types: int | list[int],
|
1107 |
+
discretization,
|
1108 |
+
num_frames: list[int] | None,
|
1109 |
+
num_steps: int,
|
1110 |
+
cfg_min: float = 1.0,
|
1111 |
+
device: str | torch.device = "cuda",
|
1112 |
+
abort_event: threading.Event | None = None,
|
1113 |
+
):
|
1114 |
+
guider_mapping = {
|
1115 |
+
0: VanillaCFG,
|
1116 |
+
1: MultiviewCFG,
|
1117 |
+
2: MultiviewTemporalCFG,
|
1118 |
+
}
|
1119 |
+
samplers = []
|
1120 |
+
if not isinstance(guider_types, (list, tuple)):
|
1121 |
+
guider_types = [guider_types]
|
1122 |
+
for i, guider_type in enumerate(guider_types):
|
1123 |
+
if guider_type not in guider_mapping:
|
1124 |
+
raise ValueError(
|
1125 |
+
f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
|
1126 |
+
)
|
1127 |
+
guider_cls = guider_mapping[guider_type]
|
1128 |
+
guider_args = ()
|
1129 |
+
if guider_type > 0:
|
1130 |
+
guider_args += (cfg_min,)
|
1131 |
+
if guider_type == 2:
|
1132 |
+
assert num_frames is not None
|
1133 |
+
guider_args = (num_frames[i], cfg_min)
|
1134 |
+
guider = guider_cls(*guider_args)
|
1135 |
+
|
1136 |
+
if abort_event is not None:
|
1137 |
+
sampler = GradioTrackedSampler(
|
1138 |
+
abort_event,
|
1139 |
+
discretization=discretization,
|
1140 |
+
guider=guider,
|
1141 |
+
num_steps=num_steps,
|
1142 |
+
s_churn=0.0,
|
1143 |
+
s_tmin=0.0,
|
1144 |
+
s_tmax=999.0,
|
1145 |
+
s_noise=1.0,
|
1146 |
+
verbose=True,
|
1147 |
+
device=device,
|
1148 |
+
)
|
1149 |
+
else:
|
1150 |
+
sampler = EulerEDMSampler(
|
1151 |
+
discretization=discretization,
|
1152 |
+
guider=guider,
|
1153 |
+
num_steps=num_steps,
|
1154 |
+
s_churn=0.0,
|
1155 |
+
s_tmin=0.0,
|
1156 |
+
s_tmax=999.0,
|
1157 |
+
s_noise=1.0,
|
1158 |
+
verbose=True,
|
1159 |
+
device=device,
|
1160 |
+
)
|
1161 |
+
samplers.append(sampler)
|
1162 |
+
return samplers
|
1163 |
+
|
1164 |
+
|
1165 |
+
def get_value_dict(
|
1166 |
+
curr_imgs,
|
1167 |
+
curr_imgs_clip,
|
1168 |
+
curr_input_frame_indices,
|
1169 |
+
curr_c2ws,
|
1170 |
+
curr_Ks,
|
1171 |
+
curr_input_camera_indices,
|
1172 |
+
all_c2ws,
|
1173 |
+
camera_scale,
|
1174 |
+
):
|
1175 |
+
assert sorted(curr_input_camera_indices) == sorted(
|
1176 |
+
range(len(curr_input_camera_indices))
|
1177 |
+
)
|
1178 |
+
H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
|
1179 |
+
|
1180 |
+
value_dict = {}
|
1181 |
+
value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
|
1182 |
+
value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
|
1183 |
+
value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
|
1184 |
+
value_dict["cond_frames_mask"][curr_input_frame_indices] = True
|
1185 |
+
value_dict["cond_aug"] = 0.0
|
1186 |
+
|
1187 |
+
c2w = to_hom_pose(curr_c2ws.float())
|
1188 |
+
w2c = torch.linalg.inv(c2w)
|
1189 |
+
|
1190 |
+
# camera centering
|
1191 |
+
ref_c2ws = all_c2ws
|
1192 |
+
camera_dist_2med = torch.norm(
|
1193 |
+
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
|
1194 |
+
dim=-1,
|
1195 |
+
)
|
1196 |
+
valid_mask = camera_dist_2med <= torch.clamp(
|
1197 |
+
torch.quantile(camera_dist_2med, 0.97) * 10,
|
1198 |
+
max=1e6,
|
1199 |
+
)
|
1200 |
+
c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
|
1201 |
+
w2c = torch.linalg.inv(c2w)
|
1202 |
+
|
1203 |
+
# camera normalization
|
1204 |
+
camera_dists = c2w[:, :3, 3].clone()
|
1205 |
+
translation_scaling_factor = (
|
1206 |
+
camera_scale
|
1207 |
+
if torch.isclose(
|
1208 |
+
torch.norm(camera_dists[0]),
|
1209 |
+
torch.zeros(1),
|
1210 |
+
atol=1e-5,
|
1211 |
+
).any()
|
1212 |
+
else (camera_scale / torch.norm(camera_dists[0]))
|
1213 |
+
)
|
1214 |
+
w2c[:, :3, 3] *= translation_scaling_factor
|
1215 |
+
c2w[:, :3, 3] *= translation_scaling_factor
|
1216 |
+
value_dict["plucker_coordinate"], _ = get_plucker_coordinates(
|
1217 |
+
extrinsics_src=w2c[0],
|
1218 |
+
extrinsics=w2c,
|
1219 |
+
intrinsics=curr_Ks.float().clone(),
|
1220 |
+
mode="plucker",
|
1221 |
+
rel_zero_translation=True,
|
1222 |
+
target_size=(H // F, W // F),
|
1223 |
+
return_grid_cam=True,
|
1224 |
+
)
|
1225 |
+
|
1226 |
+
value_dict["c2w"] = c2w
|
1227 |
+
value_dict["K"] = curr_Ks
|
1228 |
+
value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
|
1229 |
+
value_dict["camera_mask"][curr_input_camera_indices] = True
|
1230 |
+
|
1231 |
+
return value_dict
|
1232 |
+
|
1233 |
+
|
1234 |
+
def do_sample(
|
1235 |
+
model,
|
1236 |
+
ae,
|
1237 |
+
conditioner,
|
1238 |
+
denoiser,
|
1239 |
+
sampler,
|
1240 |
+
value_dict,
|
1241 |
+
H,
|
1242 |
+
W,
|
1243 |
+
C,
|
1244 |
+
F,
|
1245 |
+
T,
|
1246 |
+
cfg,
|
1247 |
+
encoding_t=1,
|
1248 |
+
decoding_t=1,
|
1249 |
+
verbose=True,
|
1250 |
+
global_pbar=None,
|
1251 |
+
**_,
|
1252 |
+
):
|
1253 |
+
imgs = value_dict["cond_frames"].to("cuda")
|
1254 |
+
input_masks = value_dict["cond_frames_mask"].to("cuda")
|
1255 |
+
pluckers = value_dict["plucker_coordinate"].to("cuda")
|
1256 |
+
|
1257 |
+
num_samples = [1, T]
|
1258 |
+
with torch.inference_mode(), torch.autocast("cuda"):
|
1259 |
+
load_model(ae)
|
1260 |
+
load_model(conditioner)
|
1261 |
+
latents = torch.nn.functional.pad(
|
1262 |
+
ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
|
1263 |
+
)
|
1264 |
+
c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
|
1265 |
+
uc_crossattn = torch.zeros_like(c_crossattn)
|
1266 |
+
c_replace = latents.new_zeros(T, *latents.shape[1:])
|
1267 |
+
c_replace[input_masks] = latents
|
1268 |
+
uc_replace = torch.zeros_like(c_replace)
|
1269 |
+
c_concat = torch.cat(
|
1270 |
+
[
|
1271 |
+
repeat(
|
1272 |
+
input_masks,
|
1273 |
+
"n -> n 1 h w",
|
1274 |
+
h=pluckers.shape[2],
|
1275 |
+
w=pluckers.shape[3],
|
1276 |
+
),
|
1277 |
+
pluckers,
|
1278 |
+
],
|
1279 |
+
1,
|
1280 |
+
)
|
1281 |
+
uc_concat = torch.cat(
|
1282 |
+
[pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
|
1283 |
+
)
|
1284 |
+
c_dense_vector = pluckers
|
1285 |
+
uc_dense_vector = c_dense_vector
|
1286 |
+
c = {
|
1287 |
+
"crossattn": c_crossattn,
|
1288 |
+
"replace": c_replace,
|
1289 |
+
"concat": c_concat,
|
1290 |
+
"dense_vector": c_dense_vector,
|
1291 |
+
}
|
1292 |
+
uc = {
|
1293 |
+
"crossattn": uc_crossattn,
|
1294 |
+
"replace": uc_replace,
|
1295 |
+
"concat": uc_concat,
|
1296 |
+
"dense_vector": uc_dense_vector,
|
1297 |
+
}
|
1298 |
+
unload_model(ae)
|
1299 |
+
unload_model(conditioner)
|
1300 |
+
|
1301 |
+
additional_model_inputs = {"num_frames": T}
|
1302 |
+
additional_sampler_inputs = {
|
1303 |
+
"c2w": value_dict["c2w"].to("cuda"),
|
1304 |
+
"K": value_dict["K"].to("cuda"),
|
1305 |
+
"input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
|
1306 |
+
}
|
1307 |
+
if global_pbar is not None:
|
1308 |
+
additional_sampler_inputs["global_pbar"] = global_pbar
|
1309 |
+
|
1310 |
+
shape = (math.prod(num_samples), C, H // F, W // F)
|
1311 |
+
randn = torch.randn(shape).to("cuda")
|
1312 |
+
|
1313 |
+
load_model(model)
|
1314 |
+
samples_z = sampler(
|
1315 |
+
lambda input, sigma, c: denoiser(
|
1316 |
+
model,
|
1317 |
+
input,
|
1318 |
+
sigma,
|
1319 |
+
c,
|
1320 |
+
**additional_model_inputs,
|
1321 |
+
),
|
1322 |
+
randn,
|
1323 |
+
scale=cfg,
|
1324 |
+
cond=c,
|
1325 |
+
uc=uc,
|
1326 |
+
verbose=verbose,
|
1327 |
+
**additional_sampler_inputs,
|
1328 |
+
)
|
1329 |
+
if samples_z is None:
|
1330 |
+
return
|
1331 |
+
unload_model(model)
|
1332 |
+
|
1333 |
+
load_model(ae)
|
1334 |
+
samples = ae.decode(samples_z, decoding_t)
|
1335 |
+
unload_model(ae)
|
1336 |
+
|
1337 |
+
return samples
|
1338 |
+
|
1339 |
+
|
1340 |
+
def run_one_scene(
|
1341 |
+
task,
|
1342 |
+
version_dict,
|
1343 |
+
model,
|
1344 |
+
ae,
|
1345 |
+
conditioner,
|
1346 |
+
denoiser,
|
1347 |
+
image_cond,
|
1348 |
+
camera_cond,
|
1349 |
+
save_path,
|
1350 |
+
use_traj_prior,
|
1351 |
+
traj_prior_Ks,
|
1352 |
+
traj_prior_c2ws,
|
1353 |
+
seed=23,
|
1354 |
+
gradio=False,
|
1355 |
+
abort_event=None,
|
1356 |
+
first_pass_pbar=None,
|
1357 |
+
second_pass_pbar=None,
|
1358 |
+
):
|
1359 |
+
H, W, T, C, F, options = (
|
1360 |
+
version_dict["H"],
|
1361 |
+
version_dict["W"],
|
1362 |
+
version_dict["T"],
|
1363 |
+
version_dict["C"],
|
1364 |
+
version_dict["f"],
|
1365 |
+
version_dict["options"],
|
1366 |
+
)
|
1367 |
+
|
1368 |
+
if isinstance(image_cond, str):
|
1369 |
+
image_cond = {"img": [image_cond]}
|
1370 |
+
imgs_clip, imgs, img_size = [], [], None
|
1371 |
+
for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
|
1372 |
+
if isinstance(img, str) or img is None:
|
1373 |
+
img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore
|
1374 |
+
img_size = img.shape[-2:]
|
1375 |
+
if options.get("L_short", -1) == -1:
|
1376 |
+
img, K = transform_img_and_K(
|
1377 |
+
img,
|
1378 |
+
(W, H),
|
1379 |
+
K=K[None],
|
1380 |
+
mode=(
|
1381 |
+
options.get("transform_input", "crop")
|
1382 |
+
if i in image_cond["input_indices"]
|
1383 |
+
else options.get("transform_target", "crop")
|
1384 |
+
),
|
1385 |
+
scale=(
|
1386 |
+
1.0
|
1387 |
+
if i in image_cond["input_indices"]
|
1388 |
+
else options.get("transform_scale", 1.0)
|
1389 |
+
),
|
1390 |
+
)
|
1391 |
+
else:
|
1392 |
+
downsample = 3
|
1393 |
+
assert options["L_short"] % F * 2**downsample == 0, (
|
1394 |
+
"Short side of the image should be divisible by "
|
1395 |
+
f"F*2**{downsample}={F * 2**downsample}."
|
1396 |
+
)
|
1397 |
+
img, K = transform_img_and_K(
|
1398 |
+
img,
|
1399 |
+
options["L_short"],
|
1400 |
+
K=K[None],
|
1401 |
+
size_stride=F * 2**downsample,
|
1402 |
+
mode=(
|
1403 |
+
options.get("transform_input", "crop")
|
1404 |
+
if i in image_cond["input_indices"]
|
1405 |
+
else options.get("transform_target", "crop")
|
1406 |
+
),
|
1407 |
+
scale=(
|
1408 |
+
1.0
|
1409 |
+
if i in image_cond["input_indices"]
|
1410 |
+
else options.get("transform_scale", 1.0)
|
1411 |
+
),
|
1412 |
+
)
|
1413 |
+
version_dict["W"] = W = img.shape[-1]
|
1414 |
+
version_dict["H"] = H = img.shape[-2]
|
1415 |
+
K = K[0]
|
1416 |
+
K[0] /= W
|
1417 |
+
K[1] /= H
|
1418 |
+
camera_cond["K"][i] = K
|
1419 |
+
img_clip = img
|
1420 |
+
elif isinstance(img, np.ndarray):
|
1421 |
+
img_size = torch.Size(img.shape[:2])
|
1422 |
+
img = torch.as_tensor(img).permute(2, 0, 1)
|
1423 |
+
img = img.unsqueeze(0)
|
1424 |
+
img = img / 255.0 * 2.0 - 1.0
|
1425 |
+
if not gradio:
|
1426 |
+
img, K = transform_img_and_K(img, (W, H), K=K[None])
|
1427 |
+
assert K is not None
|
1428 |
+
K = K[0]
|
1429 |
+
K[0] /= W
|
1430 |
+
K[1] /= H
|
1431 |
+
camera_cond["K"][i] = K
|
1432 |
+
img_clip = img
|
1433 |
+
else:
|
1434 |
+
assert (
|
1435 |
+
False
|
1436 |
+
), f"Variable `img` got {type(img)} type which is not supported!!!"
|
1437 |
+
imgs_clip.append(img_clip)
|
1438 |
+
imgs.append(img)
|
1439 |
+
imgs_clip = torch.cat(imgs_clip, dim=0)
|
1440 |
+
imgs = torch.cat(imgs, dim=0)
|
1441 |
+
|
1442 |
+
if traj_prior_Ks is not None:
|
1443 |
+
assert img_size is not None
|
1444 |
+
for i, prior_k in enumerate(traj_prior_Ks):
|
1445 |
+
img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore
|
1446 |
+
img, prior_k = transform_img_and_K(
|
1447 |
+
img,
|
1448 |
+
(W, H),
|
1449 |
+
K=prior_k[None],
|
1450 |
+
mode=options.get(
|
1451 |
+
"transform_target", "crop"
|
1452 |
+
), # mode for prior is always same as target
|
1453 |
+
scale=options.get(
|
1454 |
+
"transform_scale", 1.0
|
1455 |
+
), # scale for prior is always same as target
|
1456 |
+
)
|
1457 |
+
prior_k = prior_k[0]
|
1458 |
+
prior_k[0] /= W
|
1459 |
+
prior_k[1] /= H
|
1460 |
+
traj_prior_Ks[i] = prior_k
|
1461 |
+
|
1462 |
+
options["num_frames"] = T
|
1463 |
+
discretization = denoiser.discretization
|
1464 |
+
torch.cuda.empty_cache()
|
1465 |
+
|
1466 |
+
seed_everything(seed)
|
1467 |
+
|
1468 |
+
# Get Data
|
1469 |
+
input_indices = image_cond["input_indices"]
|
1470 |
+
input_imgs = imgs[input_indices]
|
1471 |
+
input_imgs_clip = imgs_clip[input_indices]
|
1472 |
+
input_c2ws = camera_cond["c2w"][input_indices]
|
1473 |
+
input_Ks = camera_cond["K"][input_indices]
|
1474 |
+
|
1475 |
+
test_indices = [i for i in range(len(imgs)) if i not in input_indices]
|
1476 |
+
test_imgs = imgs[test_indices]
|
1477 |
+
test_imgs_clip = imgs_clip[test_indices]
|
1478 |
+
test_c2ws = camera_cond["c2w"][test_indices]
|
1479 |
+
test_Ks = camera_cond["K"][test_indices]
|
1480 |
+
|
1481 |
+
if options.get("save_input", True):
|
1482 |
+
save_output(
|
1483 |
+
{"/image": input_imgs},
|
1484 |
+
save_path=os.path.join(save_path, "input"),
|
1485 |
+
video_save_fps=2,
|
1486 |
+
)
|
1487 |
+
|
1488 |
+
if not use_traj_prior:
|
1489 |
+
chunk_strategy = options.get("chunk_strategy", "gt")
|
1490 |
+
|
1491 |
+
(
|
1492 |
+
_,
|
1493 |
+
input_inds_per_chunk,
|
1494 |
+
input_sels_per_chunk,
|
1495 |
+
test_inds_per_chunk,
|
1496 |
+
test_sels_per_chunk,
|
1497 |
+
) = chunk_input_and_test(
|
1498 |
+
T,
|
1499 |
+
input_c2ws,
|
1500 |
+
test_c2ws,
|
1501 |
+
input_indices,
|
1502 |
+
test_indices,
|
1503 |
+
options=options,
|
1504 |
+
task=task,
|
1505 |
+
chunk_strategy=chunk_strategy,
|
1506 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
1507 |
+
)
|
1508 |
+
print(
|
1509 |
+
f"One pass - chunking with `{chunk_strategy}` strategy: total "
|
1510 |
+
f"{len(input_inds_per_chunk)} forward(s) ..."
|
1511 |
+
)
|
1512 |
+
|
1513 |
+
all_samples = {}
|
1514 |
+
all_test_inds = []
|
1515 |
+
for i, (
|
1516 |
+
chunk_input_inds,
|
1517 |
+
chunk_input_sels,
|
1518 |
+
chunk_test_inds,
|
1519 |
+
chunk_test_sels,
|
1520 |
+
) in tqdm(
|
1521 |
+
enumerate(
|
1522 |
+
zip(
|
1523 |
+
input_inds_per_chunk,
|
1524 |
+
input_sels_per_chunk,
|
1525 |
+
test_inds_per_chunk,
|
1526 |
+
test_sels_per_chunk,
|
1527 |
+
)
|
1528 |
+
),
|
1529 |
+
total=len(input_inds_per_chunk),
|
1530 |
+
leave=False,
|
1531 |
+
):
|
1532 |
+
(
|
1533 |
+
curr_input_sels,
|
1534 |
+
curr_test_sels,
|
1535 |
+
curr_input_maps,
|
1536 |
+
curr_test_maps,
|
1537 |
+
) = pad_indices(
|
1538 |
+
chunk_input_sels,
|
1539 |
+
chunk_test_sels,
|
1540 |
+
T=T,
|
1541 |
+
padding_mode=options.get("t_padding_mode", "last"),
|
1542 |
+
)
|
1543 |
+
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
|
1544 |
+
assemble(
|
1545 |
+
input=x[chunk_input_inds],
|
1546 |
+
test=y[chunk_test_inds],
|
1547 |
+
input_maps=curr_input_maps,
|
1548 |
+
test_maps=curr_test_maps,
|
1549 |
+
)
|
1550 |
+
for x, y in zip(
|
1551 |
+
[
|
1552 |
+
torch.cat(
|
1553 |
+
[
|
1554 |
+
input_imgs,
|
1555 |
+
get_k_from_dict(all_samples, "samples-rgb").to(
|
1556 |
+
input_imgs.device
|
1557 |
+
),
|
1558 |
+
],
|
1559 |
+
dim=0,
|
1560 |
+
),
|
1561 |
+
torch.cat(
|
1562 |
+
[
|
1563 |
+
input_imgs_clip,
|
1564 |
+
get_k_from_dict(all_samples, "samples-rgb").to(
|
1565 |
+
input_imgs.device
|
1566 |
+
),
|
1567 |
+
],
|
1568 |
+
dim=0,
|
1569 |
+
),
|
1570 |
+
torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
|
1571 |
+
torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
|
1572 |
+
], # procedually append generated prior views to the input views
|
1573 |
+
[test_imgs, test_imgs_clip, test_c2ws, test_Ks],
|
1574 |
+
)
|
1575 |
+
]
|
1576 |
+
value_dict = get_value_dict(
|
1577 |
+
curr_imgs.to("cuda"),
|
1578 |
+
curr_imgs_clip.to("cuda"),
|
1579 |
+
curr_input_sels
|
1580 |
+
+ [
|
1581 |
+
sel
|
1582 |
+
for (ind, sel) in zip(
|
1583 |
+
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
|
1584 |
+
curr_test_sels,
|
1585 |
+
)
|
1586 |
+
if test_indices[ind] in image_cond["input_indices"]
|
1587 |
+
],
|
1588 |
+
curr_c2ws,
|
1589 |
+
curr_Ks,
|
1590 |
+
curr_input_sels
|
1591 |
+
+ [
|
1592 |
+
sel
|
1593 |
+
for (ind, sel) in zip(
|
1594 |
+
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
|
1595 |
+
curr_test_sels,
|
1596 |
+
)
|
1597 |
+
if test_indices[ind] in camera_cond["input_indices"]
|
1598 |
+
],
|
1599 |
+
all_c2ws=camera_cond["c2w"],
|
1600 |
+
camera_scale=options.get("camera_scale", 2.0),
|
1601 |
+
)
|
1602 |
+
samplers = create_samplers(
|
1603 |
+
options["guider_types"],
|
1604 |
+
discretization,
|
1605 |
+
[len(curr_imgs)],
|
1606 |
+
options["num_steps"],
|
1607 |
+
options["cfg_min"],
|
1608 |
+
abort_event=abort_event,
|
1609 |
+
)
|
1610 |
+
assert len(samplers) == 1
|
1611 |
+
samples = do_sample(
|
1612 |
+
model,
|
1613 |
+
ae,
|
1614 |
+
conditioner,
|
1615 |
+
denoiser,
|
1616 |
+
samplers[0],
|
1617 |
+
value_dict,
|
1618 |
+
H,
|
1619 |
+
W,
|
1620 |
+
C,
|
1621 |
+
F,
|
1622 |
+
T=len(curr_imgs),
|
1623 |
+
cfg=(
|
1624 |
+
options["cfg"][0]
|
1625 |
+
if isinstance(options["cfg"], (list, tuple))
|
1626 |
+
else options["cfg"]
|
1627 |
+
),
|
1628 |
+
**{k: options[k] for k in options if k not in ["cfg", "T"]},
|
1629 |
+
)
|
1630 |
+
samples = decode_output(
|
1631 |
+
samples, len(curr_imgs), chunk_test_sels
|
1632 |
+
) # decode into dict
|
1633 |
+
if options.get("save_first_pass", False):
|
1634 |
+
save_output(
|
1635 |
+
replace_or_include_input_for_dict(
|
1636 |
+
samples,
|
1637 |
+
chunk_test_sels,
|
1638 |
+
curr_imgs,
|
1639 |
+
curr_c2ws,
|
1640 |
+
curr_Ks,
|
1641 |
+
),
|
1642 |
+
save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
|
1643 |
+
video_save_fps=2,
|
1644 |
+
)
|
1645 |
+
extend_dict(all_samples, samples)
|
1646 |
+
all_test_inds.extend(chunk_test_inds)
|
1647 |
+
else:
|
1648 |
+
assert traj_prior_c2ws is not None, (
|
1649 |
+
"`traj_prior_c2ws` should be set when using 2-pass sampling. One "
|
1650 |
+
"potential reason is that the amount of input frames is larger than "
|
1651 |
+
"T. Set `num_prior_frames` manually to overwrite the infered stats."
|
1652 |
+
)
|
1653 |
+
traj_prior_c2ws = torch.as_tensor(
|
1654 |
+
traj_prior_c2ws,
|
1655 |
+
device=input_c2ws.device,
|
1656 |
+
dtype=input_c2ws.dtype,
|
1657 |
+
)
|
1658 |
+
|
1659 |
+
if traj_prior_Ks is None:
|
1660 |
+
traj_prior_Ks = test_Ks[:1].repeat_interleave(
|
1661 |
+
traj_prior_c2ws.shape[0], dim=0
|
1662 |
+
)
|
1663 |
+
|
1664 |
+
traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
|
1665 |
+
traj_prior_imgs_clip = imgs_clip.new_zeros(
|
1666 |
+
traj_prior_c2ws.shape[0], *imgs_clip.shape[1:]
|
1667 |
+
)
|
1668 |
+
|
1669 |
+
# ---------------------------------- first pass ----------------------------------
|
1670 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
1671 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
1672 |
+
chunk_strategy_first_pass = options.get(
|
1673 |
+
"chunk_strategy_first_pass", "gt-nearest"
|
1674 |
+
)
|
1675 |
+
(
|
1676 |
+
_,
|
1677 |
+
input_inds_per_chunk,
|
1678 |
+
input_sels_per_chunk,
|
1679 |
+
prior_inds_per_chunk,
|
1680 |
+
prior_sels_per_chunk,
|
1681 |
+
) = chunk_input_and_test(
|
1682 |
+
T_first_pass,
|
1683 |
+
input_c2ws,
|
1684 |
+
traj_prior_c2ws,
|
1685 |
+
input_indices,
|
1686 |
+
image_cond["prior_indices"],
|
1687 |
+
options=options,
|
1688 |
+
task=task,
|
1689 |
+
chunk_strategy=chunk_strategy_first_pass,
|
1690 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
1691 |
+
)
|
1692 |
+
print(
|
1693 |
+
f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
|
1694 |
+
f"{len(input_inds_per_chunk)} forward(s) ..."
|
1695 |
+
)
|
1696 |
+
|
1697 |
+
all_samples = {}
|
1698 |
+
all_prior_inds = []
|
1699 |
+
for i, (
|
1700 |
+
chunk_input_inds,
|
1701 |
+
chunk_input_sels,
|
1702 |
+
chunk_prior_inds,
|
1703 |
+
chunk_prior_sels,
|
1704 |
+
) in tqdm(
|
1705 |
+
enumerate(
|
1706 |
+
zip(
|
1707 |
+
input_inds_per_chunk,
|
1708 |
+
input_sels_per_chunk,
|
1709 |
+
prior_inds_per_chunk,
|
1710 |
+
prior_sels_per_chunk,
|
1711 |
+
)
|
1712 |
+
),
|
1713 |
+
total=len(input_inds_per_chunk),
|
1714 |
+
leave=False,
|
1715 |
+
):
|
1716 |
+
(
|
1717 |
+
curr_input_sels,
|
1718 |
+
curr_prior_sels,
|
1719 |
+
curr_input_maps,
|
1720 |
+
curr_prior_maps,
|
1721 |
+
) = pad_indices(
|
1722 |
+
chunk_input_sels,
|
1723 |
+
chunk_prior_sels,
|
1724 |
+
T=T_first_pass,
|
1725 |
+
padding_mode=options.get("t_padding_mode", "last"),
|
1726 |
+
)
|
1727 |
+
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
|
1728 |
+
assemble(
|
1729 |
+
input=x[chunk_input_inds],
|
1730 |
+
test=y[chunk_prior_inds],
|
1731 |
+
input_maps=curr_input_maps,
|
1732 |
+
test_maps=curr_prior_maps,
|
1733 |
+
)
|
1734 |
+
for x, y in zip(
|
1735 |
+
[
|
1736 |
+
torch.cat(
|
1737 |
+
[
|
1738 |
+
input_imgs,
|
1739 |
+
get_k_from_dict(all_samples, "samples-rgb").to(
|
1740 |
+
input_imgs.device
|
1741 |
+
),
|
1742 |
+
],
|
1743 |
+
dim=0,
|
1744 |
+
),
|
1745 |
+
torch.cat(
|
1746 |
+
[
|
1747 |
+
input_imgs_clip,
|
1748 |
+
get_k_from_dict(all_samples, "samples-rgb").to(
|
1749 |
+
input_imgs.device
|
1750 |
+
),
|
1751 |
+
],
|
1752 |
+
dim=0,
|
1753 |
+
),
|
1754 |
+
torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
|
1755 |
+
torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
|
1756 |
+
], # procedually append generated prior views to the input views
|
1757 |
+
[
|
1758 |
+
traj_prior_imgs,
|
1759 |
+
traj_prior_imgs_clip,
|
1760 |
+
traj_prior_c2ws,
|
1761 |
+
traj_prior_Ks,
|
1762 |
+
],
|
1763 |
+
)
|
1764 |
+
]
|
1765 |
+
value_dict = get_value_dict(
|
1766 |
+
curr_imgs.to("cuda"),
|
1767 |
+
curr_imgs_clip.to("cuda"),
|
1768 |
+
curr_input_sels,
|
1769 |
+
curr_c2ws,
|
1770 |
+
curr_Ks,
|
1771 |
+
list(range(T_first_pass)),
|
1772 |
+
all_c2ws=camera_cond["c2w"],
|
1773 |
+
camera_scale=options.get("camera_scale", 2.0),
|
1774 |
+
)
|
1775 |
+
samplers = create_samplers(
|
1776 |
+
options["guider_types"],
|
1777 |
+
discretization,
|
1778 |
+
[T_first_pass, T_second_pass],
|
1779 |
+
options["num_steps"],
|
1780 |
+
options["cfg_min"],
|
1781 |
+
abort_event=abort_event,
|
1782 |
+
)
|
1783 |
+
samples = do_sample(
|
1784 |
+
model,
|
1785 |
+
ae,
|
1786 |
+
conditioner,
|
1787 |
+
denoiser,
|
1788 |
+
(
|
1789 |
+
samplers[1]
|
1790 |
+
if len(samplers) > 1
|
1791 |
+
and options.get("ltr_first_pass", False)
|
1792 |
+
and chunk_strategy_first_pass != "gt"
|
1793 |
+
and i > 0
|
1794 |
+
else samplers[0]
|
1795 |
+
),
|
1796 |
+
value_dict,
|
1797 |
+
H,
|
1798 |
+
W,
|
1799 |
+
C,
|
1800 |
+
F,
|
1801 |
+
cfg=(
|
1802 |
+
options["cfg"][0]
|
1803 |
+
if isinstance(options["cfg"], (list, tuple))
|
1804 |
+
else options["cfg"]
|
1805 |
+
),
|
1806 |
+
T=T_first_pass,
|
1807 |
+
global_pbar=first_pass_pbar,
|
1808 |
+
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
|
1809 |
+
)
|
1810 |
+
if samples is None:
|
1811 |
+
return
|
1812 |
+
samples = decode_output(
|
1813 |
+
samples, T_first_pass, chunk_prior_sels
|
1814 |
+
) # decode into dict
|
1815 |
+
extend_dict(all_samples, samples)
|
1816 |
+
all_prior_inds.extend(chunk_prior_inds)
|
1817 |
+
|
1818 |
+
if options.get("save_first_pass", True):
|
1819 |
+
save_output(
|
1820 |
+
all_samples,
|
1821 |
+
save_path=os.path.join(save_path, "first-pass"),
|
1822 |
+
video_save_fps=5,
|
1823 |
+
)
|
1824 |
+
video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
|
1825 |
+
yield video_path_0
|
1826 |
+
|
1827 |
+
# ---------------------------------- second pass ----------------------------------
|
1828 |
+
prior_indices = image_cond["prior_indices"]
|
1829 |
+
assert (
|
1830 |
+
prior_indices is not None
|
1831 |
+
), "`prior_frame_indices` needs to be set if using 2-pass sampling."
|
1832 |
+
prior_argsort = np.argsort(input_indices + prior_indices).tolist()
|
1833 |
+
prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
|
1834 |
+
gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]
|
1835 |
+
|
1836 |
+
traj_prior_imgs = torch.cat(
|
1837 |
+
[input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
|
1838 |
+
)[prior_argsort]
|
1839 |
+
traj_prior_imgs_clip = torch.cat(
|
1840 |
+
[
|
1841 |
+
input_imgs_clip,
|
1842 |
+
get_k_from_dict(all_samples, "samples-rgb"),
|
1843 |
+
],
|
1844 |
+
dim=0,
|
1845 |
+
)[prior_argsort]
|
1846 |
+
traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
|
1847 |
+
traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]
|
1848 |
+
|
1849 |
+
update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
|
1850 |
+
update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
|
1851 |
+
update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)
|
1852 |
+
|
1853 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
1854 |
+
(
|
1855 |
+
_,
|
1856 |
+
prior_inds_per_chunk,
|
1857 |
+
prior_sels_per_chunk,
|
1858 |
+
test_inds_per_chunk,
|
1859 |
+
test_sels_per_chunk,
|
1860 |
+
) = chunk_input_and_test(
|
1861 |
+
T_second_pass,
|
1862 |
+
traj_prior_c2ws,
|
1863 |
+
test_c2ws,
|
1864 |
+
prior_indices,
|
1865 |
+
test_indices,
|
1866 |
+
options=options,
|
1867 |
+
task=task,
|
1868 |
+
chunk_strategy=chunk_strategy,
|
1869 |
+
gt_input_inds=gt_input_inds,
|
1870 |
+
)
|
1871 |
+
print(
|
1872 |
+
f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
|
1873 |
+
f"{len(prior_inds_per_chunk)} forward(s) ..."
|
1874 |
+
)
|
1875 |
+
|
1876 |
+
all_samples = {}
|
1877 |
+
all_test_inds = []
|
1878 |
+
for i, (
|
1879 |
+
chunk_prior_inds,
|
1880 |
+
chunk_prior_sels,
|
1881 |
+
chunk_test_inds,
|
1882 |
+
chunk_test_sels,
|
1883 |
+
) in tqdm(
|
1884 |
+
enumerate(
|
1885 |
+
zip(
|
1886 |
+
prior_inds_per_chunk,
|
1887 |
+
prior_sels_per_chunk,
|
1888 |
+
test_inds_per_chunk,
|
1889 |
+
test_sels_per_chunk,
|
1890 |
+
)
|
1891 |
+
),
|
1892 |
+
total=len(prior_inds_per_chunk),
|
1893 |
+
leave=False,
|
1894 |
+
):
|
1895 |
+
(
|
1896 |
+
curr_prior_sels,
|
1897 |
+
curr_test_sels,
|
1898 |
+
curr_prior_maps,
|
1899 |
+
curr_test_maps,
|
1900 |
+
) = pad_indices(
|
1901 |
+
chunk_prior_sels,
|
1902 |
+
chunk_test_sels,
|
1903 |
+
T=T_second_pass,
|
1904 |
+
padding_mode="last",
|
1905 |
+
)
|
1906 |
+
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
|
1907 |
+
assemble(
|
1908 |
+
input=x[chunk_prior_inds],
|
1909 |
+
test=y[chunk_test_inds],
|
1910 |
+
input_maps=curr_prior_maps,
|
1911 |
+
test_maps=curr_test_maps,
|
1912 |
+
)
|
1913 |
+
for x, y in zip(
|
1914 |
+
[
|
1915 |
+
traj_prior_imgs,
|
1916 |
+
traj_prior_imgs_clip,
|
1917 |
+
traj_prior_c2ws,
|
1918 |
+
traj_prior_Ks,
|
1919 |
+
],
|
1920 |
+
[test_imgs, test_imgs_clip, test_c2ws, test_Ks],
|
1921 |
+
)
|
1922 |
+
]
|
1923 |
+
value_dict = get_value_dict(
|
1924 |
+
curr_imgs.to("cuda"),
|
1925 |
+
curr_imgs_clip.to("cuda"),
|
1926 |
+
curr_prior_sels,
|
1927 |
+
curr_c2ws,
|
1928 |
+
curr_Ks,
|
1929 |
+
list(range(T_second_pass)),
|
1930 |
+
all_c2ws=camera_cond["c2w"],
|
1931 |
+
camera_scale=options.get("camera_scale", 2.0),
|
1932 |
+
)
|
1933 |
+
samples = do_sample(
|
1934 |
+
model,
|
1935 |
+
ae,
|
1936 |
+
conditioner,
|
1937 |
+
denoiser,
|
1938 |
+
samplers[1] if len(samplers) > 1 else samplers[0],
|
1939 |
+
value_dict,
|
1940 |
+
H,
|
1941 |
+
W,
|
1942 |
+
C,
|
1943 |
+
F,
|
1944 |
+
T=T_second_pass,
|
1945 |
+
cfg=(
|
1946 |
+
options["cfg"][1]
|
1947 |
+
if isinstance(options["cfg"], (list, tuple))
|
1948 |
+
and len(options["cfg"]) > 1
|
1949 |
+
else options["cfg"]
|
1950 |
+
),
|
1951 |
+
global_pbar=second_pass_pbar,
|
1952 |
+
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
|
1953 |
+
)
|
1954 |
+
if samples is None:
|
1955 |
+
return
|
1956 |
+
samples = decode_output(
|
1957 |
+
samples, T_second_pass, chunk_test_sels
|
1958 |
+
) # decode into dict
|
1959 |
+
if options.get("save_second_pass", False):
|
1960 |
+
save_output(
|
1961 |
+
replace_or_include_input_for_dict(
|
1962 |
+
samples,
|
1963 |
+
chunk_test_sels,
|
1964 |
+
curr_imgs,
|
1965 |
+
curr_c2ws,
|
1966 |
+
curr_Ks,
|
1967 |
+
),
|
1968 |
+
save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
|
1969 |
+
video_save_fps=2,
|
1970 |
+
)
|
1971 |
+
extend_dict(all_samples, samples)
|
1972 |
+
all_test_inds.extend(chunk_test_inds)
|
1973 |
+
all_samples = {
|
1974 |
+
key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
|
1975 |
+
}
|
1976 |
+
save_output(
|
1977 |
+
replace_or_include_input_for_dict(
|
1978 |
+
all_samples,
|
1979 |
+
test_indices,
|
1980 |
+
imgs.clone(),
|
1981 |
+
camera_cond["c2w"].clone(),
|
1982 |
+
camera_cond["K"].clone(),
|
1983 |
+
)
|
1984 |
+
if options.get("replace_or_include_input", False)
|
1985 |
+
else all_samples,
|
1986 |
+
save_path=save_path,
|
1987 |
+
video_save_fps=options.get("video_save_fps", 2),
|
1988 |
+
)
|
1989 |
+
video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
|
1990 |
+
yield video_path_1
|
seva/geometry.py
ADDED
@@ -0,0 +1,811 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import roma
|
5 |
+
import scipy.interpolate
|
6 |
+
import torch
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
DEFAULT_FOV_RAD = 0.9424777960769379 # 54 degrees by default
|
10 |
+
|
11 |
+
|
12 |
+
def get_camera_dist(
|
13 |
+
source_c2ws: torch.Tensor, # N x 3 x 4
|
14 |
+
target_c2ws: torch.Tensor, # M x 3 x 4
|
15 |
+
mode: str = "translation",
|
16 |
+
):
|
17 |
+
if mode == "rotation":
|
18 |
+
dists = torch.acos(
|
19 |
+
(
|
20 |
+
(
|
21 |
+
torch.matmul(
|
22 |
+
source_c2ws[:, None, :3, :3],
|
23 |
+
target_c2ws[None, :, :3, :3].transpose(-1, -2),
|
24 |
+
)
|
25 |
+
.diagonal(offset=0, dim1=-2, dim2=-1)
|
26 |
+
.sum(-1)
|
27 |
+
- 1
|
28 |
+
)
|
29 |
+
/ 2
|
30 |
+
).clamp(-1, 1)
|
31 |
+
) * (180 / torch.pi)
|
32 |
+
elif mode == "translation":
|
33 |
+
dists = torch.norm(
|
34 |
+
source_c2ws[:, None, :3, 3] - target_c2ws[None, :, :3, 3], dim=-1
|
35 |
+
)
|
36 |
+
else:
|
37 |
+
raise NotImplementedError(
|
38 |
+
f"Mode {mode} is not implemented for finding nearest source indices."
|
39 |
+
)
|
40 |
+
return dists
|
41 |
+
|
42 |
+
|
43 |
+
def to_hom(X):
|
44 |
+
# get homogeneous coordinates of the input
|
45 |
+
X_hom = torch.cat([X, torch.ones_like(X[..., :1])], dim=-1)
|
46 |
+
return X_hom
|
47 |
+
|
48 |
+
|
49 |
+
def to_hom_pose(pose):
|
50 |
+
# get homogeneous coordinates of the input pose
|
51 |
+
if pose.shape[-2:] == (3, 4):
|
52 |
+
pose_hom = torch.eye(4, device=pose.device)[None].repeat(pose.shape[0], 1, 1)
|
53 |
+
pose_hom[:, :3, :] = pose
|
54 |
+
return pose_hom
|
55 |
+
return pose
|
56 |
+
|
57 |
+
|
58 |
+
def get_default_intrinsics(
|
59 |
+
fov_rad=DEFAULT_FOV_RAD,
|
60 |
+
aspect_ratio=1.0,
|
61 |
+
):
|
62 |
+
if not isinstance(fov_rad, torch.Tensor):
|
63 |
+
fov_rad = torch.tensor(
|
64 |
+
[fov_rad] if isinstance(fov_rad, (int, float)) else fov_rad
|
65 |
+
)
|
66 |
+
if aspect_ratio >= 1.0: # W >= H
|
67 |
+
focal_x = 0.5 / torch.tan(0.5 * fov_rad)
|
68 |
+
focal_y = focal_x * aspect_ratio
|
69 |
+
else: # W < H
|
70 |
+
focal_y = 0.5 / torch.tan(0.5 * fov_rad)
|
71 |
+
focal_x = focal_y / aspect_ratio
|
72 |
+
intrinsics = focal_x.new_zeros((focal_x.shape[0], 3, 3))
|
73 |
+
intrinsics[:, torch.eye(3, device=focal_x.device, dtype=bool)] = torch.stack(
|
74 |
+
[focal_x, focal_y, torch.ones_like(focal_x)], dim=-1
|
75 |
+
)
|
76 |
+
intrinsics[:, :, -1] = torch.tensor(
|
77 |
+
[0.5, 0.5, 1.0], device=focal_x.device, dtype=focal_x.dtype
|
78 |
+
)
|
79 |
+
return intrinsics
|
80 |
+
|
81 |
+
|
82 |
+
def get_image_grid(img_h, img_w):
|
83 |
+
# add 0.5 is VERY important especially when your img_h and img_w
|
84 |
+
# is not very large (e.g., 72)!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
|
85 |
+
y_range = torch.arange(img_h, dtype=torch.float32).add_(0.5)
|
86 |
+
x_range = torch.arange(img_w, dtype=torch.float32).add_(0.5)
|
87 |
+
Y, X = torch.meshgrid(y_range, x_range, indexing="ij") # [H,W]
|
88 |
+
xy_grid = torch.stack([X, Y], dim=-1).view(-1, 2) # [HW,2]
|
89 |
+
return to_hom(xy_grid) # [HW,3]
|
90 |
+
|
91 |
+
|
92 |
+
def img2cam(X, cam_intr):
|
93 |
+
return X @ cam_intr.inverse().transpose(-1, -2)
|
94 |
+
|
95 |
+
|
96 |
+
def cam2world(X, pose):
|
97 |
+
X_hom = to_hom(X)
|
98 |
+
pose_inv = torch.linalg.inv(to_hom_pose(pose))[..., :3, :4]
|
99 |
+
return X_hom @ pose_inv.transpose(-1, -2)
|
100 |
+
|
101 |
+
|
102 |
+
def get_center_and_ray(
|
103 |
+
img_h, img_w, pose, intr, zero_center_for_debugging=False
|
104 |
+
): # [HW,2]
|
105 |
+
# given the intrinsic/extrinsic matrices, get the camera center and ray directions]
|
106 |
+
# assert(opt.camera.model=="perspective")
|
107 |
+
|
108 |
+
# compute center and ray
|
109 |
+
grid_img = get_image_grid(img_h, img_w) # [HW,3]
|
110 |
+
grid_3D_cam = img2cam(grid_img.to(intr.device), intr.float()) # [B,HW,3]
|
111 |
+
center_3D_cam = torch.zeros_like(grid_3D_cam) # [B,HW,3]
|
112 |
+
|
113 |
+
# transform from camera to world coordinates
|
114 |
+
grid_3D = cam2world(grid_3D_cam, pose) # [B,HW,3]
|
115 |
+
center_3D = cam2world(center_3D_cam, pose) # [B,HW,3]
|
116 |
+
ray = grid_3D - center_3D # [B,HW,3]
|
117 |
+
|
118 |
+
return center_3D_cam if zero_center_for_debugging else center_3D, ray, grid_3D_cam
|
119 |
+
|
120 |
+
|
121 |
+
def get_plucker_coordinates(
|
122 |
+
extrinsics_src,
|
123 |
+
extrinsics,
|
124 |
+
intrinsics=None,
|
125 |
+
fov_rad=DEFAULT_FOV_RAD,
|
126 |
+
mode="plucker",
|
127 |
+
rel_zero_translation=True,
|
128 |
+
zero_center_for_debugging=False,
|
129 |
+
target_size=[72, 72], # 576-size image
|
130 |
+
return_grid_cam=False, # save for later use if want restore
|
131 |
+
):
|
132 |
+
if intrinsics is None:
|
133 |
+
intrinsics = get_default_intrinsics(fov_rad).to(extrinsics.device)
|
134 |
+
else:
|
135 |
+
# for some data preprocessed in the early stage (e.g., MVI and CO3D),
|
136 |
+
# intrinsics are expressed in raw pixel space (e.g., 576x576) instead
|
137 |
+
# of normalized image coordinates
|
138 |
+
if not (
|
139 |
+
torch.all(intrinsics[:, :2, -1] >= 0)
|
140 |
+
and torch.all(intrinsics[:, :2, -1] <= 1)
|
141 |
+
):
|
142 |
+
intrinsics[:, :2] /= intrinsics.new_tensor(target_size).view(1, -1, 1) * 8
|
143 |
+
# you should ensure the intrisics are expressed in
|
144 |
+
# resolution-independent normalized image coordinates just performing a
|
145 |
+
# very simple verification here checking if principal points are
|
146 |
+
# between 0 and 1
|
147 |
+
assert (
|
148 |
+
torch.all(intrinsics[:, :2, -1] >= 0)
|
149 |
+
and torch.all(intrinsics[:, :2, -1] <= 1)
|
150 |
+
), "Intrinsics should be expressed in resolution-independent normalized image coordinates."
|
151 |
+
|
152 |
+
c2w_src = torch.linalg.inv(extrinsics_src)
|
153 |
+
if not rel_zero_translation:
|
154 |
+
c2w_src[:3, 3] = c2w_src[3, :3] = 0.0
|
155 |
+
# transform coordinates from the source camera's coordinate system to the coordinate system of the respective camera
|
156 |
+
extrinsics_rel = torch.einsum(
|
157 |
+
"vnm,vmp->vnp", extrinsics, c2w_src[None].repeat(extrinsics.shape[0], 1, 1)
|
158 |
+
)
|
159 |
+
|
160 |
+
intrinsics[:, :2] *= extrinsics.new_tensor(
|
161 |
+
[
|
162 |
+
target_size[1], # w
|
163 |
+
target_size[0], # h
|
164 |
+
]
|
165 |
+
).view(1, -1, 1)
|
166 |
+
centers, rays, grid_cam = get_center_and_ray(
|
167 |
+
img_h=target_size[0],
|
168 |
+
img_w=target_size[1],
|
169 |
+
pose=extrinsics_rel[:, :3, :],
|
170 |
+
intr=intrinsics,
|
171 |
+
zero_center_for_debugging=zero_center_for_debugging,
|
172 |
+
)
|
173 |
+
|
174 |
+
if mode == "plucker" or "v1" in mode:
|
175 |
+
rays = torch.nn.functional.normalize(rays, dim=-1)
|
176 |
+
plucker = torch.cat((rays, torch.cross(centers, rays, dim=-1)), dim=-1)
|
177 |
+
else:
|
178 |
+
raise ValueError(f"Unknown Plucker coordinate mode: {mode}")
|
179 |
+
|
180 |
+
plucker = plucker.permute(0, 2, 1).reshape(plucker.shape[0], -1, *target_size)
|
181 |
+
if return_grid_cam:
|
182 |
+
return plucker, grid_cam.reshape(-1, *target_size, 3)
|
183 |
+
return plucker
|
184 |
+
|
185 |
+
|
186 |
+
def rt_to_mat4(
|
187 |
+
R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
|
188 |
+
) -> torch.Tensor:
|
189 |
+
"""
|
190 |
+
Args:
|
191 |
+
R (torch.Tensor): (..., 3, 3).
|
192 |
+
t (torch.Tensor): (..., 3).
|
193 |
+
s (torch.Tensor): (...,).
|
194 |
+
|
195 |
+
Returns:
|
196 |
+
torch.Tensor: (..., 4, 4)
|
197 |
+
"""
|
198 |
+
mat34 = torch.cat([R, t[..., None]], dim=-1)
|
199 |
+
if s is None:
|
200 |
+
bottom = (
|
201 |
+
mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
|
202 |
+
.reshape((1,) * (mat34.dim() - 2) + (1, 4))
|
203 |
+
.expand(mat34.shape[:-2] + (1, 4))
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
|
207 |
+
mat4 = torch.cat([mat34, bottom], dim=-2)
|
208 |
+
return mat4
|
209 |
+
|
210 |
+
|
211 |
+
def get_preset_pose_fov(
|
212 |
+
option: Literal[
|
213 |
+
"orbit",
|
214 |
+
"spiral",
|
215 |
+
"lemniscate",
|
216 |
+
"zoom-in",
|
217 |
+
"zoom-out",
|
218 |
+
"dolly zoom-in",
|
219 |
+
"dolly zoom-out",
|
220 |
+
"move-forward",
|
221 |
+
"move-backward",
|
222 |
+
"move-up",
|
223 |
+
"move-down",
|
224 |
+
"move-left",
|
225 |
+
"move-right",
|
226 |
+
"roll",
|
227 |
+
],
|
228 |
+
num_frames: int,
|
229 |
+
start_w2c: torch.Tensor,
|
230 |
+
look_at: torch.Tensor,
|
231 |
+
up_direction: torch.Tensor | None = None,
|
232 |
+
fov: float = DEFAULT_FOV_RAD,
|
233 |
+
spiral_radii: list[float] = [0.5, 0.5, 0.2],
|
234 |
+
zoom_factor: float | None = None,
|
235 |
+
):
|
236 |
+
poses = fovs = None
|
237 |
+
if option == "orbit":
|
238 |
+
poses = torch.linalg.inv(
|
239 |
+
get_arc_horizontal_w2cs(
|
240 |
+
start_w2c,
|
241 |
+
look_at,
|
242 |
+
up_direction,
|
243 |
+
num_frames=num_frames,
|
244 |
+
endpoint=False,
|
245 |
+
)
|
246 |
+
).numpy()
|
247 |
+
fovs = np.full((num_frames,), fov)
|
248 |
+
elif option == "spiral":
|
249 |
+
poses = generate_spiral_path(
|
250 |
+
torch.linalg.inv(start_w2c)[None].numpy() @ np.diagflat([1, -1, -1, 1]),
|
251 |
+
np.array([1, 5]),
|
252 |
+
n_frames=num_frames,
|
253 |
+
n_rots=2,
|
254 |
+
zrate=0.5,
|
255 |
+
radii=spiral_radii,
|
256 |
+
endpoint=False,
|
257 |
+
) @ np.diagflat([1, -1, -1, 1])
|
258 |
+
poses = np.concatenate(
|
259 |
+
[
|
260 |
+
poses,
|
261 |
+
np.array([0.0, 0.0, 0.0, 1.0])[None, None].repeat(len(poses), 0),
|
262 |
+
],
|
263 |
+
1,
|
264 |
+
)
|
265 |
+
# We want the spiral trajectory to always start from start_w2c. Thus we
|
266 |
+
# apply the relative pose to get the final trajectory.
|
267 |
+
poses = (
|
268 |
+
np.linalg.inv(start_w2c.numpy())[None] @ np.linalg.inv(poses[:1]) @ poses
|
269 |
+
)
|
270 |
+
fovs = np.full((num_frames,), fov)
|
271 |
+
elif option == "lemniscate":
|
272 |
+
poses = torch.linalg.inv(
|
273 |
+
get_lemniscate_w2cs(
|
274 |
+
start_w2c,
|
275 |
+
look_at,
|
276 |
+
up_direction,
|
277 |
+
num_frames,
|
278 |
+
degree=60.0,
|
279 |
+
endpoint=False,
|
280 |
+
)
|
281 |
+
).numpy()
|
282 |
+
fovs = np.full((num_frames,), fov)
|
283 |
+
elif option == "roll":
|
284 |
+
poses = torch.linalg.inv(
|
285 |
+
get_roll_w2cs(
|
286 |
+
start_w2c,
|
287 |
+
look_at,
|
288 |
+
None,
|
289 |
+
num_frames,
|
290 |
+
degree=360.0,
|
291 |
+
endpoint=False,
|
292 |
+
)
|
293 |
+
).numpy()
|
294 |
+
fovs = np.full((num_frames,), fov)
|
295 |
+
elif option in [
|
296 |
+
"dolly zoom-in",
|
297 |
+
"dolly zoom-out",
|
298 |
+
"zoom-in",
|
299 |
+
"zoom-out",
|
300 |
+
]:
|
301 |
+
if option.startswith("dolly"):
|
302 |
+
direction = "backward" if option == "dolly zoom-in" else "forward"
|
303 |
+
poses = torch.linalg.inv(
|
304 |
+
get_moving_w2cs(
|
305 |
+
start_w2c,
|
306 |
+
look_at,
|
307 |
+
up_direction,
|
308 |
+
num_frames,
|
309 |
+
endpoint=True,
|
310 |
+
direction=direction,
|
311 |
+
)
|
312 |
+
).numpy()
|
313 |
+
else:
|
314 |
+
poses = torch.linalg.inv(start_w2c)[None].repeat(num_frames, 1, 1).numpy()
|
315 |
+
fov_rad_start = fov
|
316 |
+
if zoom_factor is None:
|
317 |
+
zoom_factor = 0.28 if option.endswith("zoom-in") else 1.5
|
318 |
+
fov_rad_end = zoom_factor * fov
|
319 |
+
fovs = (
|
320 |
+
np.linspace(0, 1, num_frames) * (fov_rad_end - fov_rad_start)
|
321 |
+
+ fov_rad_start
|
322 |
+
)
|
323 |
+
elif option in [
|
324 |
+
"move-forward",
|
325 |
+
"move-backward",
|
326 |
+
"move-up",
|
327 |
+
"move-down",
|
328 |
+
"move-left",
|
329 |
+
"move-right",
|
330 |
+
]:
|
331 |
+
poses = torch.linalg.inv(
|
332 |
+
get_moving_w2cs(
|
333 |
+
start_w2c,
|
334 |
+
look_at,
|
335 |
+
up_direction,
|
336 |
+
num_frames,
|
337 |
+
endpoint=True,
|
338 |
+
direction=option.removeprefix("move-"),
|
339 |
+
)
|
340 |
+
).numpy()
|
341 |
+
fovs = np.full((num_frames,), fov)
|
342 |
+
else:
|
343 |
+
raise ValueError(f"Unknown preset option {option}.")
|
344 |
+
|
345 |
+
return poses, fovs
|
346 |
+
|
347 |
+
|
348 |
+
def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
|
349 |
+
"""Triangulate a set of rays to find a single lookat point.
|
350 |
+
|
351 |
+
Args:
|
352 |
+
origins (torch.Tensor): A (N, 3) array of ray origins.
|
353 |
+
viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
|
354 |
+
|
355 |
+
Returns:
|
356 |
+
torch.Tensor: A (3,) lookat point.
|
357 |
+
"""
|
358 |
+
|
359 |
+
viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
|
360 |
+
eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
|
361 |
+
# Calculate projection matrix I - rr^T
|
362 |
+
I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
|
363 |
+
# Compute sum of projections
|
364 |
+
sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
|
365 |
+
# Solve for the intersection point using least squares
|
366 |
+
lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
367 |
+
# Check NaNs.
|
368 |
+
assert not torch.any(torch.isnan(lookat))
|
369 |
+
return lookat
|
370 |
+
|
371 |
+
|
372 |
+
def get_lookat_w2cs(
|
373 |
+
positions: torch.Tensor,
|
374 |
+
lookat: torch.Tensor,
|
375 |
+
up: torch.Tensor,
|
376 |
+
face_off: bool = False,
|
377 |
+
):
|
378 |
+
"""
|
379 |
+
Args:
|
380 |
+
positions: (N, 3) tensor of camera positions
|
381 |
+
lookat: (3,) tensor of lookat point
|
382 |
+
up: (3,) or (N, 3) tensor of up vector
|
383 |
+
|
384 |
+
Returns:
|
385 |
+
w2cs: (N, 3, 3) tensor of world to camera rotation matrices
|
386 |
+
"""
|
387 |
+
forward_vectors = F.normalize(lookat - positions, dim=-1)
|
388 |
+
if face_off:
|
389 |
+
forward_vectors = -forward_vectors
|
390 |
+
if up.dim() == 1:
|
391 |
+
up = up[None]
|
392 |
+
right_vectors = F.normalize(torch.cross(forward_vectors, up, dim=-1), dim=-1)
|
393 |
+
down_vectors = F.normalize(
|
394 |
+
torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
|
395 |
+
)
|
396 |
+
Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
|
397 |
+
w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
|
398 |
+
return w2cs
|
399 |
+
|
400 |
+
|
401 |
+
def get_arc_horizontal_w2cs(
|
402 |
+
ref_w2c: torch.Tensor,
|
403 |
+
lookat: torch.Tensor,
|
404 |
+
up: torch.Tensor | None,
|
405 |
+
num_frames: int,
|
406 |
+
clockwise: bool = True,
|
407 |
+
face_off: bool = False,
|
408 |
+
endpoint: bool = False,
|
409 |
+
degree: float = 360.0,
|
410 |
+
ref_up_shift: float = 0.0,
|
411 |
+
ref_radius_scale: float = 1.0,
|
412 |
+
**_,
|
413 |
+
) -> torch.Tensor:
|
414 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
415 |
+
ref_position = ref_c2w[:3, 3]
|
416 |
+
if up is None:
|
417 |
+
up = -ref_c2w[:3, 1]
|
418 |
+
assert up is not None
|
419 |
+
ref_position += up * ref_up_shift
|
420 |
+
ref_position *= ref_radius_scale
|
421 |
+
thetas = (
|
422 |
+
torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
|
423 |
+
if endpoint
|
424 |
+
else torch.linspace(
|
425 |
+
0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
|
426 |
+
)[:-1]
|
427 |
+
)
|
428 |
+
if not clockwise:
|
429 |
+
thetas = -thetas
|
430 |
+
positions = (
|
431 |
+
torch.einsum(
|
432 |
+
"nij,j->ni",
|
433 |
+
roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
|
434 |
+
ref_position - lookat,
|
435 |
+
)
|
436 |
+
+ lookat
|
437 |
+
)
|
438 |
+
return get_lookat_w2cs(positions, lookat, up, face_off=face_off)
|
439 |
+
|
440 |
+
|
441 |
+
def get_lemniscate_w2cs(
|
442 |
+
ref_w2c: torch.Tensor,
|
443 |
+
lookat: torch.Tensor,
|
444 |
+
up: torch.Tensor | None,
|
445 |
+
num_frames: int,
|
446 |
+
degree: float,
|
447 |
+
endpoint: bool = False,
|
448 |
+
**_,
|
449 |
+
) -> torch.Tensor:
|
450 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
451 |
+
a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
|
452 |
+
# Lemniscate curve in camera space. Starting at the origin.
|
453 |
+
thetas = (
|
454 |
+
torch.linspace(0, 2 * torch.pi, num_frames, device=ref_w2c.device)
|
455 |
+
if endpoint
|
456 |
+
else torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
|
457 |
+
) + torch.pi / 2
|
458 |
+
positions = torch.stack(
|
459 |
+
[
|
460 |
+
a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
|
461 |
+
a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
|
462 |
+
torch.zeros(num_frames, device=ref_w2c.device),
|
463 |
+
],
|
464 |
+
dim=-1,
|
465 |
+
)
|
466 |
+
# Transform to world space.
|
467 |
+
positions = torch.einsum(
|
468 |
+
"ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
|
469 |
+
)
|
470 |
+
if up is None:
|
471 |
+
up = -ref_c2w[:3, 1]
|
472 |
+
assert up is not None
|
473 |
+
return get_lookat_w2cs(positions, lookat, up)
|
474 |
+
|
475 |
+
|
476 |
+
def get_moving_w2cs(
|
477 |
+
ref_w2c: torch.Tensor,
|
478 |
+
lookat: torch.Tensor,
|
479 |
+
up: torch.Tensor | None,
|
480 |
+
num_frames: int,
|
481 |
+
endpoint: bool = False,
|
482 |
+
direction: str = "forward",
|
483 |
+
tilt_xy: torch.Tensor = None,
|
484 |
+
):
|
485 |
+
"""
|
486 |
+
Args:
|
487 |
+
ref_w2c: (4, 4) tensor of the reference wolrd-to-camera matrix
|
488 |
+
lookat: (3,) tensor of lookat point
|
489 |
+
up: (3,) tensor of up vector
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
w2cs: (N, 3, 3) tensor of world to camera rotation matrices
|
493 |
+
"""
|
494 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
495 |
+
ref_position = ref_c2w[:3, -1]
|
496 |
+
if up is None:
|
497 |
+
up = -ref_c2w[:3, 1]
|
498 |
+
|
499 |
+
direction_vectors = {
|
500 |
+
"forward": (lookat - ref_position).clone(),
|
501 |
+
"backward": -(lookat - ref_position).clone(),
|
502 |
+
"up": up.clone(),
|
503 |
+
"down": -up.clone(),
|
504 |
+
"right": torch.cross((lookat - ref_position), up, dim=0),
|
505 |
+
"left": -torch.cross((lookat - ref_position), up, dim=0),
|
506 |
+
}
|
507 |
+
if direction not in direction_vectors:
|
508 |
+
raise ValueError(
|
509 |
+
f"Invalid direction: {direction}. Must be one of {list(direction_vectors.keys())}"
|
510 |
+
)
|
511 |
+
|
512 |
+
positions = ref_position + (
|
513 |
+
F.normalize(direction_vectors[direction], dim=0)
|
514 |
+
* (
|
515 |
+
torch.linspace(0, 0.99, num_frames, device=ref_w2c.device)
|
516 |
+
if endpoint
|
517 |
+
else torch.linspace(0, 1, num_frames + 1, device=ref_w2c.device)[:-1]
|
518 |
+
)[:, None]
|
519 |
+
)
|
520 |
+
|
521 |
+
if tilt_xy is not None:
|
522 |
+
positions[:, :2] += tilt_xy
|
523 |
+
|
524 |
+
return get_lookat_w2cs(positions, lookat, up)
|
525 |
+
|
526 |
+
|
527 |
+
def get_roll_w2cs(
|
528 |
+
ref_w2c: torch.Tensor,
|
529 |
+
lookat: torch.Tensor,
|
530 |
+
up: torch.Tensor | None,
|
531 |
+
num_frames: int,
|
532 |
+
endpoint: bool = False,
|
533 |
+
degree: float = 360.0,
|
534 |
+
**_,
|
535 |
+
) -> torch.Tensor:
|
536 |
+
ref_c2w = torch.linalg.inv(ref_w2c)
|
537 |
+
ref_position = ref_c2w[:3, 3]
|
538 |
+
if up is None:
|
539 |
+
up = -ref_c2w[:3, 1] # Infer the up vector from the reference.
|
540 |
+
|
541 |
+
# Create vertical angles
|
542 |
+
thetas = (
|
543 |
+
torch.linspace(0.0, torch.pi * degree / 180, num_frames, device=ref_w2c.device)
|
544 |
+
if endpoint
|
545 |
+
else torch.linspace(
|
546 |
+
0.0, torch.pi * degree / 180, num_frames + 1, device=ref_w2c.device
|
547 |
+
)[:-1]
|
548 |
+
)[:, None]
|
549 |
+
|
550 |
+
lookat_vector = F.normalize(lookat[None].float(), dim=-1)
|
551 |
+
up = up[None]
|
552 |
+
up = (
|
553 |
+
up * torch.cos(thetas)
|
554 |
+
+ torch.cross(lookat_vector, up) * torch.sin(thetas)
|
555 |
+
+ lookat_vector
|
556 |
+
* torch.einsum("ij,ij->i", lookat_vector, up)[:, None]
|
557 |
+
* (1 - torch.cos(thetas))
|
558 |
+
)
|
559 |
+
|
560 |
+
# Normalize the camera orientation
|
561 |
+
return get_lookat_w2cs(ref_position[None].repeat(num_frames, 1), lookat, up)
|
562 |
+
|
563 |
+
|
564 |
+
def normalize(x):
|
565 |
+
"""Normalization helper function."""
|
566 |
+
return x / np.linalg.norm(x)
|
567 |
+
|
568 |
+
|
569 |
+
def viewmatrix(lookdir, up, position, subtract_position=False):
|
570 |
+
"""Construct lookat view matrix."""
|
571 |
+
vec2 = normalize((lookdir - position) if subtract_position else lookdir)
|
572 |
+
vec0 = normalize(np.cross(up, vec2))
|
573 |
+
vec1 = normalize(np.cross(vec2, vec0))
|
574 |
+
m = np.stack([vec0, vec1, vec2, position], axis=1)
|
575 |
+
return m
|
576 |
+
|
577 |
+
|
578 |
+
def poses_avg(poses):
|
579 |
+
"""New pose using average position, z-axis, and up vector of input poses."""
|
580 |
+
position = poses[:, :3, 3].mean(0)
|
581 |
+
z_axis = poses[:, :3, 2].mean(0)
|
582 |
+
up = poses[:, :3, 1].mean(0)
|
583 |
+
cam2world = viewmatrix(z_axis, up, position)
|
584 |
+
return cam2world
|
585 |
+
|
586 |
+
|
587 |
+
def generate_spiral_path(
|
588 |
+
poses, bounds, n_frames=120, n_rots=2, zrate=0.5, endpoint=False, radii=None
|
589 |
+
):
|
590 |
+
"""Calculates a forward facing spiral path for rendering."""
|
591 |
+
# Find a reasonable 'focus depth' for this dataset as a weighted average
|
592 |
+
# of near and far bounds in disparity space.
|
593 |
+
close_depth, inf_depth = bounds.min() * 0.9, bounds.max() * 5.0
|
594 |
+
dt = 0.75
|
595 |
+
focal = 1 / ((1 - dt) / close_depth + dt / inf_depth)
|
596 |
+
|
597 |
+
# Get radii for spiral path using 90th percentile of camera positions.
|
598 |
+
positions = poses[:, :3, 3]
|
599 |
+
if radii is None:
|
600 |
+
radii = np.percentile(np.abs(positions), 90, 0)
|
601 |
+
radii = np.concatenate([radii, [1.0]])
|
602 |
+
|
603 |
+
# Generate poses for spiral path.
|
604 |
+
render_poses = []
|
605 |
+
cam2world = poses_avg(poses)
|
606 |
+
up = poses[:, :3, 1].mean(0)
|
607 |
+
for theta in np.linspace(0.0, 2.0 * np.pi * n_rots, n_frames, endpoint=endpoint):
|
608 |
+
t = radii * [np.cos(theta), -np.sin(theta), -np.sin(theta * zrate), 1.0]
|
609 |
+
position = cam2world @ t
|
610 |
+
lookat = cam2world @ [0, 0, -focal, 1.0]
|
611 |
+
z_axis = position - lookat
|
612 |
+
render_poses.append(viewmatrix(z_axis, up, position))
|
613 |
+
render_poses = np.stack(render_poses, axis=0)
|
614 |
+
return render_poses
|
615 |
+
|
616 |
+
|
617 |
+
def generate_interpolated_path(
|
618 |
+
poses: np.ndarray,
|
619 |
+
n_interp: int,
|
620 |
+
spline_degree: int = 5,
|
621 |
+
smoothness: float = 0.03,
|
622 |
+
rot_weight: float = 0.1,
|
623 |
+
endpoint: bool = False,
|
624 |
+
):
|
625 |
+
"""Creates a smooth spline path between input keyframe camera poses.
|
626 |
+
|
627 |
+
Spline is calculated with poses in format (position, lookat-point, up-point).
|
628 |
+
|
629 |
+
Args:
|
630 |
+
poses: (n, 3, 4) array of input pose keyframes.
|
631 |
+
n_interp: returned path will have n_interp * (n - 1) total poses.
|
632 |
+
spline_degree: polynomial degree of B-spline.
|
633 |
+
smoothness: parameter for spline smoothing, 0 forces exact interpolation.
|
634 |
+
rot_weight: relative weighting of rotation/translation in spline solve.
|
635 |
+
|
636 |
+
Returns:
|
637 |
+
Array of new camera poses with shape (n_interp * (n - 1), 3, 4).
|
638 |
+
"""
|
639 |
+
|
640 |
+
def poses_to_points(poses, dist):
|
641 |
+
"""Converts from pose matrices to (position, lookat, up) format."""
|
642 |
+
pos = poses[:, :3, -1]
|
643 |
+
lookat = poses[:, :3, -1] - dist * poses[:, :3, 2]
|
644 |
+
up = poses[:, :3, -1] + dist * poses[:, :3, 1]
|
645 |
+
return np.stack([pos, lookat, up], 1)
|
646 |
+
|
647 |
+
def points_to_poses(points):
|
648 |
+
"""Converts from (position, lookat, up) format to pose matrices."""
|
649 |
+
return np.array([viewmatrix(p - l, u - p, p) for p, l, u in points])
|
650 |
+
|
651 |
+
def interp(points, n, k, s):
|
652 |
+
"""Runs multidimensional B-spline interpolation on the input points."""
|
653 |
+
sh = points.shape
|
654 |
+
pts = np.reshape(points, (sh[0], -1))
|
655 |
+
k = min(k, sh[0] - 1)
|
656 |
+
tck, _ = scipy.interpolate.splprep(pts.T, k=k, s=s)
|
657 |
+
u = np.linspace(0, 1, n, endpoint=endpoint)
|
658 |
+
new_points = np.array(scipy.interpolate.splev(u, tck))
|
659 |
+
new_points = np.reshape(new_points.T, (n, sh[1], sh[2]))
|
660 |
+
return new_points
|
661 |
+
|
662 |
+
points = poses_to_points(poses, dist=rot_weight)
|
663 |
+
new_points = interp(
|
664 |
+
points, n_interp * (points.shape[0] - 1), k=spline_degree, s=smoothness
|
665 |
+
)
|
666 |
+
return points_to_poses(new_points)
|
667 |
+
|
668 |
+
|
669 |
+
def similarity_from_cameras(c2w, strict_scaling=False, center_method="focus"):
|
670 |
+
"""
|
671 |
+
reference: nerf-factory
|
672 |
+
Get a similarity transform to normalize dataset
|
673 |
+
from c2w (OpenCV convention) cameras
|
674 |
+
:param c2w: (N, 4)
|
675 |
+
:return T (4,4) , scale (float)
|
676 |
+
"""
|
677 |
+
t = c2w[:, :3, 3]
|
678 |
+
R = c2w[:, :3, :3]
|
679 |
+
|
680 |
+
# (1) Rotate the world so that z+ is the up axis
|
681 |
+
# we estimate the up axis by averaging the camera up axes
|
682 |
+
ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1)
|
683 |
+
world_up = np.mean(ups, axis=0)
|
684 |
+
world_up /= np.linalg.norm(world_up)
|
685 |
+
|
686 |
+
up_camspace = np.array([0.0, -1.0, 0.0])
|
687 |
+
c = (up_camspace * world_up).sum()
|
688 |
+
cross = np.cross(world_up, up_camspace)
|
689 |
+
skew = np.array(
|
690 |
+
[
|
691 |
+
[0.0, -cross[2], cross[1]],
|
692 |
+
[cross[2], 0.0, -cross[0]],
|
693 |
+
[-cross[1], cross[0], 0.0],
|
694 |
+
]
|
695 |
+
)
|
696 |
+
if c > -1:
|
697 |
+
R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c)
|
698 |
+
else:
|
699 |
+
# In the unlikely case the original data has y+ up axis,
|
700 |
+
# rotate 180-deg about x axis
|
701 |
+
R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])
|
702 |
+
|
703 |
+
# R_align = np.eye(3) # DEBUG
|
704 |
+
R = R_align @ R
|
705 |
+
fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1)
|
706 |
+
t = (R_align @ t[..., None])[..., 0]
|
707 |
+
|
708 |
+
# (2) Recenter the scene.
|
709 |
+
if center_method == "focus":
|
710 |
+
# find the closest point to the origin for each camera's center ray
|
711 |
+
nearest = t + (fwds * -t).sum(-1)[:, None] * fwds
|
712 |
+
translate = -np.median(nearest, axis=0)
|
713 |
+
elif center_method == "poses":
|
714 |
+
# use center of the camera positions
|
715 |
+
translate = -np.median(t, axis=0)
|
716 |
+
else:
|
717 |
+
raise ValueError(f"Unknown center_method {center_method}")
|
718 |
+
|
719 |
+
transform = np.eye(4)
|
720 |
+
transform[:3, 3] = translate
|
721 |
+
transform[:3, :3] = R_align
|
722 |
+
|
723 |
+
# (3) Rescale the scene using camera distances
|
724 |
+
scale_fn = np.max if strict_scaling else np.median
|
725 |
+
inv_scale = scale_fn(np.linalg.norm(t + translate, axis=-1))
|
726 |
+
if inv_scale == 0:
|
727 |
+
inv_scale = 1.0
|
728 |
+
scale = 1.0 / inv_scale
|
729 |
+
transform[:3, :] *= scale
|
730 |
+
|
731 |
+
return transform
|
732 |
+
|
733 |
+
|
734 |
+
def align_principle_axes(point_cloud):
|
735 |
+
# Compute centroid
|
736 |
+
centroid = np.median(point_cloud, axis=0)
|
737 |
+
|
738 |
+
# Translate point cloud to centroid
|
739 |
+
translated_point_cloud = point_cloud - centroid
|
740 |
+
|
741 |
+
# Compute covariance matrix
|
742 |
+
covariance_matrix = np.cov(translated_point_cloud, rowvar=False)
|
743 |
+
|
744 |
+
# Compute eigenvectors and eigenvalues
|
745 |
+
eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
|
746 |
+
|
747 |
+
# Sort eigenvectors by eigenvalues (descending order) so that the z-axis
|
748 |
+
# is the principal axis with the smallest eigenvalue.
|
749 |
+
sort_indices = eigenvalues.argsort()[::-1]
|
750 |
+
eigenvectors = eigenvectors[:, sort_indices]
|
751 |
+
|
752 |
+
# Check orientation of eigenvectors. If the determinant of the eigenvectors is
|
753 |
+
# negative, then we need to flip the sign of one of the eigenvectors.
|
754 |
+
if np.linalg.det(eigenvectors) < 0:
|
755 |
+
eigenvectors[:, 0] *= -1
|
756 |
+
|
757 |
+
# Create rotation matrix
|
758 |
+
rotation_matrix = eigenvectors.T
|
759 |
+
|
760 |
+
# Create SE(3) matrix (4x4 transformation matrix)
|
761 |
+
transform = np.eye(4)
|
762 |
+
transform[:3, :3] = rotation_matrix
|
763 |
+
transform[:3, 3] = -rotation_matrix @ centroid
|
764 |
+
|
765 |
+
return transform
|
766 |
+
|
767 |
+
|
768 |
+
def transform_points(matrix, points):
|
769 |
+
"""Transform points using a SE(4) matrix.
|
770 |
+
|
771 |
+
Args:
|
772 |
+
matrix: 4x4 SE(4) matrix
|
773 |
+
points: Nx3 array of points
|
774 |
+
|
775 |
+
Returns:
|
776 |
+
Nx3 array of transformed points
|
777 |
+
"""
|
778 |
+
assert matrix.shape == (4, 4)
|
779 |
+
assert len(points.shape) == 2 and points.shape[1] == 3
|
780 |
+
return points @ matrix[:3, :3].T + matrix[:3, 3]
|
781 |
+
|
782 |
+
|
783 |
+
def transform_cameras(matrix, camtoworlds):
|
784 |
+
"""Transform cameras using a SE(4) matrix.
|
785 |
+
|
786 |
+
Args:
|
787 |
+
matrix: 4x4 SE(4) matrix
|
788 |
+
camtoworlds: Nx4x4 array of camera-to-world matrices
|
789 |
+
|
790 |
+
Returns:
|
791 |
+
Nx4x4 array of transformed camera-to-world matrices
|
792 |
+
"""
|
793 |
+
assert matrix.shape == (4, 4)
|
794 |
+
assert len(camtoworlds.shape) == 3 and camtoworlds.shape[1:] == (4, 4)
|
795 |
+
camtoworlds = np.einsum("nij, ki -> nkj", camtoworlds, matrix)
|
796 |
+
scaling = np.linalg.norm(camtoworlds[:, 0, :3], axis=1)
|
797 |
+
camtoworlds[:, :3, :3] = camtoworlds[:, :3, :3] / scaling[:, None, None]
|
798 |
+
return camtoworlds
|
799 |
+
|
800 |
+
|
801 |
+
def normalize_scene(camtoworlds, points=None, camera_center_method="focus"):
|
802 |
+
T1 = similarity_from_cameras(camtoworlds, center_method=camera_center_method)
|
803 |
+
camtoworlds = transform_cameras(T1, camtoworlds)
|
804 |
+
if points is not None:
|
805 |
+
points = transform_points(T1, points)
|
806 |
+
T2 = align_principle_axes(points)
|
807 |
+
camtoworlds = transform_cameras(T2, camtoworlds)
|
808 |
+
points = transform_points(T2, points)
|
809 |
+
return camtoworlds, points, T2 @ T1
|
810 |
+
else:
|
811 |
+
return camtoworlds, T1
|
seva/gui.py
ADDED
@@ -0,0 +1,975 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorsys
|
2 |
+
import dataclasses
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
import scipy
|
9 |
+
import splines
|
10 |
+
import splines.quaternion
|
11 |
+
import torch
|
12 |
+
import viser
|
13 |
+
import viser.transforms as vt
|
14 |
+
|
15 |
+
from seva.geometry import get_preset_pose_fov
|
16 |
+
|
17 |
+
|
18 |
+
@dataclasses.dataclass
|
19 |
+
class Keyframe(object):
|
20 |
+
position: np.ndarray
|
21 |
+
wxyz: np.ndarray
|
22 |
+
override_fov_enabled: bool
|
23 |
+
override_fov_rad: float
|
24 |
+
aspect: float
|
25 |
+
override_transition_enabled: bool
|
26 |
+
override_transition_sec: float | None
|
27 |
+
|
28 |
+
@staticmethod
|
29 |
+
def from_camera(camera: viser.CameraHandle, aspect: float) -> "Keyframe":
|
30 |
+
return Keyframe(
|
31 |
+
camera.position,
|
32 |
+
camera.wxyz,
|
33 |
+
override_fov_enabled=False,
|
34 |
+
override_fov_rad=camera.fov,
|
35 |
+
aspect=aspect,
|
36 |
+
override_transition_enabled=False,
|
37 |
+
override_transition_sec=None,
|
38 |
+
)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def from_se3(se3: vt.SE3, fov: float, aspect: float) -> "Keyframe":
|
42 |
+
return Keyframe(
|
43 |
+
se3.translation(),
|
44 |
+
se3.rotation().wxyz,
|
45 |
+
override_fov_enabled=False,
|
46 |
+
override_fov_rad=fov,
|
47 |
+
aspect=aspect,
|
48 |
+
override_transition_enabled=False,
|
49 |
+
override_transition_sec=None,
|
50 |
+
)
|
51 |
+
|
52 |
+
|
53 |
+
class CameraTrajectory(object):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
server: viser.ViserServer,
|
57 |
+
duration_element: viser.GuiInputHandle[float],
|
58 |
+
scene_scale: float,
|
59 |
+
scene_node_prefix: str = "/",
|
60 |
+
):
|
61 |
+
self._server = server
|
62 |
+
self._keyframes: dict[int, tuple[Keyframe, viser.CameraFrustumHandle]] = {}
|
63 |
+
self._keyframe_counter: int = 0
|
64 |
+
self._spline_nodes: list[viser.SceneNodeHandle] = []
|
65 |
+
self._camera_edit_panel: viser.Gui3dContainerHandle | None = None
|
66 |
+
|
67 |
+
self._orientation_spline: splines.quaternion.KochanekBartels | None = None
|
68 |
+
self._position_spline: splines.KochanekBartels | None = None
|
69 |
+
self._fov_spline: splines.KochanekBartels | None = None
|
70 |
+
|
71 |
+
self._keyframes_visible: bool = True
|
72 |
+
|
73 |
+
self._duration_element = duration_element
|
74 |
+
self._scene_node_prefix = scene_node_prefix
|
75 |
+
|
76 |
+
self.scene_scale = scene_scale
|
77 |
+
# These parameters should be overridden externally.
|
78 |
+
self.loop: bool = False
|
79 |
+
self.framerate: float = 30.0
|
80 |
+
self.tension: float = 0.0 # Tension / alpha term.
|
81 |
+
self.default_fov: float = 0.0
|
82 |
+
self.default_transition_sec: float = 0.0
|
83 |
+
self.show_spline: bool = True
|
84 |
+
|
85 |
+
def set_keyframes_visible(self, visible: bool) -> None:
|
86 |
+
self._keyframes_visible = visible
|
87 |
+
for keyframe in self._keyframes.values():
|
88 |
+
keyframe[1].visible = visible
|
89 |
+
|
90 |
+
def add_camera(self, keyframe: Keyframe, keyframe_index: int | None = None) -> None:
|
91 |
+
"""Add a new camera, or replace an old one if `keyframe_index` is passed in."""
|
92 |
+
server = self._server
|
93 |
+
|
94 |
+
# Add a keyframe if we aren't replacing an existing one.
|
95 |
+
if keyframe_index is None:
|
96 |
+
keyframe_index = self._keyframe_counter
|
97 |
+
self._keyframe_counter += 1
|
98 |
+
|
99 |
+
print(
|
100 |
+
f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}"
|
101 |
+
)
|
102 |
+
frustum_handle = server.scene.add_camera_frustum(
|
103 |
+
str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}"),
|
104 |
+
fov=(
|
105 |
+
keyframe.override_fov_rad
|
106 |
+
if keyframe.override_fov_enabled
|
107 |
+
else self.default_fov
|
108 |
+
),
|
109 |
+
aspect=keyframe.aspect,
|
110 |
+
scale=0.1 * self.scene_scale,
|
111 |
+
color=(200, 10, 30),
|
112 |
+
wxyz=keyframe.wxyz,
|
113 |
+
position=keyframe.position,
|
114 |
+
visible=self._keyframes_visible,
|
115 |
+
)
|
116 |
+
self._server.scene.add_icosphere(
|
117 |
+
str(Path(self._scene_node_prefix) / f"cameras/{keyframe_index}/sphere"),
|
118 |
+
radius=0.03,
|
119 |
+
color=(200, 10, 30),
|
120 |
+
)
|
121 |
+
|
122 |
+
@frustum_handle.on_click
|
123 |
+
def _(_) -> None:
|
124 |
+
if self._camera_edit_panel is not None:
|
125 |
+
self._camera_edit_panel.remove()
|
126 |
+
self._camera_edit_panel = None
|
127 |
+
|
128 |
+
with server.scene.add_3d_gui_container(
|
129 |
+
"/camera_edit_panel",
|
130 |
+
position=keyframe.position,
|
131 |
+
) as camera_edit_panel:
|
132 |
+
self._camera_edit_panel = camera_edit_panel
|
133 |
+
override_fov = server.gui.add_checkbox(
|
134 |
+
"Override FOV", initial_value=keyframe.override_fov_enabled
|
135 |
+
)
|
136 |
+
override_fov_degrees = server.gui.add_slider(
|
137 |
+
"Override FOV (degrees)",
|
138 |
+
5.0,
|
139 |
+
175.0,
|
140 |
+
step=0.1,
|
141 |
+
initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
|
142 |
+
disabled=not keyframe.override_fov_enabled,
|
143 |
+
)
|
144 |
+
delete_button = server.gui.add_button(
|
145 |
+
"Delete", color="red", icon=viser.Icon.TRASH
|
146 |
+
)
|
147 |
+
go_to_button = server.gui.add_button("Go to")
|
148 |
+
close_button = server.gui.add_button("Close")
|
149 |
+
|
150 |
+
@override_fov.on_update
|
151 |
+
def _(_) -> None:
|
152 |
+
keyframe.override_fov_enabled = override_fov.value
|
153 |
+
override_fov_degrees.disabled = not override_fov.value
|
154 |
+
self.add_camera(keyframe, keyframe_index)
|
155 |
+
|
156 |
+
@override_fov_degrees.on_update
|
157 |
+
def _(_) -> None:
|
158 |
+
keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi
|
159 |
+
self.add_camera(keyframe, keyframe_index)
|
160 |
+
|
161 |
+
@delete_button.on_click
|
162 |
+
def _(event: viser.GuiEvent) -> None:
|
163 |
+
assert event.client is not None
|
164 |
+
with event.client.gui.add_modal("Confirm") as modal:
|
165 |
+
event.client.gui.add_markdown("Delete keyframe?")
|
166 |
+
confirm_button = event.client.gui.add_button(
|
167 |
+
"Yes", color="red", icon=viser.Icon.TRASH
|
168 |
+
)
|
169 |
+
exit_button = event.client.gui.add_button("Cancel")
|
170 |
+
|
171 |
+
@confirm_button.on_click
|
172 |
+
def _(_) -> None:
|
173 |
+
assert camera_edit_panel is not None
|
174 |
+
|
175 |
+
keyframe_id = None
|
176 |
+
for i, keyframe_tuple in self._keyframes.items():
|
177 |
+
if keyframe_tuple[1] is frustum_handle:
|
178 |
+
keyframe_id = i
|
179 |
+
break
|
180 |
+
assert keyframe_id is not None
|
181 |
+
|
182 |
+
self._keyframes.pop(keyframe_id)
|
183 |
+
frustum_handle.remove()
|
184 |
+
camera_edit_panel.remove()
|
185 |
+
self._camera_edit_panel = None
|
186 |
+
modal.close()
|
187 |
+
self.update_spline()
|
188 |
+
|
189 |
+
@exit_button.on_click
|
190 |
+
def _(_) -> None:
|
191 |
+
modal.close()
|
192 |
+
|
193 |
+
@go_to_button.on_click
|
194 |
+
def _(event: viser.GuiEvent) -> None:
|
195 |
+
assert event.client is not None
|
196 |
+
client = event.client
|
197 |
+
T_world_current = vt.SE3.from_rotation_and_translation(
|
198 |
+
vt.SO3(client.camera.wxyz), client.camera.position
|
199 |
+
)
|
200 |
+
T_world_target = vt.SE3.from_rotation_and_translation(
|
201 |
+
vt.SO3(keyframe.wxyz), keyframe.position
|
202 |
+
) @ vt.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
|
203 |
+
|
204 |
+
T_current_target = T_world_current.inverse() @ T_world_target
|
205 |
+
|
206 |
+
for j in range(10):
|
207 |
+
T_world_set = T_world_current @ vt.SE3.exp(
|
208 |
+
T_current_target.log() * j / 9.0
|
209 |
+
)
|
210 |
+
|
211 |
+
# Important bit: we atomically set both the orientation and
|
212 |
+
# the position of the camera.
|
213 |
+
with client.atomic():
|
214 |
+
client.camera.wxyz = T_world_set.rotation().wxyz
|
215 |
+
client.camera.position = T_world_set.translation()
|
216 |
+
time.sleep(1.0 / 30.0)
|
217 |
+
|
218 |
+
@close_button.on_click
|
219 |
+
def _(_) -> None:
|
220 |
+
assert camera_edit_panel is not None
|
221 |
+
camera_edit_panel.remove()
|
222 |
+
self._camera_edit_panel = None
|
223 |
+
|
224 |
+
self._keyframes[keyframe_index] = (keyframe, frustum_handle)
|
225 |
+
|
226 |
+
def update_aspect(self, aspect: float) -> None:
|
227 |
+
for keyframe_index, frame in self._keyframes.items():
|
228 |
+
frame = dataclasses.replace(frame[0], aspect=aspect)
|
229 |
+
self.add_camera(frame, keyframe_index=keyframe_index)
|
230 |
+
|
231 |
+
def get_aspect(self) -> float:
|
232 |
+
"""Get W/H aspect ratio, which is shared across all keyframes."""
|
233 |
+
assert len(self._keyframes) > 0
|
234 |
+
return next(iter(self._keyframes.values()))[0].aspect
|
235 |
+
|
236 |
+
def reset(self) -> None:
|
237 |
+
for frame in self._keyframes.values():
|
238 |
+
print(f"removing {frame[1]}")
|
239 |
+
frame[1].remove()
|
240 |
+
self._keyframes.clear()
|
241 |
+
self.update_spline()
|
242 |
+
print("camera traj reset")
|
243 |
+
|
244 |
+
def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
|
245 |
+
"""From a time value in seconds, compute a t value for our geometric
|
246 |
+
spline interpolation. An increment of 1 for the latter will move the
|
247 |
+
camera forward by one keyframe.
|
248 |
+
|
249 |
+
We use a PCHIP spline here to guarantee monotonicity.
|
250 |
+
"""
|
251 |
+
transition_times_cumsum = self.compute_transition_times_cumsum()
|
252 |
+
spline_indices = np.arange(transition_times_cumsum.shape[0])
|
253 |
+
|
254 |
+
if self.loop:
|
255 |
+
# In the case of a loop, we pad the spline to match the start/end
|
256 |
+
# slopes.
|
257 |
+
interpolator = scipy.interpolate.PchipInterpolator(
|
258 |
+
x=np.concatenate(
|
259 |
+
[
|
260 |
+
[-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
|
261 |
+
transition_times_cumsum,
|
262 |
+
transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
|
263 |
+
],
|
264 |
+
axis=0,
|
265 |
+
),
|
266 |
+
y=np.concatenate(
|
267 |
+
[[-1], spline_indices, [spline_indices[-1] + 1]], # type: ignore
|
268 |
+
axis=0,
|
269 |
+
),
|
270 |
+
)
|
271 |
+
else:
|
272 |
+
interpolator = scipy.interpolate.PchipInterpolator(
|
273 |
+
x=transition_times_cumsum, y=spline_indices
|
274 |
+
)
|
275 |
+
|
276 |
+
# Clip to account for floating point error.
|
277 |
+
return np.clip(interpolator(time), 0, spline_indices[-1])
|
278 |
+
|
279 |
+
def interpolate_pose_and_fov_rad(
|
280 |
+
self, normalized_t: float
|
281 |
+
) -> tuple[vt.SE3, float] | None:
|
282 |
+
if len(self._keyframes) < 2:
|
283 |
+
return None
|
284 |
+
|
285 |
+
self._fov_spline = splines.KochanekBartels(
|
286 |
+
[
|
287 |
+
(
|
288 |
+
keyframe[0].override_fov_rad
|
289 |
+
if keyframe[0].override_fov_enabled
|
290 |
+
else self.default_fov
|
291 |
+
)
|
292 |
+
for keyframe in self._keyframes.values()
|
293 |
+
],
|
294 |
+
tcb=(self.tension, 0.0, 0.0),
|
295 |
+
endconditions="closed" if self.loop else "natural",
|
296 |
+
)
|
297 |
+
|
298 |
+
assert self._orientation_spline is not None
|
299 |
+
assert self._position_spline is not None
|
300 |
+
assert self._fov_spline is not None
|
301 |
+
|
302 |
+
max_t = self.compute_duration()
|
303 |
+
t = max_t * normalized_t
|
304 |
+
spline_t = float(self.spline_t_from_t_sec(np.array(t)))
|
305 |
+
|
306 |
+
quat = self._orientation_spline.evaluate(spline_t)
|
307 |
+
assert isinstance(quat, splines.quaternion.UnitQuaternion)
|
308 |
+
return (
|
309 |
+
vt.SE3.from_rotation_and_translation(
|
310 |
+
vt.SO3(np.array([quat.scalar, *quat.vector])),
|
311 |
+
self._position_spline.evaluate(spline_t),
|
312 |
+
),
|
313 |
+
float(self._fov_spline.evaluate(spline_t)),
|
314 |
+
)
|
315 |
+
|
316 |
+
def update_spline(self) -> None:
|
317 |
+
num_frames = int(self.compute_duration() * self.framerate)
|
318 |
+
keyframes = list(self._keyframes.values())
|
319 |
+
|
320 |
+
if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
|
321 |
+
for node in self._spline_nodes:
|
322 |
+
node.remove()
|
323 |
+
self._spline_nodes.clear()
|
324 |
+
return
|
325 |
+
|
326 |
+
transition_times_cumsum = self.compute_transition_times_cumsum()
|
327 |
+
|
328 |
+
self._orientation_spline = splines.quaternion.KochanekBartels(
|
329 |
+
[
|
330 |
+
splines.quaternion.UnitQuaternion.from_unit_xyzw(
|
331 |
+
np.roll(keyframe[0].wxyz, shift=-1)
|
332 |
+
)
|
333 |
+
for keyframe in keyframes
|
334 |
+
],
|
335 |
+
tcb=(self.tension, 0.0, 0.0),
|
336 |
+
endconditions="closed" if self.loop else "natural",
|
337 |
+
)
|
338 |
+
self._position_spline = splines.KochanekBartels(
|
339 |
+
[keyframe[0].position for keyframe in keyframes],
|
340 |
+
tcb=(self.tension, 0.0, 0.0),
|
341 |
+
endconditions="closed" if self.loop else "natural",
|
342 |
+
)
|
343 |
+
|
344 |
+
# Update visualized spline.
|
345 |
+
points_array = self._position_spline.evaluate(
|
346 |
+
self.spline_t_from_t_sec(
|
347 |
+
np.linspace(0, transition_times_cumsum[-1], num_frames)
|
348 |
+
)
|
349 |
+
)
|
350 |
+
colors_array = np.array(
|
351 |
+
[
|
352 |
+
colorsys.hls_to_rgb(h, 0.5, 1.0)
|
353 |
+
for h in np.linspace(0.0, 1.0, len(points_array))
|
354 |
+
]
|
355 |
+
)
|
356 |
+
|
357 |
+
# Clear prior spline nodes.
|
358 |
+
for node in self._spline_nodes:
|
359 |
+
node.remove()
|
360 |
+
self._spline_nodes.clear()
|
361 |
+
|
362 |
+
self._spline_nodes.append(
|
363 |
+
self._server.scene.add_spline_catmull_rom(
|
364 |
+
str(Path(self._scene_node_prefix) / "camera_spline"),
|
365 |
+
positions=points_array,
|
366 |
+
color=(220, 220, 220),
|
367 |
+
closed=self.loop,
|
368 |
+
line_width=1.0,
|
369 |
+
segments=points_array.shape[0] + 1,
|
370 |
+
)
|
371 |
+
)
|
372 |
+
self._spline_nodes.append(
|
373 |
+
self._server.scene.add_point_cloud(
|
374 |
+
str(Path(self._scene_node_prefix) / "camera_spline/points"),
|
375 |
+
points=points_array,
|
376 |
+
colors=colors_array,
|
377 |
+
point_size=0.04,
|
378 |
+
)
|
379 |
+
)
|
380 |
+
|
381 |
+
def make_transition_handle(i: int) -> None:
|
382 |
+
assert self._position_spline is not None
|
383 |
+
transition_pos = self._position_spline.evaluate(
|
384 |
+
float(
|
385 |
+
self.spline_t_from_t_sec(
|
386 |
+
(transition_times_cumsum[i] + transition_times_cumsum[i + 1])
|
387 |
+
/ 2.0,
|
388 |
+
)
|
389 |
+
)
|
390 |
+
)
|
391 |
+
transition_sphere = self._server.scene.add_icosphere(
|
392 |
+
str(Path(self._scene_node_prefix) / f"camera_spline/transition_{i}"),
|
393 |
+
radius=0.04,
|
394 |
+
color=(255, 0, 0),
|
395 |
+
position=transition_pos,
|
396 |
+
)
|
397 |
+
self._spline_nodes.append(transition_sphere)
|
398 |
+
|
399 |
+
@transition_sphere.on_click
|
400 |
+
def _(_) -> None:
|
401 |
+
server = self._server
|
402 |
+
|
403 |
+
if self._camera_edit_panel is not None:
|
404 |
+
self._camera_edit_panel.remove()
|
405 |
+
self._camera_edit_panel = None
|
406 |
+
|
407 |
+
keyframe_index = (i + 1) % len(self._keyframes)
|
408 |
+
keyframe = keyframes[keyframe_index][0]
|
409 |
+
|
410 |
+
with server.scene.add_3d_gui_container(
|
411 |
+
"/camera_edit_panel",
|
412 |
+
position=transition_pos,
|
413 |
+
) as camera_edit_panel:
|
414 |
+
self._camera_edit_panel = camera_edit_panel
|
415 |
+
override_transition_enabled = server.gui.add_checkbox(
|
416 |
+
"Override transition",
|
417 |
+
initial_value=keyframe.override_transition_enabled,
|
418 |
+
)
|
419 |
+
override_transition_sec = server.gui.add_number(
|
420 |
+
"Override transition (sec)",
|
421 |
+
initial_value=(
|
422 |
+
keyframe.override_transition_sec
|
423 |
+
if keyframe.override_transition_sec is not None
|
424 |
+
else self.default_transition_sec
|
425 |
+
),
|
426 |
+
min=0.001,
|
427 |
+
max=30.0,
|
428 |
+
step=0.001,
|
429 |
+
disabled=not override_transition_enabled.value,
|
430 |
+
)
|
431 |
+
close_button = server.gui.add_button("Close")
|
432 |
+
|
433 |
+
@override_transition_enabled.on_update
|
434 |
+
def _(_) -> None:
|
435 |
+
keyframe.override_transition_enabled = (
|
436 |
+
override_transition_enabled.value
|
437 |
+
)
|
438 |
+
override_transition_sec.disabled = (
|
439 |
+
not override_transition_enabled.value
|
440 |
+
)
|
441 |
+
self._duration_element.value = self.compute_duration()
|
442 |
+
|
443 |
+
@override_transition_sec.on_update
|
444 |
+
def _(_) -> None:
|
445 |
+
keyframe.override_transition_sec = override_transition_sec.value
|
446 |
+
self._duration_element.value = self.compute_duration()
|
447 |
+
|
448 |
+
@close_button.on_click
|
449 |
+
def _(_) -> None:
|
450 |
+
assert camera_edit_panel is not None
|
451 |
+
camera_edit_panel.remove()
|
452 |
+
self._camera_edit_panel = None
|
453 |
+
|
454 |
+
(num_transitions_plus_1,) = transition_times_cumsum.shape
|
455 |
+
for i in range(num_transitions_plus_1 - 1):
|
456 |
+
make_transition_handle(i)
|
457 |
+
|
458 |
+
def compute_duration(self) -> float:
|
459 |
+
"""Compute the total duration of the trajectory."""
|
460 |
+
total = 0.0
|
461 |
+
for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
|
462 |
+
if i == 0 and not self.loop:
|
463 |
+
continue
|
464 |
+
del frustum
|
465 |
+
total += (
|
466 |
+
keyframe.override_transition_sec
|
467 |
+
if keyframe.override_transition_enabled
|
468 |
+
and keyframe.override_transition_sec is not None
|
469 |
+
else self.default_transition_sec
|
470 |
+
)
|
471 |
+
return total
|
472 |
+
|
473 |
+
def compute_transition_times_cumsum(self) -> np.ndarray:
|
474 |
+
"""Compute the total duration of the trajectory."""
|
475 |
+
total = 0.0
|
476 |
+
out = [0.0]
|
477 |
+
for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
|
478 |
+
if i == 0:
|
479 |
+
continue
|
480 |
+
del frustum
|
481 |
+
total += (
|
482 |
+
keyframe.override_transition_sec
|
483 |
+
if keyframe.override_transition_enabled
|
484 |
+
and keyframe.override_transition_sec is not None
|
485 |
+
else self.default_transition_sec
|
486 |
+
)
|
487 |
+
out.append(total)
|
488 |
+
|
489 |
+
if self.loop:
|
490 |
+
keyframe = next(iter(self._keyframes.values()))[0]
|
491 |
+
total += (
|
492 |
+
keyframe.override_transition_sec
|
493 |
+
if keyframe.override_transition_enabled
|
494 |
+
and keyframe.override_transition_sec is not None
|
495 |
+
else self.default_transition_sec
|
496 |
+
)
|
497 |
+
out.append(total)
|
498 |
+
|
499 |
+
return np.array(out)
|
500 |
+
|
501 |
+
|
502 |
+
@dataclasses.dataclass
|
503 |
+
class GuiState:
|
504 |
+
preview_render: bool
|
505 |
+
preview_fov: float
|
506 |
+
preview_aspect: float
|
507 |
+
camera_traj_list: list | None
|
508 |
+
active_input_index: int
|
509 |
+
|
510 |
+
|
511 |
+
def define_gui(
|
512 |
+
server: viser.ViserServer,
|
513 |
+
init_fov: float = 75.0,
|
514 |
+
img_wh: tuple[int, int] = (576, 576),
|
515 |
+
**kwargs,
|
516 |
+
) -> GuiState:
|
517 |
+
gui_state = GuiState(
|
518 |
+
preview_render=False,
|
519 |
+
preview_fov=0.0,
|
520 |
+
preview_aspect=1.0,
|
521 |
+
camera_traj_list=None,
|
522 |
+
active_input_index=0,
|
523 |
+
)
|
524 |
+
|
525 |
+
with server.gui.add_folder(
|
526 |
+
"Preset camera trajectories", order=99, expand_by_default=False
|
527 |
+
):
|
528 |
+
preset_traj_dropdown = server.gui.add_dropdown(
|
529 |
+
"Options",
|
530 |
+
[
|
531 |
+
"orbit",
|
532 |
+
"spiral",
|
533 |
+
"lemniscate",
|
534 |
+
"zoom-out",
|
535 |
+
"dolly zoom-out",
|
536 |
+
],
|
537 |
+
initial_value="orbit",
|
538 |
+
hint="Select a preset camera trajectory.",
|
539 |
+
)
|
540 |
+
preset_duration_num = server.gui.add_number(
|
541 |
+
"Duration (sec)",
|
542 |
+
min=1.0,
|
543 |
+
max=60.0,
|
544 |
+
step=0.5,
|
545 |
+
initial_value=2.0,
|
546 |
+
)
|
547 |
+
preset_submit_button = server.gui.add_button(
|
548 |
+
"Submit",
|
549 |
+
icon=viser.Icon.PICK,
|
550 |
+
hint="Add a new keyframe at the current pose.",
|
551 |
+
)
|
552 |
+
|
553 |
+
@preset_submit_button.on_click
|
554 |
+
def _(event: viser.GuiEvent) -> None:
|
555 |
+
camera_traj.reset()
|
556 |
+
gui_state.camera_traj_list = None
|
557 |
+
|
558 |
+
duration = preset_duration_num.value
|
559 |
+
fps = framerate_number.value
|
560 |
+
num_frames = int(duration * fps)
|
561 |
+
transition_sec = duration / num_frames
|
562 |
+
transition_sec_number.value = transition_sec
|
563 |
+
assert event.client_id is not None
|
564 |
+
transition_sec_number.disabled = True
|
565 |
+
loop_checkbox.disabled = True
|
566 |
+
add_keyframe_button.disabled = True
|
567 |
+
|
568 |
+
camera = server.get_clients()[event.client_id].camera
|
569 |
+
start_w2c = torch.linalg.inv(
|
570 |
+
torch.as_tensor(
|
571 |
+
vt.SE3.from_rotation_and_translation(
|
572 |
+
vt.SO3(camera.wxyz), camera.position
|
573 |
+
).as_matrix(),
|
574 |
+
dtype=torch.float32,
|
575 |
+
)
|
576 |
+
)
|
577 |
+
look_at = torch.as_tensor(camera.look_at, dtype=torch.float32)
|
578 |
+
up_direction = torch.as_tensor(camera.up_direction, dtype=torch.float32)
|
579 |
+
poses, fovs = get_preset_pose_fov(
|
580 |
+
option=preset_traj_dropdown.value, # type: ignore
|
581 |
+
num_frames=num_frames,
|
582 |
+
start_w2c=start_w2c,
|
583 |
+
look_at=look_at,
|
584 |
+
up_direction=up_direction,
|
585 |
+
fov=camera.fov,
|
586 |
+
)
|
587 |
+
assert poses is not None and fovs is not None
|
588 |
+
for pose, fov in zip(poses, fovs):
|
589 |
+
camera_traj.add_camera(
|
590 |
+
Keyframe.from_se3(
|
591 |
+
vt.SE3.from_matrix(pose),
|
592 |
+
fov=fov,
|
593 |
+
aspect=img_wh[0] / img_wh[1],
|
594 |
+
)
|
595 |
+
)
|
596 |
+
|
597 |
+
duration_number.value = camera_traj.compute_duration()
|
598 |
+
camera_traj.update_spline()
|
599 |
+
|
600 |
+
with server.gui.add_folder("Advanced", expand_by_default=False, order=100):
|
601 |
+
transition_sec_number = server.gui.add_number(
|
602 |
+
"Transition (sec)",
|
603 |
+
min=0.001,
|
604 |
+
max=30.0,
|
605 |
+
step=0.001,
|
606 |
+
initial_value=1.5,
|
607 |
+
hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
|
608 |
+
)
|
609 |
+
framerate_number = server.gui.add_number(
|
610 |
+
"FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
|
611 |
+
)
|
612 |
+
framerate_buttons = server.gui.add_button_group("", ("24", "30", "60"))
|
613 |
+
duration_number = server.gui.add_number(
|
614 |
+
"Duration (sec)",
|
615 |
+
min=0.0,
|
616 |
+
max=1e8,
|
617 |
+
step=0.001,
|
618 |
+
initial_value=0.0,
|
619 |
+
disabled=True,
|
620 |
+
)
|
621 |
+
|
622 |
+
@framerate_buttons.on_click
|
623 |
+
def _(_) -> None:
|
624 |
+
framerate_number.value = float(framerate_buttons.value)
|
625 |
+
|
626 |
+
fov_degree_slider = server.gui.add_slider(
|
627 |
+
"FOV",
|
628 |
+
initial_value=init_fov,
|
629 |
+
min=0.1,
|
630 |
+
max=175.0,
|
631 |
+
step=0.01,
|
632 |
+
hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
|
633 |
+
)
|
634 |
+
|
635 |
+
@fov_degree_slider.on_update
|
636 |
+
def _(_) -> None:
|
637 |
+
fov_radians = fov_degree_slider.value / 180.0 * np.pi
|
638 |
+
for client in server.get_clients().values():
|
639 |
+
client.camera.fov = fov_radians
|
640 |
+
camera_traj.default_fov = fov_radians
|
641 |
+
|
642 |
+
# Updating the aspect ratio will also re-render the camera frustums.
|
643 |
+
# Could rethink this.
|
644 |
+
camera_traj.update_aspect(img_wh[0] / img_wh[1])
|
645 |
+
compute_and_update_preview_camera_state()
|
646 |
+
|
647 |
+
scene_node_prefix = "/render_assets"
|
648 |
+
base_scene_node = server.scene.add_frame(scene_node_prefix, show_axes=False)
|
649 |
+
add_keyframe_button = server.gui.add_button(
|
650 |
+
"Add keyframe",
|
651 |
+
icon=viser.Icon.PLUS,
|
652 |
+
hint="Add a new keyframe at the current pose.",
|
653 |
+
)
|
654 |
+
|
655 |
+
@add_keyframe_button.on_click
|
656 |
+
def _(event: viser.GuiEvent) -> None:
|
657 |
+
assert event.client_id is not None
|
658 |
+
camera = server.get_clients()[event.client_id].camera
|
659 |
+
pose = vt.SE3.from_rotation_and_translation(
|
660 |
+
vt.SO3(camera.wxyz), camera.position
|
661 |
+
)
|
662 |
+
print(f"client {event.client_id} at {camera.position} {camera.wxyz}")
|
663 |
+
print(f"camera pose {pose.as_matrix()}")
|
664 |
+
|
665 |
+
# Add this camera to the trajectory.
|
666 |
+
camera_traj.add_camera(
|
667 |
+
Keyframe.from_camera(
|
668 |
+
camera,
|
669 |
+
aspect=img_wh[0] / img_wh[1],
|
670 |
+
),
|
671 |
+
)
|
672 |
+
duration_number.value = camera_traj.compute_duration()
|
673 |
+
camera_traj.update_spline()
|
674 |
+
|
675 |
+
clear_keyframes_button = server.gui.add_button(
|
676 |
+
"Clear keyframes",
|
677 |
+
icon=viser.Icon.TRASH,
|
678 |
+
hint="Remove all keyframes from the render trajectory.",
|
679 |
+
)
|
680 |
+
|
681 |
+
@clear_keyframes_button.on_click
|
682 |
+
def _(event: viser.GuiEvent) -> None:
|
683 |
+
assert event.client_id is not None
|
684 |
+
client = server.get_clients()[event.client_id]
|
685 |
+
with client.atomic(), client.gui.add_modal("Confirm") as modal:
|
686 |
+
client.gui.add_markdown("Clear all keyframes?")
|
687 |
+
confirm_button = client.gui.add_button(
|
688 |
+
"Yes", color="red", icon=viser.Icon.TRASH
|
689 |
+
)
|
690 |
+
exit_button = client.gui.add_button("Cancel")
|
691 |
+
|
692 |
+
@confirm_button.on_click
|
693 |
+
def _(_) -> None:
|
694 |
+
camera_traj.reset()
|
695 |
+
modal.close()
|
696 |
+
|
697 |
+
duration_number.value = camera_traj.compute_duration()
|
698 |
+
add_keyframe_button.disabled = False
|
699 |
+
transition_sec_number.disabled = False
|
700 |
+
transition_sec_number.value = 1.5
|
701 |
+
loop_checkbox.disabled = False
|
702 |
+
|
703 |
+
nonlocal gui_state
|
704 |
+
gui_state.camera_traj_list = None
|
705 |
+
|
706 |
+
@exit_button.on_click
|
707 |
+
def _(_) -> None:
|
708 |
+
modal.close()
|
709 |
+
|
710 |
+
play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
|
711 |
+
pause_button = server.gui.add_button(
|
712 |
+
"Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
|
713 |
+
)
|
714 |
+
|
715 |
+
# Poll the play button to see if we should be playing endlessly.
|
716 |
+
def play() -> None:
|
717 |
+
while True:
|
718 |
+
while not play_button.visible:
|
719 |
+
max_frame = int(framerate_number.value * duration_number.value)
|
720 |
+
if max_frame > 0:
|
721 |
+
assert preview_frame_slider is not None
|
722 |
+
preview_frame_slider.value = (
|
723 |
+
preview_frame_slider.value + 1
|
724 |
+
) % max_frame
|
725 |
+
time.sleep(1.0 / framerate_number.value)
|
726 |
+
time.sleep(0.1)
|
727 |
+
|
728 |
+
threading.Thread(target=play).start()
|
729 |
+
|
730 |
+
# Play the camera trajectory when the play button is pressed.
|
731 |
+
@play_button.on_click
|
732 |
+
def _(_) -> None:
|
733 |
+
play_button.visible = False
|
734 |
+
pause_button.visible = True
|
735 |
+
|
736 |
+
# Play the camera trajectory when the play button is pressed.
|
737 |
+
@pause_button.on_click
|
738 |
+
def _(_) -> None:
|
739 |
+
play_button.visible = True
|
740 |
+
pause_button.visible = False
|
741 |
+
|
742 |
+
preview_render_button = server.gui.add_button(
|
743 |
+
"Preview render",
|
744 |
+
hint="Show a preview of the render in the viewport.",
|
745 |
+
icon=viser.Icon.CAMERA_CHECK,
|
746 |
+
)
|
747 |
+
preview_render_stop_button = server.gui.add_button(
|
748 |
+
"Exit render preview",
|
749 |
+
color="red",
|
750 |
+
icon=viser.Icon.CAMERA_CANCEL,
|
751 |
+
visible=False,
|
752 |
+
)
|
753 |
+
|
754 |
+
@preview_render_button.on_click
|
755 |
+
def _(_) -> None:
|
756 |
+
gui_state.preview_render = True
|
757 |
+
preview_render_button.visible = False
|
758 |
+
preview_render_stop_button.visible = True
|
759 |
+
play_button.visible = False
|
760 |
+
pause_button.visible = True
|
761 |
+
preset_submit_button.disabled = True
|
762 |
+
|
763 |
+
maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
|
764 |
+
if maybe_pose_and_fov_rad is None:
|
765 |
+
remove_preview_camera()
|
766 |
+
return
|
767 |
+
pose, fov = maybe_pose_and_fov_rad
|
768 |
+
del fov
|
769 |
+
|
770 |
+
# Hide all render assets when we're previewing the render.
|
771 |
+
nonlocal base_scene_node
|
772 |
+
base_scene_node.visible = False
|
773 |
+
|
774 |
+
# Back up and then set camera poses.
|
775 |
+
for client in server.get_clients().values():
|
776 |
+
camera_pose_backup_from_id[client.client_id] = (
|
777 |
+
client.camera.position,
|
778 |
+
client.camera.look_at,
|
779 |
+
client.camera.up_direction,
|
780 |
+
)
|
781 |
+
with client.atomic():
|
782 |
+
client.camera.wxyz = pose.rotation().wxyz
|
783 |
+
client.camera.position = pose.translation()
|
784 |
+
|
785 |
+
def stop_preview_render() -> None:
|
786 |
+
gui_state.preview_render = False
|
787 |
+
preview_render_button.visible = True
|
788 |
+
preview_render_stop_button.visible = False
|
789 |
+
play_button.visible = True
|
790 |
+
pause_button.visible = False
|
791 |
+
preset_submit_button.disabled = False
|
792 |
+
|
793 |
+
# Revert camera poses.
|
794 |
+
for client in server.get_clients().values():
|
795 |
+
if client.client_id not in camera_pose_backup_from_id:
|
796 |
+
continue
|
797 |
+
cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
|
798 |
+
client.client_id
|
799 |
+
)
|
800 |
+
with client.atomic():
|
801 |
+
client.camera.position = cam_position
|
802 |
+
client.camera.look_at = cam_look_at
|
803 |
+
client.camera.up_direction = cam_up
|
804 |
+
client.flush()
|
805 |
+
|
806 |
+
# Un-hide render assets.
|
807 |
+
nonlocal base_scene_node
|
808 |
+
base_scene_node.visible = True
|
809 |
+
remove_preview_camera()
|
810 |
+
|
811 |
+
@preview_render_stop_button.on_click
|
812 |
+
def _(_) -> None:
|
813 |
+
stop_preview_render()
|
814 |
+
|
815 |
+
def get_max_frame_index() -> int:
|
816 |
+
return max(1, int(framerate_number.value * duration_number.value) - 1)
|
817 |
+
|
818 |
+
def add_preview_frame_slider() -> viser.GuiInputHandle[int] | None:
|
819 |
+
"""Helper for creating the current frame # slider. This is removed and
|
820 |
+
re-added anytime the `max` value changes."""
|
821 |
+
|
822 |
+
preview_frame_slider = server.gui.add_slider(
|
823 |
+
"Preview frame",
|
824 |
+
min=0,
|
825 |
+
max=get_max_frame_index(),
|
826 |
+
step=1,
|
827 |
+
initial_value=0,
|
828 |
+
order=set_traj_button.order + 0.01,
|
829 |
+
disabled=get_max_frame_index() == 1,
|
830 |
+
)
|
831 |
+
play_button.disabled = preview_frame_slider.disabled
|
832 |
+
preview_render_button.disabled = preview_frame_slider.disabled
|
833 |
+
set_traj_button.disabled = preview_frame_slider.disabled
|
834 |
+
|
835 |
+
@preview_frame_slider.on_update
|
836 |
+
def _(_) -> None:
|
837 |
+
nonlocal preview_camera_handle
|
838 |
+
maybe_pose_and_fov_rad = compute_and_update_preview_camera_state()
|
839 |
+
if maybe_pose_and_fov_rad is None:
|
840 |
+
return
|
841 |
+
pose, fov_rad = maybe_pose_and_fov_rad
|
842 |
+
|
843 |
+
preview_camera_handle = server.scene.add_camera_frustum(
|
844 |
+
str(Path(scene_node_prefix) / "preview_camera"),
|
845 |
+
fov=fov_rad,
|
846 |
+
aspect=img_wh[0] / img_wh[1],
|
847 |
+
scale=0.35,
|
848 |
+
wxyz=pose.rotation().wxyz,
|
849 |
+
position=pose.translation(),
|
850 |
+
color=(10, 200, 30),
|
851 |
+
)
|
852 |
+
if gui_state.preview_render:
|
853 |
+
for client in server.get_clients().values():
|
854 |
+
with client.atomic():
|
855 |
+
client.camera.wxyz = pose.rotation().wxyz
|
856 |
+
client.camera.position = pose.translation()
|
857 |
+
|
858 |
+
return preview_frame_slider
|
859 |
+
|
860 |
+
set_traj_button = server.gui.add_button(
|
861 |
+
"Set camera trajectory",
|
862 |
+
color="green",
|
863 |
+
icon=viser.Icon.CHECK,
|
864 |
+
hint="Save the camera trajectory for rendering.",
|
865 |
+
)
|
866 |
+
|
867 |
+
@set_traj_button.on_click
|
868 |
+
def _(event: viser.GuiEvent) -> None:
|
869 |
+
assert event.client is not None
|
870 |
+
num_frames = int(framerate_number.value * duration_number.value)
|
871 |
+
|
872 |
+
def get_intrinsics(W, H, fov_rad):
|
873 |
+
focal = 0.5 * H / np.tan(0.5 * fov_rad)
|
874 |
+
return np.array(
|
875 |
+
[[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]
|
876 |
+
)
|
877 |
+
|
878 |
+
camera_traj_list = []
|
879 |
+
for i in range(num_frames):
|
880 |
+
maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
|
881 |
+
i / num_frames
|
882 |
+
)
|
883 |
+
if maybe_pose_and_fov_rad is None:
|
884 |
+
return
|
885 |
+
pose, fov_rad = maybe_pose_and_fov_rad
|
886 |
+
H = img_wh[1]
|
887 |
+
W = img_wh[0]
|
888 |
+
K = get_intrinsics(W, H, fov_rad)
|
889 |
+
w2c = pose.inverse().as_matrix()
|
890 |
+
camera_traj_list.append(
|
891 |
+
{
|
892 |
+
"w2c": w2c.flatten().tolist(),
|
893 |
+
"K": K.flatten().tolist(),
|
894 |
+
"img_wh": (W, H),
|
895 |
+
}
|
896 |
+
)
|
897 |
+
nonlocal gui_state
|
898 |
+
gui_state.camera_traj_list = camera_traj_list
|
899 |
+
print(f"Get camera_traj_list: {gui_state.camera_traj_list}")
|
900 |
+
|
901 |
+
stop_preview_render()
|
902 |
+
|
903 |
+
preview_frame_slider = add_preview_frame_slider()
|
904 |
+
|
905 |
+
loop_checkbox = server.gui.add_checkbox(
|
906 |
+
"Loop", False, hint="Add a segment between the first and last keyframes."
|
907 |
+
)
|
908 |
+
|
909 |
+
@loop_checkbox.on_update
|
910 |
+
def _(_) -> None:
|
911 |
+
camera_traj.loop = loop_checkbox.value
|
912 |
+
duration_number.value = camera_traj.compute_duration()
|
913 |
+
|
914 |
+
@transition_sec_number.on_update
|
915 |
+
def _(_) -> None:
|
916 |
+
camera_traj.default_transition_sec = transition_sec_number.value
|
917 |
+
duration_number.value = camera_traj.compute_duration()
|
918 |
+
|
919 |
+
preview_camera_handle: viser.SceneNodeHandle | None = None
|
920 |
+
|
921 |
+
def remove_preview_camera() -> None:
|
922 |
+
nonlocal preview_camera_handle
|
923 |
+
if preview_camera_handle is not None:
|
924 |
+
preview_camera_handle.remove()
|
925 |
+
preview_camera_handle = None
|
926 |
+
|
927 |
+
def compute_and_update_preview_camera_state() -> tuple[vt.SE3, float] | None:
|
928 |
+
"""Update the render tab state with the current preview camera pose.
|
929 |
+
Returns current camera pose + FOV if available."""
|
930 |
+
|
931 |
+
if preview_frame_slider is None:
|
932 |
+
return None
|
933 |
+
maybe_pose_and_fov_rad = camera_traj.interpolate_pose_and_fov_rad(
|
934 |
+
preview_frame_slider.value / get_max_frame_index()
|
935 |
+
)
|
936 |
+
if maybe_pose_and_fov_rad is None:
|
937 |
+
remove_preview_camera()
|
938 |
+
return None
|
939 |
+
pose, fov_rad = maybe_pose_and_fov_rad
|
940 |
+
gui_state.preview_fov = fov_rad
|
941 |
+
gui_state.preview_aspect = camera_traj.get_aspect()
|
942 |
+
return pose, fov_rad
|
943 |
+
|
944 |
+
# We back up the camera poses before and after we start previewing renders.
|
945 |
+
camera_pose_backup_from_id: dict[int, tuple] = {}
|
946 |
+
|
947 |
+
# Update the # of frames.
|
948 |
+
@duration_number.on_update
|
949 |
+
@framerate_number.on_update
|
950 |
+
def _(_) -> None:
|
951 |
+
remove_preview_camera() # Will be re-added when slider is updated.
|
952 |
+
|
953 |
+
nonlocal preview_frame_slider
|
954 |
+
old = preview_frame_slider
|
955 |
+
assert old is not None
|
956 |
+
|
957 |
+
preview_frame_slider = add_preview_frame_slider()
|
958 |
+
if preview_frame_slider is not None:
|
959 |
+
old.remove()
|
960 |
+
else:
|
961 |
+
preview_frame_slider = old
|
962 |
+
|
963 |
+
camera_traj.framerate = framerate_number.value
|
964 |
+
camera_traj.update_spline()
|
965 |
+
|
966 |
+
camera_traj = CameraTrajectory(
|
967 |
+
server,
|
968 |
+
duration_number,
|
969 |
+
scene_node_prefix=scene_node_prefix,
|
970 |
+
**kwargs,
|
971 |
+
)
|
972 |
+
camera_traj.default_fov = fov_degree_slider.value / 180.0 * np.pi
|
973 |
+
camera_traj.default_transition_sec = transition_sec_number.value
|
974 |
+
|
975 |
+
return gui_state
|
seva/model.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass, field
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from seva.modules.layers import (
|
7 |
+
Downsample,
|
8 |
+
GroupNorm32,
|
9 |
+
ResBlock,
|
10 |
+
TimestepEmbedSequential,
|
11 |
+
Upsample,
|
12 |
+
timestep_embedding,
|
13 |
+
)
|
14 |
+
from seva.modules.transformer import MultiviewTransformer
|
15 |
+
|
16 |
+
|
17 |
+
@dataclass
|
18 |
+
class SevaParams(object):
|
19 |
+
in_channels: int = 11
|
20 |
+
model_channels: int = 320
|
21 |
+
out_channels: int = 4
|
22 |
+
num_frames: int = 21
|
23 |
+
num_res_blocks: int = 2
|
24 |
+
attention_resolutions: list[int] = field(default_factory=lambda: [4, 2, 1])
|
25 |
+
channel_mult: list[int] = field(default_factory=lambda: [1, 2, 4, 4])
|
26 |
+
num_head_channels: int = 64
|
27 |
+
transformer_depth: list[int] = field(default_factory=lambda: [1, 1, 1, 1])
|
28 |
+
context_dim: int = 1024
|
29 |
+
dense_in_channels: int = 6
|
30 |
+
dropout: float = 0.0
|
31 |
+
unflatten_names: list[str] = field(
|
32 |
+
default_factory=lambda: ["middle_ds8", "output_ds4", "output_ds2"]
|
33 |
+
)
|
34 |
+
|
35 |
+
def __post_init__(self):
|
36 |
+
assert len(self.channel_mult) == len(self.transformer_depth)
|
37 |
+
|
38 |
+
|
39 |
+
class Seva(nn.Module):
|
40 |
+
def __init__(self, params: SevaParams) -> None:
|
41 |
+
super().__init__()
|
42 |
+
self.params = params
|
43 |
+
self.model_channels = params.model_channels
|
44 |
+
self.out_channels = params.out_channels
|
45 |
+
self.num_head_channels = params.num_head_channels
|
46 |
+
|
47 |
+
time_embed_dim = params.model_channels * 4
|
48 |
+
self.time_embed = nn.Sequential(
|
49 |
+
nn.Linear(params.model_channels, time_embed_dim),
|
50 |
+
nn.SiLU(),
|
51 |
+
nn.Linear(time_embed_dim, time_embed_dim),
|
52 |
+
)
|
53 |
+
|
54 |
+
self.input_blocks = nn.ModuleList(
|
55 |
+
[
|
56 |
+
TimestepEmbedSequential(
|
57 |
+
nn.Conv2d(params.in_channels, params.model_channels, 3, padding=1)
|
58 |
+
)
|
59 |
+
]
|
60 |
+
)
|
61 |
+
self._feature_size = params.model_channels
|
62 |
+
input_block_chans = [params.model_channels]
|
63 |
+
ch = params.model_channels
|
64 |
+
ds = 1
|
65 |
+
for level, mult in enumerate(params.channel_mult):
|
66 |
+
for _ in range(params.num_res_blocks):
|
67 |
+
input_layers: list[ResBlock | MultiviewTransformer | Downsample] = [
|
68 |
+
ResBlock(
|
69 |
+
channels=ch,
|
70 |
+
emb_channels=time_embed_dim,
|
71 |
+
out_channels=mult * params.model_channels,
|
72 |
+
dense_in_channels=params.dense_in_channels,
|
73 |
+
dropout=params.dropout,
|
74 |
+
)
|
75 |
+
]
|
76 |
+
ch = mult * params.model_channels
|
77 |
+
if ds in params.attention_resolutions:
|
78 |
+
num_heads = ch // params.num_head_channels
|
79 |
+
dim_head = params.num_head_channels
|
80 |
+
input_layers.append(
|
81 |
+
MultiviewTransformer(
|
82 |
+
ch,
|
83 |
+
num_heads,
|
84 |
+
dim_head,
|
85 |
+
name=f"input_ds{ds}",
|
86 |
+
depth=params.transformer_depth[level],
|
87 |
+
context_dim=params.context_dim,
|
88 |
+
unflatten_names=params.unflatten_names,
|
89 |
+
)
|
90 |
+
)
|
91 |
+
self.input_blocks.append(TimestepEmbedSequential(*input_layers))
|
92 |
+
self._feature_size += ch
|
93 |
+
input_block_chans.append(ch)
|
94 |
+
if level != len(params.channel_mult) - 1:
|
95 |
+
ds *= 2
|
96 |
+
out_ch = ch
|
97 |
+
self.input_blocks.append(
|
98 |
+
TimestepEmbedSequential(Downsample(ch, out_channels=out_ch))
|
99 |
+
)
|
100 |
+
ch = out_ch
|
101 |
+
input_block_chans.append(ch)
|
102 |
+
self._feature_size += ch
|
103 |
+
|
104 |
+
num_heads = ch // params.num_head_channels
|
105 |
+
dim_head = params.num_head_channels
|
106 |
+
|
107 |
+
self.middle_block = TimestepEmbedSequential(
|
108 |
+
ResBlock(
|
109 |
+
channels=ch,
|
110 |
+
emb_channels=time_embed_dim,
|
111 |
+
out_channels=None,
|
112 |
+
dense_in_channels=params.dense_in_channels,
|
113 |
+
dropout=params.dropout,
|
114 |
+
),
|
115 |
+
MultiviewTransformer(
|
116 |
+
ch,
|
117 |
+
num_heads,
|
118 |
+
dim_head,
|
119 |
+
name=f"middle_ds{ds}",
|
120 |
+
depth=params.transformer_depth[-1],
|
121 |
+
context_dim=params.context_dim,
|
122 |
+
unflatten_names=params.unflatten_names,
|
123 |
+
),
|
124 |
+
ResBlock(
|
125 |
+
channels=ch,
|
126 |
+
emb_channels=time_embed_dim,
|
127 |
+
out_channels=None,
|
128 |
+
dense_in_channels=params.dense_in_channels,
|
129 |
+
dropout=params.dropout,
|
130 |
+
),
|
131 |
+
)
|
132 |
+
self._feature_size += ch
|
133 |
+
|
134 |
+
self.output_blocks = nn.ModuleList([])
|
135 |
+
for level, mult in list(enumerate(params.channel_mult))[::-1]:
|
136 |
+
for i in range(params.num_res_blocks + 1):
|
137 |
+
ich = input_block_chans.pop()
|
138 |
+
output_layers: list[ResBlock | MultiviewTransformer | Upsample] = [
|
139 |
+
ResBlock(
|
140 |
+
channels=ch + ich,
|
141 |
+
emb_channels=time_embed_dim,
|
142 |
+
out_channels=params.model_channels * mult,
|
143 |
+
dense_in_channels=params.dense_in_channels,
|
144 |
+
dropout=params.dropout,
|
145 |
+
)
|
146 |
+
]
|
147 |
+
ch = params.model_channels * mult
|
148 |
+
if ds in params.attention_resolutions:
|
149 |
+
num_heads = ch // params.num_head_channels
|
150 |
+
dim_head = params.num_head_channels
|
151 |
+
|
152 |
+
output_layers.append(
|
153 |
+
MultiviewTransformer(
|
154 |
+
ch,
|
155 |
+
num_heads,
|
156 |
+
dim_head,
|
157 |
+
name=f"output_ds{ds}",
|
158 |
+
depth=params.transformer_depth[level],
|
159 |
+
context_dim=params.context_dim,
|
160 |
+
unflatten_names=params.unflatten_names,
|
161 |
+
)
|
162 |
+
)
|
163 |
+
if level and i == params.num_res_blocks:
|
164 |
+
out_ch = ch
|
165 |
+
ds //= 2
|
166 |
+
output_layers.append(Upsample(ch, out_ch))
|
167 |
+
self.output_blocks.append(TimestepEmbedSequential(*output_layers))
|
168 |
+
self._feature_size += ch
|
169 |
+
|
170 |
+
self.out = nn.Sequential(
|
171 |
+
GroupNorm32(32, ch),
|
172 |
+
nn.SiLU(),
|
173 |
+
nn.Conv2d(self.model_channels, params.out_channels, 3, padding=1),
|
174 |
+
)
|
175 |
+
|
176 |
+
def forward(
|
177 |
+
self,
|
178 |
+
x: torch.Tensor,
|
179 |
+
t: torch.Tensor,
|
180 |
+
y: torch.Tensor,
|
181 |
+
dense_y: torch.Tensor,
|
182 |
+
num_frames: int | None = None,
|
183 |
+
) -> torch.Tensor:
|
184 |
+
num_frames = num_frames or self.params.num_frames
|
185 |
+
t_emb = timestep_embedding(t, self.model_channels)
|
186 |
+
t_emb = self.time_embed(t_emb)
|
187 |
+
|
188 |
+
hs = []
|
189 |
+
h = x
|
190 |
+
for module in self.input_blocks:
|
191 |
+
h = module(
|
192 |
+
h,
|
193 |
+
emb=t_emb,
|
194 |
+
context=y,
|
195 |
+
dense_emb=dense_y,
|
196 |
+
num_frames=num_frames,
|
197 |
+
)
|
198 |
+
hs.append(h)
|
199 |
+
h = self.middle_block(
|
200 |
+
h,
|
201 |
+
emb=t_emb,
|
202 |
+
context=y,
|
203 |
+
dense_emb=dense_y,
|
204 |
+
num_frames=num_frames,
|
205 |
+
)
|
206 |
+
for module in self.output_blocks:
|
207 |
+
h = torch.cat([h, hs.pop()], dim=1)
|
208 |
+
h = module(
|
209 |
+
h,
|
210 |
+
emb=t_emb,
|
211 |
+
context=y,
|
212 |
+
dense_emb=dense_y,
|
213 |
+
num_frames=num_frames,
|
214 |
+
)
|
215 |
+
h = h.type(x.dtype)
|
216 |
+
return self.out(h)
|
217 |
+
|
218 |
+
|
219 |
+
class SGMWrapper(nn.Module):
|
220 |
+
def __init__(self, module: Seva):
|
221 |
+
super().__init__()
|
222 |
+
self.module = module
|
223 |
+
|
224 |
+
def forward(
|
225 |
+
self, x: torch.Tensor, t: torch.Tensor, c: dict, **kwargs
|
226 |
+
) -> torch.Tensor:
|
227 |
+
x = torch.cat((x, c.get("concat", torch.Tensor([]).type_as(x))), dim=1)
|
228 |
+
return self.module(
|
229 |
+
x,
|
230 |
+
t=t,
|
231 |
+
y=c["crossattn"],
|
232 |
+
dense_y=c["dense_vector"],
|
233 |
+
**kwargs,
|
234 |
+
)
|
seva/modules/__init__.py
ADDED
File without changes
|
seva/modules/autoencoder.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from diffusers.models import AutoencoderKL # type: ignore
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class AutoEncoder(nn.Module):
|
7 |
+
scale_factor: float = 0.18215
|
8 |
+
downsample: int = 8
|
9 |
+
|
10 |
+
def __init__(self, chunk_size: int | None = None):
|
11 |
+
super().__init__()
|
12 |
+
self.module = AutoencoderKL.from_pretrained(
|
13 |
+
"stabilityai/stable-diffusion-2-1-base",
|
14 |
+
subfolder="vae",
|
15 |
+
force_download=False,
|
16 |
+
low_cpu_mem_usage=False,
|
17 |
+
)
|
18 |
+
self.module.eval().requires_grad_(False) # type: ignore
|
19 |
+
self.chunk_size = chunk_size
|
20 |
+
|
21 |
+
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
22 |
+
return (
|
23 |
+
self.module.encode(x).latent_dist.mean # type: ignore
|
24 |
+
* self.scale_factor
|
25 |
+
)
|
26 |
+
|
27 |
+
def encode(self, x: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
|
28 |
+
chunk_size = chunk_size or self.chunk_size
|
29 |
+
if chunk_size is not None:
|
30 |
+
return torch.cat(
|
31 |
+
[self._encode(x_chunk) for x_chunk in x.split(chunk_size)],
|
32 |
+
dim=0,
|
33 |
+
)
|
34 |
+
else:
|
35 |
+
return self._encode(x)
|
36 |
+
|
37 |
+
def _decode(self, z: torch.Tensor) -> torch.Tensor:
|
38 |
+
return self.module.decode(z / self.scale_factor).sample # type: ignore
|
39 |
+
|
40 |
+
def decode(self, z: torch.Tensor, chunk_size: int | None = None) -> torch.Tensor:
|
41 |
+
chunk_size = chunk_size or self.chunk_size
|
42 |
+
if chunk_size is not None:
|
43 |
+
return torch.cat(
|
44 |
+
[self._decode(z_chunk) for z_chunk in z.split(chunk_size)],
|
45 |
+
dim=0,
|
46 |
+
)
|
47 |
+
else:
|
48 |
+
return self._decode(z)
|
49 |
+
|
50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
51 |
+
return self.decode(self.encode(x))
|
seva/modules/conditioner.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kornia
|
2 |
+
import open_clip
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
|
6 |
+
|
7 |
+
class CLIPConditioner(nn.Module):
|
8 |
+
mean: torch.Tensor
|
9 |
+
std: torch.Tensor
|
10 |
+
|
11 |
+
def __init__(self):
|
12 |
+
super().__init__()
|
13 |
+
self.module = open_clip.create_model_and_transforms(
|
14 |
+
"ViT-H-14", pretrained="laion2b_s32b_b79k"
|
15 |
+
)[0]
|
16 |
+
self.module.eval().requires_grad_(False) # type: ignore
|
17 |
+
self.register_buffer(
|
18 |
+
"mean", torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False
|
19 |
+
)
|
20 |
+
self.register_buffer(
|
21 |
+
"std", torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False
|
22 |
+
)
|
23 |
+
|
24 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
25 |
+
x = kornia.geometry.resize(
|
26 |
+
x,
|
27 |
+
(224, 224),
|
28 |
+
interpolation="bicubic",
|
29 |
+
align_corners=True,
|
30 |
+
antialias=True,
|
31 |
+
)
|
32 |
+
x = (x + 1.0) / 2.0
|
33 |
+
x = kornia.enhance.normalize(x, self.mean, self.std)
|
34 |
+
return x
|
35 |
+
|
36 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
37 |
+
x = self.preprocess(x)
|
38 |
+
x = self.module.encode_image(x)
|
39 |
+
return x
|
seva/modules/layers.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from einops import repeat
|
6 |
+
from torch import nn
|
7 |
+
|
8 |
+
from .transformer import MultiviewTransformer
|
9 |
+
|
10 |
+
|
11 |
+
def timestep_embedding(
|
12 |
+
timesteps: torch.Tensor,
|
13 |
+
dim: int,
|
14 |
+
max_period: int = 10000,
|
15 |
+
repeat_only: bool = False,
|
16 |
+
) -> torch.Tensor:
|
17 |
+
if not repeat_only:
|
18 |
+
half = dim // 2
|
19 |
+
freqs = torch.exp(
|
20 |
+
-math.log(max_period)
|
21 |
+
* torch.arange(start=0, end=half, dtype=torch.float32)
|
22 |
+
/ half
|
23 |
+
).to(device=timesteps.device)
|
24 |
+
args = timesteps[:, None].float() * freqs[None]
|
25 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
26 |
+
if dim % 2:
|
27 |
+
embedding = torch.cat(
|
28 |
+
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
|
29 |
+
)
|
30 |
+
else:
|
31 |
+
embedding = repeat(timesteps, "b -> b d", d=dim)
|
32 |
+
return embedding
|
33 |
+
|
34 |
+
|
35 |
+
class Upsample(nn.Module):
|
36 |
+
def __init__(self, channels: int, out_channels: int | None = None):
|
37 |
+
super().__init__()
|
38 |
+
self.channels = channels
|
39 |
+
self.out_channels = out_channels or channels
|
40 |
+
self.conv = nn.Conv2d(self.channels, self.out_channels, 3, 1, 1)
|
41 |
+
|
42 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
43 |
+
assert x.shape[1] == self.channels
|
44 |
+
x = F.interpolate(x, scale_factor=2, mode="nearest")
|
45 |
+
x = self.conv(x)
|
46 |
+
return x
|
47 |
+
|
48 |
+
|
49 |
+
class Downsample(nn.Module):
|
50 |
+
def __init__(self, channels: int, out_channels: int | None = None):
|
51 |
+
super().__init__()
|
52 |
+
self.channels = channels
|
53 |
+
self.out_channels = out_channels or channels
|
54 |
+
self.op = nn.Conv2d(self.channels, self.out_channels, 3, 2, 1)
|
55 |
+
|
56 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
57 |
+
assert x.shape[1] == self.channels
|
58 |
+
return self.op(x)
|
59 |
+
|
60 |
+
|
61 |
+
class GroupNorm32(nn.GroupNorm):
|
62 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
63 |
+
return super().forward(input.float()).type(input.dtype)
|
64 |
+
|
65 |
+
|
66 |
+
class TimestepEmbedSequential(nn.Sequential):
|
67 |
+
def forward( # type: ignore[override]
|
68 |
+
self,
|
69 |
+
x: torch.Tensor,
|
70 |
+
emb: torch.Tensor,
|
71 |
+
context: torch.Tensor,
|
72 |
+
dense_emb: torch.Tensor,
|
73 |
+
num_frames: int,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
for layer in self:
|
76 |
+
if isinstance(layer, MultiviewTransformer):
|
77 |
+
assert num_frames is not None
|
78 |
+
x = layer(x, context, num_frames)
|
79 |
+
elif isinstance(layer, ResBlock):
|
80 |
+
x = layer(x, emb, dense_emb)
|
81 |
+
else:
|
82 |
+
x = layer(x)
|
83 |
+
return x
|
84 |
+
|
85 |
+
|
86 |
+
class ResBlock(nn.Module):
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
channels: int,
|
90 |
+
emb_channels: int,
|
91 |
+
out_channels: int | None,
|
92 |
+
dense_in_channels: int,
|
93 |
+
dropout: float,
|
94 |
+
):
|
95 |
+
super().__init__()
|
96 |
+
out_channels = out_channels or channels
|
97 |
+
|
98 |
+
self.in_layers = nn.Sequential(
|
99 |
+
GroupNorm32(32, channels),
|
100 |
+
nn.SiLU(),
|
101 |
+
nn.Conv2d(channels, out_channels, 3, 1, 1),
|
102 |
+
)
|
103 |
+
self.emb_layers = nn.Sequential(
|
104 |
+
nn.SiLU(), nn.Linear(emb_channels, out_channels)
|
105 |
+
)
|
106 |
+
self.dense_emb_layers = nn.Sequential(
|
107 |
+
nn.Conv2d(dense_in_channels, 2 * channels, 1, 1, 0)
|
108 |
+
)
|
109 |
+
self.out_layers = nn.Sequential(
|
110 |
+
GroupNorm32(32, out_channels),
|
111 |
+
nn.SiLU(),
|
112 |
+
nn.Dropout(dropout),
|
113 |
+
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
|
114 |
+
)
|
115 |
+
if out_channels == channels:
|
116 |
+
self.skip_connection = nn.Identity()
|
117 |
+
else:
|
118 |
+
self.skip_connection = nn.Conv2d(channels, out_channels, 1, 1, 0)
|
119 |
+
|
120 |
+
def forward(
|
121 |
+
self, x: torch.Tensor, emb: torch.Tensor, dense_emb: torch.Tensor
|
122 |
+
) -> torch.Tensor:
|
123 |
+
in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
|
124 |
+
h = in_rest(x)
|
125 |
+
dense = self.dense_emb_layers(
|
126 |
+
F.interpolate(
|
127 |
+
dense_emb, size=h.shape[2:], mode="bilinear", align_corners=True
|
128 |
+
)
|
129 |
+
).type(h.dtype)
|
130 |
+
dense_scale, dense_shift = torch.chunk(dense, 2, dim=1)
|
131 |
+
h = h * (1 + dense_scale) + dense_shift
|
132 |
+
h = in_conv(h)
|
133 |
+
emb_out = self.emb_layers(emb).type(h.dtype)
|
134 |
+
while len(emb_out.shape) < len(h.shape):
|
135 |
+
emb_out = emb_out[..., None]
|
136 |
+
h = h + emb_out
|
137 |
+
h = self.out_layers(h)
|
138 |
+
h = self.skip_connection(x) + h
|
139 |
+
return h
|
seva/modules/preprocessor.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import contextlib
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import sys
|
5 |
+
from typing import cast
|
6 |
+
|
7 |
+
import imageio.v3 as iio
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
|
11 |
+
|
12 |
+
class Dust3rPipeline(object):
|
13 |
+
def __init__(self, device: str | torch.device = "cuda"):
|
14 |
+
submodule_path = osp.realpath(
|
15 |
+
osp.join(osp.dirname(__file__), "../../third_party/dust3r/")
|
16 |
+
)
|
17 |
+
if submodule_path not in sys.path:
|
18 |
+
sys.path.insert(0, submodule_path)
|
19 |
+
try:
|
20 |
+
with open(os.devnull, "w") as f, contextlib.redirect_stdout(f):
|
21 |
+
from dust3r.cloud_opt import ( # type: ignore[import]
|
22 |
+
GlobalAlignerMode,
|
23 |
+
global_aligner,
|
24 |
+
)
|
25 |
+
from dust3r.image_pairs import make_pairs # type: ignore[import]
|
26 |
+
from dust3r.inference import inference # type: ignore[import]
|
27 |
+
from dust3r.model import AsymmetricCroCo3DStereo # type: ignore[import]
|
28 |
+
from dust3r.utils.image import load_images # type: ignore[import]
|
29 |
+
except ImportError:
|
30 |
+
raise ImportError(
|
31 |
+
"Missing required submodule: 'dust3r'. Please ensure that all submodules are properly set up.\n\n"
|
32 |
+
"To initialize them, run the following command in the project root:\n"
|
33 |
+
" git submodule update --init --recursive"
|
34 |
+
)
|
35 |
+
|
36 |
+
self.device = torch.device(device)
|
37 |
+
self.model = AsymmetricCroCo3DStereo.from_pretrained(
|
38 |
+
"naver/DUSt3R_ViTLarge_BaseDecoder_512_dpt"
|
39 |
+
).to(self.device)
|
40 |
+
|
41 |
+
self._GlobalAlignerMode = GlobalAlignerMode
|
42 |
+
self._global_aligner = global_aligner
|
43 |
+
self._make_pairs = make_pairs
|
44 |
+
self._inference = inference
|
45 |
+
self._load_images = load_images
|
46 |
+
|
47 |
+
def infer_cameras_and_points(
|
48 |
+
self,
|
49 |
+
img_paths: list[str],
|
50 |
+
Ks: list[list] = None,
|
51 |
+
c2ws: list[list] = None,
|
52 |
+
batch_size: int = 16,
|
53 |
+
schedule: str = "cosine",
|
54 |
+
lr: float = 0.01,
|
55 |
+
niter: int = 500,
|
56 |
+
min_conf_thr: int = 3,
|
57 |
+
) -> tuple[
|
58 |
+
list[np.ndarray], np.ndarray, np.ndarray, list[np.ndarray], list[np.ndarray]
|
59 |
+
]:
|
60 |
+
num_img = len(img_paths)
|
61 |
+
if num_img == 1:
|
62 |
+
print("Only one image found, duplicating it to create a stereo pair.")
|
63 |
+
img_paths = img_paths * 2
|
64 |
+
|
65 |
+
images = self._load_images(img_paths, size=512)
|
66 |
+
pairs = self._make_pairs(
|
67 |
+
images,
|
68 |
+
scene_graph="complete",
|
69 |
+
prefilter=None,
|
70 |
+
symmetrize=True,
|
71 |
+
)
|
72 |
+
output = self._inference(pairs, self.model, self.device, batch_size=batch_size)
|
73 |
+
|
74 |
+
ori_imgs = [iio.imread(p) for p in img_paths]
|
75 |
+
ori_img_whs = np.array([img.shape[1::-1] for img in ori_imgs])
|
76 |
+
img_whs = np.concatenate([image["true_shape"][:, ::-1] for image in images], 0)
|
77 |
+
|
78 |
+
scene = self._global_aligner(
|
79 |
+
output,
|
80 |
+
device=self.device,
|
81 |
+
mode=self._GlobalAlignerMode.PointCloudOptimizer,
|
82 |
+
same_focals=True,
|
83 |
+
optimize_pp=False, # True,
|
84 |
+
min_conf_thr=min_conf_thr,
|
85 |
+
)
|
86 |
+
|
87 |
+
# if Ks is not None:
|
88 |
+
# scene.preset_focal(
|
89 |
+
# torch.tensor([[K[0, 0], K[1, 1]] for K in Ks])
|
90 |
+
# )
|
91 |
+
|
92 |
+
if c2ws is not None:
|
93 |
+
scene.preset_pose(c2ws)
|
94 |
+
|
95 |
+
_ = scene.compute_global_alignment(
|
96 |
+
init="msp", niter=niter, schedule=schedule, lr=lr
|
97 |
+
)
|
98 |
+
|
99 |
+
imgs = cast(list, scene.imgs)
|
100 |
+
Ks = scene.get_intrinsics().detach().cpu().numpy().copy()
|
101 |
+
c2ws = scene.get_im_poses().detach().cpu().numpy() # type: ignore
|
102 |
+
pts3d = [x.detach().cpu().numpy() for x in scene.get_pts3d()] # type: ignore
|
103 |
+
if num_img > 1:
|
104 |
+
masks = [x.detach().cpu().numpy() for x in scene.get_masks()]
|
105 |
+
points = [p[m] for p, m in zip(pts3d, masks)]
|
106 |
+
point_colors = [img[m] for img, m in zip(imgs, masks)]
|
107 |
+
else:
|
108 |
+
points = [p.reshape(-1, 3) for p in pts3d]
|
109 |
+
point_colors = [img.reshape(-1, 3) for img in imgs]
|
110 |
+
|
111 |
+
# Convert back to the original image size.
|
112 |
+
imgs = ori_imgs
|
113 |
+
Ks[:, :2, -1] *= ori_img_whs / img_whs
|
114 |
+
Ks[:, :2, :2] *= (ori_img_whs / img_whs).mean(axis=1, keepdims=True)[..., None]
|
115 |
+
|
116 |
+
return imgs, Ks, c2ws, points, point_colors
|