|
import gradio as gr |
|
import spaces |
|
import torch |
|
import numpy as np |
|
import os |
|
import yaml |
|
import random |
|
from PIL import Image |
|
import imageio |
|
from pathlib import Path |
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
from ltx_video.pipelines.pipeline_ltx_video import ( |
|
ConditioningItem, |
|
LTXVideoPipeline, |
|
LTXMultiScalePipeline, |
|
) |
|
from ltx_video.models.autoencoders.vae_encode import vae_decode, vae_encode, un_normalize_latents, normalize_latents |
|
from inference import ( |
|
create_ltx_video_pipeline, |
|
create_latent_upsampler, |
|
load_image_to_tensor_with_resize_and_crop, |
|
load_media_file, |
|
get_device, |
|
seed_everething, |
|
calculate_padding, |
|
) |
|
from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy |
|
from ltx_video.models.autoencoders.latent_upsampler import LatentUpsampler |
|
|
|
|
|
|
|
from diffusers.utils import export_to_video |
|
|
|
|
|
|
|
DEVICE = get_device() |
|
MODEL_DIR = "downloaded_models" |
|
Path(MODEL_DIR).mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
YAML_CONFIG_PATH = "configs/ltxv-13b-0.9.7-distilled.yaml" |
|
with open(YAML_CONFIG_PATH, "r") as f: |
|
PIPELINE_CONFIG_YAML = yaml.safe_load(f) |
|
|
|
|
|
LTXV_MODEL_FILENAME = PIPELINE_CONFIG_YAML["checkpoint_path"] |
|
SPATIAL_UPSCALER_FILENAME = PIPELINE_CONFIG_YAML["spatial_upscaler_model_path"] |
|
TEXT_ENCODER_PATH = PIPELINE_CONFIG_YAML["text_encoder_model_name_or_path"] |
|
|
|
try: |
|
|
|
if not os.path.isfile(os.path.join(MODEL_DIR, LTXV_MODEL_FILENAME)): |
|
print(f"Downloading {LTXV_MODEL_FILENAME}...") |
|
ltxv_checkpoint_path = hf_hub_download( |
|
repo_id="LTX-Colab/LTX-Video-Preview", |
|
filename=LTXV_MODEL_FILENAME, |
|
local_dir=MODEL_DIR, |
|
repo_type="model", |
|
) |
|
else: |
|
ltxv_checkpoint_path = os.path.join(MODEL_DIR, LTXV_MODEL_FILENAME) |
|
|
|
|
|
if not os.path.isfile(os.path.join(MODEL_DIR, SPATIAL_UPSCALER_FILENAME)): |
|
print(f"Downloading {SPATIAL_UPSCALER_FILENAME}...") |
|
spatial_upsampler_path = hf_hub_download( |
|
repo_id="Lightricks/LTX-Video", |
|
filename=SPATIAL_UPSCALER_FILENAME, |
|
local_dir=MODEL_DIR, |
|
repo_type="model", |
|
) |
|
else: |
|
spatial_upsampler_path = os.path.join(MODEL_DIR, SPATIAL_UPSCALER_FILENAME) |
|
except Exception as e: |
|
print(f"Error downloading models: {e}") |
|
print("Please ensure model files are correctly specified and accessible.") |
|
|
|
|
|
ltxv_checkpoint_path = LTXV_MODEL_FILENAME |
|
spatial_upsampler_path = SPATIAL_UPSCALER_FILENAME |
|
|
|
|
|
print(f"Using LTX-Video checkpoint: {ltxv_checkpoint_path}") |
|
print(f"Using Spatial Upsampler: {spatial_upsampler_path}") |
|
print(f"Using Text Encoder: {TEXT_ENCODER_PATH}") |
|
|
|
|
|
pipe = create_ltx_video_pipeline( |
|
ckpt_path=ltxv_checkpoint_path, |
|
precision=PIPELINE_CONFIG_YAML["precision"], |
|
text_encoder_model_name_or_path=TEXT_ENCODER_PATH, |
|
sampler=PIPELINE_CONFIG_YAML["sampler"], |
|
device=DEVICE, |
|
enhance_prompt=False, |
|
).to(torch.bfloat16) |
|
|
|
|
|
latent_upsampler = create_latent_upsampler( |
|
latent_upsampler_model_path=spatial_upsampler_path, |
|
device=DEVICE |
|
) |
|
latent_upsampler = latent_upsampler.to(torch.bfloat16) |
|
|
|
|
|
|
|
multi_scale_pipe = LTXMultiScalePipeline( |
|
video_pipeline=pipe, |
|
latent_upsampler=latent_upsampler |
|
) |
|
|
|
|
|
|
|
MAX_SEED = np.iinfo(np.int32).max |
|
MAX_IMAGE_SIZE = 2048 |
|
|
|
|
|
def round_to_nearest_resolution_acceptable_by_vae(height, width, vae_scale_factor): |
|
|
|
height = height - (height % vae_scale_factor) |
|
width = width - (width % vae_scale_factor) |
|
|
|
return height, width |
|
|
|
@spaces.GPU |
|
def generate(prompt, |
|
negative_prompt, |
|
image_path, |
|
video_path, |
|
height, |
|
width, |
|
mode, |
|
steps, |
|
num_frames, |
|
frames_to_use, |
|
seed, |
|
randomize_seed, |
|
guidance_scale, |
|
improve_texture=False, progress=gr.Progress(track_tqdm=True)): |
|
|
|
if randomize_seed: |
|
seed = random.randint(0, MAX_SEED) |
|
seed_everething(seed) |
|
|
|
generator = torch.Generator(device=DEVICE).manual_seed(seed) |
|
|
|
|
|
conditioning_items_list = [] |
|
input_media_for_vid2vid = None |
|
|
|
|
|
|
|
|
|
vae_spatial_scale_factor = pipe.vae.spatial_downscale_factor |
|
vae_temporal_scale_factor = pipe.vae.temporal_downscale_factor |
|
|
|
|
|
height_padded_target = ((height - 1) // vae_spatial_scale_factor + 1) * vae_spatial_scale_factor |
|
width_padded_target = ((width - 1) // vae_spatial_scale_factor + 1) * vae_spatial_scale_factor |
|
|
|
|
|
|
|
|
|
|
|
|
|
num_frames_padded_target = ((num_frames - 2) // vae_temporal_scale_factor + 1) * vae_temporal_scale_factor + 1 |
|
|
|
|
|
padding_target = calculate_padding(height, width, height_padded_target, width_padded_target) |
|
|
|
|
|
if mode == "video-to-video" and video_path: |
|
|
|
|
|
|
|
input_media_for_vid2vid = load_media_file( |
|
media_path=video_path, |
|
height=height, |
|
width=width, |
|
max_frames=min(num_frames_padded_target, frames_to_use if frames_to_use > 0 else num_frames_padded_target), |
|
padding=padding_target, |
|
) |
|
|
|
conditioning_media = load_media_file( |
|
media_path=video_path, |
|
height=height, width=width, |
|
max_frames=min(frames_to_use if frames_to_use > 0 else 1, num_frames_padded_target), |
|
padding=padding_target, |
|
just_crop=True |
|
) |
|
conditioning_items_list.append(ConditioningItem(media_item=conditioning_media, media_frame_number=0, conditioning_strength=1.0)) |
|
|
|
elif mode == "image-to-video" and image_path: |
|
conditioning_media = load_image_to_tensor_with_resize_and_crop( |
|
image_input=image_path, |
|
target_height=height, |
|
target_width=width |
|
) |
|
|
|
conditioning_media = torch.nn.functional.pad(conditioning_media, padding_target) |
|
conditioning_items_list.append(ConditioningItem(media_item=conditioning_media, media_frame_number=0, conditioning_strength=1.0)) |
|
|
|
|
|
|
|
|
|
first_pass_config = PIPELINE_CONFIG_YAML.get("first_pass", {}) |
|
second_pass_config = PIPELINE_CONFIG_YAML.get("second_pass", {}) |
|
downscale_factor = PIPELINE_CONFIG_YAML.get("downscale_factor", 2/3) |
|
|
|
|
|
if steps: |
|
|
|
|
|
|
|
first_pass_config["num_inference_steps"] = steps |
|
|
|
|
|
|
|
|
|
|
|
|
|
initial_gen_height = int(height_padded_target * downscale_factor) |
|
initial_gen_width = int(width_padded_target * downscale_factor) |
|
|
|
initial_gen_height, initial_gen_width = round_to_nearest_resolution_acceptable_by_vae( |
|
initial_gen_height, initial_gen_width, vae_spatial_scale_factor |
|
) |
|
|
|
shared_pipeline_args = { |
|
"prompt": prompt, |
|
"negative_prompt": negative_prompt, |
|
"num_frames": num_frames_padded_target, |
|
"frame_rate": 30, |
|
"guidance_scale": guidance_scale, |
|
"generator": generator, |
|
"conditioning_items": conditioning_items_list if conditioning_items_list else None, |
|
"skip_layer_strategy": SkipLayerStrategy.AttentionValues, |
|
"offload_to_cpu": False, |
|
"is_video": True, |
|
"vae_per_channel_normalize": True, |
|
"mixed_precision": (PIPELINE_CONFIG_YAML["precision"] == "bfloat16"), |
|
"enhance_prompt": False, |
|
"image_cond_noise_scale": 0.025, |
|
"media_items": input_media_for_vid2vid if mode == "video-to-video" else None, |
|
|
|
} |
|
|
|
|
|
if improve_texture: |
|
print("Using LTXMultiScalePipeline for generation...") |
|
|
|
|
|
if "timesteps" not in first_pass_config: |
|
first_pass_config["num_inference_steps"] = steps |
|
|
|
first_pass_config.setdefault("decode_timestep", PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05)) |
|
first_pass_config.setdefault("decode_noise_scale", PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025)) |
|
second_pass_config.setdefault("decode_timestep", PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05)) |
|
second_pass_config.setdefault("decode_noise_scale", PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025)) |
|
|
|
|
|
result_frames_tensor = multi_scale_pipe( |
|
**shared_pipeline_args, |
|
width=initial_gen_width, |
|
height=initial_gen_height, |
|
downscale_factor=downscale_factor, |
|
first_pass=first_pass_config, |
|
second_pass=second_pass_config, |
|
output_type="pt" |
|
).images |
|
|
|
|
|
|
|
|
|
else: |
|
print("Using LTXVideoPipeline (first pass) + Manual Upsample + Decode...") |
|
|
|
if "timesteps" not in first_pass_config: |
|
first_pass_config["num_inference_steps"] = steps |
|
|
|
first_pass_args = { |
|
**shared_pipeline_args, |
|
**first_pass_config, |
|
"width": initial_gen_width, |
|
"height": initial_gen_height, |
|
"output_type": "latent" |
|
} |
|
latents = pipe(**first_pass_args).images |
|
print("First pass done!") |
|
|
|
|
|
latents_unnorm = un_normalize_latents(latents, pipe.vae, vae_per_channel_normalize=True) |
|
upsampled_latents_unnorm = latent_upsampler(latents_unnorm) |
|
upsampled_latents = normalize_latents(upsampled_latents_unnorm, pipe.vae, vae_per_channel_normalize=True) |
|
|
|
|
|
|
|
upscaled_height_for_decode = initial_gen_height * 2 |
|
upscaled_width_for_decode = initial_gen_width * 2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_video_frames_final = (upsampled_latents.shape[2] -1) * pipe.vae.temporal_downscale_factor + 1 |
|
|
|
|
|
decode_kwargs = { |
|
"target_shape": ( |
|
upsampled_latents.shape[0], |
|
3, |
|
num_video_frames_final, |
|
upscaled_height_for_decode, |
|
upscaled_width_for_decode |
|
) |
|
} |
|
if pipe.vae.decoder.timestep_conditioning: |
|
decode_kwargs["timestep"] = torch.tensor([PIPELINE_CONFIG_YAML.get("decode_timestep", 0.05)] * upsampled_latents.shape[0]).to(DEVICE) |
|
|
|
noise = torch.randn_like(upsampled_latents) |
|
decode_noise_val = PIPELINE_CONFIG_YAML.get("decode_noise_scale", 0.025) |
|
upsampled_latents = upsampled_latents * (1 - decode_noise_val) + noise * decode_noise_val |
|
|
|
print("before vae decoding") |
|
result_frames_tensor = pipe.vae.decode(upsampled_latents, **decode_kwargs).sample |
|
print("after vae decoding?") |
|
|
|
|
|
|
|
|
|
result_frames_tensor = result_frames_tensor[:, :, :num_frames, :, :] |
|
|
|
|
|
_, _, _, current_h, current_w = result_frames_tensor.shape |
|
|
|
|
|
|
|
crop_y_start = (current_h - height_padded_target) // 2 |
|
crop_x_start = (current_w - width_padded_target) // 2 |
|
|
|
result_frames_tensor = result_frames_tensor[ |
|
:, :, :, |
|
crop_y_start : crop_y_start + height_padded_target, |
|
crop_x_start : crop_x_start + width_padded_target |
|
] |
|
|
|
|
|
pad_left, pad_right, pad_top, pad_bottom = padding_target |
|
unpad_bottom = -pad_bottom if pad_bottom > 0 else result_frames_tensor.shape[3] |
|
unpad_right = -pad_right if pad_right > 0 else result_frames_tensor.shape[4] |
|
|
|
result_frames_tensor = result_frames_tensor[ |
|
:, :, :, |
|
pad_top : unpad_bottom, |
|
pad_left : unpad_right |
|
] |
|
|
|
|
|
|
|
video_pil_list = [] |
|
|
|
|
|
video_single_batch = result_frames_tensor[0] |
|
video_single_batch = (video_single_batch / 2 + 0.5).clamp(0, 1) |
|
video_single_batch = video_single_batch.permute(1, 2, 3, 0).cpu().numpy() |
|
|
|
for frame_idx in range(video_single_batch.shape[0]): |
|
frame_np = (video_single_batch[frame_idx] * 255).astype(np.uint8) |
|
video_pil_list.append(Image.fromarray(frame_np)) |
|
|
|
|
|
output_video_path = "output.mp4" |
|
export_to_video(video_pil_list, output_video_path, fps=24) |
|
return output_video_path |
|
|
|
|
|
css=""" |
|
#col-container { |
|
margin: 0 auto; |
|
max-width: 900px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css, theme=gr.themes.Ocean()) as demo: |
|
gr.Markdown("# LTX Video 0.9.7 Distilled (using LTX-Video lib)") |
|
with gr.Row(): |
|
with gr.Column(): |
|
with gr.Group(): |
|
with gr.Tab("text-to-video") as text_tab: |
|
image_n = gr.Image(label="", visible=False, value=None) |
|
video_n = gr.Video(label="", visible=False, value=None) |
|
t2v_prompt = gr.Textbox(label="prompt", value="A majestic dragon flying over a medieval castle") |
|
t2v_button = gr.Button("Generate Text-to-Video") |
|
with gr.Tab("image-to-video") as image_tab: |
|
video_i = gr.Video(label="", visible=False, value=None) |
|
image_i2v = gr.Image(label="input image", type="filepath") |
|
i2v_prompt = gr.Textbox(label="prompt", value="The creature from the image starts to move") |
|
i2v_button = gr.Button("Generate Image-to-Video") |
|
with gr.Tab("video-to-video") as video_tab: |
|
image_v = gr.Image(label="", visible=False, value=None) |
|
video_v2v = gr.Video(label="input video") |
|
frames_to_use = gr.Number(label="num frames to use",info="first # of frames to use from the input video for conditioning/transformation", value=9) |
|
v2v_prompt = gr.Textbox(label="prompt", value="Change the style to cinematic anime") |
|
v2v_button = gr.Button("Generate Video-to-Video") |
|
|
|
improve_texture = gr.Checkbox(label="improve texture (multi-scale)", value=True, info="Uses a two-pass generation for better quality, but is slower.") |
|
|
|
with gr.Column(): |
|
output = gr.Video(interactive=False) |
|
|
|
with gr.Accordion("Advanced settings", open=False): |
|
negative_prompt_input = gr.Textbox(label="negative prompt", value="worst quality, inconsistent motion, blurry, jittery, distorted") |
|
with gr.Row(): |
|
seed_input = gr.Number(label="seed", value=42, precision=0) |
|
randomize_seed_input = gr.Checkbox(label="randomize seed", value=False) |
|
with gr.Row(): |
|
guidance_scale_input = gr.Slider(label="guidance scale", minimum=0, maximum=10, value=1.0, step=0.1, info="For distilled models, CFG is often 1.0 (disabled) or very low.") |
|
steps_input = gr.Slider(label="Steps (for first pass if multi-scale)", minimum=1, maximum=30, value=PIPELINE_CONFIG_YAML.get("first_pass", {}).get("timesteps", [1]*8).__len__(), step=1, info="Number of inference steps. If YAML defines timesteps, this is ignored for that pass.") |
|
num_frames_input = gr.Slider(label="# frames", minimum=9, maximum=121, value=25, step=8, info="Should be N*8+1, e.g., 9, 17, 25...") |
|
with gr.Row(): |
|
height_input = gr.Slider(label="height", value=512, step=8, minimum=256, maximum=MAX_IMAGE_SIZE) |
|
width_input = gr.Slider(label="width", value=704, step=8, minimum=256, maximum=MAX_IMAGE_SIZE) |
|
|
|
t2v_button.click(fn=generate, |
|
inputs=[t2v_prompt, |
|
negative_prompt_input, |
|
image_n, |
|
video_n, |
|
height_input, |
|
width_input, |
|
gr.State("text-to-video"), |
|
steps_input, |
|
num_frames_input, |
|
gr.State(0), |
|
seed_input, |
|
randomize_seed_input, guidance_scale_input, improve_texture], |
|
outputs=[output]) |
|
|
|
i2v_button.click(fn=generate, |
|
inputs=[i2v_prompt, |
|
negative_prompt_input, |
|
image_i2v, |
|
video_i, |
|
height_input, |
|
width_input, |
|
gr.State("image-to-video"), |
|
steps_input, |
|
num_frames_input, |
|
gr.State(0), |
|
seed_input, |
|
randomize_seed_input, guidance_scale_input, improve_texture], |
|
outputs=[output]) |
|
|
|
v2v_button.click(fn=generate, |
|
inputs=[v2v_prompt, |
|
negative_prompt_input, |
|
image_v, |
|
video_v2v, |
|
height_input, |
|
width_input, |
|
gr.State("video-to-video"), |
|
steps_input, |
|
num_frames_input, |
|
frames_to_use, |
|
seed_input, |
|
randomize_seed_input, guidance_scale_input, improve_texture], |
|
outputs=[output]) |
|
|
|
demo.launch() |