stable-virtual-camera / cli_all_app.py
svjack's picture
Upload 3 files
6b551c2 verified
'''
python cli_all_app.py --input_img_path 战场原.webp --preset_traj "orbit" "spiral" "lemniscate" "zoom-in" "zoom-out" "dolly zoom-in" "dolly zoom-out" "move-forward" "move-backward" "move-up" "move-down" "move-left" "move-right" --output_dir 战场原
'''
import copy
import json
import os
import os.path as osp
import queue
import secrets
import threading
import time
from datetime import datetime
from glob import glob
from pathlib import Path
from typing import Literal, List
import imageio.v3 as iio
import numpy as np
import torch
import torch.nn.functional as F
import tyro
import viser
import viser.transforms as vt
from einops import rearrange
from seva.eval import (
IS_TORCH_NIGHTLY,
chunk_input_and_test,
create_transforms_simple,
infer_prior_stats,
run_one_scene,
transform_img_and_K,
)
from seva.geometry import (
DEFAULT_FOV_RAD,
get_default_intrinsics,
get_preset_pose_fov,
normalize_scene,
)
from seva.model import SGMWrapper
from seva.modules.autoencoder import AutoEncoder
from seva.modules.conditioner import CLIPConditioner
from seva.modules.preprocessor import Dust3rPipeline
from seva.sampling import DDPMDiscretization, DiscreteDenoiser
from seva.utils import load_model
device = "cuda:0"
# Constants.
WORK_DIR = "work_dirs/demo_gr"
MAX_SESSIONS = 1
if IS_TORCH_NIGHTLY:
COMPILE = True
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1"
else:
COMPILE = False
# Shared global variables across sessions.
DUST3R = Dust3rPipeline(device=device) # type: ignore
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device)
AE = AutoEncoder(chunk_size=1).to(device)
CONDITIONER = CLIPConditioner().to(device)
DISCRETIZATION = DDPMDiscretization()
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device)
VERSION_DICT = {
"H": 576,
"W": 576,
"T": 21,
"C": 4,
"f": 8,
"options": {},
}
SERVERS = {}
ABORT_EVENTS = {}
if COMPILE:
MODEL = torch.compile(MODEL)
CONDITIONER = torch.compile(CONDITIONER)
AE = torch.compile(AE)
class SevaRenderer(object):
def __init__(self):
self.gui_state = None
def preprocess(self, input_img_path: str) -> dict:
# Simply hardcode these such that aspect ratio is always kept and
# shorter side is resized to 576. This is only to make GUI option fewer
# though, changing it still works.
shorter: int = 576
# Has to be 64 multiple for the network.
shorter = round(shorter / 64) * 64
# Assume `Basic` demo mode: just hardcode the camera parameters and ignore points.
input_imgs = torch.as_tensor(
iio.imread(input_img_path) / 255.0, dtype=torch.float32
)[None, ..., :3]
input_imgs = transform_img_and_K(
input_imgs.permute(0, 3, 1, 2),
shorter,
K=None,
size_stride=64,
)[0].permute(0, 2, 3, 1)
input_Ks = get_default_intrinsics(
aspect_ratio=input_imgs.shape[2] / input_imgs.shape[1]
)
input_c2ws = torch.eye(4)[None]
# Simulate a small time interval such that gradio can update
# propgress properly.
time.sleep(0.1)
return {
"input_imgs": input_imgs,
"input_Ks": input_Ks,
"input_c2ws": input_c2ws,
"input_wh": (input_imgs.shape[2], input_imgs.shape[1]),
"points": [np.zeros((0, 3))],
"point_colors": [np.zeros((0, 3))],
"scene_scale": 1.0,
}
def render(
self,
preprocessed: dict,
seed: int,
chunk_strategy: str,
cfg: float,
preset_traj: Literal[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
],
num_frames: int,
zoom_factor: float | None,
camera_scale: float,
output_dir: str,
) -> str:
# Generate a unique render name based on the input image filename and preset_traj
input_img_name = osp.splitext(osp.basename(preprocessed["input_img_path"]))[0]
render_name = f"{input_img_name}_{preset_traj}"
render_dir = osp.join(output_dir, render_name)
input_imgs, input_Ks, input_c2ws, (W, H) = (
preprocessed["input_imgs"],
preprocessed["input_Ks"],
preprocessed["input_c2ws"],
preprocessed["input_wh"],
)
num_inputs = len(input_imgs)
assert num_inputs == 1
input_c2ws = torch.eye(4)[None].to(dtype=input_c2ws.dtype)
target_c2ws, target_Ks = self.get_target_c2ws_and_Ks_from_preset(
preprocessed, preset_traj, num_frames, zoom_factor
)
all_c2ws = torch.cat([input_c2ws, target_c2ws], 0)
all_Ks = (
torch.cat([input_Ks, target_Ks], 0)
* input_Ks.new_tensor([W, H, 1])[:, None]
)
num_targets = len(target_c2ws)
input_indices = list(range(num_inputs))
target_indices = np.arange(num_inputs, num_inputs + num_targets).tolist()
# Get anchor cameras.
T = VERSION_DICT["T"]
version_dict = copy.deepcopy(VERSION_DICT)
num_anchors = infer_prior_stats(
T,
num_inputs,
num_total_frames=num_targets,
version_dict=version_dict,
)
# infer_prior_stats modifies T in-place.
T = version_dict["T"]
assert isinstance(num_anchors, int)
anchor_indices = np.linspace(
num_inputs,
num_inputs + num_targets - 1,
num_anchors,
).tolist()
anchor_c2ws = all_c2ws[[round(ind) for ind in anchor_indices]]
anchor_Ks = all_Ks[[round(ind) for ind in anchor_indices]]
# Create image conditioning.
all_imgs_np = (
F.pad(input_imgs, (0, 0, 0, 0, 0, 0, 0, num_targets), value=0.0).numpy()
* 255.0
).astype(np.uint8)
image_cond = {
"img": all_imgs_np,
"input_indices": input_indices,
"prior_indices": anchor_indices,
}
# Create camera conditioning (K is unnormalized).
camera_cond = {
"c2w": all_c2ws,
"K": all_Ks,
"input_indices": list(range(num_inputs + num_targets)),
}
# Run rendering.
num_steps = 50
options_ori = VERSION_DICT["options"]
options = copy.deepcopy(options_ori)
options["chunk_strategy"] = chunk_strategy
options["video_save_fps"] = 30.0
options["beta_linear_start"] = 5e-6
options["log_snr_shift"] = 2.4
options["guider_types"] = [1, 2]
options["cfg"] = [
float(cfg),
3.0 if num_inputs >= 9 else 2.0,
] # We define semi-dense-view regime to have 9 input views.
options["camera_scale"] = camera_scale
options["num_steps"] = num_steps
options["cfg_min"] = 1.2
options["encoding_t"] = 1
options["decoding_t"] = 1
task = "img2trajvid"
# Get number of first pass chunks.
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
chunk_strategy_first_pass = options.get(
"chunk_strategy_first_pass", "gt-nearest"
)
num_chunks_0 = len(
chunk_input_and_test(
T_first_pass,
input_c2ws,
anchor_c2ws,
input_indices,
image_cond["prior_indices"],
options={**options, "sampler_verbose": False},
task=task,
chunk_strategy=chunk_strategy_first_pass,
gt_input_inds=list(range(input_c2ws.shape[0])),
)[1]
)
# Get number of second pass chunks.
anchor_argsort = np.argsort(input_indices + anchor_indices).tolist()
anchor_indices = np.array(input_indices + anchor_indices)[
anchor_argsort
].tolist()
gt_input_inds = [anchor_argsort.index(i) for i in range(input_c2ws.shape[0])]
anchor_c2ws_second_pass = torch.cat([input_c2ws, anchor_c2ws], dim=0)[
anchor_argsort
]
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
chunk_strategy = options.get("chunk_strategy", "nearest")
num_chunks_1 = len(
chunk_input_and_test(
T_second_pass,
anchor_c2ws_second_pass,
target_c2ws,
anchor_indices,
target_indices,
options={**options, "sampler_verbose": False},
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=gt_input_inds,
)[1]
)
video_path_generator = run_one_scene(
task=task,
version_dict={
"H": H,
"W": W,
"T": T,
"C": VERSION_DICT["C"],
"f": VERSION_DICT["f"],
"options": options,
},
model=MODEL,
ae=AE,
conditioner=CONDITIONER,
denoiser=DENOISER,
image_cond=image_cond,
camera_cond=camera_cond,
save_path=render_dir,
use_traj_prior=True,
traj_prior_c2ws=anchor_c2ws,
traj_prior_Ks=anchor_Ks,
seed=seed,
gradio=True,
)
for video_path in video_path_generator:
return video_path
return ""
def get_target_c2ws_and_Ks_from_preset(
self,
preprocessed: dict,
preset_traj: Literal[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
],
num_frames: int,
zoom_factor: float | None,
):
img_wh = preprocessed["input_wh"]
start_c2w = preprocessed["input_c2ws"][0]
start_w2c = torch.linalg.inv(start_c2w)
look_at = torch.tensor([0, 0, 10])
start_fov = DEFAULT_FOV_RAD
target_c2ws, target_fovs = get_preset_pose_fov(
preset_traj,
num_frames,
start_w2c,
look_at,
-start_c2w[:3, 1],
start_fov,
spiral_radii=[1.0, 1.0, 0.5],
zoom_factor=zoom_factor,
)
target_c2ws = torch.as_tensor(target_c2ws)
target_fovs = torch.as_tensor(target_fovs)
target_Ks = get_default_intrinsics(
target_fovs, # type: ignore
aspect_ratio=img_wh[0] / img_wh[1],
)
return target_c2ws, target_Ks
def main(
input_img_path: str,
preset_traj: List[Literal[
"orbit",
"spiral",
"lemniscate",
"zoom-in",
"zoom-out",
"dolly zoom-in",
"dolly zoom-out",
"move-forward",
"move-backward",
"move-up",
"move-down",
"move-left",
"move-right",
]],
num_frames: int = 80,
zoom_factor: float | None = None,
seed: int = 23,
chunk_strategy: str = "interp",
cfg: float = 4.0,
camera_scale: float = 2.0,
output_dir: str = WORK_DIR,
):
renderer = SevaRenderer()
preprocessed = renderer.preprocess(input_img_path)
preprocessed["input_img_path"] = input_img_path # Add input_img_path to preprocessed dict
for traj in preset_traj:
video_path = renderer.render(
preprocessed,
seed,
chunk_strategy,
cfg,
traj,
num_frames,
zoom_factor,
camera_scale,
output_dir,
)
print(f"Rendered video saved to: {video_path}")
if __name__ == "__main__":
tyro.cli(main)