AetherV1 / app.py
Wenzheng Chang
final version
cd4da5b
import gc
import os
import random
import re
from datetime import datetime
from typing import Dict, List, Optional, Tuple
import gradio as gr
import imageio.v3 as iio
import numpy as np
import PIL
import rootutils
import torch
from diffusers import (
AutoencoderKLCogVideoX,
CogVideoXDPMScheduler,
CogVideoXTransformer3DModel,
)
from transformers import AutoTokenizer, T5EncoderModel
import spaces
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
from aether.pipelines.aetherv1_pipeline_cogvideox import ( # noqa: E402
AetherV1PipelineCogVideoX,
AetherV1PipelineOutput,
)
from aether.utils.postprocess_utils import ( # noqa: E402
align_camera_extrinsics,
apply_transformation,
colorize_depth,
compute_scale,
get_intrinsics,
interpolate_poses,
postprocess_pointmap,
project,
raymap_to_poses,
)
from aether.utils.visualize_utils import predictions_to_glb # noqa: E402
def seed_all(seed: int = 0) -> None:
"""
Set random seeds of all components.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# # Global pipeline
cogvideox_pretrained_model_name_or_path: str = "THUDM/CogVideoX-5b-I2V"
aether_pretrained_model_name_or_path: str = "AetherWorldModel/AetherV1"
pipeline = AetherV1PipelineCogVideoX(
tokenizer=AutoTokenizer.from_pretrained(
cogvideox_pretrained_model_name_or_path,
subfolder="tokenizer",
),
text_encoder=T5EncoderModel.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="text_encoder"
),
vae=AutoencoderKLCogVideoX.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch.bfloat16
),
scheduler=CogVideoXDPMScheduler.from_pretrained(
cogvideox_pretrained_model_name_or_path, subfolder="scheduler"
),
transformer=CogVideoXTransformer3DModel.from_pretrained(
aether_pretrained_model_name_or_path, subfolder="transformer", torch_dtype=torch.bfloat16
),
)
pipeline.vae.enable_slicing()
pipeline.vae.enable_tiling()
def build_pipeline(device: torch.device) -> AetherV1PipelineCogVideoX:
"""Initialize the model pipeline."""
pipeline.to(device)
return pipeline
def get_window_starts(
total_frames: int, sliding_window_size: int, temporal_stride: int
) -> List[int]:
"""Calculate window start indices."""
starts = list(
range(
0,
total_frames - sliding_window_size + 1,
temporal_stride,
)
)
if (
total_frames > sliding_window_size
and (total_frames - sliding_window_size) % temporal_stride != 0
):
starts.append(total_frames - sliding_window_size)
return starts
def blend_and_merge_window_results(
window_results: List[AetherV1PipelineOutput], window_indices: List[int], args: Dict
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Blend and merge window results."""
merged_rgb = None
merged_disparity = None
merged_poses = None
merged_focals = None
align_pointmaps = args.get("align_pointmaps", True)
smooth_camera = args.get("smooth_camera", True)
smooth_method = args.get("smooth_method", "kalman") if smooth_camera else "none"
if align_pointmaps:
merged_pointmaps = None
w1 = window_results[0].disparity
for idx, (window_result, t_start) in enumerate(zip(window_results, window_indices)):
t_end = t_start + window_result.rgb.shape[0]
if idx == 0:
merged_rgb = window_result.rgb
merged_disparity = window_result.disparity
pointmap_dict = postprocess_pointmap(
window_result.disparity,
window_result.raymap,
vae_downsample_scale=8,
ray_o_scale_inv=0.1,
smooth_camera=smooth_camera,
smooth_method=smooth_method if smooth_camera else "none",
)
merged_poses = pointmap_dict["camera_pose"]
merged_focals = (
pointmap_dict["intrinsics"][:, 0, 0]
+ pointmap_dict["intrinsics"][:, 1, 1]
) / 2
if align_pointmaps:
merged_pointmaps = pointmap_dict["pointmap"]
else:
overlap_t = window_indices[idx - 1] + window_result.rgb.shape[0] - t_start
window_disparity = window_result.disparity
# Align disparity
disp_mask = window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]) > 0.1
scale = compute_scale(
window_disparity[:overlap_t].reshape(1, -1, w1.shape[-1]),
merged_disparity[-overlap_t:].reshape(1, -1, w1.shape[-1]),
disp_mask.reshape(1, -1, w1.shape[-1]),
)
window_disparity = scale * window_disparity
# Blend disparity
result_disparity = np.ones((t_end, *w1.shape[1:]))
result_disparity[:t_start] = merged_disparity[:t_start]
result_disparity[t_start + overlap_t :] = window_disparity[overlap_t:]
weight = np.linspace(1, 0, overlap_t)[:, None, None]
result_disparity[t_start : t_start + overlap_t] = merged_disparity[
t_start : t_start + overlap_t
] * weight + window_disparity[:overlap_t] * (1 - weight)
merged_disparity = result_disparity
# Blend RGB
result_rgb = np.ones((t_end, *w1.shape[1:], 3))
result_rgb[:t_start] = merged_rgb[:t_start]
result_rgb[t_start + overlap_t :] = window_result.rgb[overlap_t:]
weight_rgb = np.linspace(1, 0, overlap_t)[:, None, None, None]
result_rgb[t_start : t_start + overlap_t] = merged_rgb[
t_start : t_start + overlap_t
] * weight_rgb + window_result.rgb[:overlap_t] * (1 - weight_rgb)
merged_rgb = result_rgb
# Align poses
window_raymap = window_result.raymap
window_poses, window_Fov_x, window_Fov_y = raymap_to_poses(
window_raymap, ray_o_scale_inv=0.1
)
rel_r, rel_t, rel_s = align_camera_extrinsics(
torch.from_numpy(window_poses[:overlap_t]),
torch.from_numpy(merged_poses[-overlap_t:]),
)
aligned_window_poses = (
apply_transformation(
torch.from_numpy(window_poses),
rel_r,
rel_t,
rel_s,
return_extri=True,
)
.cpu()
.numpy()
)
result_poses = np.ones((t_end, 4, 4))
result_poses[:t_start] = merged_poses[:t_start]
result_poses[t_start + overlap_t :] = aligned_window_poses[overlap_t:]
# Interpolate poses in overlap region
weights = np.linspace(1, 0, overlap_t)
for t in range(overlap_t):
weight = weights[t]
pose1 = merged_poses[t_start + t]
pose2 = aligned_window_poses[t]
result_poses[t_start + t] = interpolate_poses(pose1, pose2, weight)
merged_poses = result_poses
# Align intrinsics
window_intrinsics, _ = get_intrinsics(
batch_size=window_poses.shape[0],
h=window_result.disparity.shape[1],
w=window_result.disparity.shape[2],
fovx=window_Fov_x,
fovy=window_Fov_y,
)
window_focals = (
window_intrinsics[:, 0, 0] + window_intrinsics[:, 1, 1]
) / 2
scale = (merged_focals[-overlap_t:] / window_focals[:overlap_t]).mean()
window_focals = scale * window_focals
result_focals = np.ones((t_end,))
result_focals[:t_start] = merged_focals[:t_start]
result_focals[t_start + overlap_t :] = window_focals[overlap_t:]
weight = np.linspace(1, 0, overlap_t)
result_focals[t_start : t_start + overlap_t] = merged_focals[
t_start : t_start + overlap_t
] * weight + window_focals[:overlap_t] * (1 - weight)
merged_focals = result_focals
if align_pointmaps:
# Align pointmaps
window_pointmaps = postprocess_pointmap(
result_disparity[t_start:],
window_raymap,
vae_downsample_scale=8,
camera_pose=aligned_window_poses,
focal=window_focals,
ray_o_scale_inv=0.1,
smooth_camera=smooth_camera,
smooth_method=smooth_method if smooth_camera else "none",
)
result_pointmaps = np.ones((t_end, *w1.shape[1:], 3))
result_pointmaps[:t_start] = merged_pointmaps[:t_start]
result_pointmaps[t_start + overlap_t :] = window_pointmaps["pointmap"][
overlap_t:
]
weight = np.linspace(1, 0, overlap_t)[:, None, None, None]
result_pointmaps[t_start : t_start + overlap_t] = merged_pointmaps[
t_start : t_start + overlap_t
] * weight + window_pointmaps["pointmap"][:overlap_t] * (1 - weight)
merged_pointmaps = result_pointmaps
# project to pointmaps
height = args.get("height", 480)
width = args.get("width", 720)
intrinsics = [
np.array([[f, 0, 0.5 * width], [0, f, 0.5 * height], [0, 0, 1]])
for f in merged_focals
]
if align_pointmaps:
pointmaps = merged_pointmaps
else:
pointmaps = np.stack(
[
project(
1 / np.clip(merged_disparity[i], 1e-8, 1e8),
intrinsics[i],
merged_poses[i],
)
for i in range(merged_poses.shape[0])
]
)
return merged_rgb, merged_disparity, merged_poses, pointmaps
def process_video_to_frames(video_path: str, fps_sample: int = 12) -> List[str]:
"""Process video into frames and save them locally."""
# Create a unique output directory
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_dir = f"temp_frames_{timestamp}"
os.makedirs(output_dir, exist_ok=True)
# Read video
video = iio.imread(video_path)
# Calculate frame interval based on original video fps
if isinstance(video, np.ndarray):
# For captured videos
total_frames = len(video)
frame_interval = max(
1, round(total_frames / (fps_sample * (total_frames / 30)))
)
else:
# Default if can't determine
frame_interval = 2
frame_paths = []
for i, frame in enumerate(video[::frame_interval]):
frame_path = os.path.join(output_dir, f"frame_{i:04d}.jpg")
if isinstance(frame, np.ndarray):
iio.imwrite(frame_path, frame)
frame_paths.append(frame_path)
return frame_paths, output_dir
def save_output_files(
rgb: np.ndarray,
disparity: np.ndarray,
poses: Optional[np.ndarray] = None,
raymap: Optional[np.ndarray] = None,
pointmap: Optional[np.ndarray] = None,
task: str = "reconstruction",
output_dir: str = "outputs",
**kwargs,
) -> Dict[str, str]:
"""
Save outputs and return paths to saved files.
"""
os.makedirs(output_dir, exist_ok=True)
if pointmap is None and raymap is not None:
# # Generate pointmap from raymap and disparity
# smooth_camera = kwargs.get("smooth_camera", True)
# smooth_method = (
# kwargs.get("smooth_method", "kalman") if smooth_camera else "none"
# )
# pointmap_dict = postprocess_pointmap(
# disparity,
# raymap,
# vae_downsample_scale=8,
# ray_o_scale_inv=0.1,
# smooth_camera=smooth_camera,
# smooth_method=smooth_method,
# )
# pointmap = pointmap_dict["pointmap"]
window_result = AetherV1PipelineOutput(
rgb=rgb,
disparity=disparity,
raymap=raymap
)
window_results = [window_result]
window_indices = [0]
_, _, poses_from_blend, pointmap = blend_and_merge_window_results(window_results, window_indices, kwargs)
# Use poses from blend_and_merge_window_results if poses is None
if poses is None:
poses = poses_from_blend
if poses is None and raymap is not None:
poses, _, _ = raymap_to_poses(raymap, ray_o_scale_inv=0.1)
# Create a unique filename
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
base_filename = f"{task}_{timestamp}"
# Paths for saved files
paths = {}
# Save RGB video
rgb_path = os.path.join(output_dir, f"{base_filename}_rgb.mp4")
iio.imwrite(
rgb_path,
(np.clip(rgb, 0, 1) * 255).astype(np.uint8),
fps=kwargs.get("fps", 12),
)
paths["rgb"] = rgb_path
# Save depth/disparity video
depth_path = os.path.join(output_dir, f"{base_filename}_disparity.mp4")
iio.imwrite(
depth_path,
(colorize_depth(disparity) * 255).astype(np.uint8),
fps=kwargs.get("fps", 12),
)
paths["disparity"] = depth_path
# Save point cloud GLB files
if pointmap is not None and poses is not None:
pointcloud_save_frame_interval = kwargs.get(
"pointcloud_save_frame_interval", 10
)
max_depth = kwargs.get("max_depth", 100.0)
rtol = kwargs.get("rtol", 0.03)
glb_paths = []
# Determine which frames to save based on the interval
frames_to_save = list(
range(0, pointmap.shape[0], pointcloud_save_frame_interval)
)
# Always include the first and last frame
if 0 not in frames_to_save:
frames_to_save.insert(0, 0)
if pointmap.shape[0] - 1 not in frames_to_save:
frames_to_save.append(pointmap.shape[0] - 1)
# Sort the frames to ensure they're in order
frames_to_save = sorted(set(frames_to_save))
for frame_idx in frames_to_save:
if frame_idx >= pointmap.shape[0]:
continue
# fix the problem of point cloud being upside down and left-right reversed: flip Y axis and X axis
flipped_pointmap = pointmap[frame_idx:frame_idx+1].copy()
flipped_pointmap[..., 1] = -flipped_pointmap[..., 1] # flip Y axis (up and down)
flipped_pointmap[..., 0] = -flipped_pointmap[..., 0] # flip X axis (left and right)
# flip camera poses
flipped_poses = poses[frame_idx:frame_idx+1].copy()
# flip Y axis and X axis of camera orientation
flipped_poses[..., 1, :3] = -flipped_poses[..., 1, :3] # flip Y axis of camera orientation
flipped_poses[..., 0, :3] = -flipped_poses[..., 0, :3] # flip X axis of camera orientation
flipped_poses[..., :3, 1] = -flipped_poses[..., :3, 1] # flip Y axis of camera orientation
flipped_poses[..., :3, 0] = -flipped_poses[..., :3, 0] # flip X axis of camera orientation
# flip Y axis and X axis of camera position
flipped_poses[..., 1, 3] = -flipped_poses[..., 1, 3] # flip Y axis position
flipped_poses[..., 0, 3] = -flipped_poses[..., 0, 3] # flip X axis position
# use flipped point cloud and camera poses
predictions = {
"world_points": flipped_pointmap,
"images": rgb[frame_idx : frame_idx + 1],
"depths": 1 / np.clip(disparity[frame_idx : frame_idx + 1], 1e-8, 1e8),
"camera_poses": flipped_poses,
}
glb_path = os.path.join(
output_dir, f"{base_filename}_pointcloud_frame_{frame_idx}.glb"
)
scene_3d = predictions_to_glb(
predictions,
filter_by_frames="all",
show_cam=True,
max_depth=max_depth,
rtol=rtol,
frame_rel_idx=float(frame_idx) / pointmap.shape[0],
)
scene_3d.export(glb_path)
glb_paths.append(glb_path)
paths["pointcloud_glbs"] = glb_paths
return paths
@spaces.GPU(duration=300)
def process_reconstruction(
video_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process reconstruction task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
seed_all(seed)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
pipeline = build_pipeline(device)
progress(0.1, "Loading video")
# Check if video_file is a string or a file object
if isinstance(video_file, str):
video_path = video_file
else:
video_path = video_file.name
video = iio.imread(video_path).astype(np.float32) / 255.0
# Setup arguments
args = {
"height": height,
"width": width,
"num_frames": num_frames,
"sliding_window_stride": sliding_window_stride,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
# Process in sliding windows
window_results = []
window_indices = get_window_starts(
len(video), num_frames, sliding_window_stride
)
progress(0.2, f"Processing video in {len(window_indices)} windows")
for i, start_idx in enumerate(window_indices):
progress_val = 0.2 + (0.6 * (i / len(window_indices)))
progress(progress_val, f"Processing window {i+1}/{len(window_indices)}")
output = pipeline(
task="reconstruction",
image=None,
goal=None,
video=video[start_idx : start_idx + num_frames],
raymap=None,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
window_results.append(output)
progress(0.8, "Merging results from all windows")
# Merge window results
(
merged_rgb,
merged_disparity,
merged_poses,
pointmaps,
) = blend_and_merge_window_results(window_results, window_indices, args)
progress(0.9, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=merged_rgb,
disparity=merged_disparity,
poses=merged_poses,
pointmap=pointmaps,
task="reconstruction",
output_dir=output_dir,
fps=12,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
@spaces.GPU(duration=300)
def process_prediction(
image_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process prediction task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
# Set random seed
seed_all(seed)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Build the pipeline
pipeline = build_pipeline(device)
progress(0.1, "Loading image")
# Check if image_file is a string or a file object
if isinstance(image_file, str):
image_path = image_file
else:
image_path = image_file.name
image = PIL.Image.open(image_path)
progress(0.2, "Running prediction")
# Run prediction
output = pipeline(
task="prediction",
image=image,
video=None,
goal=None,
raymap=np.load(f"assets/example_raymaps/raymap_{raymap_option}.npy"),
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=use_dynamic_cfg,
generator=torch.Generator(device=device).manual_seed(seed),
return_dict=True,
)
# Show RGB output immediately
rgb_output = output.rgb
# Setup arguments for saving
args = {
"height": height,
"width": width,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
if post_reconstruction:
progress(0.5, "Running post-reconstruction for better quality")
recon_output = pipeline(
task="reconstruction",
video=output.rgb,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=4,
guidance_scale=1.0,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
disparity = recon_output.disparity
raymap = recon_output.raymap
else:
disparity = output.disparity
raymap = output.raymap
progress(0.8, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=rgb_output,
disparity=disparity,
raymap=raymap,
task="prediction",
output_dir=output_dir,
fps=12,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
@spaces.GPU(duration=300)
def process_planning(
image_file,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
progress=gr.Progress(),
):
"""
Process planning task.
"""
try:
gc.collect()
torch.cuda.empty_cache()
# Set random seed
seed_all(seed)
# Check if CUDA is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Build the pipeline
pipeline = build_pipeline(device)
progress(0.1, "Loading images")
# Check if image_file and goal_file are strings or file objects
if isinstance(image_file, str):
image_path = image_file
else:
image_path = image_file.name
if isinstance(goal_file, str):
goal_path = goal_file
else:
goal_path = goal_file.name
image = PIL.Image.open(image_path)
goal = PIL.Image.open(goal_path)
progress(0.2, "Running planning")
# Run planning
output = pipeline(
task="planning",
image=image,
video=None,
goal=goal,
raymap=None,
height=height,
width=width,
num_frames=num_frames,
fps=fps,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
use_dynamic_cfg=use_dynamic_cfg,
generator=torch.Generator(device=device).manual_seed(seed),
return_dict=True,
)
# Show RGB output immediately
rgb_output = output.rgb
# Setup arguments for saving
args = {
"height": height,
"width": width,
"smooth_camera": smooth_camera,
"smooth_method": "kalman" if smooth_camera else "none",
"align_pointmaps": align_pointmaps,
"max_depth": max_depth,
"rtol": rtol,
"pointcloud_save_frame_interval": pointcloud_save_frame_interval,
}
if post_reconstruction:
progress(0.5, "Running post-reconstruction for better quality")
recon_output = pipeline(
task="reconstruction",
video=output.rgb,
height=height,
width=width,
num_frames=num_frames,
fps=12,
num_inference_steps=4,
guidance_scale=1.0,
use_dynamic_cfg=False,
generator=torch.Generator(device=device).manual_seed(seed),
)
disparity = recon_output.disparity
raymap = recon_output.raymap
else:
disparity = output.disparity
raymap = output.raymap
progress(0.8, "Saving output files")
# Save output files
output_dir = "outputs"
os.makedirs(output_dir, exist_ok=True)
output_paths = save_output_files(
rgb=rgb_output,
disparity=disparity,
raymap=raymap,
task="planning",
output_dir=output_dir,
fps=fps,
**args,
)
progress(1.0, "Done!")
# Return paths for displaying
return (
output_paths["rgb"],
output_paths["disparity"],
output_paths.get("pointcloud_glbs", []),
)
except Exception:
import traceback
traceback.print_exc()
return None, None, []
def update_task_ui(task):
"""Update UI elements based on selected task."""
if task == "reconstruction":
return (
gr.update(visible=True), # reconstruction_group
gr.update(visible=False), # prediction_group
gr.update(visible=False), # planning_group
gr.update(visible=False), # preview_row
gr.update(value=4), # num_inference_steps
gr.update(visible=True), # sliding_window_stride
gr.update(visible=False), # use_dynamic_cfg
gr.update(visible=False), # raymap_option
gr.update(visible=False), # post_reconstruction
gr.update(value=1.0), # guidance_scale
)
elif task == "prediction":
return (
gr.update(visible=False), # reconstruction_group
gr.update(visible=True), # prediction_group
gr.update(visible=False), # planning_group
gr.update(visible=True), # preview_row
gr.update(value=50), # num_inference_steps
gr.update(visible=False), # sliding_window_stride
gr.update(visible=True), # use_dynamic_cfg
gr.update(visible=True), # raymap_option
gr.update(visible=True), # post_reconstruction
gr.update(value=3.0), # guidance_scale
)
elif task == "planning":
return (
gr.update(visible=False), # reconstruction_group
gr.update(visible=False), # prediction_group
gr.update(visible=True), # planning_group
gr.update(visible=True), # preview_row
gr.update(value=50), # num_inference_steps
gr.update(visible=False), # sliding_window_stride
gr.update(visible=True), # use_dynamic_cfg
gr.update(visible=False), # raymap_option
gr.update(visible=True), # post_reconstruction
gr.update(value=3.0), # guidance_scale
)
def update_image_preview(image_file):
"""Update the image preview."""
if image_file is None:
return None
if isinstance(image_file, str):
return image_file
return image_file.name if hasattr(image_file, 'name') else None
def update_goal_preview(goal_file):
"""Update the goal preview."""
if goal_file is None:
return None
if isinstance(goal_file, str):
return goal_file
return goal_file.name if hasattr(goal_file, 'name') else None
def get_download_link(selected_frame, all_paths):
"""Update the download button with the selected file path."""
if not selected_frame or not all_paths:
return gr.update(visible=False, value=None)
frame_num = int(re.search(r"Frame (\d+)", selected_frame).group(1))
for path in all_paths:
if f"frame_{frame_num}" in path:
# Make sure the file exists before setting it
if os.path.exists(path):
return gr.update(visible=True, value=path, interactive=True)
return gr.update(visible=False, value=None)
# Theme setup
theme = gr.themes.Default(
primary_hue="blue",
secondary_hue="cyan",
)
with gr.Blocks(
theme=theme,
css="""
.output-column {
min-height: 400px;
}
.warning {
color: #856404 !important;
font-weight: bold !important;
padding: 10px !important;
background-color: #fff3cd !important;
border-left: 4px solid #ffc107 !important;
border-radius: 4px !important;
margin: 10px 0 !important;
}
.dark .warning {
background-color: rgba(255, 193, 7, 0.1) !important;
color: #fbd38d !important;
}
.highlight {
background-color: rgba(0, 123, 255, 0.1);
padding: 10px;
border-radius: 8px;
border-left: 5px solid #007bff;
margin: 10px 0;
}
.task-header {
margin-top: 15px;
margin-bottom: 20px;
font-size: 1.4em;
font-weight: bold;
color: #007bff;
}
.flex-display {
display: flex;
flex-wrap: wrap;
gap: 10px;
}
.output-subtitle {
font-size: 1.1em;
margin-top: 5px;
margin-bottom: 5px;
color: #505050;
}
.input-section, .params-section, .advanced-section {
border: 1px solid #ddd;
padding: 20px;
border-radius: 8px;
margin-bottom: 20px;
}
.logo-container {
display: flex;
justify-content: center;
margin-bottom: 20px;
}
.logo-image {
max-width: 300px;
height: auto;
}
/* Optimize layout and spacing */
.container {
margin: 0 auto;
padding: 0 15px;
max-width: 1800px;
}
.header {
text-align: center;
margin-bottom: 20px;
padding: 15px;
background: linear-gradient(to right, #f8f9fa, #e9ecef);
border-radius: 10px;
}
.dark .header {
background: linear-gradient(to right, #2d3748, #1a202c);
}
.main-title {
font-size: 2.2em;
font-weight: bold;
margin: 0 auto;
color: #2c3e50;
max-width: 800px;
}
.dark .main-title {
color: #e2e8f0;
}
.links-bar {
display: flex;
justify-content: center;
gap: 15px;
margin: 12px 0;
}
.link-button {
display: inline-flex;
align-items: center;
padding: 6px 12px;
background-color: #007bff;
color: white !important;
text-decoration: none;
border-radius: 5px;
transition: background-color 0.3s;
font-size: 0.95em;
}
.link-button:hover {
background-color: #0056b3;
text-decoration: none;
}
.features-limitations-container {
display: flex;
gap: 15px;
margin: 20px 0;
}
.capabilities-box, .limitations-box {
flex: 1;
padding: 18px;
border-radius: 8px;
margin-bottom: 15px;
}
.capabilities-box {
background: #f0f9ff;
border-left: 5px solid #3498db;
}
.dark .capabilities-box {
background: #172a3a;
border-left: 5px solid #3498db;
}
.limitations-box {
background: #f8f9fa;
border-left: 5px solid #ffc107;
}
.dark .limitations-box {
background: #2d2a20;
border-left: 5px solid #ffc107;
}
.capabilities-text, .limitations-text {
color: #495057;
line-height: 1.6;
}
.dark .capabilities-text, .dark .limitations-text {
color: #cbd5e0;
}
.capabilities-text h3 {
color: #2980b9;
margin-top: 0;
margin-bottom: 15px;
}
.dark .capabilities-text h3 {
color: #63b3ed;
}
.limitations-text h3 {
color: #d39e00;
margin-top: 0;
margin-bottom: 15px;
}
.dark .limitations-text h3 {
color: #fbd38d;
}
.capabilities-text blockquote, .limitations-text blockquote {
margin: 20px 0 0 0;
padding: 10px 20px;
font-style: italic;
}
.capabilities-text blockquote {
border-left: 3px solid #3498db;
background: rgba(52, 152, 219, 0.1);
}
.dark .capabilities-text blockquote {
background: rgba(52, 152, 219, 0.2);
}
.limitations-text blockquote {
border-left: 3px solid #ffc107;
background: rgba(255, 193, 7, 0.1);
}
.dark .limitations-text blockquote {
background: rgba(255, 193, 7, 0.2);
}
/* Optimize layout and spacing */
.main-interface {
display: flex;
gap: 30px;
margin-top: 20px;
}
.input-column, .output-column {
flex: 1;
min-width: 0;
display: flex;
flex-direction: column;
}
.output-panel {
border: 1px solid #ddd;
border-radius: 8px;
padding: 20px;
height: 100%;
display: flex;
flex-direction: column;
overflow-y: auto;
}
.dark .output-panel {
border-color: #4a5568;
}
.run-button-container {
display: flex;
justify-content: center;
margin: 15px 0;
}
.run-button {
padding: 10px 30px;
font-size: 1.1em;
font-weight: bold;
background: linear-gradient(to right, #3498db, #2980b9);
border: none;
border-radius: 5px;
color: white;
cursor: pointer;
transition: all 0.3s;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
}
.run-button:hover {
background: linear-gradient(to right, #2980b9, #1a5276);
box-shadow: 0 6px 8px rgba(0, 0, 0, 0.15);
transform: translateY(-2px);
}
.task-selector {
background-color: #f8f9fa;
padding: 12px;
border-radius: 8px;
margin-bottom: 15px;
border: 1px solid #e9ecef;
}
.dark .task-selector {
background-color: #2d3748;
border-color: #4a5568;
}
/* Compact parameter settings */
.compact-params .row {
margin-bottom: 8px;
}
.compact-params label {
margin-bottom: 4px;
}
/* More obvious advanced options */
.advanced-options-header {
background-color: #e9ecef;
padding: 10px 15px;
border-radius: 6px;
margin-top: 10px;
font-weight: bold;
color: #495057;
border-left: 4px solid #6c757d;
cursor: pointer;
transition: all 0.2s;
}
.advanced-options-header:hover {
background-color: #dee2e6;
}
.dark .advanced-options-header {
background-color: #2d3748;
color: #e2e8f0;
border-left: 4px solid #a0aec0;
}
.dark .advanced-options-header:hover {
background-color: #4a5568;
}
/* Vertical arrangement of output section */
.output-section {
margin-bottom: 30px;
border: 1px solid #e9ecef;
border-radius: 8px;
padding: 20px;
}
.output-section-title {
font-weight: bold;
color: #495057;
margin-bottom: 15px;
font-size: 1.2em;
}
.dark .output-section-title {
color: #e2e8f0;
}
.pointcloud-controls {
display: flex;
gap: 10px;
margin-bottom: 10px;
align-items: center;
}
.note-box {
background-color: #fff8e1 !important;
border-left: 4px solid #ffc107 !important;
padding: 12px !important;
margin: 15px 0 !important;
border-radius: 4px !important;
color: #333 !important;
}
.dark .note-box {
background-color: rgba(255, 193, 7, 0.1) !important;
color: #e0e0e0 !important;
}
.note-box p, .note-box strong {
color: inherit !important;
}
/* Ensure warning class styles are correctly applied */
.warning {
color: #856404 !important;
font-weight: bold !important;
padding: 10px !important;
background-color: #fff3cd !important;
border-left: 4px solid #ffc107 !important;
border-radius: 4px !important;
margin: 10px 0 !important;
}
.dark .warning {
background-color: rgba(255, 193, 7, 0.1) !important;
color: #fbd38d !important;
}
.warning-box {
background-color: #fff3cd;
border-left: 4px solid #ffc107;
padding: 12px;
margin: 15px 0;
border-radius: 4px;
color: #856404;
}
.dark .warning-box {
background-color: rgba(255, 193, 7, 0.1);
color: #fbd38d;
}
""",
) as demo:
with gr.Column(elem_classes=["container"]):
with gr.Row(elem_classes=["header"]):
with gr.Column():
gr.Markdown(
"""
# Aether: Geometric-Aware Unified World Modeling
""",
elem_classes=["main-title"]
)
gr.Markdown(
"""
<div class="links-bar">
🌐<a href="https://aether-world.github.io/" class="link-button" target="_blank"> Project Page</a>
📄<a href="https://arxiv.org/abs/2503.18945" class="link-button" target="_blank"> Paper</a>
💻<a href="https://github.com/OpenRobotLab/Aether" class="link-button" target="_blank"> Code</a>
🤗<a href="https://huggingface.co/AetherWorldModel/AetherV1" class="link-button" target="_blank"> Model</a>
</div>
""",
)
with gr.Row(elem_classes=["features-limitations-container"]):
with gr.Column(elem_classes=["capabilities-box"]):
gr.Markdown(
"""
### 🚀 Key Capabilities
Aether addresses a fundamental challenge in AI: integrating geometric reconstruction with generative modeling for human-like spatial reasoning. Our framework unifies three core capabilities:
- 🌏 **4D Dynamic Reconstruction**: Reconstruct dynamic point clouds from videos by estimating depths and camera poses.
- 🎬 **Action-Conditioned Prediction**: Predict future frames based on initial observations, with optional camera trajectory actions.
- 🎯 **Goal-Conditioned Planning**: Generate planning paths from pairs of observation and goal images.
> *Trained entirely on synthetic data, Aether achieves strong zero-shot generalization to real-world scenarios.*
""",
elem_classes=["capabilities-text"]
)
with gr.Column(elem_classes=["limitations-box"]):
gr.Markdown(
"""
### 📝 Current Limitations
Aether represents an initial step in our journey, trained entirely on synthetic data. While it demonstrates promising capabilities, it is important to be aware of its current limitations:
- 🔄 **Dynamic Scenarios**: Struggles with highly dynamic scenarios involving significant motion or dense crowds.
- 📸 **Camera Stability**: Camera pose estimation can be less stable in certain conditions.
- 📐 **Planning Range**: For visual planning tasks, we recommend keeping the observations and goals relatively close to ensure optimal performance.
> *We are actively working on the next generation of Aether and are committed to addressing these limitations in future releases.*
""",
elem_classes=["limitations-text"]
)
with gr.Row(elem_classes=["main-interface"]):
with gr.Column(elem_classes=["input-column"]):
gpu_time_warning = gr.Markdown(
"""
<div class="warning-box">
<strong>⚠️ Warning:</strong><br>
Due to HuggingFace Spaces ZERO GPU quota limitations, only short video reconstruction tasks (less than 100 frames) can be completed online.
<strong>💻 Recommendation:</strong><br>
We strongly encourage you to deploy Aether locally for:
- Processing longer video reconstruction tasks
- Better performance and full access to prediction and planning tasks
Visit our <a href="https://github.com/OpenRobotLab/Aether" target="_blank">GitHub repository</a> for local deployment instructions.
</div>
""",
)
with gr.Group(elem_classes=["task-selector"]):
task = gr.Radio(
["reconstruction", "prediction", "planning"],
label="Select Task",
value="reconstruction",
info="Choose the task you want to perform",
)
with gr.Group(elem_classes=["input-section"]):
gr.Markdown("## 📥 Input", elem_classes=["task-header"])
# Task-specific inputs
with gr.Group(visible=True) as reconstruction_group:
video_input = gr.Video(
label="Upload Input Video",
sources=["upload"],
interactive=True,
elem_id="video_input",
)
reconstruction_examples = gr.Examples(
examples=[
["assets/example_videos/bridge.mp4"],
["assets/example_videos/moviegen.mp4"],
["assets/example_videos/nuscenes.mp4"],
["assets/example_videos/veo2.mp4"],
],
inputs=[video_input],
label="Reconstruction Examples",
examples_per_page=4,
)
with gr.Group(visible=False) as prediction_group:
image_input = gr.Image(
label="Upload Start Image",
type="filepath",
interactive=True,
elem_id="image_input",
)
prediction_examples = gr.Examples(
examples=[
["assets/example_obs/car.png"],
["assets/example_obs/cartoon.png"],
["assets/example_obs/garden.jpg"],
["assets/example_obs/room.jpg"],
],
inputs=[image_input],
label="Prediction Examples",
examples_per_page=4,
)
with gr.Group(visible=False) as planning_group:
with gr.Row():
image_input_planning = gr.Image(
label="Upload Start Image",
type="filepath",
interactive=True,
elem_id="image_input_planning",
)
goal_input = gr.Image(
label="Upload Goal Image",
type="filepath",
interactive=True,
elem_id="goal_input",
)
planning_examples = gr.Examples(
examples=[
["assets/example_obs_goal/01_obs.png", "assets/example_obs_goal/01_goal.png"],
["assets/example_obs_goal/02_obs.png", "assets/example_obs_goal/02_goal.png"],
["assets/example_obs_goal/03_obs.png", "assets/example_obs_goal/03_goal.png"],
["assets/example_obs_goal/04_obs.png", "assets/example_obs_goal/04_goal.png"],
],
inputs=[image_input_planning, goal_input],
label="Planning Examples",
examples_per_page=4,
)
with gr.Row(visible=False) as preview_row:
image_preview = gr.Image(
label="Start Image Preview",
elem_id="image_preview",
visible=False,
)
goal_preview = gr.Image(
label="Goal Image Preview",
elem_id="goal_preview",
visible=False,
)
with gr.Group(elem_classes=["params-section", "compact-params"]):
gr.Markdown("## ⚙️ Parameters", elem_classes=["task-header"])
with gr.Row():
with gr.Column(scale=1):
height = gr.Dropdown(
choices=[480],
value=480,
label="Height",
info="Height of the output video",
)
with gr.Column(scale=1):
width = gr.Dropdown(
choices=[720],
value=720,
label="Width",
info="Width of the output video",
)
with gr.Row():
with gr.Column(scale=1):
num_frames = gr.Dropdown(
choices=[17, 25, 33, 41],
value=41,
label="Number of Frames",
info="Number of frames to predict",
)
with gr.Column(scale=1):
fps = gr.Dropdown(
choices=[8, 10, 12, 15, 24],
value=24,
label="FPS",
info="Frames per second",
)
with gr.Row():
num_inference_steps = gr.Slider(
minimum=1,
maximum=60,
value=4,
step=1,
label="Inference Steps",
info="Number of inference step",
)
sliding_window_stride = gr.Slider(
minimum=1,
maximum=40,
value=24,
step=1,
label="Sliding Window Stride",
info="Sliding window stride (window size equals to num_frames). Only used for 'reconstruction' task",
visible=True,
)
use_dynamic_cfg = gr.Checkbox(
label="Use Dynamic CFG",
value=True,
info="Use dynamic CFG",
visible=False,
)
raymap_option = gr.Radio(
choices=["backward", "forward_right", "left_forward", "right"],
label="Camera Movement Direction",
value="forward_right",
info="Direction of camera action. We offer 4 pre-defined actions for you to choose from.",
visible=False,
)
post_reconstruction = gr.Checkbox(
label="Post-Reconstruction",
value=True,
info="Run reconstruction after prediction for better quality",
visible=False,
)
with gr.Accordion(
"Advanced Options", open=False, visible=True, elem_classes=["advanced-options-header"]
) as advanced_options:
with gr.Group(elem_classes=["advanced-section"]):
with gr.Row():
guidance_scale = gr.Slider(
minimum=1.0,
maximum=10.0,
value=1.0,
step=0.1,
label="Guidance Scale",
info="Guidance scale (only for prediction / planning)",
)
with gr.Row():
seed = gr.Number(
value=42,
label="Random Seed",
info="Set a seed for reproducible results",
precision=0,
minimum=0,
maximum=2147483647,
)
with gr.Row():
with gr.Column(scale=1):
smooth_camera = gr.Checkbox(
label="Smooth Camera",
value=True,
info="Apply smoothing to camera trajectory",
)
with gr.Column(scale=1):
align_pointmaps = gr.Checkbox(
label="Align Point Maps",
value=False,
info="Align point maps across frames",
)
with gr.Row():
with gr.Column(scale=1):
max_depth = gr.Slider(
minimum=10,
maximum=200,
value=60,
step=10,
label="Max Depth",
info="Maximum depth for point cloud (higher = more distant points)",
)
with gr.Column(scale=1):
rtol = gr.Slider(
minimum=0.01,
maximum=2.0,
value=0.2,
step=0.01,
label="Relative Tolerance",
info="Used for depth edge detection. Lower = remove more edges",
)
pointcloud_save_frame_interval = gr.Slider(
minimum=1,
maximum=20,
value=10,
step=1,
label="Point Cloud Frame Interval",
info="Save point cloud every N frames (higher = fewer files but less complete representation)",
)
with gr.Group(elem_classes=["run-button-container"]):
run_button = gr.Button("Run Aether", variant="primary", elem_classes=["run-button"])
with gr.Column(elem_classes=["output-column"]):
with gr.Group(elem_classes=["output-panel"]):
gr.Markdown("## 📤 Output", elem_classes=["task-header"])
with gr.Group(elem_classes=["output-section"]):
gr.Markdown("### RGB Video", elem_classes=["output-section-title"])
rgb_output = gr.Video(
label="RGB Output", interactive=False, elem_id="rgb_output"
)
with gr.Group(elem_classes=["output-section"]):
gr.Markdown("### Depth Video", elem_classes=["output-section-title"])
depth_output = gr.Video(
label="Depth Output", interactive=False, elem_id="depth_output"
)
with gr.Group(elem_classes=["output-section"]):
gr.Markdown("### Point Clouds", elem_classes=["output-section-title"])
with gr.Row(elem_classes=["pointcloud-controls"]):
pointcloud_frames = gr.Dropdown(
label="Select Frame",
choices=[],
value=None,
interactive=True,
elem_id="pointcloud_frames",
)
pointcloud_download = gr.DownloadButton(
label="Download Point Cloud",
visible=False,
elem_id="pointcloud_download",
)
model_output = gr.Model3D(
label="Point Cloud Viewer", interactive=True, elem_id="model_output"
)
gr.Markdown(
"""
> **Note:** 3D point clouds take a long time to visualize, and we show the keyframes only.
> You can control the keyframe interval by modifying the `pointcloud_save_frame_interval`.
"""
)
with gr.Group(elem_classes=["output-section"]):
gr.Markdown("### About Results", elem_classes=["output-section-title"])
gr.Markdown(
"""
#### Understanding the Outputs
- **RGB Video**: Shows the predicted or reconstructed RGB frames
- **Depth Video**: Visualizes the disparity maps in color (closer = red, further = blue)
- **Point Clouds**: Interactive 3D point cloud with camera positions shown as colored pyramids
"""
)
# Event handlers
task.change(
fn=update_task_ui,
inputs=[task],
outputs=[
reconstruction_group,
prediction_group,
planning_group,
preview_row,
num_inference_steps,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
guidance_scale,
],
)
image_input.change(
fn=update_image_preview,
inputs=[image_input],
outputs=[image_preview]
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
goal_input.change(
fn=update_goal_preview,
inputs=[goal_input],
outputs=[goal_preview]
).then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[preview_row])
def update_pointcloud_frames(pointcloud_paths):
"""Update the pointcloud frames dropdown with available frames."""
if not pointcloud_paths:
return gr.update(choices=[], value=None), None, gr.update(visible=False)
# Extract frame numbers from filenames
frame_info = []
for path in pointcloud_paths:
filename = os.path.basename(path)
match = re.search(r"frame_(\d+)", filename)
if match:
frame_num = int(match.group(1))
frame_info.append((f"Frame {frame_num}", path))
# Sort by frame number
frame_info.sort(key=lambda x: int(re.search(r"Frame (\d+)", x[0]).group(1)))
choices = [label for label, _ in frame_info]
paths = [path for _, path in frame_info]
if not choices:
return gr.update(choices=[], value=None), None, gr.update(visible=False)
# Make download button visible when we have point cloud files
return (
gr.update(choices=choices, value=choices[0]),
paths[0],
gr.update(visible=True),
)
def select_pointcloud_frame(frame_label, all_paths):
"""Select a specific pointcloud frame."""
if not frame_label or not all_paths:
return None
frame_num = int(re.search(r"Frame (\d+)", frame_label).group(1))
for path in all_paths:
if f"frame_{frame_num}" in path:
return path
return None
# Then in the run button click handler:
def process_task(task_type, *args):
"""Process selected task with appropriate function."""
if task_type == "reconstruction":
rgb_path, depth_path, pointcloud_paths = process_reconstruction(*args)
# Update the pointcloud frames dropdown
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
elif task_type == "prediction":
rgb_path, depth_path, pointcloud_paths = process_prediction(*args)
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
elif task_type == "planning":
rgb_path, depth_path, pointcloud_paths = process_planning(*args)
frame_dropdown, initial_path, download_visible = update_pointcloud_frames(
pointcloud_paths
)
return (
rgb_path,
depth_path,
initial_path,
frame_dropdown,
pointcloud_paths,
download_visible,
)
return (
None,
None,
None,
gr.update(choices=[], value=None),
[],
gr.update(visible=False),
)
# Store all pointcloud paths for later use
all_pointcloud_paths = gr.State([])
run_button.click(
fn=lambda task_type,
video_file,
image_file,
image_input_planning,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed: process_task(
task_type,
*(
[
video_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
if task_type == "reconstruction"
else [
image_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
if task_type == "prediction"
else [
image_input_planning,
goal_file,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
use_dynamic_cfg,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
]
),
),
inputs=[
task,
video_input,
image_input,
image_input_planning,
goal_input,
height,
width,
num_frames,
num_inference_steps,
guidance_scale,
sliding_window_stride,
use_dynamic_cfg,
raymap_option,
post_reconstruction,
fps,
smooth_camera,
align_pointmaps,
max_depth,
rtol,
pointcloud_save_frame_interval,
seed,
],
outputs=[
rgb_output,
depth_output,
model_output,
pointcloud_frames,
all_pointcloud_paths,
pointcloud_download,
],
)
pointcloud_frames.change(
fn=select_pointcloud_frame,
inputs=[pointcloud_frames, all_pointcloud_paths],
outputs=[model_output],
).then(
fn=get_download_link,
inputs=[pointcloud_frames, all_pointcloud_paths],
outputs=[pointcloud_download],
)
# Load the model at startup
demo.load(lambda: build_pipeline(torch.device("cpu")), inputs=None, outputs=None)
if __name__ == "__main__":
os.environ["TOKENIZERS_PARALLELISM"] = "false"
demo.queue(max_size=20).launch(show_error=True, share=True)