|
from typing import List, Literal |
|
from pathlib import Path |
|
from functools import partial |
|
import spaces |
|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
from torchvision.datasets.utils import download_and_extract_archive |
|
from einops import repeat |
|
from omegaconf import OmegaConf |
|
from algorithms.dfot import DFoTVideoPose |
|
from history_guidance import HistoryGuidance |
|
from utils.ckpt_utils import download_pretrained |
|
from datasets.video.utils.io import read_video |
|
from export import export_to_video, export_to_gif, export_images_to_gif |
|
from camera_pose import extend_poses, CameraPose |
|
from scipy.spatial.transform import Rotation, Slerp |
|
|
|
DATASET_URL = "https://huggingface.co/kiwhansong/DFoT/resolve/main/datasets/RealEstate10K_Tiny.tar.gz" |
|
DATASET_DIR = Path("data/real-estate-10k-tiny") |
|
LONG_LENGTH = 10 |
|
NAVIGATION_FPS = 3 |
|
|
|
if not DATASET_DIR.exists(): |
|
DATASET_DIR.mkdir(parents=True) |
|
download_and_extract_archive( |
|
DATASET_URL, |
|
DATASET_DIR.parent, |
|
remove_finished=True, |
|
) |
|
|
|
|
|
metadata = torch.load(DATASET_DIR / "metadata" / "test.pt", weights_only=False) |
|
video_list = [ |
|
read_video(path).permute(0, 3, 1, 2) / 255.0 for path in metadata["video_paths"] |
|
] |
|
poses_list = [ |
|
torch.cat( |
|
[ |
|
poses[:, :4], |
|
poses[:, 6:], |
|
], |
|
dim=-1, |
|
).to(torch.float32) |
|
for poses in ( |
|
torch.load(DATASET_DIR / "test_poses" / f"{path.stem}.pt") |
|
for path in metadata["video_paths"] |
|
) |
|
] |
|
|
|
first_frame_list = [ |
|
(video[0] * 255).permute(1, 2, 0).numpy().clip(0, 255).astype("uint8") |
|
for video in video_list |
|
] |
|
gif_paths = [] |
|
for idx, video, path in zip( |
|
range(len(video_list)), video_list, metadata["video_paths"] |
|
): |
|
indices = torch.linspace(0, video.size(0) - 1, 16, dtype=torch.long) |
|
gif_paths.append(export_to_gif(video[indices], fps=8)) |
|
|
|
|
|
|
|
dfot = DFoTVideoPose.load_from_checkpoint( |
|
checkpoint_path=download_pretrained("pretrained:DFoT_RE10K.ckpt"), |
|
map_location="cpu", |
|
cfg=OmegaConf.load("config.yaml"), |
|
).eval() |
|
dfot.to("cuda") |
|
|
|
|
|
def prepare_long_gt_video(idx: int): |
|
video = video_list[idx] |
|
indices = torch.linspace(0, video.size(0) - 1, 200, dtype=torch.long) |
|
return export_to_video(video[indices], fps=200 // LONG_LENGTH) |
|
|
|
|
|
def prepare_short_gt_video(idx: int): |
|
video = video_list[idx] |
|
indices = torch.linspace(0, video.size(0) - 1, 8, dtype=torch.long) |
|
video = ( |
|
(video[indices].permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy() |
|
) |
|
return [video[i] for i in range(video.shape[0])] |
|
|
|
|
|
def video_to_gif_and_images(video, indices): |
|
masked_video = [ |
|
image if i in indices else np.zeros_like(image) for i, image in enumerate(video) |
|
] |
|
return [(export_images_to_gif(masked_video), "GIF")] + [ |
|
(image, f"t={i}" if i in indices else "") |
|
for i, image in enumerate(masked_video) |
|
] |
|
|
|
|
|
def get_duration_single_image_to_long_video(idx: int, guidance_scale: float, fps: int, progress:gr.Progress): |
|
return 30 * fps |
|
|
|
|
|
@spaces.GPU(duration=get_duration_single_image_to_long_video) |
|
@torch.autocast("cuda") |
|
@torch.no_grad() |
|
def single_image_to_long_video( |
|
idx: int, guidance_scale: float, fps: int, progress=gr.Progress(track_tqdm=True) |
|
): |
|
video = video_list[idx] |
|
poses = poses_list[idx] |
|
indices = torch.linspace(0, video.size(0) - 1, LONG_LENGTH * fps, dtype=torch.long) |
|
xs = video[indices].unsqueeze(0).to("cuda") |
|
conditions = poses[indices].unsqueeze(0).to("cuda") |
|
dfot.cfg.tasks.prediction.history_guidance.guidance_scale = guidance_scale |
|
dfot.cfg.tasks.prediction.keyframe_density = 12 / (fps * LONG_LENGTH) |
|
|
|
gen_video = dfot._unnormalize_x( |
|
dfot._predict_videos( |
|
dfot._normalize_x(xs), |
|
conditions, |
|
) |
|
) |
|
return export_to_video(gen_video[0].detach().cpu(), fps=fps) |
|
|
|
|
|
@spaces.GPU(duration=30) |
|
@torch.autocast("cuda") |
|
@torch.no_grad() |
|
def any_images_to_short_video( |
|
scene_idx: int, |
|
image_indices: List[int], |
|
guidance_scale: float, |
|
): |
|
video = video_list[scene_idx] |
|
poses = poses_list[scene_idx] |
|
indices = torch.linspace(0, video.size(0) - 1, 8, dtype=torch.long) |
|
xs = video[indices].unsqueeze(0).to("cuda") |
|
conditions = poses[indices].unsqueeze(0).to("cuda") |
|
pbar = CustomProgressBar( |
|
gr.Progress(track_tqdm=True).tqdm( |
|
iterable=None, |
|
desc="Sampling with DFoT", |
|
total=dfot.sampling_timesteps, |
|
) |
|
) |
|
gen_video = dfot._unnormalize_x( |
|
dfot._sample_sequence( |
|
batch_size=1, |
|
context=dfot._normalize_x(xs), |
|
context_mask=torch.tensor([i in image_indices for i in range(8)]) |
|
.unsqueeze(0) |
|
.to("cuda"), |
|
conditions=conditions, |
|
history_guidance=HistoryGuidance.vanilla( |
|
guidance_scale=guidance_scale, |
|
visualize=False, |
|
), |
|
pbar=pbar, |
|
)[0] |
|
) |
|
gen_video = ( |
|
(gen_video[0].detach().cpu().permute(0, 2, 3, 1) * 255) |
|
.clamp(0, 255) |
|
.to(torch.uint8) |
|
.numpy() |
|
) |
|
return video_to_gif_and_images([image for image in gen_video], list(range(8))) |
|
|
|
|
|
class CustomProgressBar: |
|
def __init__(self, pbar): |
|
self.pbar = pbar |
|
|
|
def set_postfix(self, **kwargs): |
|
pass |
|
|
|
def __getattr__(self, attr): |
|
return getattr(self.pbar, attr) |
|
|
|
def get_duration_navigate_video(video: torch.Tensor, |
|
poses: torch.Tensor, |
|
x_angle: float, |
|
y_angle: float, |
|
distance: float |
|
): |
|
if abs(x_angle) < 30 and abs(y_angle) < 30 and distance < 150: |
|
return 45 |
|
return 30 |
|
|
|
@spaces.GPU(duration=45) |
|
@torch.autocast("cuda") |
|
@torch.no_grad() |
|
def navigate_video( |
|
video: torch.Tensor, |
|
poses: torch.Tensor, |
|
x_angle: float, |
|
y_angle: float, |
|
distance: float, |
|
): |
|
n_context_frames = min(len(video), 4) |
|
n_prediction_frames = 8 - n_context_frames |
|
pbar = CustomProgressBar( |
|
gr.Progress(track_tqdm=True).tqdm( |
|
iterable=None, |
|
desc=f"Predicting next {n_prediction_frames} frames with DFoT", |
|
total=dfot.sampling_timesteps, |
|
) |
|
) |
|
xs = dfot._normalize_x(video.clone().unsqueeze(0).to("cuda")) |
|
conditions = poses.clone().unsqueeze(0).to("cuda") |
|
conditions = extend_poses( |
|
conditions, |
|
n=n_prediction_frames, |
|
x_angle=x_angle, |
|
y_angle=y_angle, |
|
distance=distance, |
|
) |
|
context_mask = ( |
|
torch.cat( |
|
[ |
|
torch.ones(1, n_context_frames) * (1 if n_context_frames == 1 else 2), |
|
torch.zeros(1, n_prediction_frames), |
|
], |
|
dim=-1, |
|
) |
|
.long() |
|
.to("cuda") |
|
) |
|
next_video = ( |
|
dfot._unnormalize_x( |
|
dfot._sample_sequence( |
|
batch_size=1, |
|
context=torch.cat( |
|
[ |
|
xs[:, -n_context_frames:], |
|
torch.zeros( |
|
1, |
|
n_prediction_frames, |
|
*xs.shape[2:], |
|
device=xs.device, |
|
dtype=xs.dtype, |
|
), |
|
], |
|
dim=1, |
|
), |
|
context_mask=context_mask, |
|
conditions=conditions[:, -8:], |
|
history_guidance=HistoryGuidance.smart( |
|
x_angle=x_angle, |
|
y_angle=y_angle, |
|
distance=distance, |
|
visualize=False, |
|
), |
|
pbar=pbar, |
|
)[0] |
|
)[0][n_context_frames:] |
|
.detach() |
|
.cpu() |
|
) |
|
gen_video = torch.cat([video, next_video], dim=0) |
|
poses = conditions[0].detach().cpu() |
|
|
|
images = (gen_video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy() |
|
|
|
return ( |
|
gen_video, |
|
poses, |
|
images[-1], |
|
export_to_video(gen_video, fps=NAVIGATION_FPS), |
|
[(image, f"t={i}") for i, image in enumerate(images)], |
|
) |
|
|
|
|
|
def undo_navigation( |
|
video: torch.Tensor, |
|
poses: torch.Tensor, |
|
): |
|
if len(video) > 8: |
|
video = video[:-4] |
|
poses = poses[:-4] |
|
elif len(video) == 8: |
|
video = video[:1] |
|
poses = poses[:1] |
|
else: |
|
gr.Warning("You have no moves left to undo!") |
|
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy() |
|
return ( |
|
video, |
|
poses, |
|
images[-1], |
|
export_to_video(video, fps=NAVIGATION_FPS), |
|
[(image, f"t={i}") for i, image in enumerate(images)], |
|
) |
|
|
|
|
|
def _interpolate_conditions(conditions, indices): |
|
""" |
|
Interpolate conditions to fill out missing frames |
|
|
|
Aegs: |
|
conditions (Tensor): conditions (B, T, C) |
|
indices (Tensor): indices of keyframes (T') |
|
""" |
|
assert indices[0].item() == 0 |
|
assert indices[-1].item() == conditions.shape[1] - 1 |
|
|
|
indices = indices.cpu().numpy() |
|
batch_size, n_tokens, _ = conditions.shape |
|
t = np.linspace(0, n_tokens - 1, n_tokens) |
|
|
|
key_conditions = conditions[:, indices] |
|
poses = CameraPose.from_vectors(key_conditions) |
|
extrinsics = poses.extrinsics().cpu().numpy() |
|
ps = extrinsics[..., :3, 3] |
|
rs = extrinsics[..., :3, :3].reshape(batch_size, -1, 3, 3) |
|
|
|
interp_extrinsics = np.zeros((batch_size, n_tokens, 3, 4)) |
|
for i in range(batch_size): |
|
slerp = Slerp(indices, Rotation.from_matrix(rs[i])) |
|
interp_extrinsics[i, :, :3, :3] = slerp(t).as_matrix() |
|
for j in range(3): |
|
interp_extrinsics[i, :, j, 3] = np.interp(t, indices, ps[i, :, j]) |
|
interp_extrinsics = torch.from_numpy(interp_extrinsics.astype(np.float32)) |
|
interp_extrinsics = interp_extrinsics.to(conditions.device).flatten(2) |
|
conditions = repeat(key_conditions[:, 0, :4], "b c -> b t c", t=n_tokens) |
|
conditions = torch.cat([conditions.clone(), interp_extrinsics], dim=-1) |
|
|
|
return conditions |
|
|
|
|
|
def _interpolate_between( |
|
xs: torch.Tensor, |
|
conditions: torch.Tensor, |
|
interpolation_factor: int, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
l = xs.shape[1] |
|
final_l = (l - 1) * interpolation_factor + 1 |
|
x_shape = xs.shape[2:] |
|
context = torch.zeros( |
|
( |
|
1, |
|
final_l, |
|
*x_shape, |
|
), |
|
device=xs.device, |
|
dtype=xs.dtype, |
|
) |
|
long_conditions = torch.zeros( |
|
(1, final_l, *conditions.shape[2:]), |
|
device=conditions.device, |
|
dtype=conditions.dtype, |
|
) |
|
context_mask = torch.zeros( |
|
(1, final_l), |
|
device=xs.device, |
|
dtype=torch.bool, |
|
) |
|
context_indices = torch.arange( |
|
0, final_l, interpolation_factor, device=conditions.device |
|
) |
|
context[:, context_indices] = xs |
|
long_conditions[:, context_indices] = conditions |
|
context_mask[:, ::interpolation_factor] = True |
|
long_conditions = _interpolate_conditions( |
|
long_conditions, |
|
context_indices, |
|
) |
|
|
|
xs = dfot._interpolate_videos( |
|
context, |
|
context_mask, |
|
conditions=long_conditions, |
|
) |
|
return xs, long_conditions |
|
|
|
|
|
def get_duration_smooth_navigation( |
|
video: torch.Tensor, poses: torch.Tensor, interpolation_factor: int, progress: gr.Progress |
|
): |
|
length = (len(video) - 1) * interpolation_factor + 1 |
|
return 2 * length |
|
|
|
|
|
@spaces.GPU(duration=get_duration_smooth_navigation) |
|
@torch.autocast("cuda") |
|
@torch.no_grad() |
|
def smooth_navigation( |
|
video: torch.Tensor, |
|
poses: torch.Tensor, |
|
interpolation_factor: int, |
|
progress=gr.Progress(track_tqdm=True), |
|
): |
|
if len(video) < 8: |
|
gr.Warning("Navigate first before applying temporal super-resolution!") |
|
else: |
|
video, poses = _interpolate_between( |
|
dfot._normalize_x(video.clone().unsqueeze(0).to("cuda")), |
|
poses.clone().unsqueeze(0).to("cuda"), |
|
interpolation_factor, |
|
) |
|
video = dfot._unnormalize_x(video)[0].detach().cpu() |
|
poses = poses[0].detach().cpu() |
|
images = (video.permute(0, 2, 3, 1) * 255).clamp(0, 255).to(torch.uint8).numpy() |
|
return ( |
|
video, |
|
poses, |
|
images[-1], |
|
export_to_video(video, fps=NAVIGATION_FPS * interpolation_factor), |
|
[(image, f"t={i}") for i, image in enumerate(images)], |
|
) |
|
|
|
def render_demo1(s: Literal["Selection", "Generation"], idx: int, demo1_stage: gr.State, demo1_selected_index: gr.State): |
|
gr.Markdown( |
|
f""" |
|
## Demo 1: Single Image → Long {LONG_LENGTH}-second Video |
|
> #### _Diffusion Forcing Transformer can generate long videos via sliding window rollouts and temporal super-resolution._ |
|
""", |
|
elem_classes=["task-title"] |
|
) |
|
match s: |
|
case "Selection": |
|
with gr.Group(): |
|
demo1_image_gallery = gr.Gallery( |
|
height=300, |
|
value=first_frame_list, |
|
label="Select an Image to Animate", |
|
columns=[8], |
|
selected_index=idx, |
|
allow_preview=False, |
|
preview=False, |
|
) |
|
|
|
@demo1_image_gallery.select( |
|
inputs=None, outputs=[demo1_stage, demo1_selected_index] |
|
) |
|
def move_to_generation(selection: gr.SelectData): |
|
return "Generation", selection.index |
|
|
|
case "Generation": |
|
with gr.Row(): |
|
gr.Image( |
|
value=first_frame_list[idx], |
|
label="Input Image", |
|
width=256, |
|
height=256, |
|
) |
|
gr.Video( |
|
value=prepare_long_gt_video(idx), |
|
label="Ground Truth Video", |
|
width=256, |
|
height=256, |
|
autoplay=True, |
|
loop=True, |
|
) |
|
demo1_video = gr.Video( |
|
label="Generated Video", |
|
width=256, |
|
height=256, |
|
autoplay=True, |
|
loop=True, |
|
show_share_button=True, |
|
show_download_button=True, |
|
) |
|
|
|
gr.Markdown("### Generation Controls ↓") |
|
demo1_guidance_scale = gr.Slider( |
|
minimum=1, |
|
maximum=6, |
|
value=4, |
|
step=0.5, |
|
label="History Guidance Scale", |
|
info="Without history guidance: 1.0; Recommended: 4.0", |
|
interactive=True, |
|
) |
|
demo1_fps = gr.Slider( |
|
minimum=4, |
|
maximum=20, |
|
value=4, |
|
step=1, |
|
label="FPS", |
|
info=f"A {LONG_LENGTH}-second video will be generated at this FPS; Decrease for faster generation; Increase for a smoother video", |
|
interactive=True, |
|
) |
|
gr.Button("Generate Video", variant="primary").click( |
|
fn=single_image_to_long_video, |
|
inputs=[ |
|
demo1_selected_index, |
|
demo1_guidance_scale, |
|
demo1_fps, |
|
], |
|
outputs=demo1_video, |
|
) |
|
|
|
def render_demo2(s: Literal["Scene", "Image", "Generation"], scene_idx: int, image_indices: List[int], demo2_stage: gr.State, demo2_selected_scene_index: gr.State, demo2_selected_image_indices: gr.State): |
|
gr.Markdown( |
|
""" |
|
## Demo 2: Any Number of Images → Short 2-second Video |
|
> #### _Diffusion Forcing Transformer is a flexible model that can generate videos given variable number of context frames._ |
|
""", |
|
elem_classes=["task-title"] |
|
) |
|
|
|
match s: |
|
case "Scene": |
|
with gr.Group(): |
|
demo2_scene_gallery = gr.Gallery( |
|
height=300, |
|
value=gif_paths, |
|
label="Select a Scene to Generate Video", |
|
columns=[8], |
|
selected_index=scene_idx, |
|
allow_preview=False, |
|
preview=False, |
|
) |
|
|
|
@demo2_scene_gallery.select( |
|
inputs=None, outputs=[demo2_stage, demo2_selected_scene_index] |
|
) |
|
def move_to_image_selection(selection: gr.SelectData): |
|
return "Image", selection.index |
|
|
|
case "Image": |
|
with gr.Group(): |
|
demo2_image_gallery = gr.Gallery( |
|
height=150, |
|
value=[ |
|
(image, f"t={i}") |
|
for i, image in enumerate( |
|
prepare_short_gt_video(scene_idx) |
|
) |
|
], |
|
label="Select Input Images", |
|
columns=[8], |
|
) |
|
|
|
demo2_selector = gr.CheckboxGroup( |
|
label="Select Any Number of Input Images", |
|
info="Image-to-Video: Select t=0; Interpolation: Select t=0 and t=7", |
|
choices=[(f"t={i}", i) for i in range(8)], |
|
value=[], |
|
) |
|
demo2_image_select_button = gr.Button( |
|
"Next Step", variant="primary" |
|
) |
|
|
|
@demo2_image_select_button.click( |
|
inputs=[demo2_selector], |
|
outputs=[demo2_stage, demo2_selected_image_indices], |
|
) |
|
def generate_video(selected_indices): |
|
if len(selected_indices) == 0: |
|
gr.Warning("Select at least one image!") |
|
return "Image", [] |
|
else: |
|
return "Generation", selected_indices |
|
|
|
case "Generation": |
|
with gr.Group(): |
|
gt_video = prepare_short_gt_video(scene_idx) |
|
|
|
demo2_input_image_gallery = gr.Gallery( |
|
height=150, |
|
value=video_to_gif_and_images(gt_video, image_indices), |
|
label="Input Images", |
|
columns=[9], |
|
) |
|
demo2_generated_gallery = gr.Gallery( |
|
height=150, |
|
value=[], |
|
label="Generated Video", |
|
columns=[9], |
|
) |
|
|
|
demo2_ground_truth_gallery = gr.Gallery( |
|
height=150, |
|
value=video_to_gif_and_images(gt_video, list(range(8))), |
|
label="Ground Truth Video", |
|
columns=[9], |
|
) |
|
gr.Markdown("### Generation Controls ↓") |
|
demo2_guidance_scale = gr.Slider( |
|
minimum=1, |
|
maximum=6, |
|
value=4, |
|
step=0.5, |
|
label="History Guidance Scale", |
|
info="Without history guidance: 1.0; Recommended: 4.0", |
|
interactive=True, |
|
) |
|
gr.Button("Generate Video", variant="primary").click( |
|
fn=any_images_to_short_video, |
|
inputs=[ |
|
demo2_selected_scene_index, |
|
demo2_selected_image_indices, |
|
demo2_guidance_scale, |
|
], |
|
outputs=demo2_generated_gallery, |
|
) |
|
|
|
def render_demo3( |
|
s: Literal["Selection", "Generation"], |
|
idx: int, |
|
demo3_stage: gr.State, |
|
demo3_selected_index: gr.State, |
|
demo3_current_video: gr.State, |
|
demo3_current_poses: gr.State |
|
): |
|
gr.Markdown( |
|
""" |
|
## Demo 3: Single Image → Extremely Long Video _(Navigate with Your Camera Movements!)_ |
|
> #### _History Guidance significantly improves quality and temporal consistency, enabling stable rollouts for extremely long videos._ |
|
""", |
|
elem_classes=["task-title"] |
|
) |
|
match s: |
|
case "Selection": |
|
with gr.Group(): |
|
demo3_image_gallery = gr.Gallery( |
|
height=300, |
|
value=first_frame_list, |
|
label="Select an Image to Start Navigation", |
|
columns=[8], |
|
selected_index=idx, |
|
allow_preview=False, |
|
preview=False, |
|
) |
|
|
|
@demo3_image_gallery.select( |
|
inputs=None, outputs=[demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses] |
|
) |
|
def move_to_generation(selection: gr.SelectData): |
|
idx = selection.index |
|
return ( |
|
"Generation", |
|
idx, |
|
video_list[idx][:1], |
|
poses_list[idx][:1], |
|
) |
|
|
|
case "Generation": |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
with gr.Row(): |
|
demo3_current_view = gr.Image( |
|
value=first_frame_list[idx], |
|
label="Current View", |
|
width=256, |
|
height=256, |
|
) |
|
demo3_video = gr.Video( |
|
label="Generated Video", |
|
width=256, |
|
height=256, |
|
autoplay=True, |
|
loop=True, |
|
show_share_button=True, |
|
show_download_button=True, |
|
) |
|
|
|
demo3_generated_gallery = gr.Gallery( |
|
value=[], |
|
label="Generated Frames", |
|
columns=[6], |
|
) |
|
|
|
with gr.Column(): |
|
gr.Markdown("### Navigation Controls ↓") |
|
with gr.Accordion("Instructions", open=False): |
|
gr.Markdown(""" |
|
- **The model will predict the next few frames based on your camera movements. Repeat the process to continue navigating through the scene.** |
|
- **At the end of your navigation, apply temporal super-resolution to increase the FPS,** also utilizing the DFoT model. |
|
- The most suitable history guidance scheme will be automatically selected based on your camera movements. |
|
""") |
|
with gr.Tab("Basic", elem_id="basic-controls-tab"): |
|
with gr.Group(): |
|
gr.Markdown("_**Select a direction to move:**_") |
|
with gr.Row(elem_id="basic-controls"): |
|
gr.Button( |
|
"↰-60°\nVeer", |
|
size="sm", |
|
min_width=0, |
|
variant="primary", |
|
).click( |
|
fn=partial( |
|
navigate_video, |
|
x_angle=0, |
|
y_angle=-60, |
|
distance=0, |
|
), |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
|
|
gr.Button( |
|
"↖-30°\nTurn", |
|
size="sm", |
|
min_width=0, |
|
variant="primary", |
|
).click( |
|
fn=partial( |
|
navigate_video, |
|
x_angle=0, |
|
y_angle=-30, |
|
distance=50, |
|
), |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
|
|
gr.Button( |
|
"↑0°\nAhead", |
|
size="sm", |
|
min_width=0, |
|
variant="primary", |
|
).click( |
|
fn=partial( |
|
navigate_video, |
|
x_angle=0, |
|
y_angle=0, |
|
distance=100, |
|
), |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
gr.Button( |
|
"↗30°\nTurn", |
|
size="sm", |
|
min_width=0, |
|
variant="primary", |
|
).click( |
|
fn=partial( |
|
navigate_video, |
|
x_angle=0, |
|
y_angle=30, |
|
distance=50, |
|
), |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
gr.Button( |
|
"↱\n60° Veer", |
|
size="sm", |
|
min_width=0, |
|
variant="primary", |
|
).click( |
|
fn=partial( |
|
navigate_video, |
|
x_angle=0, |
|
y_angle=60, |
|
distance=0, |
|
), |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
with gr.Tab("Advanced", elem_id="advanced-controls-tab"): |
|
with gr.Group(): |
|
gr.Markdown("_**Select angles and distance:**_") |
|
|
|
demo3_y_angle = gr.Slider( |
|
minimum=-90, |
|
maximum=90, |
|
value=0, |
|
step=10, |
|
label="Horizontal Angle", |
|
interactive=True, |
|
) |
|
demo3_x_angle = gr.Slider( |
|
minimum=-40, |
|
maximum=40, |
|
value=0, |
|
step=10, |
|
label="Vertical Angle", |
|
interactive=True, |
|
) |
|
demo3_distance = gr.Slider( |
|
minimum=0, |
|
maximum=200, |
|
value=100, |
|
step=10, |
|
label="Distance", |
|
interactive=True, |
|
) |
|
|
|
gr.Button( |
|
"Generate Next Move", variant="primary" |
|
).click( |
|
fn=navigate_video, |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_x_angle, |
|
demo3_y_angle, |
|
demo3_distance, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
gr.Markdown("---") |
|
with gr.Group(): |
|
gr.Markdown("_You can always undo your last move:_") |
|
gr.Button("Undo Last Move", variant="huggingface").click( |
|
fn=undo_navigation, |
|
inputs=[demo3_current_video, demo3_current_poses], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
with gr.Group(): |
|
gr.Markdown( |
|
"_At the end, apply temporal super-resolution to obtain a smoother video:_" |
|
) |
|
demo3_interpolation_factor = gr.Slider( |
|
minimum=2, |
|
maximum=10, |
|
value=2, |
|
step=1, |
|
label="By a Factor of", |
|
interactive=True, |
|
) |
|
gr.Button("Smooth Out Video", variant="huggingface").click( |
|
fn=smooth_navigation, |
|
inputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_interpolation_factor, |
|
], |
|
outputs=[ |
|
demo3_current_video, |
|
demo3_current_poses, |
|
demo3_current_view, |
|
demo3_video, |
|
demo3_generated_gallery, |
|
], |
|
) |
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(primary_hue="teal")) as demo: |
|
gr.HTML( |
|
""" |
|
<style> |
|
[data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] { |
|
font-size: 16px !important; |
|
font-weight: bold; |
|
} |
|
#page-title h1 { |
|
color: #0D9488 !important; |
|
} |
|
.task-title h2 { |
|
color: #F59E0C !important; |
|
} |
|
.header-button-row { |
|
gap: 4px !important; |
|
} |
|
.header-button-row div { |
|
width: 131.0px !important; |
|
} |
|
|
|
.header-button-column { |
|
width: 131.0px !important; |
|
gap: 5px !important; |
|
} |
|
.header-button a { |
|
border: 1px solid #e4e4e7; |
|
} |
|
.header-button .button-icon { |
|
margin-right: 8px; |
|
} |
|
.demo-button-column .gap { |
|
gap: 5px !important; |
|
} |
|
#basic-controls { |
|
column-gap: 0px; |
|
} |
|
#basic-controls-tab { |
|
padding: 0px; |
|
} |
|
#advanced-controls-tab { |
|
padding: 0px; |
|
} |
|
#selected-demo-button { |
|
color: #F59E0C; |
|
text-decoration: underline; |
|
} |
|
.demo-button { |
|
text-align: left !important; |
|
display: block !important; |
|
} |
|
</style> |
|
""" |
|
) |
|
|
|
demo_idx = gr.State(value=1) |
|
|
|
with gr.Sidebar(): |
|
gr.Markdown("# Diffusion Forcing Transformer with History Guidance", elem_id="page-title") |
|
gr.Markdown( |
|
"### Official Interactive Demo for [_History-Guided Video Diffusion_](https://arxiv.org/abs/2502.06764)" |
|
) |
|
gr.Markdown("---") |
|
gr.Markdown("#### Links ↓") |
|
with gr.Row(elem_classes=["header-button-row"]): |
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
gr.Button( |
|
value="Website", |
|
link="https://boyuan.space/history-guidance", |
|
icon="https://simpleicons.org/icons/googlechrome.svg", |
|
elem_classes=["header-button"], |
|
size="md", |
|
min_width=0, |
|
) |
|
gr.Button( |
|
value="Paper", |
|
link="https://arxiv.org/abs/2502.06764", |
|
icon="https://simpleicons.org/icons/arxiv.svg", |
|
elem_classes=["header-button"], |
|
size="md", |
|
min_width=0, |
|
) |
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
gr.Button( |
|
value="Code", |
|
link="https://github.com/kwsong0113/diffusion-forcing-transformer", |
|
icon="https://simpleicons.org/icons/github.svg", |
|
elem_classes=["header-button"], |
|
size="md", |
|
min_width=0, |
|
) |
|
gr.Button( |
|
value="Weights", |
|
link="https://huggingface.co/kiwhansong/DFoT", |
|
icon="https://simpleicons.org/icons/huggingface.svg", |
|
elem_classes=["header-button"], |
|
size="md", |
|
min_width=0, |
|
) |
|
gr.Markdown("---") |
|
gr.Markdown("#### Choose a Demo ↓") |
|
with gr.Column(elem_classes=["demo-button-column"]): |
|
@gr.render(inputs=[demo_idx]) |
|
def render_demo_tabs(idx): |
|
demo_tab_button1 = gr.Button( |
|
"1: Image → Long Video", |
|
size="md", elem_classes=["demo-button"], **{"elem_id": "selected-demo-button"} if idx == 1 else {} |
|
).click( |
|
fn=lambda: 1, |
|
outputs=demo_idx |
|
) |
|
demo_tab_button2 = gr.Button( |
|
"2: Any # of Images → Short Video", |
|
size="md", elem_classes=["demo-button"], **{"elem_id": "selected-demo-button"} if idx == 2 else {} |
|
).click( |
|
fn=lambda: 2, |
|
outputs=demo_idx |
|
) |
|
demo_tab_button3 = gr.Button( |
|
"3: Image → Extremely Long Video", |
|
size="md", elem_classes=["demo-button"], **{"elem_id": "selected-demo-button"} if idx == 3 else {} |
|
).click( |
|
fn=lambda: 3, |
|
outputs=demo_idx |
|
) |
|
gr.Markdown("---") |
|
gr.Markdown("#### Troubleshooting ↓") |
|
with gr.Group(): |
|
with gr.Accordion("Error or Unexpected Results?", open=False): |
|
gr.Markdown("Please try again after refreshing the page and ensure you do not click the same button multiple times.") |
|
with gr.Accordion("Too Slow or No GPU Allocation?", open=False): |
|
gr.Markdown( |
|
"Consider running the demo locally (click the dots in the top-right corner). Alternatively, you can subscribe to Hugging Face Pro for an increased GPU quota." |
|
) |
|
|
|
demo1_stage = gr.State(value="Selection") |
|
demo1_selected_index = gr.State(value=None) |
|
demo2_stage = gr.State(value="Scene") |
|
demo2_selected_scene_index = gr.State(value=None) |
|
demo2_selected_image_indices = gr.State(value=[]) |
|
demo3_stage = gr.State(value="Selection") |
|
demo3_selected_index = gr.State(value=None) |
|
demo3_current_video = gr.State(value=None) |
|
demo3_current_poses = gr.State(value=None) |
|
|
|
@gr.render(inputs=[demo_idx, demo1_stage, demo1_selected_index, demo2_stage, demo2_selected_scene_index, demo2_selected_image_indices, demo3_stage, demo3_selected_index]) |
|
def render_demo( |
|
_demo_idx, _demo1_stage, _demo1_selected_index, _demo2_stage, _demo2_selected_scene_index, _demo2_selected_image_indices, _demo3_stage, _demo3_selected_index |
|
): |
|
match _demo_idx: |
|
case 1: |
|
render_demo1(_demo1_stage, _demo1_selected_index, demo1_stage, demo1_selected_index) |
|
case 2: |
|
render_demo2(_demo2_stage, _demo2_selected_scene_index, _demo2_selected_image_indices, |
|
demo2_stage, demo2_selected_scene_index, demo2_selected_image_indices) |
|
case 3: |
|
render_demo3(_demo3_stage, _demo3_selected_index, demo3_stage, demo3_selected_index, demo3_current_video, demo3_current_poses) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|