Spaces:
Runtime error
Runtime error
Upload 3 files
Browse files- cli_all_app.py +388 -0
- cli_app.py +380 -0
- cli_batch_app.py +408 -0
cli_all_app.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
python cli_all_app.py --input_img_path 战场原.webp --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 战场原
|
3 |
+
'''
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import os.path as osp
|
9 |
+
import queue
|
10 |
+
import secrets
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
from datetime import datetime
|
14 |
+
from glob import glob
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Literal, List
|
17 |
+
|
18 |
+
import imageio.v3 as iio
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import tyro
|
23 |
+
import viser
|
24 |
+
import viser.transforms as vt
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from seva.eval import (
|
28 |
+
IS_TORCH_NIGHTLY,
|
29 |
+
chunk_input_and_test,
|
30 |
+
create_transforms_simple,
|
31 |
+
infer_prior_stats,
|
32 |
+
run_one_scene,
|
33 |
+
transform_img_and_K,
|
34 |
+
)
|
35 |
+
from seva.geometry import (
|
36 |
+
DEFAULT_FOV_RAD,
|
37 |
+
get_default_intrinsics,
|
38 |
+
get_preset_pose_fov,
|
39 |
+
normalize_scene,
|
40 |
+
)
|
41 |
+
from seva.model import SGMWrapper
|
42 |
+
from seva.modules.autoencoder import AutoEncoder
|
43 |
+
from seva.modules.conditioner import CLIPConditioner
|
44 |
+
from seva.modules.preprocessor import Dust3rPipeline
|
45 |
+
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
|
46 |
+
from seva.utils import load_model
|
47 |
+
|
48 |
+
device = "cuda:0"
|
49 |
+
|
50 |
+
# Constants.
|
51 |
+
WORK_DIR = "work_dirs/demo_gr"
|
52 |
+
MAX_SESSIONS = 1
|
53 |
+
|
54 |
+
if IS_TORCH_NIGHTLY:
|
55 |
+
COMPILE = True
|
56 |
+
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
57 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
58 |
+
else:
|
59 |
+
COMPILE = False
|
60 |
+
|
61 |
+
# Shared global variables across sessions.
|
62 |
+
DUST3R = Dust3rPipeline(device=device) # type: ignore
|
63 |
+
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
|
64 |
+
AE = AutoEncoder(chunk_size=1).to(device)
|
65 |
+
CONDITIONER = CLIPConditioner().to(device)
|
66 |
+
DISCRETIZATION = DDPMDiscretization()
|
67 |
+
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
|
68 |
+
VERSION_DICT = {
|
69 |
+
"H": 576,
|
70 |
+
"W": 576,
|
71 |
+
"T": 21,
|
72 |
+
"C": 4,
|
73 |
+
"f": 8,
|
74 |
+
"options": {},
|
75 |
+
}
|
76 |
+
SERVERS = {}
|
77 |
+
ABORT_EVENTS = {}
|
78 |
+
|
79 |
+
if COMPILE:
|
80 |
+
MODEL = torch.compile(MODEL)
|
81 |
+
CONDITIONER = torch.compile(CONDITIONER)
|
82 |
+
AE = torch.compile(AE)
|
83 |
+
|
84 |
+
|
85 |
+
class SevaRenderer(object):
|
86 |
+
def __init__(self):
|
87 |
+
self.gui_state = None
|
88 |
+
|
89 |
+
def preprocess(self, input_img_path: str) -> dict:
|
90 |
+
# Simply hardcode these such that aspect ratio is always kept and
|
91 |
+
# shorter side is resized to 576. This is only to make GUI option fewer
|
92 |
+
# though, changing it still works.
|
93 |
+
shorter: int = 576
|
94 |
+
# Has to be 64 multiple for the network.
|
95 |
+
shorter = round(shorter / 64) * 64
|
96 |
+
|
97 |
+
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
|
98 |
+
input_imgs = torch.as_tensor(
|
99 |
+
iio.imread(input_img_path) / 255.0, dtype=torch.float32
|
100 |
+
)[None, ..., :3]
|
101 |
+
input_imgs = transform_img_and_K(
|
102 |
+
input_imgs.permute(0, 3, 1, 2),
|
103 |
+
shorter,
|
104 |
+
K=None,
|
105 |
+
size_stride=64,
|
106 |
+
)[0].permute(0, 2, 3, 1)
|
107 |
+
input_Ks = get_default_intrinsics(
|
108 |
+
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
|
109 |
+
)
|
110 |
+
input_c2ws = torch.eye(4)[None]
|
111 |
+
# Simulate a small time interval such that gradio can update
|
112 |
+
# propgress properly.
|
113 |
+
time.sleep(0.1)
|
114 |
+
return {
|
115 |
+
"input_imgs": input_imgs,
|
116 |
+
"input_Ks": input_Ks,
|
117 |
+
"input_c2ws": input_c2ws,
|
118 |
+
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
|
119 |
+
"points": [np.zeros((0, 3))],
|
120 |
+
"point_colors": [np.zeros((0, 3))],
|
121 |
+
"scene_scale": 1.0,
|
122 |
+
}
|
123 |
+
|
124 |
+
def render(
|
125 |
+
self,
|
126 |
+
preprocessed: dict,
|
127 |
+
seed: int,
|
128 |
+
chunk_strategy: str,
|
129 |
+
cfg: float,
|
130 |
+
preset_traj: Literal[
|
131 |
+
"orbit",
|
132 |
+
"spiral",
|
133 |
+
"lemniscate",
|
134 |
+
"zoom-in",
|
135 |
+
"zoom-out",
|
136 |
+
"dolly zoom-in",
|
137 |
+
"dolly zoom-out",
|
138 |
+
"move-forward",
|
139 |
+
"move-backward",
|
140 |
+
"move-up",
|
141 |
+
"move-down",
|
142 |
+
"move-left",
|
143 |
+
"move-right",
|
144 |
+
],
|
145 |
+
num_frames: int,
|
146 |
+
zoom_factor: float | None,
|
147 |
+
camera_scale: float,
|
148 |
+
output_dir: str,
|
149 |
+
) -> str:
|
150 |
+
# Generate a unique render name based on the input image filename and preset_traj
|
151 |
+
input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0]
|
152 |
+
render_name = f"{input_img_name}_{preset_traj}"
|
153 |
+
render_dir = osp.join(output_dir, render_name)
|
154 |
+
|
155 |
+
input_imgs, input_Ks, input_c2ws, (W, H) = (
|
156 |
+
preprocessed["input_imgs"],
|
157 |
+
preprocessed["input_Ks"],
|
158 |
+
preprocessed["input_c2ws"],
|
159 |
+
preprocessed["input_wh"],
|
160 |
+
)
|
161 |
+
num_inputs = len(input_imgs)
|
162 |
+
assert num_inputs == 1
|
163 |
+
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
|
164 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
|
165 |
+
preprocessed, preset_traj, num_frames, zoom_factor
|
166 |
+
)
|
167 |
+
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
|
168 |
+
all_Ks = (
|
169 |
+
torch.cat([input_Ks, target_Ks], 0)
|
170 |
+
* input_Ks.new_tensor([W, H, 1])[:, None]
|
171 |
+
)
|
172 |
+
num_targets = len(target_c2ws)
|
173 |
+
input_indices = list(range(num_inputs))
|
174 |
+
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
|
175 |
+
# Get anchor cameras.
|
176 |
+
T = VERSION_DICT["T"]
|
177 |
+
version_dict = copy.deepcopy(VERSION_DICT)
|
178 |
+
num_anchors = infer_prior_stats(
|
179 |
+
T,
|
180 |
+
num_inputs,
|
181 |
+
num_total_frames=num_targets,
|
182 |
+
version_dict=version_dict,
|
183 |
+
)
|
184 |
+
# infer_prior_stats modifies T in-place.
|
185 |
+
T = version_dict["T"]
|
186 |
+
assert isinstance(num_anchors, int)
|
187 |
+
anchor_indices = np.linspace(
|
188 |
+
num_inputs,
|
189 |
+
num_inputs + num_targets - 1,
|
190 |
+
num_anchors,
|
191 |
+
).tolist()
|
192 |
+
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
|
193 |
+
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
|
194 |
+
# Create image conditioning.
|
195 |
+
all_imgs_np = (
|
196 |
+
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
|
197 |
+
* 255.0
|
198 |
+
).astype(np.uint8)
|
199 |
+
image_cond = {
|
200 |
+
"img": all_imgs_np,
|
201 |
+
"input_indices": input_indices,
|
202 |
+
"prior_indices": anchor_indices,
|
203 |
+
}
|
204 |
+
# Create camera conditioning (K is unnormalized).
|
205 |
+
camera_cond = {
|
206 |
+
"c2w": all_c2ws,
|
207 |
+
"K": all_Ks,
|
208 |
+
"input_indices": list(range(num_inputs + num_targets)),
|
209 |
+
}
|
210 |
+
# Run rendering.
|
211 |
+
num_steps = 50
|
212 |
+
options_ori = VERSION_DICT["options"]
|
213 |
+
options = copy.deepcopy(options_ori)
|
214 |
+
options["chunk_strategy"] = chunk_strategy
|
215 |
+
options["video_save_fps"] = 30.0
|
216 |
+
options["beta_linear_start"] = 5e-6
|
217 |
+
options["log_snr_shift"] = 2.4
|
218 |
+
options["guider_types"] = [1, 2]
|
219 |
+
options["cfg"] = [
|
220 |
+
float(cfg),
|
221 |
+
3.0 if num_inputs >= 9 else 2.0,
|
222 |
+
] # We define semi-dense-view regime to have 9 input views.
|
223 |
+
options["camera_scale"] = camera_scale
|
224 |
+
options["num_steps"] = num_steps
|
225 |
+
options["cfg_min"] = 1.2
|
226 |
+
options["encoding_t"] = 1
|
227 |
+
options["decoding_t"] = 1
|
228 |
+
task = "img2trajvid"
|
229 |
+
# Get number of first pass chunks.
|
230 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
231 |
+
chunk_strategy_first_pass = options.get(
|
232 |
+
"chunk_strategy_first_pass", "gt-nearest"
|
233 |
+
)
|
234 |
+
num_chunks_0 = len(
|
235 |
+
chunk_input_and_test(
|
236 |
+
T_first_pass,
|
237 |
+
input_c2ws,
|
238 |
+
anchor_c2ws,
|
239 |
+
input_indices,
|
240 |
+
image_cond["prior_indices"],
|
241 |
+
options={**options, "sampler_verbose": False},
|
242 |
+
task=task,
|
243 |
+
chunk_strategy=chunk_strategy_first_pass,
|
244 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
245 |
+
)[1]
|
246 |
+
)
|
247 |
+
# Get number of second pass chunks.
|
248 |
+
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
|
249 |
+
anchor_indices = np.array(input_indices + anchor_indices)[
|
250 |
+
anchor_argsort
|
251 |
+
].tolist()
|
252 |
+
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
|
253 |
+
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
|
254 |
+
anchor_argsort
|
255 |
+
]
|
256 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
257 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
258 |
+
num_chunks_1 = len(
|
259 |
+
chunk_input_and_test(
|
260 |
+
T_second_pass,
|
261 |
+
anchor_c2ws_second_pass,
|
262 |
+
target_c2ws,
|
263 |
+
anchor_indices,
|
264 |
+
target_indices,
|
265 |
+
options={**options, "sampler_verbose": False},
|
266 |
+
task=task,
|
267 |
+
chunk_strategy=chunk_strategy,
|
268 |
+
gt_input_inds=gt_input_inds,
|
269 |
+
)[1]
|
270 |
+
)
|
271 |
+
video_path_generator = run_one_scene(
|
272 |
+
task=task,
|
273 |
+
version_dict={
|
274 |
+
"H": H,
|
275 |
+
"W": W,
|
276 |
+
"T": T,
|
277 |
+
"C": VERSION_DICT["C"],
|
278 |
+
"f": VERSION_DICT["f"],
|
279 |
+
"options": options,
|
280 |
+
},
|
281 |
+
model=MODEL,
|
282 |
+
ae=AE,
|
283 |
+
conditioner=CONDITIONER,
|
284 |
+
denoiser=DENOISER,
|
285 |
+
image_cond=image_cond,
|
286 |
+
camera_cond=camera_cond,
|
287 |
+
save_path=render_dir,
|
288 |
+
use_traj_prior=True,
|
289 |
+
traj_prior_c2ws=anchor_c2ws,
|
290 |
+
traj_prior_Ks=anchor_Ks,
|
291 |
+
seed=seed,
|
292 |
+
gradio=True,
|
293 |
+
)
|
294 |
+
for video_path in video_path_generator:
|
295 |
+
return video_path
|
296 |
+
return ""
|
297 |
+
|
298 |
+
def get_target_c2ws_and_Ks_from_preset(
|
299 |
+
self,
|
300 |
+
preprocessed: dict,
|
301 |
+
preset_traj: Literal[
|
302 |
+
"orbit",
|
303 |
+
"spiral",
|
304 |
+
"lemniscate",
|
305 |
+
"zoom-in",
|
306 |
+
"zoom-out",
|
307 |
+
"dolly zoom-in",
|
308 |
+
"dolly zoom-out",
|
309 |
+
"move-forward",
|
310 |
+
"move-backward",
|
311 |
+
"move-up",
|
312 |
+
"move-down",
|
313 |
+
"move-left",
|
314 |
+
"move-right",
|
315 |
+
],
|
316 |
+
num_frames: int,
|
317 |
+
zoom_factor: float | None,
|
318 |
+
):
|
319 |
+
img_wh = preprocessed["input_wh"]
|
320 |
+
start_c2w = preprocessed["input_c2ws"][0]
|
321 |
+
start_w2c = torch.linalg.inv(start_c2w)
|
322 |
+
look_at = torch.tensor([0, 0, 10])
|
323 |
+
start_fov = DEFAULT_FOV_RAD
|
324 |
+
target_c2ws, target_fovs = get_preset_pose_fov(
|
325 |
+
preset_traj,
|
326 |
+
num_frames,
|
327 |
+
start_w2c,
|
328 |
+
look_at,
|
329 |
+
-start_c2w[:3, 1],
|
330 |
+
start_fov,
|
331 |
+
spiral_radii=[1.0, 1.0, 0.5],
|
332 |
+
zoom_factor=zoom_factor,
|
333 |
+
)
|
334 |
+
target_c2ws = torch.as_tensor(target_c2ws)
|
335 |
+
target_fovs = torch.as_tensor(target_fovs)
|
336 |
+
target_Ks = get_default_intrinsics(
|
337 |
+
target_fovs, # type: ignore
|
338 |
+
aspect_ratio=img_wh[0] / img_wh[1],
|
339 |
+
)
|
340 |
+
return target_c2ws, target_Ks
|
341 |
+
|
342 |
+
|
343 |
+
def main(
|
344 |
+
input_img_path: str,
|
345 |
+
preset_traj: List[Literal[
|
346 |
+
"orbit",
|
347 |
+
"spiral",
|
348 |
+
"lemniscate",
|
349 |
+
"zoom-in",
|
350 |
+
"zoom-out",
|
351 |
+
"dolly zoom-in",
|
352 |
+
"dolly zoom-out",
|
353 |
+
"move-forward",
|
354 |
+
"move-backward",
|
355 |
+
"move-up",
|
356 |
+
"move-down",
|
357 |
+
"move-left",
|
358 |
+
"move-right",
|
359 |
+
]],
|
360 |
+
num_frames: int = 80,
|
361 |
+
zoom_factor: float | None = None,
|
362 |
+
seed: int = 23,
|
363 |
+
chunk_strategy: str = "interp",
|
364 |
+
cfg: float = 4.0,
|
365 |
+
camera_scale: float = 2.0,
|
366 |
+
output_dir: str = WORK_DIR,
|
367 |
+
):
|
368 |
+
renderer = SevaRenderer()
|
369 |
+
preprocessed = renderer.preprocess(input_img_path)
|
370 |
+
preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict
|
371 |
+
|
372 |
+
for traj in preset_traj:
|
373 |
+
video_path = renderer.render(
|
374 |
+
preprocessed,
|
375 |
+
seed,
|
376 |
+
chunk_strategy,
|
377 |
+
cfg,
|
378 |
+
traj,
|
379 |
+
num_frames,
|
380 |
+
zoom_factor,
|
381 |
+
camera_scale,
|
382 |
+
output_dir,
|
383 |
+
)
|
384 |
+
print(f"Rendered video saved to: {video_path}")
|
385 |
+
|
386 |
+
|
387 |
+
if __name__ == "__main__":
|
388 |
+
tyro.cli(main)
|
cli_app.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
python cli_app.py --input_img_path 战场原.webp --preset_traj orbit --num_frames 80 --seed 23 --chunk_strategy interp --cfg 4.0 --camera_scale 2.0
|
3 |
+
'''
|
4 |
+
|
5 |
+
import copy
|
6 |
+
import json
|
7 |
+
import os
|
8 |
+
import os.path as osp
|
9 |
+
import queue
|
10 |
+
import secrets
|
11 |
+
import threading
|
12 |
+
import time
|
13 |
+
from datetime import datetime
|
14 |
+
from glob import glob
|
15 |
+
from pathlib import Path
|
16 |
+
from typing import Literal
|
17 |
+
|
18 |
+
import imageio.v3 as iio
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
import torch.nn.functional as F
|
22 |
+
import tyro
|
23 |
+
import viser
|
24 |
+
import viser.transforms as vt
|
25 |
+
from einops import rearrange
|
26 |
+
|
27 |
+
from seva.eval import (
|
28 |
+
IS_TORCH_NIGHTLY,
|
29 |
+
chunk_input_and_test,
|
30 |
+
create_transforms_simple,
|
31 |
+
infer_prior_stats,
|
32 |
+
run_one_scene,
|
33 |
+
transform_img_and_K,
|
34 |
+
)
|
35 |
+
from seva.geometry import (
|
36 |
+
DEFAULT_FOV_RAD,
|
37 |
+
get_default_intrinsics,
|
38 |
+
get_preset_pose_fov,
|
39 |
+
normalize_scene,
|
40 |
+
)
|
41 |
+
from seva.model import SGMWrapper
|
42 |
+
from seva.modules.autoencoder import AutoEncoder
|
43 |
+
from seva.modules.conditioner import CLIPConditioner
|
44 |
+
from seva.modules.preprocessor import Dust3rPipeline
|
45 |
+
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
|
46 |
+
from seva.utils import load_model
|
47 |
+
|
48 |
+
device = "cuda:0"
|
49 |
+
|
50 |
+
# Constants.
|
51 |
+
WORK_DIR = "work_dirs/demo_gr"
|
52 |
+
MAX_SESSIONS = 1
|
53 |
+
|
54 |
+
if IS_TORCH_NIGHTLY:
|
55 |
+
COMPILE = True
|
56 |
+
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
57 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
58 |
+
else:
|
59 |
+
COMPILE = False
|
60 |
+
|
61 |
+
# Shared global variables across sessions.
|
62 |
+
DUST3R = Dust3rPipeline(device=device) # type: ignore
|
63 |
+
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
|
64 |
+
AE = AutoEncoder(chunk_size=1).to(device)
|
65 |
+
CONDITIONER = CLIPConditioner().to(device)
|
66 |
+
DISCRETIZATION = DDPMDiscretization()
|
67 |
+
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
|
68 |
+
VERSION_DICT = {
|
69 |
+
"H": 576,
|
70 |
+
"W": 576,
|
71 |
+
"T": 21,
|
72 |
+
"C": 4,
|
73 |
+
"f": 8,
|
74 |
+
"options": {},
|
75 |
+
}
|
76 |
+
SERVERS = {}
|
77 |
+
ABORT_EVENTS = {}
|
78 |
+
|
79 |
+
if COMPILE:
|
80 |
+
MODEL = torch.compile(MODEL)
|
81 |
+
CONDITIONER = torch.compile(CONDITIONER)
|
82 |
+
AE = torch.compile(AE)
|
83 |
+
|
84 |
+
|
85 |
+
class SevaRenderer(object):
|
86 |
+
def __init__(self):
|
87 |
+
self.gui_state = None
|
88 |
+
|
89 |
+
def preprocess(self, input_img_path: str) -> dict:
|
90 |
+
# Simply hardcode these such that aspect ratio is always kept and
|
91 |
+
# shorter side is resized to 576. This is only to make GUI option fewer
|
92 |
+
# though, changing it still works.
|
93 |
+
shorter: int = 576
|
94 |
+
# Has to be 64 multiple for the network.
|
95 |
+
shorter = round(shorter / 64) * 64
|
96 |
+
|
97 |
+
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
|
98 |
+
input_imgs = torch.as_tensor(
|
99 |
+
iio.imread(input_img_path) / 255.0, dtype=torch.float32
|
100 |
+
)[None, ..., :3]
|
101 |
+
input_imgs = transform_img_and_K(
|
102 |
+
input_imgs.permute(0, 3, 1, 2),
|
103 |
+
shorter,
|
104 |
+
K=None,
|
105 |
+
size_stride=64,
|
106 |
+
)[0].permute(0, 2, 3, 1)
|
107 |
+
input_Ks = get_default_intrinsics(
|
108 |
+
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
|
109 |
+
)
|
110 |
+
input_c2ws = torch.eye(4)[None]
|
111 |
+
# Simulate a small time interval such that gradio can update
|
112 |
+
# propgress properly.
|
113 |
+
time.sleep(0.1)
|
114 |
+
return {
|
115 |
+
"input_imgs": input_imgs,
|
116 |
+
"input_Ks": input_Ks,
|
117 |
+
"input_c2ws": input_c2ws,
|
118 |
+
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
|
119 |
+
"points": [np.zeros((0, 3))],
|
120 |
+
"point_colors": [np.zeros((0, 3))],
|
121 |
+
"scene_scale": 1.0,
|
122 |
+
}
|
123 |
+
|
124 |
+
def render(
|
125 |
+
self,
|
126 |
+
preprocessed: dict,
|
127 |
+
seed: int,
|
128 |
+
chunk_strategy: str,
|
129 |
+
cfg: float,
|
130 |
+
preset_traj: Literal[
|
131 |
+
"orbit",
|
132 |
+
"spiral",
|
133 |
+
"lemniscate",
|
134 |
+
"zoom-in",
|
135 |
+
"zoom-out",
|
136 |
+
"dolly zoom-in",
|
137 |
+
"dolly zoom-out",
|
138 |
+
"move-forward",
|
139 |
+
"move-backward",
|
140 |
+
"move-up",
|
141 |
+
"move-down",
|
142 |
+
"move-left",
|
143 |
+
"move-right",
|
144 |
+
],
|
145 |
+
num_frames: int,
|
146 |
+
zoom_factor: float | None,
|
147 |
+
camera_scale: float,
|
148 |
+
) -> str:
|
149 |
+
render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
|
150 |
+
render_dir = osp.join(WORK_DIR, render_name)
|
151 |
+
|
152 |
+
input_imgs, input_Ks, input_c2ws, (W, H) = (
|
153 |
+
preprocessed["input_imgs"],
|
154 |
+
preprocessed["input_Ks"],
|
155 |
+
preprocessed["input_c2ws"],
|
156 |
+
preprocessed["input_wh"],
|
157 |
+
)
|
158 |
+
num_inputs = len(input_imgs)
|
159 |
+
assert num_inputs == 1
|
160 |
+
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
|
161 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
|
162 |
+
preprocessed, preset_traj, num_frames, zoom_factor
|
163 |
+
)
|
164 |
+
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
|
165 |
+
all_Ks = (
|
166 |
+
torch.cat([input_Ks, target_Ks], 0)
|
167 |
+
* input_Ks.new_tensor([W, H, 1])[:, None]
|
168 |
+
)
|
169 |
+
num_targets = len(target_c2ws)
|
170 |
+
input_indices = list(range(num_inputs))
|
171 |
+
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
|
172 |
+
# Get anchor cameras.
|
173 |
+
T = VERSION_DICT["T"]
|
174 |
+
version_dict = copy.deepcopy(VERSION_DICT)
|
175 |
+
num_anchors = infer_prior_stats(
|
176 |
+
T,
|
177 |
+
num_inputs,
|
178 |
+
num_total_frames=num_targets,
|
179 |
+
version_dict=version_dict,
|
180 |
+
)
|
181 |
+
# infer_prior_stats modifies T in-place.
|
182 |
+
T = version_dict["T"]
|
183 |
+
assert isinstance(num_anchors, int)
|
184 |
+
anchor_indices = np.linspace(
|
185 |
+
num_inputs,
|
186 |
+
num_inputs + num_targets - 1,
|
187 |
+
num_anchors,
|
188 |
+
).tolist()
|
189 |
+
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
|
190 |
+
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
|
191 |
+
# Create image conditioning.
|
192 |
+
all_imgs_np = (
|
193 |
+
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
|
194 |
+
* 255.0
|
195 |
+
).astype(np.uint8)
|
196 |
+
image_cond = {
|
197 |
+
"img": all_imgs_np,
|
198 |
+
"input_indices": input_indices,
|
199 |
+
"prior_indices": anchor_indices,
|
200 |
+
}
|
201 |
+
# Create camera conditioning (K is unnormalized).
|
202 |
+
camera_cond = {
|
203 |
+
"c2w": all_c2ws,
|
204 |
+
"K": all_Ks,
|
205 |
+
"input_indices": list(range(num_inputs + num_targets)),
|
206 |
+
}
|
207 |
+
# Run rendering.
|
208 |
+
num_steps = 50
|
209 |
+
options_ori = VERSION_DICT["options"]
|
210 |
+
options = copy.deepcopy(options_ori)
|
211 |
+
options["chunk_strategy"] = chunk_strategy
|
212 |
+
options["video_save_fps"] = 30.0
|
213 |
+
options["beta_linear_start"] = 5e-6
|
214 |
+
options["log_snr_shift"] = 2.4
|
215 |
+
options["guider_types"] = [1, 2]
|
216 |
+
options["cfg"] = [
|
217 |
+
float(cfg),
|
218 |
+
3.0 if num_inputs >= 9 else 2.0,
|
219 |
+
] # We define semi-dense-view regime to have 9 input views.
|
220 |
+
options["camera_scale"] = camera_scale
|
221 |
+
options["num_steps"] = num_steps
|
222 |
+
options["cfg_min"] = 1.2
|
223 |
+
options["encoding_t"] = 1
|
224 |
+
options["decoding_t"] = 1
|
225 |
+
task = "img2trajvid"
|
226 |
+
# Get number of first pass chunks.
|
227 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
228 |
+
chunk_strategy_first_pass = options.get(
|
229 |
+
"chunk_strategy_first_pass", "gt-nearest"
|
230 |
+
)
|
231 |
+
num_chunks_0 = len(
|
232 |
+
chunk_input_and_test(
|
233 |
+
T_first_pass,
|
234 |
+
input_c2ws,
|
235 |
+
anchor_c2ws,
|
236 |
+
input_indices,
|
237 |
+
image_cond["prior_indices"],
|
238 |
+
options={**options, "sampler_verbose": False},
|
239 |
+
task=task,
|
240 |
+
chunk_strategy=chunk_strategy_first_pass,
|
241 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
242 |
+
)[1]
|
243 |
+
)
|
244 |
+
# Get number of second pass chunks.
|
245 |
+
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
|
246 |
+
anchor_indices = np.array(input_indices + anchor_indices)[
|
247 |
+
anchor_argsort
|
248 |
+
].tolist()
|
249 |
+
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
|
250 |
+
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
|
251 |
+
anchor_argsort
|
252 |
+
]
|
253 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
254 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
255 |
+
num_chunks_1 = len(
|
256 |
+
chunk_input_and_test(
|
257 |
+
T_second_pass,
|
258 |
+
anchor_c2ws_second_pass,
|
259 |
+
target_c2ws,
|
260 |
+
anchor_indices,
|
261 |
+
target_indices,
|
262 |
+
options={**options, "sampler_verbose": False},
|
263 |
+
task=task,
|
264 |
+
chunk_strategy=chunk_strategy,
|
265 |
+
gt_input_inds=gt_input_inds,
|
266 |
+
)[1]
|
267 |
+
)
|
268 |
+
video_path_generator = run_one_scene(
|
269 |
+
task=task,
|
270 |
+
version_dict={
|
271 |
+
"H": H,
|
272 |
+
"W": W,
|
273 |
+
"T": T,
|
274 |
+
"C": VERSION_DICT["C"],
|
275 |
+
"f": VERSION_DICT["f"],
|
276 |
+
"options": options,
|
277 |
+
},
|
278 |
+
model=MODEL,
|
279 |
+
ae=AE,
|
280 |
+
conditioner=CONDITIONER,
|
281 |
+
denoiser=DENOISER,
|
282 |
+
image_cond=image_cond,
|
283 |
+
camera_cond=camera_cond,
|
284 |
+
save_path=render_dir,
|
285 |
+
use_traj_prior=True,
|
286 |
+
traj_prior_c2ws=anchor_c2ws,
|
287 |
+
traj_prior_Ks=anchor_Ks,
|
288 |
+
seed=seed,
|
289 |
+
gradio=True,
|
290 |
+
)
|
291 |
+
for video_path in video_path_generator:
|
292 |
+
return video_path
|
293 |
+
return ""
|
294 |
+
|
295 |
+
def get_target_c2ws_and_Ks_from_preset(
|
296 |
+
self,
|
297 |
+
preprocessed: dict,
|
298 |
+
preset_traj: Literal[
|
299 |
+
"orbit",
|
300 |
+
"spiral",
|
301 |
+
"lemniscate",
|
302 |
+
"zoom-in",
|
303 |
+
"zoom-out",
|
304 |
+
"dolly zoom-in",
|
305 |
+
"dolly zoom-out",
|
306 |
+
"move-forward",
|
307 |
+
"move-backward",
|
308 |
+
"move-up",
|
309 |
+
"move-down",
|
310 |
+
"move-left",
|
311 |
+
"move-right",
|
312 |
+
],
|
313 |
+
num_frames: int,
|
314 |
+
zoom_factor: float | None,
|
315 |
+
):
|
316 |
+
img_wh = preprocessed["input_wh"]
|
317 |
+
start_c2w = preprocessed["input_c2ws"][0]
|
318 |
+
start_w2c = torch.linalg.inv(start_c2w)
|
319 |
+
look_at = torch.tensor([0, 0, 10])
|
320 |
+
start_fov = DEFAULT_FOV_RAD
|
321 |
+
target_c2ws, target_fovs = get_preset_pose_fov(
|
322 |
+
preset_traj,
|
323 |
+
num_frames,
|
324 |
+
start_w2c,
|
325 |
+
look_at,
|
326 |
+
-start_c2w[:3, 1],
|
327 |
+
start_fov,
|
328 |
+
spiral_radii=[1.0, 1.0, 0.5],
|
329 |
+
zoom_factor=zoom_factor,
|
330 |
+
)
|
331 |
+
target_c2ws = torch.as_tensor(target_c2ws)
|
332 |
+
target_fovs = torch.as_tensor(target_fovs)
|
333 |
+
target_Ks = get_default_intrinsics(
|
334 |
+
target_fovs, # type: ignore
|
335 |
+
aspect_ratio=img_wh[0] / img_wh[1],
|
336 |
+
)
|
337 |
+
return target_c2ws, target_Ks
|
338 |
+
|
339 |
+
|
340 |
+
def main(
|
341 |
+
input_img_path: str,
|
342 |
+
preset_traj: Literal[
|
343 |
+
"orbit",
|
344 |
+
"spiral",
|
345 |
+
"lemniscate",
|
346 |
+
"zoom-in",
|
347 |
+
"zoom-out",
|
348 |
+
"dolly zoom-in",
|
349 |
+
"dolly zoom-out",
|
350 |
+
"move-forward",
|
351 |
+
"move-backward",
|
352 |
+
"move-up",
|
353 |
+
"move-down",
|
354 |
+
"move-left",
|
355 |
+
"move-right",
|
356 |
+
] = "orbit",
|
357 |
+
num_frames: int = 80,
|
358 |
+
zoom_factor: float | None = None,
|
359 |
+
seed: int = 23,
|
360 |
+
chunk_strategy: str = "interp",
|
361 |
+
cfg: float = 4.0,
|
362 |
+
camera_scale: float = 2.0,
|
363 |
+
):
|
364 |
+
renderer = SevaRenderer()
|
365 |
+
preprocessed = renderer.preprocess(input_img_path)
|
366 |
+
video_path = renderer.render(
|
367 |
+
preprocessed,
|
368 |
+
seed,
|
369 |
+
chunk_strategy,
|
370 |
+
cfg,
|
371 |
+
preset_traj,
|
372 |
+
num_frames,
|
373 |
+
zoom_factor,
|
374 |
+
camera_scale,
|
375 |
+
)
|
376 |
+
print(f"Rendered video saved to: {video_path}")
|
377 |
+
|
378 |
+
|
379 |
+
if __name__ == "__main__":
|
380 |
+
tyro.cli(main)
|
cli_batch_app.py
ADDED
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
python cli_batch_app.py --input_path imgs --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 相机路径
|
3 |
+
|
4 |
+
python cli_batch_app.py --input_path imgs --preset_traj "orbit" "spiral" "lemniscate" --output_dir 相机路径
|
5 |
+
|
6 |
+
#### 人物 或立体主体场景
|
7 |
+
"orbit"
|
8 |
+
|
9 |
+
#### 平面风景场景
|
10 |
+
"spiral" "lemniscate"
|
11 |
+
'''
|
12 |
+
|
13 |
+
import copy
|
14 |
+
import json
|
15 |
+
import os
|
16 |
+
import os.path as osp
|
17 |
+
import queue
|
18 |
+
import secrets
|
19 |
+
import threading
|
20 |
+
import time
|
21 |
+
from datetime import datetime
|
22 |
+
from glob import glob
|
23 |
+
from pathlib import Path
|
24 |
+
from typing import Literal, List
|
25 |
+
|
26 |
+
import imageio.v3 as iio
|
27 |
+
import numpy as np
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
import tyro
|
31 |
+
import viser
|
32 |
+
import viser.transforms as vt
|
33 |
+
from einops import rearrange
|
34 |
+
|
35 |
+
from seva.eval import (
|
36 |
+
IS_TORCH_NIGHTLY,
|
37 |
+
chunk_input_and_test,
|
38 |
+
create_transforms_simple,
|
39 |
+
infer_prior_stats,
|
40 |
+
run_one_scene,
|
41 |
+
transform_img_and_K,
|
42 |
+
)
|
43 |
+
from seva.geometry import (
|
44 |
+
DEFAULT_FOV_RAD,
|
45 |
+
get_default_intrinsics,
|
46 |
+
get_preset_pose_fov,
|
47 |
+
normalize_scene,
|
48 |
+
)
|
49 |
+
from seva.model import SGMWrapper
|
50 |
+
from seva.modules.autoencoder import AutoEncoder
|
51 |
+
from seva.modules.conditioner import CLIPConditioner
|
52 |
+
from seva.modules.preprocessor import Dust3rPipeline
|
53 |
+
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
|
54 |
+
from seva.utils import load_model
|
55 |
+
|
56 |
+
device = "cuda:0"
|
57 |
+
|
58 |
+
# Constants.
|
59 |
+
WORK_DIR = "work_dirs/demo_gr"
|
60 |
+
MAX_SESSIONS = 1
|
61 |
+
|
62 |
+
if IS_TORCH_NIGHTLY:
|
63 |
+
COMPILE = True
|
64 |
+
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
|
65 |
+
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
|
66 |
+
else:
|
67 |
+
COMPILE = False
|
68 |
+
|
69 |
+
# Shared global variables across sessions.
|
70 |
+
DUST3R = Dust3rPipeline(device=device) # type: ignore
|
71 |
+
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
|
72 |
+
AE = AutoEncoder(chunk_size=1).to(device)
|
73 |
+
CONDITIONER = CLIPConditioner().to(device)
|
74 |
+
DISCRETIZATION = DDPMDiscretization()
|
75 |
+
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
|
76 |
+
VERSION_DICT = {
|
77 |
+
"H": 576,
|
78 |
+
"W": 576,
|
79 |
+
"T": 21,
|
80 |
+
"C": 4,
|
81 |
+
"f": 8,
|
82 |
+
"options": {},
|
83 |
+
}
|
84 |
+
SERVERS = {}
|
85 |
+
ABORT_EVENTS = {}
|
86 |
+
|
87 |
+
if COMPILE:
|
88 |
+
MODEL = torch.compile(MODEL)
|
89 |
+
CONDITIONER = torch.compile(CONDITIONER)
|
90 |
+
AE = torch.compile(AE)
|
91 |
+
|
92 |
+
|
93 |
+
class SevaRenderer(object):
|
94 |
+
def __init__(self):
|
95 |
+
self.gui_state = None
|
96 |
+
|
97 |
+
def preprocess(self, input_img_path: str) -> dict:
|
98 |
+
# Simply hardcode these such that aspect ratio is always kept and
|
99 |
+
# shorter side is resized to 576. This is only to make GUI option fewer
|
100 |
+
# though, changing it still works.
|
101 |
+
shorter: int = 576
|
102 |
+
# Has to be 64 multiple for the network.
|
103 |
+
shorter = round(shorter / 64) * 64
|
104 |
+
|
105 |
+
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
|
106 |
+
input_imgs = torch.as_tensor(
|
107 |
+
iio.imread(input_img_path) / 255.0, dtype=torch.float32
|
108 |
+
)[None, ..., :3]
|
109 |
+
input_imgs = transform_img_and_K(
|
110 |
+
input_imgs.permute(0, 3, 1, 2),
|
111 |
+
shorter,
|
112 |
+
K=None,
|
113 |
+
size_stride=64,
|
114 |
+
)[0].permute(0, 2, 3, 1)
|
115 |
+
input_Ks = get_default_intrinsics(
|
116 |
+
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
|
117 |
+
)
|
118 |
+
input_c2ws = torch.eye(4)[None]
|
119 |
+
# Simulate a small time interval such that gradio can update
|
120 |
+
# propgress properly.
|
121 |
+
time.sleep(0.1)
|
122 |
+
return {
|
123 |
+
"input_imgs": input_imgs,
|
124 |
+
"input_Ks": input_Ks,
|
125 |
+
"input_c2ws": input_c2ws,
|
126 |
+
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
|
127 |
+
"points": [np.zeros((0, 3))],
|
128 |
+
"point_colors": [np.zeros((0, 3))],
|
129 |
+
"scene_scale": 1.0,
|
130 |
+
}
|
131 |
+
|
132 |
+
def render(
|
133 |
+
self,
|
134 |
+
preprocessed: dict,
|
135 |
+
seed: int,
|
136 |
+
chunk_strategy: str,
|
137 |
+
cfg: float,
|
138 |
+
preset_traj: Literal[
|
139 |
+
"orbit",
|
140 |
+
"spiral",
|
141 |
+
"lemniscate",
|
142 |
+
"zoom-in",
|
143 |
+
"zoom-out",
|
144 |
+
"dolly zoom-in",
|
145 |
+
"dolly zoom-out",
|
146 |
+
"move-forward",
|
147 |
+
"move-backward",
|
148 |
+
"move-up",
|
149 |
+
"move-down",
|
150 |
+
"move-left",
|
151 |
+
"move-right",
|
152 |
+
],
|
153 |
+
num_frames: int,
|
154 |
+
zoom_factor: float | None,
|
155 |
+
camera_scale: float,
|
156 |
+
output_dir: str,
|
157 |
+
) -> str:
|
158 |
+
# Generate a unique render name based on the input image filename and preset_traj
|
159 |
+
input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0]
|
160 |
+
render_name = f"{input_img_name}_{preset_traj}"
|
161 |
+
render_dir = osp.join(output_dir, input_img_name)
|
162 |
+
os.makedirs(render_dir, exist_ok=True)
|
163 |
+
|
164 |
+
input_imgs, input_Ks, input_c2ws, (W, H) = (
|
165 |
+
preprocessed["input_imgs"],
|
166 |
+
preprocessed["input_Ks"],
|
167 |
+
preprocessed["input_c2ws"],
|
168 |
+
preprocessed["input_wh"],
|
169 |
+
)
|
170 |
+
num_inputs = len(input_imgs)
|
171 |
+
assert num_inputs == 1
|
172 |
+
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
|
173 |
+
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
|
174 |
+
preprocessed, preset_traj, num_frames, zoom_factor
|
175 |
+
)
|
176 |
+
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
|
177 |
+
all_Ks = (
|
178 |
+
torch.cat([input_Ks, target_Ks], 0)
|
179 |
+
* input_Ks.new_tensor([W, H, 1])[:, None]
|
180 |
+
)
|
181 |
+
num_targets = len(target_c2ws)
|
182 |
+
input_indices = list(range(num_inputs))
|
183 |
+
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
|
184 |
+
# Get anchor cameras.
|
185 |
+
T = VERSION_DICT["T"]
|
186 |
+
version_dict = copy.deepcopy(VERSION_DICT)
|
187 |
+
num_anchors = infer_prior_stats(
|
188 |
+
T,
|
189 |
+
num_inputs,
|
190 |
+
num_total_frames=num_targets,
|
191 |
+
version_dict=version_dict,
|
192 |
+
)
|
193 |
+
# infer_prior_stats modifies T in-place.
|
194 |
+
T = version_dict["T"]
|
195 |
+
assert isinstance(num_anchors, int)
|
196 |
+
anchor_indices = np.linspace(
|
197 |
+
num_inputs,
|
198 |
+
num_inputs + num_targets - 1,
|
199 |
+
num_anchors,
|
200 |
+
).tolist()
|
201 |
+
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
|
202 |
+
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
|
203 |
+
# Create image conditioning.
|
204 |
+
all_imgs_np = (
|
205 |
+
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
|
206 |
+
* 255.0
|
207 |
+
).astype(np.uint8)
|
208 |
+
image_cond = {
|
209 |
+
"img": all_imgs_np,
|
210 |
+
"input_indices": input_indices,
|
211 |
+
"prior_indices": anchor_indices,
|
212 |
+
}
|
213 |
+
# Create camera conditioning (K is unnormalized).
|
214 |
+
camera_cond = {
|
215 |
+
"c2w": all_c2ws,
|
216 |
+
"K": all_Ks,
|
217 |
+
"input_indices": list(range(num_inputs + num_targets)),
|
218 |
+
}
|
219 |
+
# Run rendering.
|
220 |
+
num_steps = 50
|
221 |
+
options_ori = VERSION_DICT["options"]
|
222 |
+
options = copy.deepcopy(options_ori)
|
223 |
+
options["chunk_strategy"] = chunk_strategy
|
224 |
+
options["video_save_fps"] = 30.0
|
225 |
+
options["beta_linear_start"] = 5e-6
|
226 |
+
options["log_snr_shift"] = 2.4
|
227 |
+
options["guider_types"] = [1, 2]
|
228 |
+
options["cfg"] = [
|
229 |
+
float(cfg),
|
230 |
+
3.0 if num_inputs >= 9 else 2.0,
|
231 |
+
] # We define semi-dense-view regime to have 9 input views.
|
232 |
+
options["camera_scale"] = camera_scale
|
233 |
+
options["num_steps"] = num_steps
|
234 |
+
options["cfg_min"] = 1.2
|
235 |
+
options["encoding_t"] = 1
|
236 |
+
options["decoding_t"] = 1
|
237 |
+
task = "img2trajvid"
|
238 |
+
# Get number of first pass chunks.
|
239 |
+
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
|
240 |
+
chunk_strategy_first_pass = options.get(
|
241 |
+
"chunk_strategy_first_pass", "gt-nearest"
|
242 |
+
)
|
243 |
+
num_chunks_0 = len(
|
244 |
+
chunk_input_and_test(
|
245 |
+
T_first_pass,
|
246 |
+
input_c2ws,
|
247 |
+
anchor_c2ws,
|
248 |
+
input_indices,
|
249 |
+
image_cond["prior_indices"],
|
250 |
+
options={**options, "sampler_verbose": False},
|
251 |
+
task=task,
|
252 |
+
chunk_strategy=chunk_strategy_first_pass,
|
253 |
+
gt_input_inds=list(range(input_c2ws.shape[0])),
|
254 |
+
)[1]
|
255 |
+
)
|
256 |
+
# Get number of second pass chunks.
|
257 |
+
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
|
258 |
+
anchor_indices = np.array(input_indices + anchor_indices)[
|
259 |
+
anchor_argsort
|
260 |
+
].tolist()
|
261 |
+
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
|
262 |
+
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
|
263 |
+
anchor_argsort
|
264 |
+
]
|
265 |
+
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
|
266 |
+
chunk_strategy = options.get("chunk_strategy", "nearest")
|
267 |
+
num_chunks_1 = len(
|
268 |
+
chunk_input_and_test(
|
269 |
+
T_second_pass,
|
270 |
+
anchor_c2ws_second_pass,
|
271 |
+
target_c2ws,
|
272 |
+
anchor_indices,
|
273 |
+
target_indices,
|
274 |
+
options={**options, "sampler_verbose": False},
|
275 |
+
task=task,
|
276 |
+
chunk_strategy=chunk_strategy,
|
277 |
+
gt_input_inds=gt_input_inds,
|
278 |
+
)[1]
|
279 |
+
)
|
280 |
+
video_path_generator = run_one_scene(
|
281 |
+
task=task,
|
282 |
+
version_dict={
|
283 |
+
"H": H,
|
284 |
+
"W": W,
|
285 |
+
"T": T,
|
286 |
+
"C": VERSION_DICT["C"],
|
287 |
+
"f": VERSION_DICT["f"],
|
288 |
+
"options": options,
|
289 |
+
},
|
290 |
+
model=MODEL,
|
291 |
+
ae=AE,
|
292 |
+
conditioner=CONDITIONER,
|
293 |
+
denoiser=DENOISER,
|
294 |
+
image_cond=image_cond,
|
295 |
+
camera_cond=camera_cond,
|
296 |
+
save_path=render_dir,
|
297 |
+
use_traj_prior=True,
|
298 |
+
traj_prior_c2ws=anchor_c2ws,
|
299 |
+
traj_prior_Ks=anchor_Ks,
|
300 |
+
seed=seed,
|
301 |
+
gradio=True,
|
302 |
+
)
|
303 |
+
for video_path in video_path_generator:
|
304 |
+
# Rename the video file to the desired format
|
305 |
+
new_video_path = osp.join(render_dir, f"{render_name}.mp4")
|
306 |
+
os.rename(video_path, new_video_path)
|
307 |
+
return new_video_path
|
308 |
+
return ""
|
309 |
+
|
310 |
+
def get_target_c2ws_and_Ks_from_preset(
|
311 |
+
self,
|
312 |
+
preprocessed: dict,
|
313 |
+
preset_traj: Literal[
|
314 |
+
"orbit",
|
315 |
+
"spiral",
|
316 |
+
"lemniscate",
|
317 |
+
"zoom-in",
|
318 |
+
"zoom-out",
|
319 |
+
"dolly zoom-in",
|
320 |
+
"dolly zoom-out",
|
321 |
+
"move-forward",
|
322 |
+
"move-backward",
|
323 |
+
"move-up",
|
324 |
+
"move-down",
|
325 |
+
"move-left",
|
326 |
+
"move-right",
|
327 |
+
],
|
328 |
+
num_frames: int,
|
329 |
+
zoom_factor: float | None,
|
330 |
+
):
|
331 |
+
img_wh = preprocessed["input_wh"]
|
332 |
+
start_c2w = preprocessed["input_c2ws"][0]
|
333 |
+
start_w2c = torch.linalg.inv(start_c2w)
|
334 |
+
look_at = torch.tensor([0, 0, 10])
|
335 |
+
start_fov = DEFAULT_FOV_RAD
|
336 |
+
target_c2ws, target_fovs = get_preset_pose_fov(
|
337 |
+
preset_traj,
|
338 |
+
num_frames,
|
339 |
+
start_w2c,
|
340 |
+
look_at,
|
341 |
+
-start_c2w[:3, 1],
|
342 |
+
start_fov,
|
343 |
+
spiral_radii=[1.0, 1.0, 0.5],
|
344 |
+
zoom_factor=zoom_factor,
|
345 |
+
)
|
346 |
+
target_c2ws = torch.as_tensor(target_c2ws)
|
347 |
+
target_fovs = torch.as_tensor(target_fovs)
|
348 |
+
target_Ks = get_default_intrinsics(
|
349 |
+
target_fovs, # type: ignore
|
350 |
+
aspect_ratio=img_wh[0] / img_wh[1],
|
351 |
+
)
|
352 |
+
return target_c2ws, target_Ks
|
353 |
+
|
354 |
+
|
355 |
+
def main(
|
356 |
+
input_path: str,
|
357 |
+
preset_traj: List[Literal[
|
358 |
+
"orbit",
|
359 |
+
"spiral",
|
360 |
+
"lemniscate",
|
361 |
+
"zoom-in",
|
362 |
+
"zoom-out",
|
363 |
+
"dolly zoom-in",
|
364 |
+
"dolly zoom-out",
|
365 |
+
"move-forward",
|
366 |
+
"move-backward",
|
367 |
+
"move-up",
|
368 |
+
"move-down",
|
369 |
+
"move-left",
|
370 |
+
"move-right",
|
371 |
+
]],
|
372 |
+
num_frames: int = 80,
|
373 |
+
zoom_factor: float | None = None,
|
374 |
+
seed: int = 23,
|
375 |
+
chunk_strategy: str = "interp",
|
376 |
+
cfg: float = 4.0,
|
377 |
+
camera_scale: float = 2.0,
|
378 |
+
output_dir: str = WORK_DIR,
|
379 |
+
):
|
380 |
+
renderer = SevaRenderer()
|
381 |
+
|
382 |
+
# Check if input_path is a directory or a single image
|
383 |
+
if osp.isdir(input_path):
|
384 |
+
image_paths = [osp.join(input_path, fname) for fname in os.listdir(input_path) if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
385 |
+
else:
|
386 |
+
image_paths = [input_path]
|
387 |
+
|
388 |
+
for input_img_path in image_paths:
|
389 |
+
preprocessed = renderer.preprocess(input_img_path)
|
390 |
+
preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict
|
391 |
+
|
392 |
+
for traj in preset_traj:
|
393 |
+
video_path = renderer.render(
|
394 |
+
preprocessed,
|
395 |
+
seed,
|
396 |
+
chunk_strategy,
|
397 |
+
cfg,
|
398 |
+
traj,
|
399 |
+
num_frames,
|
400 |
+
zoom_factor,
|
401 |
+
camera_scale,
|
402 |
+
output_dir,
|
403 |
+
)
|
404 |
+
print(f"Rendered video saved to: {video_path}")
|
405 |
+
|
406 |
+
|
407 |
+
if __name__ == "__main__":
|
408 |
+
tyro.cli(main)
|