svjack commited on
Commit
6b551c2
·
verified ·
1 Parent(s): 3ba11d5

Upload 3 files

Browse files
Files changed (3) hide show
  1. cli_all_app.py +388 -0
  2. cli_app.py +380 -0
  3. cli_batch_app.py +408 -0
cli_all_app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python cli_all_app.py --input_img_path 战场原.webp --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 战场原
3
+ '''
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+ import os.path as osp
9
+ import queue
10
+ import secrets
11
+ import threading
12
+ import time
13
+ from datetime import datetime
14
+ from glob import glob
15
+ from pathlib import Path
16
+ from typing import Literal, List
17
+
18
+ import imageio.v3 as iio
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import tyro
23
+ import viser
24
+ import viser.transforms as vt
25
+ from einops import rearrange
26
+
27
+ from seva.eval import (
28
+ IS_TORCH_NIGHTLY,
29
+ chunk_input_and_test,
30
+ create_transforms_simple,
31
+ infer_prior_stats,
32
+ run_one_scene,
33
+ transform_img_and_K,
34
+ )
35
+ from seva.geometry import (
36
+ DEFAULT_FOV_RAD,
37
+ get_default_intrinsics,
38
+ get_preset_pose_fov,
39
+ normalize_scene,
40
+ )
41
+ from seva.model import SGMWrapper
42
+ from seva.modules.autoencoder import AutoEncoder
43
+ from seva.modules.conditioner import CLIPConditioner
44
+ from seva.modules.preprocessor import Dust3rPipeline
45
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
46
+ from seva.utils import load_model
47
+
48
+ device = "cuda:0"
49
+
50
+ # Constants.
51
+ WORK_DIR = "work_dirs/demo_gr"
52
+ MAX_SESSIONS = 1
53
+
54
+ if IS_TORCH_NIGHTLY:
55
+ COMPILE = True
56
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
57
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
58
+ else:
59
+ COMPILE = False
60
+
61
+ # Shared global variables across sessions.
62
+ DUST3R = Dust3rPipeline(device=device) # type: ignore
63
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
64
+ AE = AutoEncoder(chunk_size=1).to(device)
65
+ CONDITIONER = CLIPConditioner().to(device)
66
+ DISCRETIZATION = DDPMDiscretization()
67
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
68
+ VERSION_DICT = {
69
+ "H": 576,
70
+ "W": 576,
71
+ "T": 21,
72
+ "C": 4,
73
+ "f": 8,
74
+ "options": {},
75
+ }
76
+ SERVERS = {}
77
+ ABORT_EVENTS = {}
78
+
79
+ if COMPILE:
80
+ MODEL = torch.compile(MODEL)
81
+ CONDITIONER = torch.compile(CONDITIONER)
82
+ AE = torch.compile(AE)
83
+
84
+
85
+ class SevaRenderer(object):
86
+ def __init__(self):
87
+ self.gui_state = None
88
+
89
+ def preprocess(self, input_img_path: str) -> dict:
90
+ # Simply hardcode these such that aspect ratio is always kept and
91
+ # shorter side is resized to 576. This is only to make GUI option fewer
92
+ # though, changing it still works.
93
+ shorter: int = 576
94
+ # Has to be 64 multiple for the network.
95
+ shorter = round(shorter / 64) * 64
96
+
97
+ # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
98
+ input_imgs = torch.as_tensor(
99
+ iio.imread(input_img_path) / 255.0, dtype=torch.float32
100
+ )[None, ..., :3]
101
+ input_imgs = transform_img_and_K(
102
+ input_imgs.permute(0, 3, 1, 2),
103
+ shorter,
104
+ K=None,
105
+ size_stride=64,
106
+ )[0].permute(0, 2, 3, 1)
107
+ input_Ks = get_default_intrinsics(
108
+ aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
109
+ )
110
+ input_c2ws = torch.eye(4)[None]
111
+ # Simulate a small time interval such that gradio can update
112
+ # propgress properly.
113
+ time.sleep(0.1)
114
+ return {
115
+ "input_imgs": input_imgs,
116
+ "input_Ks": input_Ks,
117
+ "input_c2ws": input_c2ws,
118
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
119
+ "points": [np.zeros((0, 3))],
120
+ "point_colors": [np.zeros((0, 3))],
121
+ "scene_scale": 1.0,
122
+ }
123
+
124
+ def render(
125
+ self,
126
+ preprocessed: dict,
127
+ seed: int,
128
+ chunk_strategy: str,
129
+ cfg: float,
130
+ preset_traj: Literal[
131
+ "orbit",
132
+ "spiral",
133
+ "lemniscate",
134
+ "zoom-in",
135
+ "zoom-out",
136
+ "dolly zoom-in",
137
+ "dolly zoom-out",
138
+ "move-forward",
139
+ "move-backward",
140
+ "move-up",
141
+ "move-down",
142
+ "move-left",
143
+ "move-right",
144
+ ],
145
+ num_frames: int,
146
+ zoom_factor: float | None,
147
+ camera_scale: float,
148
+ output_dir: str,
149
+ ) -> str:
150
+ # Generate a unique render name based on the input image filename and preset_traj
151
+ input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0]
152
+ render_name = f"{input_img_name}_{preset_traj}"
153
+ render_dir = osp.join(output_dir, render_name)
154
+
155
+ input_imgs, input_Ks, input_c2ws, (W, H) = (
156
+ preprocessed["input_imgs"],
157
+ preprocessed["input_Ks"],
158
+ preprocessed["input_c2ws"],
159
+ preprocessed["input_wh"],
160
+ )
161
+ num_inputs = len(input_imgs)
162
+ assert num_inputs == 1
163
+ input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
164
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
165
+ preprocessed, preset_traj, num_frames, zoom_factor
166
+ )
167
+ all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
168
+ all_Ks = (
169
+ torch.cat([input_Ks, target_Ks], 0)
170
+ * input_Ks.new_tensor([W, H, 1])[:, None]
171
+ )
172
+ num_targets = len(target_c2ws)
173
+ input_indices = list(range(num_inputs))
174
+ target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
175
+ # Get anchor cameras.
176
+ T = VERSION_DICT["T"]
177
+ version_dict = copy.deepcopy(VERSION_DICT)
178
+ num_anchors = infer_prior_stats(
179
+ T,
180
+ num_inputs,
181
+ num_total_frames=num_targets,
182
+ version_dict=version_dict,
183
+ )
184
+ # infer_prior_stats modifies T in-place.
185
+ T = version_dict["T"]
186
+ assert isinstance(num_anchors, int)
187
+ anchor_indices = np.linspace(
188
+ num_inputs,
189
+ num_inputs + num_targets - 1,
190
+ num_anchors,
191
+ ).tolist()
192
+ anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
193
+ anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
194
+ # Create image conditioning.
195
+ all_imgs_np = (
196
+ F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
197
+ * 255.0
198
+ ).astype(np.uint8)
199
+ image_cond = {
200
+ "img": all_imgs_np,
201
+ "input_indices": input_indices,
202
+ "prior_indices": anchor_indices,
203
+ }
204
+ # Create camera conditioning (K is unnormalized).
205
+ camera_cond = {
206
+ "c2w": all_c2ws,
207
+ "K": all_Ks,
208
+ "input_indices": list(range(num_inputs + num_targets)),
209
+ }
210
+ # Run rendering.
211
+ num_steps = 50
212
+ options_ori = VERSION_DICT["options"]
213
+ options = copy.deepcopy(options_ori)
214
+ options["chunk_strategy"] = chunk_strategy
215
+ options["video_save_fps"] = 30.0
216
+ options["beta_linear_start"] = 5e-6
217
+ options["log_snr_shift"] = 2.4
218
+ options["guider_types"] = [1, 2]
219
+ options["cfg"] = [
220
+ float(cfg),
221
+ 3.0 if num_inputs >= 9 else 2.0,
222
+ ] # We define semi-dense-view regime to have 9 input views.
223
+ options["camera_scale"] = camera_scale
224
+ options["num_steps"] = num_steps
225
+ options["cfg_min"] = 1.2
226
+ options["encoding_t"] = 1
227
+ options["decoding_t"] = 1
228
+ task = "img2trajvid"
229
+ # Get number of first pass chunks.
230
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
231
+ chunk_strategy_first_pass = options.get(
232
+ "chunk_strategy_first_pass", "gt-nearest"
233
+ )
234
+ num_chunks_0 = len(
235
+ chunk_input_and_test(
236
+ T_first_pass,
237
+ input_c2ws,
238
+ anchor_c2ws,
239
+ input_indices,
240
+ image_cond["prior_indices"],
241
+ options={**options, "sampler_verbose": False},
242
+ task=task,
243
+ chunk_strategy=chunk_strategy_first_pass,
244
+ gt_input_inds=list(range(input_c2ws.shape[0])),
245
+ )[1]
246
+ )
247
+ # Get number of second pass chunks.
248
+ anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
249
+ anchor_indices = np.array(input_indices + anchor_indices)[
250
+ anchor_argsort
251
+ ].tolist()
252
+ gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
253
+ anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
254
+ anchor_argsort
255
+ ]
256
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
257
+ chunk_strategy = options.get("chunk_strategy", "nearest")
258
+ num_chunks_1 = len(
259
+ chunk_input_and_test(
260
+ T_second_pass,
261
+ anchor_c2ws_second_pass,
262
+ target_c2ws,
263
+ anchor_indices,
264
+ target_indices,
265
+ options={**options, "sampler_verbose": False},
266
+ task=task,
267
+ chunk_strategy=chunk_strategy,
268
+ gt_input_inds=gt_input_inds,
269
+ )[1]
270
+ )
271
+ video_path_generator = run_one_scene(
272
+ task=task,
273
+ version_dict={
274
+ "H": H,
275
+ "W": W,
276
+ "T": T,
277
+ "C": VERSION_DICT["C"],
278
+ "f": VERSION_DICT["f"],
279
+ "options": options,
280
+ },
281
+ model=MODEL,
282
+ ae=AE,
283
+ conditioner=CONDITIONER,
284
+ denoiser=DENOISER,
285
+ image_cond=image_cond,
286
+ camera_cond=camera_cond,
287
+ save_path=render_dir,
288
+ use_traj_prior=True,
289
+ traj_prior_c2ws=anchor_c2ws,
290
+ traj_prior_Ks=anchor_Ks,
291
+ seed=seed,
292
+ gradio=True,
293
+ )
294
+ for video_path in video_path_generator:
295
+ return video_path
296
+ return ""
297
+
298
+ def get_target_c2ws_and_Ks_from_preset(
299
+ self,
300
+ preprocessed: dict,
301
+ preset_traj: Literal[
302
+ "orbit",
303
+ "spiral",
304
+ "lemniscate",
305
+ "zoom-in",
306
+ "zoom-out",
307
+ "dolly zoom-in",
308
+ "dolly zoom-out",
309
+ "move-forward",
310
+ "move-backward",
311
+ "move-up",
312
+ "move-down",
313
+ "move-left",
314
+ "move-right",
315
+ ],
316
+ num_frames: int,
317
+ zoom_factor: float | None,
318
+ ):
319
+ img_wh = preprocessed["input_wh"]
320
+ start_c2w = preprocessed["input_c2ws"][0]
321
+ start_w2c = torch.linalg.inv(start_c2w)
322
+ look_at = torch.tensor([0, 0, 10])
323
+ start_fov = DEFAULT_FOV_RAD
324
+ target_c2ws, target_fovs = get_preset_pose_fov(
325
+ preset_traj,
326
+ num_frames,
327
+ start_w2c,
328
+ look_at,
329
+ -start_c2w[:3, 1],
330
+ start_fov,
331
+ spiral_radii=[1.0, 1.0, 0.5],
332
+ zoom_factor=zoom_factor,
333
+ )
334
+ target_c2ws = torch.as_tensor(target_c2ws)
335
+ target_fovs = torch.as_tensor(target_fovs)
336
+ target_Ks = get_default_intrinsics(
337
+ target_fovs, # type: ignore
338
+ aspect_ratio=img_wh[0] / img_wh[1],
339
+ )
340
+ return target_c2ws, target_Ks
341
+
342
+
343
+ def main(
344
+ input_img_path: str,
345
+ preset_traj: List[Literal[
346
+ "orbit",
347
+ "spiral",
348
+ "lemniscate",
349
+ "zoom-in",
350
+ "zoom-out",
351
+ "dolly zoom-in",
352
+ "dolly zoom-out",
353
+ "move-forward",
354
+ "move-backward",
355
+ "move-up",
356
+ "move-down",
357
+ "move-left",
358
+ "move-right",
359
+ ]],
360
+ num_frames: int = 80,
361
+ zoom_factor: float | None = None,
362
+ seed: int = 23,
363
+ chunk_strategy: str = "interp",
364
+ cfg: float = 4.0,
365
+ camera_scale: float = 2.0,
366
+ output_dir: str = WORK_DIR,
367
+ ):
368
+ renderer = SevaRenderer()
369
+ preprocessed = renderer.preprocess(input_img_path)
370
+ preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict
371
+
372
+ for traj in preset_traj:
373
+ video_path = renderer.render(
374
+ preprocessed,
375
+ seed,
376
+ chunk_strategy,
377
+ cfg,
378
+ traj,
379
+ num_frames,
380
+ zoom_factor,
381
+ camera_scale,
382
+ output_dir,
383
+ )
384
+ print(f"Rendered video saved to: {video_path}")
385
+
386
+
387
+ if __name__ == "__main__":
388
+ tyro.cli(main)
cli_app.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python cli_app.py --input_img_path 战场原.webp --preset_traj orbit --num_frames 80 --seed 23 --chunk_strategy interp --cfg 4.0 --camera_scale 2.0
3
+ '''
4
+
5
+ import copy
6
+ import json
7
+ import os
8
+ import os.path as osp
9
+ import queue
10
+ import secrets
11
+ import threading
12
+ import time
13
+ from datetime import datetime
14
+ from glob import glob
15
+ from pathlib import Path
16
+ from typing import Literal
17
+
18
+ import imageio.v3 as iio
19
+ import numpy as np
20
+ import torch
21
+ import torch.nn.functional as F
22
+ import tyro
23
+ import viser
24
+ import viser.transforms as vt
25
+ from einops import rearrange
26
+
27
+ from seva.eval import (
28
+ IS_TORCH_NIGHTLY,
29
+ chunk_input_and_test,
30
+ create_transforms_simple,
31
+ infer_prior_stats,
32
+ run_one_scene,
33
+ transform_img_and_K,
34
+ )
35
+ from seva.geometry import (
36
+ DEFAULT_FOV_RAD,
37
+ get_default_intrinsics,
38
+ get_preset_pose_fov,
39
+ normalize_scene,
40
+ )
41
+ from seva.model import SGMWrapper
42
+ from seva.modules.autoencoder import AutoEncoder
43
+ from seva.modules.conditioner import CLIPConditioner
44
+ from seva.modules.preprocessor import Dust3rPipeline
45
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
46
+ from seva.utils import load_model
47
+
48
+ device = "cuda:0"
49
+
50
+ # Constants.
51
+ WORK_DIR = "work_dirs/demo_gr"
52
+ MAX_SESSIONS = 1
53
+
54
+ if IS_TORCH_NIGHTLY:
55
+ COMPILE = True
56
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
57
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
58
+ else:
59
+ COMPILE = False
60
+
61
+ # Shared global variables across sessions.
62
+ DUST3R = Dust3rPipeline(device=device) # type: ignore
63
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
64
+ AE = AutoEncoder(chunk_size=1).to(device)
65
+ CONDITIONER = CLIPConditioner().to(device)
66
+ DISCRETIZATION = DDPMDiscretization()
67
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
68
+ VERSION_DICT = {
69
+ "H": 576,
70
+ "W": 576,
71
+ "T": 21,
72
+ "C": 4,
73
+ "f": 8,
74
+ "options": {},
75
+ }
76
+ SERVERS = {}
77
+ ABORT_EVENTS = {}
78
+
79
+ if COMPILE:
80
+ MODEL = torch.compile(MODEL)
81
+ CONDITIONER = torch.compile(CONDITIONER)
82
+ AE = torch.compile(AE)
83
+
84
+
85
+ class SevaRenderer(object):
86
+ def __init__(self):
87
+ self.gui_state = None
88
+
89
+ def preprocess(self, input_img_path: str) -> dict:
90
+ # Simply hardcode these such that aspect ratio is always kept and
91
+ # shorter side is resized to 576. This is only to make GUI option fewer
92
+ # though, changing it still works.
93
+ shorter: int = 576
94
+ # Has to be 64 multiple for the network.
95
+ shorter = round(shorter / 64) * 64
96
+
97
+ # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
98
+ input_imgs = torch.as_tensor(
99
+ iio.imread(input_img_path) / 255.0, dtype=torch.float32
100
+ )[None, ..., :3]
101
+ input_imgs = transform_img_and_K(
102
+ input_imgs.permute(0, 3, 1, 2),
103
+ shorter,
104
+ K=None,
105
+ size_stride=64,
106
+ )[0].permute(0, 2, 3, 1)
107
+ input_Ks = get_default_intrinsics(
108
+ aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
109
+ )
110
+ input_c2ws = torch.eye(4)[None]
111
+ # Simulate a small time interval such that gradio can update
112
+ # propgress properly.
113
+ time.sleep(0.1)
114
+ return {
115
+ "input_imgs": input_imgs,
116
+ "input_Ks": input_Ks,
117
+ "input_c2ws": input_c2ws,
118
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
119
+ "points": [np.zeros((0, 3))],
120
+ "point_colors": [np.zeros((0, 3))],
121
+ "scene_scale": 1.0,
122
+ }
123
+
124
+ def render(
125
+ self,
126
+ preprocessed: dict,
127
+ seed: int,
128
+ chunk_strategy: str,
129
+ cfg: float,
130
+ preset_traj: Literal[
131
+ "orbit",
132
+ "spiral",
133
+ "lemniscate",
134
+ "zoom-in",
135
+ "zoom-out",
136
+ "dolly zoom-in",
137
+ "dolly zoom-out",
138
+ "move-forward",
139
+ "move-backward",
140
+ "move-up",
141
+ "move-down",
142
+ "move-left",
143
+ "move-right",
144
+ ],
145
+ num_frames: int,
146
+ zoom_factor: float | None,
147
+ camera_scale: float,
148
+ ) -> str:
149
+ render_name = datetime.now().strftime("%Y%m%d_%H%M%S")
150
+ render_dir = osp.join(WORK_DIR, render_name)
151
+
152
+ input_imgs, input_Ks, input_c2ws, (W, H) = (
153
+ preprocessed["input_imgs"],
154
+ preprocessed["input_Ks"],
155
+ preprocessed["input_c2ws"],
156
+ preprocessed["input_wh"],
157
+ )
158
+ num_inputs = len(input_imgs)
159
+ assert num_inputs == 1
160
+ input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
161
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
162
+ preprocessed, preset_traj, num_frames, zoom_factor
163
+ )
164
+ all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
165
+ all_Ks = (
166
+ torch.cat([input_Ks, target_Ks], 0)
167
+ * input_Ks.new_tensor([W, H, 1])[:, None]
168
+ )
169
+ num_targets = len(target_c2ws)
170
+ input_indices = list(range(num_inputs))
171
+ target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
172
+ # Get anchor cameras.
173
+ T = VERSION_DICT["T"]
174
+ version_dict = copy.deepcopy(VERSION_DICT)
175
+ num_anchors = infer_prior_stats(
176
+ T,
177
+ num_inputs,
178
+ num_total_frames=num_targets,
179
+ version_dict=version_dict,
180
+ )
181
+ # infer_prior_stats modifies T in-place.
182
+ T = version_dict["T"]
183
+ assert isinstance(num_anchors, int)
184
+ anchor_indices = np.linspace(
185
+ num_inputs,
186
+ num_inputs + num_targets - 1,
187
+ num_anchors,
188
+ ).tolist()
189
+ anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
190
+ anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
191
+ # Create image conditioning.
192
+ all_imgs_np = (
193
+ F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
194
+ * 255.0
195
+ ).astype(np.uint8)
196
+ image_cond = {
197
+ "img": all_imgs_np,
198
+ "input_indices": input_indices,
199
+ "prior_indices": anchor_indices,
200
+ }
201
+ # Create camera conditioning (K is unnormalized).
202
+ camera_cond = {
203
+ "c2w": all_c2ws,
204
+ "K": all_Ks,
205
+ "input_indices": list(range(num_inputs + num_targets)),
206
+ }
207
+ # Run rendering.
208
+ num_steps = 50
209
+ options_ori = VERSION_DICT["options"]
210
+ options = copy.deepcopy(options_ori)
211
+ options["chunk_strategy"] = chunk_strategy
212
+ options["video_save_fps"] = 30.0
213
+ options["beta_linear_start"] = 5e-6
214
+ options["log_snr_shift"] = 2.4
215
+ options["guider_types"] = [1, 2]
216
+ options["cfg"] = [
217
+ float(cfg),
218
+ 3.0 if num_inputs >= 9 else 2.0,
219
+ ] # We define semi-dense-view regime to have 9 input views.
220
+ options["camera_scale"] = camera_scale
221
+ options["num_steps"] = num_steps
222
+ options["cfg_min"] = 1.2
223
+ options["encoding_t"] = 1
224
+ options["decoding_t"] = 1
225
+ task = "img2trajvid"
226
+ # Get number of first pass chunks.
227
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
228
+ chunk_strategy_first_pass = options.get(
229
+ "chunk_strategy_first_pass", "gt-nearest"
230
+ )
231
+ num_chunks_0 = len(
232
+ chunk_input_and_test(
233
+ T_first_pass,
234
+ input_c2ws,
235
+ anchor_c2ws,
236
+ input_indices,
237
+ image_cond["prior_indices"],
238
+ options={**options, "sampler_verbose": False},
239
+ task=task,
240
+ chunk_strategy=chunk_strategy_first_pass,
241
+ gt_input_inds=list(range(input_c2ws.shape[0])),
242
+ )[1]
243
+ )
244
+ # Get number of second pass chunks.
245
+ anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
246
+ anchor_indices = np.array(input_indices + anchor_indices)[
247
+ anchor_argsort
248
+ ].tolist()
249
+ gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
250
+ anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
251
+ anchor_argsort
252
+ ]
253
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
254
+ chunk_strategy = options.get("chunk_strategy", "nearest")
255
+ num_chunks_1 = len(
256
+ chunk_input_and_test(
257
+ T_second_pass,
258
+ anchor_c2ws_second_pass,
259
+ target_c2ws,
260
+ anchor_indices,
261
+ target_indices,
262
+ options={**options, "sampler_verbose": False},
263
+ task=task,
264
+ chunk_strategy=chunk_strategy,
265
+ gt_input_inds=gt_input_inds,
266
+ )[1]
267
+ )
268
+ video_path_generator = run_one_scene(
269
+ task=task,
270
+ version_dict={
271
+ "H": H,
272
+ "W": W,
273
+ "T": T,
274
+ "C": VERSION_DICT["C"],
275
+ "f": VERSION_DICT["f"],
276
+ "options": options,
277
+ },
278
+ model=MODEL,
279
+ ae=AE,
280
+ conditioner=CONDITIONER,
281
+ denoiser=DENOISER,
282
+ image_cond=image_cond,
283
+ camera_cond=camera_cond,
284
+ save_path=render_dir,
285
+ use_traj_prior=True,
286
+ traj_prior_c2ws=anchor_c2ws,
287
+ traj_prior_Ks=anchor_Ks,
288
+ seed=seed,
289
+ gradio=True,
290
+ )
291
+ for video_path in video_path_generator:
292
+ return video_path
293
+ return ""
294
+
295
+ def get_target_c2ws_and_Ks_from_preset(
296
+ self,
297
+ preprocessed: dict,
298
+ preset_traj: Literal[
299
+ "orbit",
300
+ "spiral",
301
+ "lemniscate",
302
+ "zoom-in",
303
+ "zoom-out",
304
+ "dolly zoom-in",
305
+ "dolly zoom-out",
306
+ "move-forward",
307
+ "move-backward",
308
+ "move-up",
309
+ "move-down",
310
+ "move-left",
311
+ "move-right",
312
+ ],
313
+ num_frames: int,
314
+ zoom_factor: float | None,
315
+ ):
316
+ img_wh = preprocessed["input_wh"]
317
+ start_c2w = preprocessed["input_c2ws"][0]
318
+ start_w2c = torch.linalg.inv(start_c2w)
319
+ look_at = torch.tensor([0, 0, 10])
320
+ start_fov = DEFAULT_FOV_RAD
321
+ target_c2ws, target_fovs = get_preset_pose_fov(
322
+ preset_traj,
323
+ num_frames,
324
+ start_w2c,
325
+ look_at,
326
+ -start_c2w[:3, 1],
327
+ start_fov,
328
+ spiral_radii=[1.0, 1.0, 0.5],
329
+ zoom_factor=zoom_factor,
330
+ )
331
+ target_c2ws = torch.as_tensor(target_c2ws)
332
+ target_fovs = torch.as_tensor(target_fovs)
333
+ target_Ks = get_default_intrinsics(
334
+ target_fovs, # type: ignore
335
+ aspect_ratio=img_wh[0] / img_wh[1],
336
+ )
337
+ return target_c2ws, target_Ks
338
+
339
+
340
+ def main(
341
+ input_img_path: str,
342
+ preset_traj: Literal[
343
+ "orbit",
344
+ "spiral",
345
+ "lemniscate",
346
+ "zoom-in",
347
+ "zoom-out",
348
+ "dolly zoom-in",
349
+ "dolly zoom-out",
350
+ "move-forward",
351
+ "move-backward",
352
+ "move-up",
353
+ "move-down",
354
+ "move-left",
355
+ "move-right",
356
+ ] = "orbit",
357
+ num_frames: int = 80,
358
+ zoom_factor: float | None = None,
359
+ seed: int = 23,
360
+ chunk_strategy: str = "interp",
361
+ cfg: float = 4.0,
362
+ camera_scale: float = 2.0,
363
+ ):
364
+ renderer = SevaRenderer()
365
+ preprocessed = renderer.preprocess(input_img_path)
366
+ video_path = renderer.render(
367
+ preprocessed,
368
+ seed,
369
+ chunk_strategy,
370
+ cfg,
371
+ preset_traj,
372
+ num_frames,
373
+ zoom_factor,
374
+ camera_scale,
375
+ )
376
+ print(f"Rendered video saved to: {video_path}")
377
+
378
+
379
+ if __name__ == "__main__":
380
+ tyro.cli(main)
cli_batch_app.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python cli_batch_app.py --input_path imgs --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 相机路径
3
+
4
+ python cli_batch_app.py --input_path imgs --preset_traj "orbit" "spiral" "lemniscate" --output_dir 相机路径
5
+
6
+ #### 人物 或立体主体场景
7
+ "orbit"
8
+
9
+ #### 平面风景场景
10
+ "spiral" "lemniscate"
11
+ '''
12
+
13
+ import copy
14
+ import json
15
+ import os
16
+ import os.path as osp
17
+ import queue
18
+ import secrets
19
+ import threading
20
+ import time
21
+ from datetime import datetime
22
+ from glob import glob
23
+ from pathlib import Path
24
+ from typing import Literal, List
25
+
26
+ import imageio.v3 as iio
27
+ import numpy as np
28
+ import torch
29
+ import torch.nn.functional as F
30
+ import tyro
31
+ import viser
32
+ import viser.transforms as vt
33
+ from einops import rearrange
34
+
35
+ from seva.eval import (
36
+ IS_TORCH_NIGHTLY,
37
+ chunk_input_and_test,
38
+ create_transforms_simple,
39
+ infer_prior_stats,
40
+ run_one_scene,
41
+ transform_img_and_K,
42
+ )
43
+ from seva.geometry import (
44
+ DEFAULT_FOV_RAD,
45
+ get_default_intrinsics,
46
+ get_preset_pose_fov,
47
+ normalize_scene,
48
+ )
49
+ from seva.model import SGMWrapper
50
+ from seva.modules.autoencoder import AutoEncoder
51
+ from seva.modules.conditioner import CLIPConditioner
52
+ from seva.modules.preprocessor import Dust3rPipeline
53
+ from seva.sampling import DDPMDiscretization, DiscreteDenoiser
54
+ from seva.utils import load_model
55
+
56
+ device = "cuda:0"
57
+
58
+ # Constants.
59
+ WORK_DIR = "work_dirs/demo_gr"
60
+ MAX_SESSIONS = 1
61
+
62
+ if IS_TORCH_NIGHTLY:
63
+ COMPILE = True
64
+ os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
65
+ os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
66
+ else:
67
+ COMPILE = False
68
+
69
+ # Shared global variables across sessions.
70
+ DUST3R = Dust3rPipeline(device=device) # type: ignore
71
+ MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
72
+ AE = AutoEncoder(chunk_size=1).to(device)
73
+ CONDITIONER = CLIPConditioner().to(device)
74
+ DISCRETIZATION = DDPMDiscretization()
75
+ DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
76
+ VERSION_DICT = {
77
+ "H": 576,
78
+ "W": 576,
79
+ "T": 21,
80
+ "C": 4,
81
+ "f": 8,
82
+ "options": {},
83
+ }
84
+ SERVERS = {}
85
+ ABORT_EVENTS = {}
86
+
87
+ if COMPILE:
88
+ MODEL = torch.compile(MODEL)
89
+ CONDITIONER = torch.compile(CONDITIONER)
90
+ AE = torch.compile(AE)
91
+
92
+
93
+ class SevaRenderer(object):
94
+ def __init__(self):
95
+ self.gui_state = None
96
+
97
+ def preprocess(self, input_img_path: str) -> dict:
98
+ # Simply hardcode these such that aspect ratio is always kept and
99
+ # shorter side is resized to 576. This is only to make GUI option fewer
100
+ # though, changing it still works.
101
+ shorter: int = 576
102
+ # Has to be 64 multiple for the network.
103
+ shorter = round(shorter / 64) * 64
104
+
105
+ # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
106
+ input_imgs = torch.as_tensor(
107
+ iio.imread(input_img_path) / 255.0, dtype=torch.float32
108
+ )[None, ..., :3]
109
+ input_imgs = transform_img_and_K(
110
+ input_imgs.permute(0, 3, 1, 2),
111
+ shorter,
112
+ K=None,
113
+ size_stride=64,
114
+ )[0].permute(0, 2, 3, 1)
115
+ input_Ks = get_default_intrinsics(
116
+ aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
117
+ )
118
+ input_c2ws = torch.eye(4)[None]
119
+ # Simulate a small time interval such that gradio can update
120
+ # propgress properly.
121
+ time.sleep(0.1)
122
+ return {
123
+ "input_imgs": input_imgs,
124
+ "input_Ks": input_Ks,
125
+ "input_c2ws": input_c2ws,
126
+ "input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
127
+ "points": [np.zeros((0, 3))],
128
+ "point_colors": [np.zeros((0, 3))],
129
+ "scene_scale": 1.0,
130
+ }
131
+
132
+ def render(
133
+ self,
134
+ preprocessed: dict,
135
+ seed: int,
136
+ chunk_strategy: str,
137
+ cfg: float,
138
+ preset_traj: Literal[
139
+ "orbit",
140
+ "spiral",
141
+ "lemniscate",
142
+ "zoom-in",
143
+ "zoom-out",
144
+ "dolly zoom-in",
145
+ "dolly zoom-out",
146
+ "move-forward",
147
+ "move-backward",
148
+ "move-up",
149
+ "move-down",
150
+ "move-left",
151
+ "move-right",
152
+ ],
153
+ num_frames: int,
154
+ zoom_factor: float | None,
155
+ camera_scale: float,
156
+ output_dir: str,
157
+ ) -> str:
158
+ # Generate a unique render name based on the input image filename and preset_traj
159
+ input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0]
160
+ render_name = f"{input_img_name}_{preset_traj}"
161
+ render_dir = osp.join(output_dir, input_img_name)
162
+ os.makedirs(render_dir, exist_ok=True)
163
+
164
+ input_imgs, input_Ks, input_c2ws, (W, H) = (
165
+ preprocessed["input_imgs"],
166
+ preprocessed["input_Ks"],
167
+ preprocessed["input_c2ws"],
168
+ preprocessed["input_wh"],
169
+ )
170
+ num_inputs = len(input_imgs)
171
+ assert num_inputs == 1
172
+ input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
173
+ target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
174
+ preprocessed, preset_traj, num_frames, zoom_factor
175
+ )
176
+ all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
177
+ all_Ks = (
178
+ torch.cat([input_Ks, target_Ks], 0)
179
+ * input_Ks.new_tensor([W, H, 1])[:, None]
180
+ )
181
+ num_targets = len(target_c2ws)
182
+ input_indices = list(range(num_inputs))
183
+ target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
184
+ # Get anchor cameras.
185
+ T = VERSION_DICT["T"]
186
+ version_dict = copy.deepcopy(VERSION_DICT)
187
+ num_anchors = infer_prior_stats(
188
+ T,
189
+ num_inputs,
190
+ num_total_frames=num_targets,
191
+ version_dict=version_dict,
192
+ )
193
+ # infer_prior_stats modifies T in-place.
194
+ T = version_dict["T"]
195
+ assert isinstance(num_anchors, int)
196
+ anchor_indices = np.linspace(
197
+ num_inputs,
198
+ num_inputs + num_targets - 1,
199
+ num_anchors,
200
+ ).tolist()
201
+ anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
202
+ anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
203
+ # Create image conditioning.
204
+ all_imgs_np = (
205
+ F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
206
+ * 255.0
207
+ ).astype(np.uint8)
208
+ image_cond = {
209
+ "img": all_imgs_np,
210
+ "input_indices": input_indices,
211
+ "prior_indices": anchor_indices,
212
+ }
213
+ # Create camera conditioning (K is unnormalized).
214
+ camera_cond = {
215
+ "c2w": all_c2ws,
216
+ "K": all_Ks,
217
+ "input_indices": list(range(num_inputs + num_targets)),
218
+ }
219
+ # Run rendering.
220
+ num_steps = 50
221
+ options_ori = VERSION_DICT["options"]
222
+ options = copy.deepcopy(options_ori)
223
+ options["chunk_strategy"] = chunk_strategy
224
+ options["video_save_fps"] = 30.0
225
+ options["beta_linear_start"] = 5e-6
226
+ options["log_snr_shift"] = 2.4
227
+ options["guider_types"] = [1, 2]
228
+ options["cfg"] = [
229
+ float(cfg),
230
+ 3.0 if num_inputs >= 9 else 2.0,
231
+ ] # We define semi-dense-view regime to have 9 input views.
232
+ options["camera_scale"] = camera_scale
233
+ options["num_steps"] = num_steps
234
+ options["cfg_min"] = 1.2
235
+ options["encoding_t"] = 1
236
+ options["decoding_t"] = 1
237
+ task = "img2trajvid"
238
+ # Get number of first pass chunks.
239
+ T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
240
+ chunk_strategy_first_pass = options.get(
241
+ "chunk_strategy_first_pass", "gt-nearest"
242
+ )
243
+ num_chunks_0 = len(
244
+ chunk_input_and_test(
245
+ T_first_pass,
246
+ input_c2ws,
247
+ anchor_c2ws,
248
+ input_indices,
249
+ image_cond["prior_indices"],
250
+ options={**options, "sampler_verbose": False},
251
+ task=task,
252
+ chunk_strategy=chunk_strategy_first_pass,
253
+ gt_input_inds=list(range(input_c2ws.shape[0])),
254
+ )[1]
255
+ )
256
+ # Get number of second pass chunks.
257
+ anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
258
+ anchor_indices = np.array(input_indices + anchor_indices)[
259
+ anchor_argsort
260
+ ].tolist()
261
+ gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
262
+ anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
263
+ anchor_argsort
264
+ ]
265
+ T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
266
+ chunk_strategy = options.get("chunk_strategy", "nearest")
267
+ num_chunks_1 = len(
268
+ chunk_input_and_test(
269
+ T_second_pass,
270
+ anchor_c2ws_second_pass,
271
+ target_c2ws,
272
+ anchor_indices,
273
+ target_indices,
274
+ options={**options, "sampler_verbose": False},
275
+ task=task,
276
+ chunk_strategy=chunk_strategy,
277
+ gt_input_inds=gt_input_inds,
278
+ )[1]
279
+ )
280
+ video_path_generator = run_one_scene(
281
+ task=task,
282
+ version_dict={
283
+ "H": H,
284
+ "W": W,
285
+ "T": T,
286
+ "C": VERSION_DICT["C"],
287
+ "f": VERSION_DICT["f"],
288
+ "options": options,
289
+ },
290
+ model=MODEL,
291
+ ae=AE,
292
+ conditioner=CONDITIONER,
293
+ denoiser=DENOISER,
294
+ image_cond=image_cond,
295
+ camera_cond=camera_cond,
296
+ save_path=render_dir,
297
+ use_traj_prior=True,
298
+ traj_prior_c2ws=anchor_c2ws,
299
+ traj_prior_Ks=anchor_Ks,
300
+ seed=seed,
301
+ gradio=True,
302
+ )
303
+ for video_path in video_path_generator:
304
+ # Rename the video file to the desired format
305
+ new_video_path = osp.join(render_dir, f"{render_name}.mp4")
306
+ os.rename(video_path, new_video_path)
307
+ return new_video_path
308
+ return ""
309
+
310
+ def get_target_c2ws_and_Ks_from_preset(
311
+ self,
312
+ preprocessed: dict,
313
+ preset_traj: Literal[
314
+ "orbit",
315
+ "spiral",
316
+ "lemniscate",
317
+ "zoom-in",
318
+ "zoom-out",
319
+ "dolly zoom-in",
320
+ "dolly zoom-out",
321
+ "move-forward",
322
+ "move-backward",
323
+ "move-up",
324
+ "move-down",
325
+ "move-left",
326
+ "move-right",
327
+ ],
328
+ num_frames: int,
329
+ zoom_factor: float | None,
330
+ ):
331
+ img_wh = preprocessed["input_wh"]
332
+ start_c2w = preprocessed["input_c2ws"][0]
333
+ start_w2c = torch.linalg.inv(start_c2w)
334
+ look_at = torch.tensor([0, 0, 10])
335
+ start_fov = DEFAULT_FOV_RAD
336
+ target_c2ws, target_fovs = get_preset_pose_fov(
337
+ preset_traj,
338
+ num_frames,
339
+ start_w2c,
340
+ look_at,
341
+ -start_c2w[:3, 1],
342
+ start_fov,
343
+ spiral_radii=[1.0, 1.0, 0.5],
344
+ zoom_factor=zoom_factor,
345
+ )
346
+ target_c2ws = torch.as_tensor(target_c2ws)
347
+ target_fovs = torch.as_tensor(target_fovs)
348
+ target_Ks = get_default_intrinsics(
349
+ target_fovs, # type: ignore
350
+ aspect_ratio=img_wh[0] / img_wh[1],
351
+ )
352
+ return target_c2ws, target_Ks
353
+
354
+
355
+ def main(
356
+ input_path: str,
357
+ preset_traj: List[Literal[
358
+ "orbit",
359
+ "spiral",
360
+ "lemniscate",
361
+ "zoom-in",
362
+ "zoom-out",
363
+ "dolly zoom-in",
364
+ "dolly zoom-out",
365
+ "move-forward",
366
+ "move-backward",
367
+ "move-up",
368
+ "move-down",
369
+ "move-left",
370
+ "move-right",
371
+ ]],
372
+ num_frames: int = 80,
373
+ zoom_factor: float | None = None,
374
+ seed: int = 23,
375
+ chunk_strategy: str = "interp",
376
+ cfg: float = 4.0,
377
+ camera_scale: float = 2.0,
378
+ output_dir: str = WORK_DIR,
379
+ ):
380
+ renderer = SevaRenderer()
381
+
382
+ # Check if input_path is a directory or a single image
383
+ if osp.isdir(input_path):
384
+ image_paths = [osp.join(input_path, fname) for fname in os.listdir(input_path) if fname.lower().endswith(('.png', '.jpg', '.jpeg'))]
385
+ else:
386
+ image_paths = [input_path]
387
+
388
+ for input_img_path in image_paths:
389
+ preprocessed = renderer.preprocess(input_img_path)
390
+ preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict
391
+
392
+ for traj in preset_traj:
393
+ video_path = renderer.render(
394
+ preprocessed,
395
+ seed,
396
+ chunk_strategy,
397
+ cfg,
398
+ traj,
399
+ num_frames,
400
+ zoom_factor,
401
+ camera_scale,
402
+ output_dir,
403
+ )
404
+ print(f"Rendered video saved to: {video_path}")
405
+
406
+
407
+ if __name__ == "__main__":
408
+ tyro.cli(main)