|
import os |
|
import torch |
|
|
|
|
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":16:8" |
|
torch.backends.cudnn.benchmark = False |
|
torch.use_deterministic_algorithms(True) |
|
import ptp_utils |
|
import random |
|
import abc |
|
import gradio as gr |
|
from glob import glob |
|
from einops import rearrange |
|
from omegaconf import OmegaConf |
|
from safetensors import safe_open |
|
from diffusers import AutoencoderKL |
|
from diffusers import DDIMScheduler |
|
from diffusers.utils.import_utils import is_xformers_available |
|
import transformers |
|
transformers.utils.move_cache() |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from animatediff.models.unet import UNet3DConditionModel |
|
from animatediff.pipelines.pipeline_animation import AnimationPipeline |
|
from animatediff.pipelines.pipeline_animation import AnimationCtrlPipeline |
|
from animatediff.utils.util import save_videos_grid |
|
from animatediff.utils.convert_from_ckpt import ( |
|
convert_ldm_unet_checkpoint, |
|
convert_ldm_clip_checkpoint, |
|
) |
|
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( |
|
convert_ldm_vae_checkpoint, |
|
) |
|
from diffusers.training_utils import set_seed |
|
|
|
|
|
pretrained_model_path = "./models/StableDiffusion" |
|
inference_config_path = "configs/inference/inference-v1.yaml" |
|
|
|
css = """ |
|
.toolbutton { |
|
margin-buttom: 0em 0em 0em 0em; |
|
max-width: 2.5em; |
|
min-width: 2.5em !important; |
|
height: 2.5em; |
|
} |
|
""" |
|
|
|
|
|
class AttentionControl(abc.ABC): |
|
def step_callback(self, x_t): |
|
return x_t |
|
|
|
def between_steps(self): |
|
return |
|
|
|
@abc.abstractmethod |
|
def forward(self, attn, is_cross: bool, place_in_unet: str): |
|
raise NotImplementedError |
|
|
|
def __call__(self, hidden_states, video_length, place_in_unet: str): |
|
hidden_states = rearrange(hidden_states, "(b f) d c -> b f d c", f=video_length) |
|
batch_size = hidden_states.shape[0] // 2 |
|
|
|
if batch_size == 2: |
|
|
|
hidden_states_uncondition, hidden_states_condition = hidden_states.chunk(2) |
|
|
|
if self.cur_step <= self.motion_control_step: |
|
hidden_states_motion_uncondition = hidden_states_uncondition[ |
|
1 |
|
].unsqueeze(0) |
|
else: |
|
hidden_states_motion_uncondition = hidden_states_uncondition[ |
|
0 |
|
].unsqueeze(0) |
|
|
|
hidden_states_out_uncondition = torch.cat( |
|
[ |
|
hidden_states_motion_uncondition, |
|
hidden_states_uncondition[1].unsqueeze(0), |
|
], |
|
dim=0, |
|
) |
|
hidden_states_sac_in_uncondition = self.forward( |
|
hidden_states_uncondition[0].unsqueeze(0), video_length, place_in_unet |
|
) |
|
hidden_states_sac_out_uncondition = torch.cat( |
|
[ |
|
hidden_states_sac_in_uncondition, |
|
hidden_states_uncondition[1].unsqueeze(0), |
|
], |
|
dim=0, |
|
) |
|
|
|
if self.cur_step <= self.motion_control_step: |
|
hidden_states_motion_condition = hidden_states_condition[1].unsqueeze(0) |
|
else: |
|
hidden_states_motion_condition = hidden_states_condition[0].unsqueeze(0) |
|
|
|
hidden_states_out_condition = torch.cat( |
|
[ |
|
hidden_states_motion_condition, |
|
hidden_states_condition[1].unsqueeze(0), |
|
], |
|
dim=0, |
|
) |
|
hidden_states_sac_in_condition = self.forward( |
|
hidden_states_condition[0].unsqueeze(0), video_length, place_in_unet |
|
) |
|
hidden_states_sac_out_condition = torch.cat( |
|
[ |
|
hidden_states_sac_in_condition, |
|
hidden_states_condition[1].unsqueeze(0), |
|
], |
|
dim=0, |
|
) |
|
|
|
hidden_states_out = torch.cat( |
|
[hidden_states_out_uncondition, hidden_states_out_condition], dim=0 |
|
) |
|
hidden_states_sac_out = torch.cat( |
|
[hidden_states_sac_out_uncondition, hidden_states_sac_out_condition], |
|
dim=0, |
|
) |
|
|
|
elif batch_size == 1: |
|
if self.cur_step <= self.motion_control_step: |
|
hidden_states_motion = hidden_states[1].unsqueeze(0) |
|
else: |
|
hidden_states_motion = hidden_states[0].unsqueeze(0) |
|
|
|
hidden_states_out = torch.cat( |
|
[hidden_states_motion, hidden_states[1].unsqueeze(0)], dim=0 |
|
) |
|
hidden_states_sac_in = self.forward( |
|
hidden_states[0].unsqueeze(0), video_length, place_in_unet |
|
) |
|
hidden_states_sac_out = torch.cat( |
|
[hidden_states_sac_in, hidden_states[1].unsqueeze(0)], dim=0 |
|
) |
|
|
|
else: |
|
raise gr.Error(f"Not implemented error") |
|
hidden_states = rearrange(hidden_states, "b f d c -> (b f) d c", f=video_length) |
|
hidden_states_out = rearrange( |
|
hidden_states_out, "b f d c -> (b f) d c", f=video_length |
|
) |
|
hidden_states_sac_out = rearrange( |
|
hidden_states_sac_out, "b f d c -> (b f) d c", f=video_length |
|
) |
|
self.cur_att_layer += 1 |
|
if self.cur_att_layer == self.num_att_layers: |
|
self.cur_att_layer = 0 |
|
self.cur_step += 1 |
|
return hidden_states_out, hidden_states_sac_out, hidden_states_sac_out |
|
|
|
def reset(self): |
|
self.cur_step = 0 |
|
self.cur_att_layer = 0 |
|
self.num_att_layers = -1 |
|
self.motion_control_step = 0 |
|
|
|
def __init__(self): |
|
self.cur_step = 0 |
|
self.cur_att_layer = 0 |
|
self.num_att_layers = -1 |
|
self.motion_control_step = 0 |
|
|
|
|
|
class EmptyControl(AttentionControl): |
|
def forward(self, hidden_states, video_length, place_in_unet): |
|
return hidden_states |
|
|
|
|
|
class FreeSAC(AttentionControl): |
|
def forward(self, hidden_states, video_length, place_in_unet): |
|
hidden_states_sac = ( |
|
hidden_states[:, 0, :, :].unsqueeze(1).repeat(1, video_length, 1, 1) |
|
) |
|
return hidden_states_sac |
|
|
|
|
|
examples = [ |
|
|
|
[ |
|
"realisticVisionV60B1_v20Novae.safetensors", |
|
"mm_sd_v14.ckpt", |
|
"A panda standing on a surfboard in the ocean under moonlight.", |
|
"worst quality, low quality, nsfw, logo", |
|
0.2, |
|
512, |
|
512, |
|
"12345", |
|
["use_fp16"], |
|
], |
|
[ |
|
"toonyou_beta3.safetensors", |
|
"mm_sd_v14.ckpt", |
|
"(best quality, masterpiece), 1girl, looking at viewer, blurry background, upper body, contemporary, dress", |
|
"(worst quality, low quality)", |
|
0.2, |
|
512, |
|
512, |
|
"12345", |
|
["use_fp16"], |
|
], |
|
[ |
|
"lyriel_v16.safetensors", |
|
"mm_sd_v14.ckpt", |
|
"hypercars cyberpunk moving, muted colors, swirling color smokes, legend, cityscape, space", |
|
"3d, cartoon, anime, sketches, worst quality, low quality, nsfw, logo", |
|
0.2, |
|
512, |
|
512, |
|
"12345", |
|
["use_fp16"], |
|
], |
|
[ |
|
"rcnzCartoon3d_v10.safetensors", |
|
"mm_sd_v14.ckpt", |
|
"A cute raccoon playing guitar in a boat on the ocean", |
|
"worst quality, low quality, nsfw, logo", |
|
0.2, |
|
512, |
|
512, |
|
"42", |
|
["use_fp16"], |
|
], |
|
[ |
|
"majicmixRealistic_v5Preview.safetensors", |
|
"mm_sd_v14.ckpt", |
|
"1girl, reading book", |
|
"(ng_deepnegative_v1_75t:1.2), (badhandv4:1), (worst quality:2), (low quality:2), (normal quality:2), lowres, bad anatomy, bad hands, watermark, moles", |
|
0.2, |
|
512, |
|
512, |
|
"12345", |
|
["use_fp16"], |
|
], |
|
] |
|
|
|
|
|
print(f"### Cleaning cached examples ...") |
|
os.system(f"rm -rf gradio_cached_examples/") |
|
|
|
|
|
class AnimateController: |
|
def __init__(self): |
|
|
|
self.basedir = os.getcwd() |
|
self.stable_diffusion_dir = os.path.join( |
|
self.basedir, "models", "StableDiffusion" |
|
) |
|
self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module") |
|
self.personalized_model_dir = os.path.join( |
|
self.basedir, "models", "DreamBooth_LoRA" |
|
) |
|
self.savedir = os.path.join(self.basedir, "samples") |
|
os.makedirs(self.savedir, exist_ok=True) |
|
|
|
self.base_model_list = [None] |
|
self.motion_module_list = [] |
|
self.selected_base_model = None |
|
self.selected_motion_module = None |
|
self.set_width = None |
|
self.set_height = None |
|
|
|
self.refresh_motion_module() |
|
self.refresh_personalized_model() |
|
|
|
|
|
self.inference_config = OmegaConf.load(inference_config_path) |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
pretrained_model_path, subfolder="tokenizer" |
|
) |
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
pretrained_model_path, subfolder="text_encoder" |
|
).cuda() |
|
self.vae = AutoencoderKL.from_pretrained( |
|
pretrained_model_path, subfolder="vae" |
|
).cuda() |
|
self.unet = UNet3DConditionModel.from_pretrained_2d( |
|
pretrained_model_path, |
|
subfolder="unet", |
|
unet_additional_kwargs=OmegaConf.to_container( |
|
self.inference_config.unet_additional_kwargs |
|
), |
|
).cuda() |
|
|
|
self.freq_filter = None |
|
|
|
self.update_base_model(self.base_model_list[-2]) |
|
self.update_motion_module(self.motion_module_list[0]) |
|
|
|
def refresh_motion_module(self): |
|
motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt")) |
|
self.motion_module_list = sorted( |
|
[os.path.basename(p) for p in motion_module_list] |
|
) |
|
|
|
def refresh_personalized_model(self): |
|
base_model_list = glob( |
|
os.path.join(self.personalized_model_dir, "*.safetensors") |
|
) |
|
self.base_model_list += sorted([os.path.basename(p) for p in base_model_list]) |
|
|
|
def update_base_model(self, base_model_dropdown): |
|
self.selected_base_model = base_model_dropdown |
|
if base_model_dropdown == "None" or base_model_dropdown is None: |
|
return gr.Dropdown.update() |
|
|
|
base_model_dropdown = os.path.join( |
|
self.personalized_model_dir, base_model_dropdown |
|
) |
|
base_model_state_dict = {} |
|
with safe_open(base_model_dropdown, framework="pt", device="cpu") as f: |
|
for key in f.keys(): |
|
base_model_state_dict[key] = f.get_tensor(key) |
|
|
|
converted_vae_checkpoint = convert_ldm_vae_checkpoint( |
|
base_model_state_dict, self.vae.config |
|
) |
|
self.vae.load_state_dict(converted_vae_checkpoint) |
|
|
|
converted_unet_checkpoint = convert_ldm_unet_checkpoint( |
|
base_model_state_dict, self.unet.config |
|
) |
|
self.unet.load_state_dict(converted_unet_checkpoint, strict=False) |
|
|
|
self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict) |
|
return gr.Dropdown.update() |
|
|
|
def update_motion_module(self, motion_module_dropdown): |
|
self.selected_motion_module = motion_module_dropdown |
|
|
|
motion_module_dropdown = os.path.join( |
|
self.motion_module_dir, motion_module_dropdown |
|
) |
|
motion_module_state_dict = torch.load( |
|
motion_module_dropdown, map_location="cpu" |
|
) |
|
_, unexpected = self.unet.load_state_dict( |
|
motion_module_state_dict, strict=False |
|
) |
|
assert len(unexpected) == 0 |
|
return gr.Dropdown.update() |
|
|
|
def run_pipeline(self, pipeline, args): |
|
|
|
torch.cuda.init() |
|
|
|
return pipeline(**args) |
|
|
|
def animate_ctrl( |
|
self, |
|
base_model_dropdown, |
|
motion_module_dropdown, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
motion_control, |
|
width_slider, |
|
height_slider, |
|
seed_textbox, |
|
|
|
speed_up_options, |
|
): |
|
set_seed(42) |
|
inference_step = 25 |
|
|
|
if self.selected_base_model != base_model_dropdown: |
|
self.update_base_model(base_model_dropdown) |
|
if self.selected_motion_module != motion_module_dropdown: |
|
self.update_motion_module(motion_module_dropdown) |
|
|
|
if is_xformers_available(): |
|
self.unet.enable_xformers_memory_efficient_attention() |
|
|
|
if int(seed_textbox) > 0: |
|
seed = int(seed_textbox) |
|
else: |
|
seed = random.randint(1, 1e16) |
|
torch.manual_seed(int(seed)) |
|
|
|
assert seed == torch.initial_seed() |
|
print(f"### seed: {seed}") |
|
|
|
generator = torch.Generator(device="cuda:0") |
|
generator.manual_seed(seed) |
|
|
|
pipeline = AnimationCtrlPipeline( |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet, |
|
scheduler=DDIMScheduler( |
|
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs) |
|
), |
|
).to("cuda") |
|
|
|
motion_control_step = motion_control * inference_step |
|
|
|
attn_controller = FreeSAC() |
|
attn_controller.motion_control_step = motion_control_step |
|
ptp_utils.register_attention_control(pipeline, attn_controller) |
|
|
|
sample_output_ctrl = pipeline( |
|
prompt_textbox, |
|
negative_prompt=negative_prompt_textbox, |
|
num_inference_steps=inference_step, |
|
guidance_scale=7.5, |
|
width=width_slider, |
|
height=height_slider, |
|
video_length=16, |
|
use_fp16=True if "use_fp16" in speed_up_options else False, |
|
generator=generator, |
|
) |
|
|
|
ctrl_sample = sample_output_ctrl.videos |
|
|
|
save_ctrl_sample_path = os.path.join(self.savedir, "ctrl_sample.mp4") |
|
save_videos_grid(ctrl_sample, save_ctrl_sample_path) |
|
|
|
json_config = { |
|
"prompt": prompt_textbox, |
|
"n_prompt": negative_prompt_textbox, |
|
"width": width_slider, |
|
"height": height_slider, |
|
"seed": seed, |
|
"base_model": base_model_dropdown, |
|
"motion_module": motion_module_dropdown, |
|
"use_fp16": True if "use_fp16" in speed_up_options else False, |
|
} |
|
|
|
del attn_controller |
|
del pipeline |
|
torch.cuda.empty_cache() |
|
return ( |
|
gr.Video.update(value=save_ctrl_sample_path), |
|
gr.Json.update(value=json_config), |
|
) |
|
|
|
def animate( |
|
self, |
|
base_model_dropdown, |
|
motion_module_dropdown, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
motion_control, |
|
width_slider, |
|
height_slider, |
|
seed_textbox, |
|
|
|
filter_type_dropdown, |
|
speed_up_options, |
|
): |
|
|
|
set_seed(42) |
|
|
|
inference_step = 25 |
|
|
|
if self.selected_base_model != base_model_dropdown: |
|
self.update_base_model(base_model_dropdown) |
|
if self.selected_motion_module != motion_module_dropdown: |
|
self.update_motion_module(motion_module_dropdown) |
|
|
|
if is_xformers_available(): |
|
self.unet.enable_xformers_memory_efficient_attention() |
|
|
|
if seed_textbox and int(seed_textbox) >= 0: |
|
seed = int(seed_textbox) |
|
else: |
|
seed = random.randint(0, 2**32 - 1) |
|
torch.manual_seed(int(seed)) |
|
|
|
assert seed == torch.initial_seed() |
|
print(f"seed: {seed}") |
|
|
|
generator = torch.Generator(device="cuda:0") |
|
generator.manual_seed(seed) |
|
|
|
pipeline = AnimationPipeline( |
|
vae=self.vae, |
|
text_encoder=self.text_encoder, |
|
tokenizer=self.tokenizer, |
|
unet=self.unet, |
|
scheduler=DDIMScheduler( |
|
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs) |
|
), |
|
).to("cuda") |
|
|
|
attn_controller = EmptyControl() |
|
attn_controller.motion_control_step = -1 |
|
ptp_utils.register_attention_control(pipeline, attn_controller) |
|
|
|
sample_output_orig = pipeline( |
|
prompt_textbox, |
|
negative_prompt=negative_prompt_textbox, |
|
num_inference_steps=inference_step, |
|
guidance_scale=7.5, |
|
width=width_slider, |
|
height=height_slider, |
|
video_length=16, |
|
use_fp16=( |
|
True if speed_up_options and "use_fp16" in speed_up_options else False |
|
), |
|
generator=generator, |
|
) |
|
|
|
orig_sample = sample_output_orig.videos |
|
|
|
save_orig_sample_path = os.path.join(self.savedir, "orig_sample.mp4") |
|
save_videos_grid(orig_sample, save_orig_sample_path) |
|
|
|
json_config = { |
|
"prompt": prompt_textbox, |
|
"n_prompt": negative_prompt_textbox, |
|
"width": width_slider, |
|
"height": height_slider, |
|
"seed": seed, |
|
"base_model": base_model_dropdown, |
|
"motion_module": motion_module_dropdown, |
|
"filter_type": filter_type_dropdown, |
|
"use_fp16": ( |
|
True if speed_up_options and "use_fp16" in speed_up_options else False |
|
), |
|
} |
|
del pipeline |
|
torch.cuda.empty_cache() |
|
|
|
return ( |
|
gr.Video.update(value=save_orig_sample_path), |
|
gr.Json.update(value=json_config), |
|
) |
|
|
|
|
|
controller = AnimateController() |
|
|
|
|
|
def ui(): |
|
with gr.Blocks(css=css) as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
<div align="center"> |
|
<h1>UniCtrl: Improving the Spatiotemporal Consistency of Text-to-Video Diffusion Models via Training-Free Unified Attention Control</h1> |
|
</div> |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
<p align="center"> |
|
<a title="Project Page" href="https://unified-attention-control.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/Project-Website-5B7493?logo=googlechrome&logoColor=5B7493"> |
|
</a> |
|
<a title="arXiv" href="https://arxiv.org/abs/2312.07537" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=b31b1b"> |
|
</a> |
|
<a title="GitHub" href="https://github.com/XuweiyiChen/UniCtrl" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> |
|
<img src="https://img.shields.io/github/stars/XuweiyiChen/UniCtrl?label=GitHub%E2%98%85&&logo=github" alt="badge-github-stars"> |
|
</a> |
|
</p> |
|
""" |
|
) |
|
gr.Markdown( |
|
""" |
|
Official Gradio Demo for ***UniCtrl: Improving the Spatiotemporal Consistency of Text-to-Video Diffusion Models via Training-Free Unified Attention Control***. |
|
UniCtrl improves spatiotemporal consistency of diffusion-based video generation at inference time. In this demo, we apply FreeInit on [AnimateDiff v1](https://github.com/guoyww/AnimateDiff) as an example. Sampling time: ~ 80s.<br> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_textbox = gr.Textbox( |
|
label="Prompt", lines=3, placeholder="Enter your prompt here" |
|
) |
|
negative_prompt_textbox = gr.Textbox( |
|
label="Negative Prompt", |
|
lines=3, |
|
value="worst quality, low quality, nsfw, logo", |
|
) |
|
motion_control = gr.Slider( |
|
label="Motion Injection Degree", |
|
value=0.2, |
|
minimum=0, |
|
maximum=1, |
|
step=0.1, |
|
info="Motion Control Strength", |
|
) |
|
|
|
gr.Markdown( |
|
""" |
|
*Prompt Tips:* |
|
|
|
For each personalized model in `Model Settings`, you can refer to their webpage on CivitAI to learn how to write good prompts for them: |
|
- [`realisticVisionV60B1_v20Novae.safetensors.safetensors`](https://civitai.com/models/4201?modelVersionId=130072) |
|
- [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775) |
|
- [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel) |
|
- [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009) |
|
- [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068) |
|
""" |
|
) |
|
|
|
with gr.Accordion("Model Settings", open=False): |
|
gr.Markdown( |
|
""" |
|
Select personalized model and motion module for AnimateDiff. |
|
""" |
|
) |
|
base_model_dropdown = gr.Dropdown( |
|
label="Base DreamBooth Model", |
|
choices=controller.base_model_list, |
|
value=controller.base_model_list[-2], |
|
interactive=True, |
|
info="Select personalized text-to-image model from community", |
|
) |
|
motion_module_dropdown = gr.Dropdown( |
|
label="Motion Module", |
|
choices=controller.motion_module_list, |
|
value=controller.motion_module_list[0], |
|
interactive=True, |
|
info="Select motion module. Recommend mm_sd_v14.ckpt for larger movements.", |
|
) |
|
|
|
base_model_dropdown.change( |
|
fn=controller.update_base_model, |
|
inputs=[base_model_dropdown], |
|
outputs=[base_model_dropdown], |
|
) |
|
motion_module_dropdown.change( |
|
fn=controller.update_motion_module, |
|
inputs=[motion_module_dropdown], |
|
outputs=[base_model_dropdown], |
|
) |
|
|
|
with gr.Accordion("Advance", open=False): |
|
with gr.Row(): |
|
width_slider = gr.Slider( |
|
label="Width", value=512, minimum=256, maximum=1024, step=64 |
|
) |
|
height_slider = gr.Slider( |
|
label="Height", |
|
value=512, |
|
minimum=256, |
|
maximum=1024, |
|
step=64, |
|
) |
|
with gr.Row(): |
|
seed_textbox = gr.Textbox(label="Seed", value=442) |
|
seed_button = gr.Button( |
|
value="\U0001F3B2", elem_classes="toolbutton" |
|
) |
|
seed_button.click( |
|
fn=lambda: gr.Textbox.update(value=random.randint(1, 1e9)), |
|
inputs=[], |
|
outputs=[seed_textbox], |
|
) |
|
with gr.Row(): |
|
speed_up_options = gr.CheckboxGroup( |
|
["use_fp16"], |
|
label="Speed-Up Options", |
|
value=["use_fp16"], |
|
) |
|
|
|
with gr.Column(): |
|
with gr.Row(): |
|
orig_video = gr.Video(label="AnimateDiff", interactive=False) |
|
ctrl_video = gr.Video( |
|
label="AnimateDiff + UniCtrl", interactive=False |
|
) |
|
with gr.Row(): |
|
generate_button = gr.Button( |
|
value="Generate Original", variant="primary" |
|
) |
|
generate_button_ctr = gr.Button( |
|
value="Generate UniCtrl", variant="primary" |
|
) |
|
with gr.Row(): |
|
json_config = gr.Json(label="Config", value=None) |
|
|
|
inputs = [ |
|
base_model_dropdown, |
|
motion_module_dropdown, |
|
prompt_textbox, |
|
negative_prompt_textbox, |
|
motion_control, |
|
width_slider, |
|
height_slider, |
|
seed_textbox, |
|
speed_up_options, |
|
] |
|
|
|
generate_button.click( |
|
fn=controller.animate, inputs=inputs, outputs=[orig_video, json_config] |
|
) |
|
generate_button_ctr.click( |
|
fn=controller.animate_ctrl, |
|
inputs=inputs, |
|
outputs=[ctrl_video, json_config], |
|
) |
|
|
|
gr.Examples( |
|
fn=controller.animate_ctrl, |
|
examples=examples, |
|
inputs=inputs, |
|
outputs=[ctrl_video, json_config], |
|
cache_examples=True, |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
demo = ui() |
|
demo.queue(max_size=20) |
|
demo.launch(server_name="localhost", share=True) |
|
|