Spaces:
Runtime error
Runtime error
import spaces | |
import gradio as gr | |
import torch | |
import torchvision as tv | |
import random, os | |
from diffusers import StableVideoDiffusionPipeline | |
from PIL import Image | |
from glob import glob | |
from typing import Optional | |
from tdd_svd_scheduler import TDDSVDStochasticIterativeScheduler | |
from utils import load_lora_weights, save_video | |
# LOCAL = True | |
LOCAL = False | |
if LOCAL: | |
svd_path = '/share2/duanyuxuan/diff_playground/diffusers_models/stable-video-diffusion-img2vid-xt-1-1' | |
lora_file_path = '/share2/duanyuxuan/diff_playground/SVD-TDD/svd-xt-1-1_tdd_lora_weights.safetensors' | |
else: | |
svd_path = 'stabilityai/stable-video-diffusion-img2vid-xt-1-1' | |
lora_repo_path = 'RED-AIGC/TDD' | |
lora_weight_name = 'svd-xt-1-1_tdd_lora_weights.safetensors' | |
if torch.cuda.is_available(): | |
noise_scheduler = TDDSVDStochasticIterativeScheduler(num_train_timesteps = 250, sigma_min = 0.002, sigma_max = 700.0, sigma_data = 1.0, | |
s_noise = 1.0, rho = 7, clip_denoised = False) | |
pipeline = StableVideoDiffusionPipeline.from_pretrained(svd_path, scheduler = noise_scheduler, torch_dtype = torch.float16, variant = "fp16").to('cuda') | |
if LOCAL: | |
load_lora_weights(pipeline.unet, lora_file_path) | |
else: | |
load_lora_weights(pipeline.unet, lora_repo_path, weight_name = lora_weight_name) | |
max_64_bit_int = 2**63 - 1 | |
def sample( | |
image: Image, | |
seed: Optional[int] = 1, | |
randomize_seed: bool = False, | |
num_inference_steps: int = 4, | |
eta: float = 0.3, | |
min_guidance_scale: float = 1.0, | |
max_guidance_scale: float = 1.0, | |
fps: int = 7, | |
width: int = 512, | |
height: int = 512, | |
num_frames: int = 25, | |
motion_bucket_id: int = 127, | |
output_folder: str = "outputs_gradio", | |
): | |
pipeline.scheduler.set_eta(eta) | |
if randomize_seed: | |
seed = random.randint(0, max_64_bit_int) | |
generator = torch.manual_seed(seed) | |
os.makedirs(output_folder, exist_ok=True) | |
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) | |
video_path = os.path.join(output_folder, f"{base_count:06d}.mp4") | |
with torch.autocast("cuda"): | |
frames = pipeline( | |
image, height = height, width = width, | |
num_inference_steps = num_inference_steps, | |
min_guidance_scale = min_guidance_scale, | |
max_guidance_scale = max_guidance_scale, | |
num_frames = num_frames, fps = fps, motion_bucket_id = motion_bucket_id, | |
decode_chunk_size = 8, | |
noise_aug_strength = 0.02, | |
generator = generator, | |
).frames[0] | |
save_video(frames, video_path, fps = fps, quality = 5.0) | |
torch.manual_seed(seed) | |
return video_path, seed | |
def preprocess_image(image, height = 512, width = 512): | |
image = image.convert('RGB') | |
if image.size[0] != image.size[1]: | |
image = tv.transforms.functional.pil_to_tensor(image) | |
image = tv.transforms.functional.center_crop(image, min(image.shape[-2:])) | |
image = tv.transforms.functional.to_pil_image(image) | |
image = image.resize((width, height)) | |
return image | |
css = """ | |
h1 { | |
text-align: center; | |
display:block; | |
} | |
.gradio-container { | |
max-width: 70.5rem !important; | |
} | |
""" | |
with gr.Blocks(css = css) as demo: | |
gr.Markdown( | |
""" | |
# Stable Video Diffusion distilled by ✨Target-Driven Distillation✨ | |
Target-Driven Distillation (TDD) is a state-of-the-art consistency distillation model that largely accelerates the inference processes of diffusion models. Using its delicate strategies of *target timestep selection* and *decoupled guidance*, models distilled by TDD can generated highly detailed images with only a few steps. | |
Besides, TDD is also available for distilling video generation models. This space presents TDD-distilled [SVD-xt 1.1](https://huggingface.co/stabilityai/stable-video-diffusion-img2vid-xt-1-1). | |
[**Project Page**](https://redaigc.github.io/TDD/) **|** [**Paper**](https://arxiv.org/abs/2409.01347) **|** [**Code**](https://github.com/RedAIGC/Target-Driven-Distillation) **|** [**Model**](https://huggingface.co/RED-AIGC/TDD) **|** [🤗 **TDD-SDXL Demo**](https://huggingface.co/spaces/RED-AIGC/TDD) **|** [🤗 **TDD-SVD Demo**](https://huggingface.co/spaces/RED-AIGC/SVD-TDD) | |
The codes of this space are built on [AnimateLCM-SVD](https://huggingface.co/spaces/wangfuyun/AnimateLCM-SVD) and we acknowledge their contribution. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
image = gr.Image(label="Upload your image", type="pil") | |
generate_btn = gr.Button("Generate") | |
video = gr.Video() | |
with gr.Accordion("Options", open = True): | |
seed = gr.Slider( | |
label="Seed", | |
value=1, | |
randomize=False, | |
minimum=0, | |
maximum=max_64_bit_int, | |
step=1, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=False) | |
min_guidance_scale = gr.Slider( | |
label="Min guidance scale", | |
info="min strength of classifier-free guidance", | |
value=1.0, | |
minimum=1.0, | |
maximum=1.5, | |
) | |
max_guidance_scale = gr.Slider( | |
label="Max guidance scale", | |
info="max strength of classifier-free guidance, it should not be less than Min guidance scale", | |
value=1.0, | |
minimum=1.0, | |
maximum=3.0, | |
) | |
num_inference_steps = gr.Slider( | |
label="Num inference steps", | |
info="steps for inference", | |
value=4, | |
minimum=4, | |
maximum=8, | |
step=1, | |
) | |
eta = gr.Slider( | |
label = "Eta", | |
info = "the value of gamma in gamma-sampling", | |
value = 0.3, | |
minimum = 0.0, | |
maximum = 1.0, | |
step = 0.1, | |
) | |
image.upload(fn = preprocess_image, inputs = image, outputs = image, queue = False) | |
generate_btn.click( | |
fn = sample, | |
inputs = [ | |
image, | |
seed, | |
randomize_seed, | |
num_inference_steps, | |
eta, | |
min_guidance_scale, | |
max_guidance_scale, | |
], | |
outputs = [video, seed], | |
api_name = "video", | |
) | |
# safetensors_dropdown.change(fn=model_select, inputs=safetensors_dropdown) | |
# gr.Examples( | |
# examples=[ | |
# ["examples/ipadapter_cat.jpg"], | |
# ], | |
# inputs=[image], | |
# outputs=[video, seed], | |
# fn=sample, | |
# cache_examples=True, | |
# ) | |
if __name__ == "__main__": | |
if LOCAL: | |
demo.queue().launch(share=True, server_name='0.0.0.0') | |
else: | |
demo.queue(api_open=False).launch(show_api=False) |