kadirnar's picture
update
2a37fe9
raw
history blame
6.06 kB
import gradio as gr
import torch
from video_diffusion.tuneavideo.models.unet import UNet3DConditionModel
from video_diffusion.tuneavideo.pipelines.pipeline_tuneavideo import TuneAVideoPipeline
from video_diffusion.tuneavideo.util import save_videos_grid
from video_diffusion.utils.model_list import stable_model_list
video_diffusion_model_list = [
"Tune-A-Video-library/a-man-is-surfing",
"Tune-A-Video-library/mo-di-bear-guitar",
"Tune-A-Video-library/redshift-man-skiing",
]
class TunaVideoText2VideoGenerator:
def __init__(self):
self.pipe = None
self.unet = None
def load_model(self, video_diffusion_model_list, stable_model_list):
if self.pipe is None:
if self.unet is None:
self.unet = UNet3DConditionModel.from_pretrained(
video_diffusion_model_list, subfolder="unet", torch_dtype=torch.float16
).to("cuda")
self.pipe = TuneAVideoPipeline.from_pretrained(
stable_model_list, unet=self.unet, torch_dtype=torch.float16
)
self.pipe.to("cuda")
self.pipe.enable_xformers_memory_efficient_attention()
return self.pipe
def generate_video(
self,
video_diffusion_model: str,
stable_model_list: str,
prompt: str,
negative_prompt: str,
video_length: int,
height: int,
width: int,
num_inference_steps: int,
guidance_scale: int,
fps: int,
):
pipe = self.load_model(video_diffusion_model, stable_model_list)
video = pipe(
prompt,
negative_prompt=negative_prompt,
video_length=video_length,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
).videos
save_videos_grid(videos=video, path="output.gif", fps=fps)
return "output.gif"
def app():
with gr.Blocks():
with gr.Row():
with gr.Column():
tunevideo_video_diffusion_model_list = gr.Dropdown(
choices=video_diffusion_model_list,
label="Video Diffusion Model",
value=video_diffusion_model_list[0],
)
tunevideo_stable_model_list = gr.Dropdown(
choices=stable_model_list,
label="Stable Model List",
value=stable_model_list[0],
)
with gr.Row():
with gr.Column():
tunevideo_prompt = gr.Textbox(
lines=1,
placeholder="Prompt",
show_label=False,
)
tunevideo_video_length = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=10,
label="Video Length",
)
tunevideo_num_inference_steps = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Num Inference Steps",
)
tunevideo_fps = gr.Slider(
minimum=1,
maximum=60,
step=1,
value=5,
label="Fps",
)
with gr.Row():
with gr.Column():
tunevideo_negative_prompt = gr.Textbox(
lines=1,
placeholder="Negative Prompt",
show_label=False,
)
tunevideo_guidance_scale = gr.Slider(
minimum=1,
maximum=15,
step=1,
value=7.5,
label="Guidance Scale",
)
tunevideo_height = gr.Slider(
minimum=1,
maximum=1280,
step=32,
value=512,
label="Height",
)
tunevideo_width = gr.Slider(
minimum=1,
maximum=1280,
step=32,
value=512,
label="Width",
)
tunevideo_generate = gr.Button(value="Generator")
with gr.Column():
tunevideo_output = gr.Video(label="Output")
tunevideo_generate.click(
fn=TunaVideoText2VideoGenerator().generate_video,
inputs=[
tunevideo_video_diffusion_model_list,
tunevideo_stable_model_list,
tunevideo_prompt,
tunevideo_negative_prompt,
tunevideo_video_length,
tunevideo_height,
tunevideo_width,
tunevideo_num_inference_steps,
tunevideo_guidance_scale,
tunevideo_fps,
],
outputs=tunevideo_output,
)