Instant-Video / app.py
SahaniJi's picture
Update app.py
a6977ca verified
raw
history blame
6.75 kB
import gradio as gr
import torch
import os
import spaces
import uuid
from diffusers import AnimateDiffPipeline, MotionAdapter, EulerDiscreteScheduler
from diffusers.utils import export_to_video
from huggingface_hub import hf_hub_download
from safetensors.torch import load_file
from PIL import Image
# Constants
bases = {
"Cartoon": "frankjoshua/toonyou_beta6",
"Realistic": "emilianJR/epiCRealism",
"3d": "Lykon/DreamShaper",
"Anime": "Yntec/mistoonAnime2"
}
step_loaded = None
base_loaded = "Realistic"
motion_loaded = None
# Ensure model and scheduler are initialized in GPU-enabled function
if not torch.cuda.is_available():
raise NotImplementedError("No GPU detected!")
device = "cuda"
dtype = torch.float16
pipe = AnimateDiffPipeline.from_pretrained(bases[base_loaded], torch_dtype=dtype).to(device)
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing", beta_schedule="linear")
# Safety checkers
from transformers import CLIPFeatureExtractor
feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32")
# Function
@spaces.GPU(duration=60,queue=False)
def generate_image(prompt, base="Realistic", motion="", step=8, progress=gr.Progress()):
global step_loaded
global base_loaded
global motion_loaded
print(prompt, base, step)
if step_loaded != step:
repo = "ByteDance/AnimateDiff-Lightning"
ckpt = f"animatediff_lightning_{step}step_diffusers.safetensors"
pipe.unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device), strict=False)
step_loaded = step
if base_loaded != base:
pipe.unet.load_state_dict(torch.load(hf_hub_download(bases[base], "unet/diffusion_pytorch_model.bin"), map_location=device), strict=False)
base_loaded = base
if motion_loaded != motion:
pipe.unload_lora_weights()
if motion != "":
pipe.load_lora_weights(motion, adapter_name="motion")
pipe.set_adapters(["motion"], [0.7])
motion_loaded = motion
progress((0, step))
def progress_callback(i, t, z):
progress((i+1, step))
output = pipe(prompt=prompt, guidance_scale=1.2, num_inference_steps=step, callback=progress_callback, callback_steps=1)
name = str(uuid.uuid4()).replace("-", "")
path = f"/tmp/{name}.mp4"
export_to_video(output.frames[0], path, fps=10)
return path
# Gradio Interface
css = """
body {font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; background-color: #f4f4f9; color: #333;}
h1 {color: #333; text-align: center; margin-bottom: 20px;}
.gradio-container {max-width: 800px; margin: auto; padding: 20px; background: #fff; box-shadow: 0px 0px 20px rgba(0,0,0,0.1); border-radius: 10px;}
.gr-input {margin-bottom: 15px;}
.gr-button {width: 100%; background-color: #4CAF50; color: white; border: none; padding: 10px 20px; text-align: center; text-decoration: none; display: inline-block; font-size: 16px; border-radius: 5px; cursor: pointer; transition: background-color 0.3s;}
.gr-button:hover {background-color: #45a049;}
.gr-video {margin-top: 20px;}
.gr-examples {margin-top: 30px;}
.gr-examples .gr-example {display: inline-block; width: 100%; text-align: center; padding: 10px; background: #eaeaea; border-radius: 5px; margin-bottom: 10px;}
.container {display: flex; flex-wrap: wrap;}
.inputs, .output {padding: 20px;}
.inputs {flex: 1; min-width: 300px;}
.output {flex: 1; min-width: 300px;}
@media (max-width: 768px) {
.container {flex-direction: column-reverse;}
}
.svelte-1ybb3u7, .svelte-1clup3e {display: none !important;}
"""
with gr.Blocks(css=css) as demo:
gr.HTML("<h1>Instant⚡ Text to Video</h1>")
with gr.Row(elem_id="container"):
with gr.Column(elem_id="inputs"):
prompt = gr.Textbox(label='Prompt', placeholder="Enter text to generate video...", elem_id="gr-input")
select_base = gr.Dropdown(
label='Base model',
choices=["Cartoon", "Realistic", "3d", "Anime"],
value=base_loaded,
interactive=True,
elem_id="gr-input"
)
select_motion = gr.Dropdown(
label='Motion',
choices=[
("Default", ""),
("Zoom in", "guoyww/animatediff-motion-lora-zoom-in"),
("Zoom out", "guoyww/animatediff-motion-lora-zoom-out"),
("Tilt up", "guoyww/animatediff-motion-lora-tilt-up"),
("Tilt down", "guoyww/animatediff-motion-lora-tilt-down"),
("Pan left", "guoyww/animatediff-motion-lora-pan-left"),
("Pan right", "guoyww/animatediff-motion-lora-pan-right"),
("Roll left", "guoyww/animatediff-motion-lora-rolling-anticlockwise"),
("Roll right", "guoyww/animatediff-motion-lora-rolling-clockwise"),
],
value="guoyww/animatediff-motion-lora-zoom-in",
interactive=True,
elem_id="gr-input"
)
select_step = gr.Dropdown(
label='Inference steps',
choices=[('1-Step', 1), ('2-Step', 2), ('4-Step', 4), ('8-Step', 8)],
value=4,
interactive=True,
elem_id="gr-input"
)
submit = gr.Button("Generate Video", variant='primary', elem_id="gr-button")
with gr.Column(elem_id="output"):
video = gr.Video(label='AnimateDiff-Lightning', autoplay=True, height=512, width=512, elem_id="gr-video")
prompt.submit(fn=generate_image, inputs=[prompt, select_base, select_motion, select_step], outputs=video)
submit.click(fn=generate_image, inputs=[prompt, select_base, select_motion, select_step], outputs=video, api_name="instant_video")
gr.Examples(
examples=[
["Focus: Eiffel Tower (Animate: Clouds moving)"],
["Focus: Trees In forest (Animate: Lion running)"],
["Focus: Astronaut in Space"],
["Focus: Group of Birds in sky (Animate: Birds Moving) (Shot From distance)"],
["Focus: Statue of liberty (Shot from Drone) (Animate: Drone coming toward statue)"],
["Focus: Panda in Forest (Animate: Drinking Tea)"],
["Focus: Kids Playing (Season: Winter)"],
["Focus: Cars in Street (Season: Rain, Daytime) (Shot from Distance) (Movement: Cars running)"]
],
fn=generate_image,
inputs=[prompt],
outputs=video,
cache_examples=True,
elem_id="gr-examples"
)
demo.queue().launch()