''' python cli_app.py --input_img_path 战场原.webp --preset_traj orbit --num_frames 80 --seed 23 --chunk_strategy interp --cfg 4.0 --camera_scale 2.0 ''' import copy import json import os import os.path as osp import queue import secrets import threading import time from datetime import datetime from glob import glob from pathlib import Path from typing import Literal import imageio.v3 as iio import numpy as np import torch import torch.nn.functional as F import tyro import viser import viser.transforms as vt from einops import rearrange from seva.eval import ( IS_TORCH_NIGHTLY, chunk_input_and_test, create_transforms_simple, infer_prior_stats, run_one_scene, transform_img_and_K, ) from seva.geometry import ( DEFAULT_FOV_RAD, get_default_intrinsics, get_preset_pose_fov, normalize_scene, ) from seva.model import SGMWrapper from seva.modules.autoencoder import AutoEncoder from seva.modules.conditioner import CLIPConditioner from seva.modules.preprocessor import Dust3rPipeline from seva.sampling import DDPMDiscretization, DiscreteDenoiser from seva.utils import load_model device = "cuda:0" # Constants. WORK_DIR = "work_dirs/demo_gr" MAX_SESSIONS = 1 if IS_TORCH_NIGHTLY: COMPILE = True os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" else: COMPILE = False # Shared global variables across sessions. DUST3R = Dust3rPipeline(device=device) # type: ignore MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) AE = AutoEncoder(chunk_size=1).to(device) CONDITIONER = CLIPConditioner().to(device) DISCRETIZATION = DDPMDiscretization() DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) VERSION_DICT = { "H": 576, "W": 576, "T": 21, "C": 4, "f": 8, "options": {}, } SERVERS = {} ABORT_EVENTS = {} if COMPILE: MODEL = torch.compile(MODEL) CONDITIONER = torch.compile(CONDITIONER) AE = torch.compile(AE) class SevaRenderer(object): def __init__(self): self.gui_state = None def preprocess(self, input_img_path: str) -> dict: # Simply hardcode these such that aspect ratio is always kept and # shorter side is resized to 576. This is only to make GUI option fewer # though, changing it still works. shorter: int = 576 # Has to be 64 multiple for the network. shorter = round(shorter / 64) * 64 # Assume `Basic` demo mode: just hardcode the camera parameters and ignore points. input_imgs = torch.as_tensor( iio.imread(input_img_path) / 255.0, dtype=torch.float32 )[None, ..., :3] input_imgs = transform_img_and_K( input_imgs.permute(0, 3, 1, 2), shorter, K=None, size_stride=64, )[0].permute(0, 2, 3, 1) input_Ks = get_default_intrinsics( aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1] ) input_c2ws = torch.eye(4)[None] # Simulate a small time interval such that gradio can update # propgress properly. time.sleep(0.1) return { "input_imgs": input_imgs, "input_Ks": input_Ks, "input_c2ws": input_c2ws, "input_wh": (input_imgs.shape[2], input_imgs.shape[1]), "points": [np.zeros((0, 3))], "point_colors": [np.zeros((0, 3))], "scene_scale": 1.0, } def render( self, preprocessed: dict, seed: int, chunk_strategy: str, cfg: float, preset_traj: Literal[ "orbit", "spiral", "lemniscate", "zoom-in", "zoom-out", "dolly zoom-in", "dolly zoom-out", "move-forward", "move-backward", "move-up", "move-down", "move-left", "move-right", ], num_frames: int, zoom_factor: float | None, camera_scale: float, ) -> str: render_name = datetime.now().strftime("%Y%m%d_%H%M%S") render_dir = osp.join(WORK_DIR, render_name) input_imgs, input_Ks, input_c2ws, (W, H) = ( preprocessed["input_imgs"], preprocessed["input_Ks"], preprocessed["input_c2ws"], preprocessed["input_wh"], ) num_inputs = len(input_imgs) assert num_inputs == 1 input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype) target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset( preprocessed, preset_traj, num_frames, zoom_factor ) all_c2ws = torch.cat([input_c2ws, target_c2ws], 0) all_Ks = ( torch.cat([input_Ks, target_Ks], 0) * input_Ks.new_tensor([W, H, 1])[:, None] ) num_targets = len(target_c2ws) input_indices = list(range(num_inputs)) target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist() # Get anchor cameras. T = VERSION_DICT["T"] version_dict = copy.deepcopy(VERSION_DICT) num_anchors = infer_prior_stats( T, num_inputs, num_total_frames=num_targets, version_dict=version_dict, ) # infer_prior_stats modifies T in-place. T = version_dict["T"] assert isinstance(num_anchors, int) anchor_indices = np.linspace( num_inputs, num_inputs + num_targets - 1, num_anchors, ).tolist() anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]] anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]] # Create image conditioning. all_imgs_np = ( F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy() * 255.0 ).astype(np.uint8) image_cond = { "img": all_imgs_np, "input_indices": input_indices, "prior_indices": anchor_indices, } # Create camera conditioning (K is unnormalized). camera_cond = { "c2w": all_c2ws, "K": all_Ks, "input_indices": list(range(num_inputs + num_targets)), } # Run rendering. num_steps = 50 options_ori = VERSION_DICT["options"] options = copy.deepcopy(options_ori) options["chunk_strategy"] = chunk_strategy options["video_save_fps"] = 30.0 options["beta_linear_start"] = 5e-6 options["log_snr_shift"] = 2.4 options["guider_types"] = [1, 2] options["cfg"] = [ float(cfg), 3.0 if num_inputs >= 9 else 2.0, ] # We define semi-dense-view regime to have 9 input views. options["camera_scale"] = camera_scale options["num_steps"] = num_steps options["cfg_min"] = 1.2 options["encoding_t"] = 1 options["decoding_t"] = 1 task = "img2trajvid" # Get number of first pass chunks. T_first_pass = T[0] if isinstance(T, (list, tuple)) else T chunk_strategy_first_pass = options.get( "chunk_strategy_first_pass", "gt-nearest" ) num_chunks_0 = len( chunk_input_and_test( T_first_pass, input_c2ws, anchor_c2ws, input_indices, image_cond["prior_indices"], options={**options, "sampler_verbose": False}, task=task, chunk_strategy=chunk_strategy_first_pass, gt_input_inds=list(range(input_c2ws.shape[0])), )[1] ) # Get number of second pass chunks. anchor_argsort = np.argsort(input_indices + anchor_indices).tolist() anchor_indices = np.array(input_indices + anchor_indices)[ anchor_argsort ].tolist() gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])] anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[ anchor_argsort ] T_second_pass = T[1] if isinstance(T, (list, tuple)) else T chunk_strategy = options.get("chunk_strategy", "nearest") num_chunks_1 = len( chunk_input_and_test( T_second_pass, anchor_c2ws_second_pass, target_c2ws, anchor_indices, target_indices, options={**options, "sampler_verbose": False}, task=task, chunk_strategy=chunk_strategy, gt_input_inds=gt_input_inds, )[1] ) video_path_generator = run_one_scene( task=task, version_dict={ "H": H, "W": W, "T": T, "C": VERSION_DICT["C"], "f": VERSION_DICT["f"], "options": options, }, model=MODEL, ae=AE, conditioner=CONDITIONER, denoiser=DENOISER, image_cond=image_cond, camera_cond=camera_cond, save_path=render_dir, use_traj_prior=True, traj_prior_c2ws=anchor_c2ws, traj_prior_Ks=anchor_Ks, seed=seed, gradio=True, ) for video_path in video_path_generator: return video_path return "" def get_target_c2ws_and_Ks_from_preset( self, preprocessed: dict, preset_traj: Literal[ "orbit", "spiral", "lemniscate", "zoom-in", "zoom-out", "dolly zoom-in", "dolly zoom-out", "move-forward", "move-backward", "move-up", "move-down", "move-left", "move-right", ], num_frames: int, zoom_factor: float | None, ): img_wh = preprocessed["input_wh"] start_c2w = preprocessed["input_c2ws"][0] start_w2c = torch.linalg.inv(start_c2w) look_at = torch.tensor([0, 0, 10]) start_fov = DEFAULT_FOV_RAD target_c2ws, target_fovs = get_preset_pose_fov( preset_traj, num_frames, start_w2c, look_at, -start_c2w[:3, 1], start_fov, spiral_radii=[1.0, 1.0, 0.5], zoom_factor=zoom_factor, ) target_c2ws = torch.as_tensor(target_c2ws) target_fovs = torch.as_tensor(target_fovs) target_Ks = get_default_intrinsics( target_fovs, # type: ignore aspect_ratio=img_wh[0] / img_wh[1], ) return target_c2ws, target_Ks def main( input_img_path: str, preset_traj: Literal[ "orbit", "spiral", "lemniscate", "zoom-in", "zoom-out", "dolly zoom-in", "dolly zoom-out", "move-forward", "move-backward", "move-up", "move-down", "move-left", "move-right", ] = "orbit", num_frames: int = 80, zoom_factor: float | None = None, seed: int = 23, chunk_strategy: str = "interp", cfg: float = 4.0, camera_scale: float = 2.0, ): renderer = SevaRenderer() preprocessed = renderer.preprocess(input_img_path) video_path = renderer.render( preprocessed, seed, chunk_strategy, cfg, preset_traj, num_frames, zoom_factor, camera_scale, ) print(f"Rendered video saved to: {video_path}") if __name__ == "__main__": tyro.cli(main)