ASVA / app.py
Lin Z
update
ee6ef96
import spaces
import warnings
warnings.filterwarnings("ignore")
import gradio as gr
import torch
import torch.nn as nn
from diffusers.models import AutoencoderKL
from diffusers.schedulers import PNDMScheduler
from unet import AudioUNet3DConditionModel
from audio_encoder import ImageBindSegmaskAudioEncoder
from pipeline import AudioCondAnimationPipeline, generate_videos
device = torch.device("cuda")
dtype = torch.float16
def freeze_and_make_eval(model: nn.Module):
for param in model.parameters():
param.requires_grad = False
model.eval()
def create_pipeline(device=torch.device("cuda"), dtype=torch.float32):
# 2. Prepare model
pretrained_stable_diffusion_path = "./pretrained/stable-diffusion-v1-5"
checkpoint_path = f"checkpoints/audio-cond_animation/avsync15_audio-cond_cfg/ckpts/checkpoint-37000/modules"
category_text_encoding_mapping = torch.load('datasets/AVSync15/class_clip_text_encodings_stable-diffusion-v1-5.pt', map_location="cpu")
scheduler = PNDMScheduler.from_pretrained(pretrained_stable_diffusion_path, subfolder="scheduler")
vae = AutoencoderKL.from_pretrained(pretrained_stable_diffusion_path, subfolder="vae").to(device=device, dtype=dtype)
audio_encoder = ImageBindSegmaskAudioEncoder(n_segment=12).to(device=device, dtype=dtype)
freeze_and_make_eval(audio_encoder)
unet = AudioUNet3DConditionModel.from_pretrained(checkpoint_path, subfolder="unet").to(device=device, dtype=dtype)
pipeline = AudioCondAnimationPipeline(
unet=unet,
scheduler=scheduler,
vae=vae,
audio_encoder=audio_encoder,
null_text_encodings_path="./pretrained/openai-clip-l_null_text_encoding.pt"
)
pipeline.to(torch_device=device, dtype=dtype)
pipeline.set_progress_bar_config(disable=True)
return pipeline, category_text_encoding_mapping
pipeline, category_text_encoding_mapping = create_pipeline(device, dtype)
@spaces.GPU(duration=120)
def generate_video(image, audio, text, audio_guidance_scale, denoising_step):
category_text_encoding = category_text_encoding_mapping[text].view(1, 77, 768)
generate_videos(
pipeline,
audio_path=audio,
image_path=image,
category_text_encoding=category_text_encoding,
image_size=(256, 256),
video_fps=6,
video_num_frame=12,
text_guidance_scale=1.0,
audio_guidance_scale=audio_guidance_scale,
denoising_step=denoising_step,
seed=123,
save_path="./output_video.mp4",
device=device
)
return "./output_video.mp4"
if __name__ == "__main__":
categories = [
"baby babbling crying", "dog barking", "hammering", "striking bowling", "cap gun shooting",
"chicken crowing", "frog croaking", "lions roaring", "machine gun shooting", "playing cello",
"playing trombone", "playing trumpet", "playing violin fiddle", "sharpen knife", "toilet flushing"
]
title = ""
description = """
<div align="center">
<h1 style="font-size: 60px;">Audio-Synchronized Visual Animation</h1>
<p style="font-size: 30px;">
<a href="https://lzhangbj.github.io/projects/asva/asva.html">Project Webpage</a>
</p>
<p style="font-size: 30px;">
<a href="https://lzhangbj.github.io/">Lin Zhang</a>,
<a href="https://scholar.google.com/citations?user=6aYncPAAAAAJ">Shentong Mo</a>,
<a href="https://yijingz02.github.io/">Yijing Zhang</a>,
<a href="https://pedro-morgado.github.io/">Pedro Morgado</a>
</p>
<p style="font-size: 30px;">
University of Wisconsin Madison,
Carnegie Mellon University
<p>
<strong style="font-size: 30px;">ECCV 2024</strong>
<strong style="font-size: 25px;">Animate your images with audio-synchronized motion! </strong>
<p style="font-size: 18px;">Notes:</p>
<p style="font-size: 18px;">(1) Only the first 2 seconds of audio is used. </p>
<p style="font-size: 18px;">(2) Increase audio guidance scale for amplified visual dynamics. </p>
<p style="font-size: 18px;">(3) Increase sampling steps for higher visual quality. </p>
</div>
"""
# <p style="font-size: 20px;">Please be patient. Due to limited resources on huggingface, the generation may take up to 10mins </p>
# Gradio Interface
iface = gr.Interface(
fn=generate_video,
inputs=[
gr.Image( label="Upload Image", type="filepath", height=256),
gr.Audio(label="Upload Audio", type="filepath"),
gr.Dropdown(choices=categories, label="Select Audio Category"),
gr.Slider(minimum=1.0, maximum=12.0, step=0.1, value=4.0, label="Audio Guidance Scale"),
gr.Slider(minimum=1, maximum=50, step=1, value=20, label="Sampling steps")
],
outputs=gr.Video(label="Generated Video", height=256),
title=title,
description=description,
examples = [
["./assets/lion_and_gun.png", "./assets/lions_roaring.wav", "lions roaring", 4.0, 20],
["./assets/lion_and_gun.png", "./assets/machine_gun_shooting.wav", "machine gun shooting", 4.0, 20],
]
)
# Launch the interface
iface.launch()