Spaces:
Runtime error
Runtime error
import glob | |
import os | |
import os.path as osp | |
import fire | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from PIL import Image | |
from tqdm import tqdm | |
from seva.data_io import get_parser | |
from seva.eval import ( | |
IS_TORCH_NIGHTLY, | |
compute_relative_inds, | |
create_transforms_simple, | |
infer_prior_inds, | |
infer_prior_stats, | |
run_one_scene, | |
) | |
from seva.geometry import ( | |
generate_interpolated_path, | |
generate_spiral_path, | |
get_arc_horizontal_w2cs, | |
get_default_intrinsics, | |
get_lookat, | |
get_preset_pose_fov, | |
) | |
from seva.model import SGMWrapper | |
from seva.modules.autoencoder import AutoEncoder | |
from seva.modules.conditioner import CLIPConditioner | |
from seva.sampling import DDPMDiscretization, DiscreteDenoiser | |
from seva.utils import load_model | |
device = "cuda:0" | |
# Constants. | |
WORK_DIR = "work_dirs/demo" | |
if IS_TORCH_NIGHTLY: | |
COMPILE = True | |
os.environ["TORCHINDUCTOR_AUTOGRAD_CACHE"] = "1" | |
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "1" | |
else: | |
COMPILE = False | |
MODEL = SGMWrapper(load_model(device="cpu", verbose=True).eval()).to(device) | |
AE = AutoEncoder(chunk_size=1).to(device) | |
CONDITIONER = CLIPConditioner().to(device) | |
DISCRETIZATION = DDPMDiscretization() | |
DENOISER = DiscreteDenoiser(discretization=DISCRETIZATION, num_idx=1000, device=device) | |
VERSION_DICT = { | |
"H": 576, | |
"W": 576, | |
"T": 21, | |
"C": 4, | |
"f": 8, | |
"options": {}, | |
} | |
if COMPILE: | |
MODEL = torch.compile(MODEL, dynamic=False) | |
CONDITIONER = torch.compile(CONDITIONER, dynamic=False) | |
AE = torch.compile(AE, dynamic=False) | |
def parse_task( | |
task, | |
scene, | |
num_inputs, | |
T, | |
version_dict, | |
): | |
options = version_dict["options"] | |
anchor_indices = None | |
anchor_c2ws = None | |
anchor_Ks = None | |
if task == "img2trajvid_s-prob": | |
if num_inputs is not None: | |
assert ( | |
num_inputs == 1 | |
), "Task `img2trajvid_s-prob` only support 1-view conditioning..." | |
else: | |
num_inputs = 1 | |
num_targets = options.get("num_targets", T - 1) | |
num_anchors = infer_prior_stats( | |
T, | |
num_inputs, | |
num_total_frames=num_targets, | |
version_dict=version_dict, | |
) | |
input_indices = [0] | |
anchor_indices = np.linspace(1, num_targets, num_anchors).tolist() | |
all_imgs_path = [scene] + [None] * num_targets | |
c2ws, fovs = get_preset_pose_fov( | |
option=options.get("traj_prior", "orbit"), | |
num_frames=num_targets + 1, | |
start_w2c=torch.eye(4), | |
look_at=torch.Tensor([0, 0, 10]), | |
) | |
with Image.open(scene) as img: | |
W, H = img.size | |
aspect_ratio = W / H | |
Ks = get_default_intrinsics(fovs, aspect_ratio=aspect_ratio) # unormalized | |
Ks[:, :2] *= ( | |
torch.tensor([W, H]).reshape(1, -1, 1).repeat(Ks.shape[0], 1, 1) | |
) # normalized | |
Ks = Ks.numpy() | |
anchor_c2ws = c2ws[[round(ind) for ind in anchor_indices]] | |
anchor_Ks = Ks[[round(ind) for ind in anchor_indices]] | |
else: | |
parser = get_parser( | |
parser_type="reconfusion", | |
data_dir=scene, | |
normalize=False, | |
) | |
all_imgs_path = parser.image_paths | |
c2ws = parser.camtoworlds | |
camera_ids = parser.camera_ids | |
Ks = np.concatenate([parser.Ks_dict[cam_id][None] for cam_id in camera_ids], 0) | |
if num_inputs is None: | |
assert len(parser.splits_per_num_input_frames.keys()) == 1 | |
num_inputs = list(parser.splits_per_num_input_frames.keys())[0] | |
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore | |
elif isinstance(num_inputs, str): | |
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore | |
num_inputs = int(num_inputs.split("-")[0]) # for example 1_from32 | |
else: | |
split_dict = parser.splits_per_num_input_frames[num_inputs] # type: ignore | |
num_targets = len(split_dict["test_ids"]) | |
if task == "img2img": | |
# Note in this setting, we should refrain from using all the other camera | |
# info except ones from sampled_indices, and most importantly, the order. | |
num_anchors = infer_prior_stats( | |
T, | |
num_inputs, | |
num_total_frames=num_targets, | |
version_dict=version_dict, | |
) | |
sampled_indices = np.sort( | |
np.array(split_dict["train_ids"] + split_dict["test_ids"]) | |
) # we always sort all indices first | |
traj_prior = options.get("traj_prior", None) | |
if traj_prior == "spiral": | |
assert parser.bounds is not None | |
anchor_c2ws = generate_spiral_path( | |
c2ws[sampled_indices] @ np.diagflat([1, -1, -1, 1]), | |
parser.bounds[sampled_indices], | |
n_frames=num_anchors + 1, | |
n_rots=2, | |
zrate=0.5, | |
endpoint=False, | |
)[1:] @ np.diagflat([1, -1, -1, 1]) | |
elif traj_prior == "interpolated": | |
assert num_inputs > 1 | |
anchor_c2ws = generate_interpolated_path( | |
c2ws[split_dict["train_ids"], :3], | |
round((num_anchors + 1) / (num_inputs - 1)), | |
endpoint=False, | |
)[1 : num_anchors + 1] | |
elif traj_prior == "orbit": | |
c2ws_th = torch.as_tensor(c2ws) | |
lookat = get_lookat( | |
c2ws_th[sampled_indices, :3, 3], | |
c2ws_th[sampled_indices, :3, 2], | |
) | |
anchor_c2ws = torch.linalg.inv( | |
get_arc_horizontal_w2cs( | |
torch.linalg.inv(c2ws_th[split_dict["train_ids"][0]]), | |
lookat, | |
-F.normalize( | |
c2ws_th[split_dict["train_ids"]][:, :3, 1].mean(0), | |
dim=-1, | |
), | |
num_frames=num_anchors + 1, | |
endpoint=False, | |
) | |
).numpy()[1:, :3] | |
else: | |
anchor_c2ws = None | |
# anchor_Ks is default to be the first from target_Ks | |
all_imgs_path = [all_imgs_path[i] for i in sampled_indices] | |
c2ws = c2ws[sampled_indices] | |
Ks = Ks[sampled_indices] | |
# absolute to relative indices | |
input_indices = compute_relative_inds( | |
sampled_indices, | |
np.array(split_dict["train_ids"]), | |
) | |
anchor_indices = np.arange( | |
sampled_indices.shape[0], | |
sampled_indices.shape[0] + num_anchors, | |
).tolist() # the order has no meaning here | |
elif task == "img2vid": | |
num_targets = len(all_imgs_path) - num_inputs | |
num_anchors = infer_prior_stats( | |
T, | |
num_inputs, | |
num_total_frames=num_targets, | |
version_dict=version_dict, | |
) | |
input_indices = split_dict["train_ids"] | |
anchor_indices = infer_prior_inds( | |
c2ws, | |
num_prior_frames=num_anchors, | |
input_frame_indices=input_indices, | |
options=options, | |
).tolist() | |
num_anchors = len(anchor_indices) | |
anchor_c2ws = c2ws[anchor_indices, :3] | |
anchor_Ks = Ks[anchor_indices] | |
elif task == "img2trajvid": | |
num_anchors = infer_prior_stats( | |
T, | |
num_inputs, | |
num_total_frames=num_targets, | |
version_dict=version_dict, | |
) | |
target_c2ws = c2ws[split_dict["test_ids"], :3] | |
target_Ks = Ks[split_dict["test_ids"]] | |
anchor_c2ws = target_c2ws[ | |
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) | |
] | |
anchor_Ks = target_Ks[ | |
np.linspace(0, num_targets - 1, num_anchors).round().astype(np.int64) | |
] | |
sampled_indices = split_dict["train_ids"] + split_dict["test_ids"] | |
all_imgs_path = [all_imgs_path[i] for i in sampled_indices] | |
c2ws = c2ws[sampled_indices] | |
Ks = Ks[sampled_indices] | |
input_indices = np.arange(num_inputs).tolist() | |
anchor_indices = np.linspace( | |
num_inputs, num_inputs + num_targets - 1, num_anchors | |
).tolist() | |
else: | |
raise ValueError(f"Unknown task: {task}") | |
return ( | |
all_imgs_path, | |
num_inputs, | |
num_targets, | |
input_indices, | |
anchor_indices, | |
torch.tensor(c2ws[:, :3]).float(), | |
torch.tensor(Ks).float(), | |
(torch.tensor(anchor_c2ws[:, :3]).float() if anchor_c2ws is not None else None), | |
(torch.tensor(anchor_Ks).float() if anchor_Ks is not None else None), | |
) | |
def main( | |
data_path, | |
data_items=None, | |
task="img2img", | |
save_subdir="", | |
H=None, | |
W=None, | |
T=None, | |
use_traj_prior=False, | |
**overwrite_options, | |
): | |
if H is not None: | |
VERSION_DICT["H"] = H | |
if W is not None: | |
VERSION_DICT["W"] = W | |
if T is not None: | |
VERSION_DICT["T"] = [int(t) for t in T.split(",")] if isinstance(T, str) else T | |
options = VERSION_DICT["options"] | |
options["chunk_strategy"] = "nearest-gt" | |
options["video_save_fps"] = 30.0 | |
options["beta_linear_start"] = 5e-6 | |
options["log_snr_shift"] = 2.4 | |
options["guider_types"] = 1 | |
options["cfg"] = 2.0 | |
options["camera_scale"] = 2.0 | |
options["num_steps"] = 50 | |
options["cfg_min"] = 1.2 | |
options["encoding_t"] = 1 | |
options["decoding_t"] = 1 | |
options["num_inputs"] = None | |
options["seed"] = 23 | |
options.update(overwrite_options) | |
num_inputs = options["num_inputs"] | |
seed = options["seed"] | |
if data_items is not None: | |
if not isinstance(data_items, (list, tuple)): | |
data_items = data_items.split(",") | |
scenes = [os.path.join(data_path, item) for item in data_items] | |
else: | |
scenes = glob.glob(osp.join(data_path, "*")) | |
for scene in tqdm(scenes): | |
save_path_scene = os.path.join( | |
WORK_DIR, task, save_subdir, os.path.splitext(os.path.basename(scene))[0] | |
) | |
if options.get("skip_saved", False) and os.path.exists( | |
os.path.join(save_path_scene, "transforms.json") | |
): | |
print(f"Skipping {scene} as it is already sampled.") | |
continue | |
# parse_task -> infer_prior_stats modifies VERSION_DICT["T"] in-place. | |
( | |
all_imgs_path, | |
num_inputs, | |
num_targets, | |
input_indices, | |
anchor_indices, | |
c2ws, | |
Ks, | |
anchor_c2ws, | |
anchor_Ks, | |
) = parse_task( | |
task, | |
scene, | |
num_inputs, | |
VERSION_DICT["T"], | |
VERSION_DICT, | |
) | |
assert num_inputs is not None | |
# Create image conditioning. | |
image_cond = { | |
"img": all_imgs_path, | |
"input_indices": input_indices, | |
"prior_indices": anchor_indices, | |
} | |
# Create camera conditioning. | |
camera_cond = { | |
"c2w": c2ws.clone(), | |
"K": Ks.clone(), | |
"input_indices": list(range(num_inputs + num_targets)), | |
} | |
# run_one_scene -> transform_img_and_K modifies VERSION_DICT["H"] and VERSION_DICT["W"] in-place. | |
video_path_generator = run_one_scene( | |
task, | |
VERSION_DICT, # H, W maybe updated in run_one_scene | |
model=MODEL, | |
ae=AE, | |
conditioner=CONDITIONER, | |
denoiser=DENOISER, | |
image_cond=image_cond, | |
camera_cond=camera_cond, | |
save_path=save_path_scene, | |
use_traj_prior=use_traj_prior, | |
traj_prior_Ks=anchor_Ks, | |
traj_prior_c2ws=anchor_c2ws, | |
seed=seed, # to ensure sampled video can be reproduced in regardless of start and i | |
) | |
for _ in video_path_generator: | |
pass | |
# Convert from OpenCV to OpenGL camera format. | |
c2ws = c2ws @ torch.tensor(np.diag([1, -1, -1, 1])).float() | |
img_paths = sorted(glob.glob(osp.join(save_path_scene, "samples-rgb", "*.png"))) | |
if len(img_paths) != len(c2ws): | |
input_img_paths = sorted( | |
glob.glob(osp.join(save_path_scene, "input", "*.png")) | |
) | |
assert len(img_paths) == num_targets | |
assert len(input_img_paths) == num_inputs | |
assert c2ws.shape[0] == num_inputs + num_targets | |
target_indices = [i for i in range(c2ws.shape[0]) if i not in input_indices] | |
img_paths = [ | |
input_img_paths[input_indices.index(i)] | |
if i in input_indices | |
else img_paths[target_indices.index(i)] | |
for i in range(c2ws.shape[0]) | |
] | |
create_transforms_simple( | |
save_path=save_path_scene, | |
img_paths=img_paths, | |
img_whs=np.array([VERSION_DICT["W"], VERSION_DICT["H"]])[None].repeat( | |
num_inputs + num_targets, 0 | |
), | |
c2ws=c2ws, | |
Ks=Ks, | |
) | |
if __name__ == "__main__": | |
fire.Fire(main) | |