Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import warnings | |
warnings.filterwarnings("ignore") | |
import gradio as gr | |
import torch | |
import torch.nn as nn | |
from diffusers.models import AutoencoderKL | |
from diffusers.schedulers import PNDMScheduler | |
from unet import AudioUNet3DConditionModel | |
from audio_encoder import ImageBindSegmaskAudioEncoder | |
from pipeline import AudioCondAnimationPipeline, generate_videos | |
device = torch.device("cuda") | |
dtype = torch.float16 | |
def freeze_and_make_eval(model: nn.Module): | |
for param in model.parameters(): | |
param.requires_grad = False | |
model.eval() | |
def create_pipeline(device=torch.device("cuda"), dtype=torch.float32): | |
# 2. Prepare model | |
pretrained_stable_diffusion_path = "./pretrained/stable-diffusion-v1-5" | |
checkpoint_path = f"checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules" | |
category_text_encoding_mapping = torch.load('datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt', map_location="cpu") | |
scheduler = PNDMScheduler.from_pretrained(pretrained_stable_diffusion_path, subfolder="scheduler") | |
vae = AutoencoderKL.from_pretrained(pretrained_stable_diffusion_path, subfolder="vae").to(device=device, dtype=dtype) | |
audio_encoder = ImageBindSegmaskAudioEncoder(n_segment=12).to(device=device, dtype=dtype) | |
freeze_and_make_eval(audio_encoder) | |
unet = AudioUNet3DConditionModel.from_pretrained(checkpoint_path, subfolder="unet").to(device=device, dtype=dtype) | |
pipeline = AudioCondAnimationPipeline( | |
unet=unet, | |
scheduler=scheduler, | |
vae=vae, | |
audio_encoder=audio_encoder, | |
null_text_encodings_path="./pretrained/openai-clip-l_null_text_encoding.pt" | |
) | |
pipeline.to(torch_device=device, dtype=dtype) | |
pipeline.set_progress_bar_config(disable=True) | |
return pipeline, category_text_encoding_mapping | |
pipeline, category_text_encoding_mapping = create_pipeline(device, dtype) | |
def generate_video(image, audio, text, audio_guidance_scale, denoising_step): | |
category_text_encoding = category_text_encoding_mapping[text].view(1, 77, 768) | |
generate_videos( | |
pipeline, | |
audio_path=audio, | |
image_path=image, | |
category_text_encoding=category_text_encoding, | |
image_size=(256, 256), | |
video_fps=6, | |
video_num_frame=12, | |
text_guidance_scale=1.0, | |
audio_guidance_scale=audio_guidance_scale, | |
denoising_step=denoising_step, | |
seed=123, | |
save_path="./output_video.mp4", | |
device=device | |
) | |
return "./output_video.mp4" | |
if __name__ == "__main__": | |
categories = [ | |
"baby babbling crying", "dog barking", "hammering", "striking bowling", "cap gun shooting", | |
"chicken crowing", "frog croaking", "lions roaring", "machine gun shooting", "playing cello", | |
"playing trombone", "playing trumpet", "playing violin fiddle", "sharpen knife", "toilet flushing" | |
] | |
title = "" | |
description = """ | |
<div align="center"> | |
<h1 style="font-size: 60px;">Audio-Synchronized Visual Animation</h1> | |
<p style="font-size: 30px;"> | |
<a href="https://lzhangbj.github.io/projects/asva/asva.html">Project Webpage</a> | |
</p> | |
<p style="font-size: 30px;"> | |
<a href="https://lzhangbj.github.io/">Lin Zhang</a>, | |
<a href="https://scholar.google.com/citations?user=6aYncPAAAAAJ">Shentong Mo</a>, | |
<a href="https://yijingz02.github.io/">Yijing Zhang</a>, | |
<a href="https://pedro-morgado.github.io/">Pedro Morgado</a> | |
</p> | |
<p style="font-size: 30px;"> | |
University of Wisconsin Madison, | |
Carnegie Mellon University | |
<p> | |
<strong style="font-size: 30px;">ECCV 2024</strong> | |
<strong style="font-size: 25px;">Animate your images with audio-synchronized motion! </strong> | |
<p style="font-size: 18px;">Notes:</p> | |
<p style="font-size: 18px;">(1) Only the first 2 seconds of audio is used. </p> | |
<p style="font-size: 18px;">(2) Increase audio guidance scale for amplified visual dynamics. </p> | |
<p style="font-size: 18px;">(3) Increase sampling steps for higher visual quality. </p> | |
</div> | |
""" | |
# <p style="font-size: 20px;">Please be patient. Due to limited resources on huggingface, the generation may take up to 10mins </p> | |
# Gradio Interface | |
iface = gr.Interface( | |
fn=generate_video, | |
inputs=[ | |
gr.Image( label="Upload Image", type="filepath", height=256), | |
gr.Audio(label="Upload Audio", type="filepath"), | |
gr.Dropdown(choices=categories, label="Select Audio Category"), | |
gr.Slider(minimum=1.0, maximum=12.0, step=0.1, value=4.0, label="Audio Guidance Scale"), | |
gr.Slider(minimum=1, maximum=50, step=1, value=20, label="Sampling steps") | |
], | |
outputs=gr.Video(label="Generated Video", height=256), | |
title=title, | |
description=description, | |
examples = [ | |
["./assets/lion_and_gun.png", "./assets/lions_roaring.wav", "lions roaring", 4.0, 20], | |
["./assets/lion_and_gun.png", "./assets/machine_gun_shooting.wav", "machine gun shooting", 4.0, 20], | |
] | |
) | |
# Launch the interface | |
iface.launch() |