guoyww commited on
Commit
24081dd
·
1 Parent(s): c4e73e7
Files changed (2) hide show
  1. app.py +191 -225
  2. requirements.txt +7 -5
app.py CHANGED
@@ -1,17 +1,14 @@
1
-
2
  import os
3
- import json
4
  import torch
5
  import random
6
 
7
  import gradio as gr
8
  from glob import glob
9
  from omegaconf import OmegaConf
10
- from datetime import datetime
11
  from safetensors import safe_open
12
 
13
  from diffusers import AutoencoderKL
14
- from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
15
  from diffusers.utils.import_utils import is_xformers_available
16
  from transformers import CLIPTextModel, CLIPTokenizer
17
 
@@ -19,15 +16,10 @@ from animatediff.models.unet import UNet3DConditionModel
19
  from animatediff.pipelines.pipeline_animation import AnimationPipeline
20
  from animatediff.utils.util import save_videos_grid
21
  from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
22
- from animatediff.utils.convert_lora_safetensor_to_diffusers import convert_lora
23
 
24
 
25
- sample_idx = 0
26
- scheduler_dict = {
27
- "Euler": EulerDiscreteScheduler,
28
- "PNDM": PNDMScheduler,
29
- "DDIM": DDIMScheduler,
30
- }
31
 
32
  css = """
33
  .toolbutton {
@@ -38,6 +30,78 @@ css = """
38
  }
39
  """
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  class AnimateController:
42
  def __init__(self):
43
 
@@ -46,156 +110,120 @@ class AnimateController:
46
  self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
47
  self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
48
  self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
49
- self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
50
- self.savedir_sample = os.path.join(self.savedir, "sample")
51
  os.makedirs(self.savedir, exist_ok=True)
52
 
53
- self.stable_diffusion_list = []
54
- self.motion_module_list = []
55
- self.personalized_model_list = []
 
 
56
 
57
- self.refresh_stable_diffusion()
58
  self.refresh_motion_module()
59
  self.refresh_personalized_model()
60
 
61
  # config models
62
- self.tokenizer = None
63
- self.text_encoder = None
64
- self.vae = None
65
- self.unet = None
66
- self.pipeline = None
67
- self.lora_model_state_dict = {}
68
-
69
- self.inference_config = OmegaConf.load("configs/inference/inference.yaml")
70
-
71
- def refresh_stable_diffusion(self):
72
- self.stable_diffusion_list = glob(os.path.join(self.stable_diffusion_dir, "*/"))
73
 
 
 
 
 
 
 
 
 
 
74
  def refresh_motion_module(self):
75
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
76
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
77
 
78
  def refresh_personalized_model(self):
79
- personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
80
- self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
81
-
82
- def update_stable_diffusion(self, stable_diffusion_dropdown):
83
- self.tokenizer = CLIPTokenizer.from_pretrained(stable_diffusion_dropdown, subfolder="tokenizer")
84
- self.text_encoder = CLIPTextModel.from_pretrained(stable_diffusion_dropdown, subfolder="text_encoder").cuda()
85
- self.vae = AutoencoderKL.from_pretrained(stable_diffusion_dropdown, subfolder="vae").cuda()
86
- self.unet = UNet3DConditionModel.from_pretrained_2d(stable_diffusion_dropdown, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
87
- return gr.Dropdown.update()
88
 
89
- def update_motion_module(self, motion_module_dropdown):
90
- if self.unet is None:
91
- gr.Info(f"Please select a pretrained model path.")
92
- return gr.Dropdown.update(value=None)
93
- else:
94
- motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
95
- motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
96
- missing, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
97
- assert len(unexpected) == 0
98
- return gr.Dropdown.update()
99
 
100
  def update_base_model(self, base_model_dropdown):
101
- if self.unet is None:
102
- gr.Info(f"Please select a pretrained model path.")
103
- return gr.Dropdown.update(value=None)
104
- else:
105
- base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
106
- base_model_state_dict = {}
107
- with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
108
- for key in f.keys():
109
- base_model_state_dict[key] = f.get_tensor(key)
110
-
111
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
112
- self.vae.load_state_dict(converted_vae_checkpoint)
113
-
114
- converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
115
- self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
116
-
117
- self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
118
- return gr.Dropdown.update()
119
-
120
- def update_lora_model(self, lora_model_dropdown):
121
- lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
122
- self.lora_model_state_dict = {}
123
- if lora_model_dropdown == "none": pass
124
- else:
125
- with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
126
- for key in f.keys():
127
- self.lora_model_state_dict[key] = f.get_tensor(key)
128
  return gr.Dropdown.update()
129
 
 
 
 
 
 
 
 
 
 
 
130
  def animate(
131
  self,
132
- stable_diffusion_dropdown,
133
- motion_module_dropdown,
134
  base_model_dropdown,
135
- lora_alpha_slider,
136
- prompt_textbox,
137
- negative_prompt_textbox,
138
- sampler_dropdown,
139
- sample_step_slider,
140
- width_slider,
141
- length_slider,
142
- height_slider,
143
- cfg_scale_slider,
144
- seed_textbox
145
- ):
146
- if self.unet is None:
147
- raise gr.Error(f"Please select a pretrained model path.")
148
- if motion_module_dropdown == "":
149
- raise gr.Error(f"Please select a motion module.")
150
- if base_model_dropdown == "":
151
- raise gr.Error(f"Please select a base DreamBooth model.")
152
-
153
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
154
 
155
  pipeline = AnimationPipeline(
156
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
157
- scheduler=scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
158
  ).to("cuda")
159
 
160
- if self.lora_model_state_dict != {}:
161
- pipeline = convert_lora(pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
162
-
163
- pipeline.to("cuda")
164
-
165
- if seed_textbox != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
166
- else: torch.seed()
167
- seed = torch.initial_seed()
 
168
 
169
  sample = pipeline(
170
  prompt_textbox,
171
  negative_prompt = negative_prompt_textbox,
172
- num_inference_steps = sample_step_slider,
173
- guidance_scale = cfg_scale_slider,
174
  width = width_slider,
175
  height = height_slider,
176
- video_length = length_slider,
 
177
  ).videos
178
 
179
- save_sample_path = os.path.join(self.savedir_sample, f"{sample_idx}.mp4")
180
  save_videos_grid(sample, save_sample_path)
181
 
182
- sample_config = {
183
  "prompt": prompt_textbox,
184
  "n_prompt": negative_prompt_textbox,
185
- "sampler": sampler_dropdown,
186
- "num_inference_steps": sample_step_slider,
187
- "guidance_scale": cfg_scale_slider,
188
  "width": width_slider,
189
  "height": height_slider,
190
- "video_length": length_slider,
191
- "seed": seed
 
192
  }
193
- json_str = json.dumps(sample_config, indent=4)
194
- with open(os.path.join(self.savedir, "logs.json"), "a") as f:
195
- f.write(json_str)
196
- f.write("\n\n")
197
-
198
- return gr.Video.update(value=save_sample_path)
199
 
200
 
201
  controller = AnimateController()
@@ -205,124 +233,62 @@ def ui():
205
  with gr.Blocks(css=css) as demo:
206
  gr.Markdown(
207
  """
208
- # [AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning](https://arxiv.org/abs/2307.04725)
209
  Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
210
  [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
211
  """
212
  )
213
- with gr.Column(variant="panel"):
214
- gr.Markdown(
215
- """
216
- ### 1. Model checkpoints (select pretrained model path first).
217
- """
218
- )
219
- with gr.Row():
220
- stable_diffusion_dropdown = gr.Dropdown(
221
- label="Pretrained Model Path",
222
- choices=controller.stable_diffusion_list,
223
- interactive=True,
224
- )
225
- stable_diffusion_dropdown.change(fn=controller.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
226
-
227
- stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
228
- def update_stable_diffusion():
229
- controller.refresh_stable_diffusion()
230
- return gr.Dropdown.update(choices=controller.stable_diffusion_list)
231
- stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[stable_diffusion_dropdown])
232
-
233
- with gr.Row():
234
- motion_module_dropdown = gr.Dropdown(
235
- label="Select motion module",
236
- choices=controller.motion_module_list,
237
- interactive=True,
238
- )
239
  motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
240
-
241
- motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
242
- def update_motion_module():
243
- controller.refresh_motion_module()
244
- return gr.Dropdown.update(choices=controller.motion_module_list)
245
- motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
246
-
247
- base_model_dropdown = gr.Dropdown(
248
- label="Select base Dreambooth model (required)",
249
- choices=controller.personalized_model_list,
250
- interactive=True,
251
- )
252
- base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
253
-
254
- lora_model_dropdown = gr.Dropdown(
255
- label="Select LoRA model (optional)",
256
- choices=["none"] + controller.personalized_model_list,
257
- value="none",
258
- interactive=True,
259
- )
260
- lora_model_dropdown.change(fn=controller.update_lora_model, inputs=[lora_model_dropdown], outputs=[lora_model_dropdown])
261
-
262
- lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
263
-
264
- personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
265
- def update_personalized_model():
266
- controller.refresh_personalized_model()
267
- return [
268
- gr.Dropdown.update(choices=controller.personalized_model_list),
269
- gr.Dropdown.update(choices=["none"] + controller.personalized_model_list)
270
- ]
271
- personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
272
-
273
- with gr.Column(variant="panel"):
274
- gr.Markdown(
275
- """
276
- ### 2. Configs for AnimateDiff.
277
- """
278
- )
279
-
280
- prompt_textbox = gr.Textbox(label="Prompt", lines=2)
281
- negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2)
282
-
283
- with gr.Row().style(equal_height=False):
284
- with gr.Column():
285
  with gr.Row():
286
- sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
287
- sample_step_slider = gr.Slider(label="Sampling steps", value=25, minimum=10, maximum=100, step=1)
288
-
289
- width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
290
- height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
291
- length_slider = gr.Slider(label="Animation length", value=16, minimum=8, maximum=24, step=1)
292
- cfg_scale_slider = gr.Slider(label="CFG Scale", value=7.5, minimum=0, maximum=20)
293
-
294
  with gr.Row():
295
- seed_textbox = gr.Textbox(label="Seed", value=-1)
296
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
297
- seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e8)), inputs=[], outputs=[seed_textbox])
298
-
299
- generate_button = gr.Button(value="Generate", variant='primary')
300
-
301
- result_video = gr.Video(label="Generated Animation", interactive=False)
302
-
303
- generate_button.click(
304
- fn=controller.animate,
305
- inputs=[
306
- stable_diffusion_dropdown,
307
- motion_module_dropdown,
308
- base_model_dropdown,
309
- lora_alpha_slider,
310
- prompt_textbox,
311
- negative_prompt_textbox,
312
- sampler_dropdown,
313
- sample_step_slider,
314
- width_slider,
315
- length_slider,
316
- height_slider,
317
- cfg_scale_slider,
318
- seed_textbox,
319
- ],
320
- outputs=[result_video]
321
- )
322
 
 
 
 
 
323
  return demo
324
 
325
 
326
  if __name__ == "__main__":
327
  demo = ui()
 
328
  demo.launch()
 
 
1
  import os
 
2
  import torch
3
  import random
4
 
5
  import gradio as gr
6
  from glob import glob
7
  from omegaconf import OmegaConf
 
8
  from safetensors import safe_open
9
 
10
  from diffusers import AutoencoderKL
11
+ from diffusers import EulerDiscreteScheduler, DDIMScheduler
12
  from diffusers.utils.import_utils import is_xformers_available
13
  from transformers import CLIPTextModel, CLIPTokenizer
14
 
 
16
  from animatediff.pipelines.pipeline_animation import AnimationPipeline
17
  from animatediff.utils.util import save_videos_grid
18
  from animatediff.utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
 
19
 
20
 
21
+ pretrained_model_path = "models/StableDiffusion/stable-diffusion-v1-5"
22
+ inference_config_path = "configs/inference/inference.yaml"
 
 
 
 
23
 
24
  css = """
25
  .toolbutton {
 
30
  }
31
  """
32
 
33
+ examples = [
34
+ # 1-ToonYou
35
+ [
36
+ "toonyou_beta3.safetensors",
37
+ "mm_sd_v14.ckpt",
38
+ "masterpiece, best quality, 1girl, solo, cherry blossoms, hanami, pink flower, white flower, spring season, wisteria, petals, flower, plum blossoms, outdoors, falling petals, white hair, black eyes",
39
+ "worst quality, low quality, nsfw, logo",
40
+ 512, 512, "13204175718326964000"
41
+ ],
42
+ # 2-Lyriel
43
+ [
44
+ "lyriel_v16.safetensors",
45
+ "mm_sd_v15.ckpt",
46
+ "A forbidden castle high up in the mountains, pixel art, intricate details2, hdr, intricate details, hyperdetailed5, natural skin texture, hyperrealism, soft light, sharp, game art, key visual, surreal",
47
+ "3d, cartoon, anime, sketches, worst quality, low quality, normal quality, lowres, normal quality, monochrome, grayscale, skin spots, acnes, skin blemishes, bad anatomy, girl, loli, young, large breasts, red eyes, muscular",
48
+ 512, 512, "6681501646976930000"
49
+ ],
50
+ # 3-RCNZ
51
+ [
52
+ "rcnzCartoon3d_v10.safetensors",
53
+ "mm_sd_v14.ckpt",
54
+ "Jane Eyre with headphones, natural skin texture,4mm,k textures, soft cinematic light, adobe lightroom, photolab, hdr, intricate, elegant, highly detailed, sharp focus, cinematic look, soothing tones, insane details, intricate details, hyperdetailed, low contrast, soft cinematic light, dim colors, exposure blend, hdr, faded",
55
+ "deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, mutated hands and fingers, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation",
56
+ 512, 512, "2416282124261060"
57
+ ],
58
+ # 4-MajicMix
59
+ [
60
+ "majicmixRealistic_v5Preview.safetensors",
61
+ "mm_sd_v14.ckpt",
62
+ "1girl, offshoulder, light smile, shiny skin best quality, masterpiece, photorealistic",
63
+ "bad hand, worst quality, low quality, normal quality, lowres, bad anatomy, bad hands, watermark, moles",
64
+ 512, 512, "7132772652786303"
65
+ ],
66
+ # 5-RealisticVision
67
+ [
68
+ "realisticVisionV40_v20Novae.safetensors",
69
+ "mm_sd_v15.ckpt",
70
+ "photo of coastline, rocks, storm weather, wind, waves, lightning, 8k uhd, dslr, soft lighting, high quality, film grain, Fujifilm XT3",
71
+ "blur, haze, deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime, mutated hands and fingers, deformed, distorted, disfigured, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, amputation",
72
+ 512, 512, "1490157606650685400"
73
+ ]
74
+ ]
75
+
76
+ # clean unrelated ckpts
77
+ ckpts = [
78
+ "realisticVisionV40_v20Novae.safetensors",
79
+ "majicmixRealistic_v5Preview.safetensors",
80
+ "rcnzCartoon3d_v10.safetensors",
81
+ "lyriel_v16.safetensors",
82
+ "toonyou_beta3.safetensors"
83
+ ]
84
+
85
+ for path in glob(os.path.join("models", "DreamBooth_LoRA", "*.safetensors")):
86
+ for ckpt in ckpts:
87
+ if path.endswith(ckpt): break
88
+ else:
89
+ print(f"### Cleaning {path} ...")
90
+ os.system(f"rm -rf {path}")
91
+
92
+ # os.system(f"rm -rf {os.path.join('models', 'DreamBooth_LoRA', '*.safetensors')}")
93
+
94
+ # os.system(f"bash download_bashscripts/1-ToonYou.sh")
95
+ # os.system(f"bash download_bashscripts/2-Lyriel.sh")
96
+ # os.system(f"bash download_bashscripts/3-RcnzCartoon.sh")
97
+ # os.system(f"bash download_bashscripts/4-MajicMix.sh")
98
+ # os.system(f"bash download_bashscripts/5-RealisticVision.sh")
99
+
100
+ # clean Grdio cache
101
+ print(f"### Cleaning cached examples ...")
102
+ os.system(f"rm -rf gradio_cached_examples/")
103
+
104
+
105
  class AnimateController:
106
  def __init__(self):
107
 
 
110
  self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
111
  self.motion_module_dir = os.path.join(self.basedir, "models", "Motion_Module")
112
  self.personalized_model_dir = os.path.join(self.basedir, "models", "DreamBooth_LoRA")
113
+ self.savedir = os.path.join(self.basedir, "samples")
 
114
  os.makedirs(self.savedir, exist_ok=True)
115
 
116
+ self.base_model_list = []
117
+ self.motion_module_list = []
118
+
119
+ self.selected_base_model = None
120
+ self.selected_motion_module = None
121
 
 
122
  self.refresh_motion_module()
123
  self.refresh_personalized_model()
124
 
125
  # config models
126
+ self.inference_config = OmegaConf.load(inference_config_path)
 
 
 
 
 
 
 
 
 
 
127
 
128
+ self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
129
+ self.text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").cuda()
130
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").cuda()
131
+ self.unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", unet_additional_kwargs=OmegaConf.to_container(self.inference_config.unet_additional_kwargs)).cuda()
132
+
133
+ self.update_base_model(self.base_model_list[0])
134
+ self.update_motion_module(self.motion_module_list[0])
135
+
136
+
137
  def refresh_motion_module(self):
138
  motion_module_list = glob(os.path.join(self.motion_module_dir, "*.ckpt"))
139
  self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
140
 
141
  def refresh_personalized_model(self):
142
+ base_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
143
+ self.base_model_list = [os.path.basename(p) for p in base_model_list]
 
 
 
 
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
145
 
146
  def update_base_model(self, base_model_dropdown):
147
+ self.selected_base_model = base_model_dropdown
148
+
149
+ base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
150
+ base_model_state_dict = {}
151
+ with safe_open(base_model_dropdown, framework="pt", device="cpu") as f:
152
+ for key in f.keys(): base_model_state_dict[key] = f.get_tensor(key)
153
+
154
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(base_model_state_dict, self.vae.config)
155
+ self.vae.load_state_dict(converted_vae_checkpoint)
156
+
157
+ converted_unet_checkpoint = convert_ldm_unet_checkpoint(base_model_state_dict, self.unet.config)
158
+ self.unet.load_state_dict(converted_unet_checkpoint, strict=False)
159
+
160
+ self.text_encoder = convert_ldm_clip_checkpoint(base_model_state_dict)
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return gr.Dropdown.update()
162
 
163
+ def update_motion_module(self, motion_module_dropdown):
164
+ self.selected_motion_module = motion_module_dropdown
165
+
166
+ motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
167
+ motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
168
+ _, unexpected = self.unet.load_state_dict(motion_module_state_dict, strict=False)
169
+ assert len(unexpected) == 0
170
+ return gr.Dropdown.update()
171
+
172
+
173
  def animate(
174
  self,
 
 
175
  base_model_dropdown,
176
+ motion_module_dropdown,
177
+ prompt_textbox,
178
+ negative_prompt_textbox,
179
+ width_slider,
180
+ height_slider,
181
+ seed_textbox,
182
+ ):
183
+ if self.selected_base_model != base_model_dropdown: self.update_base_model(base_model_dropdown)
184
+ if self.selected_motion_module != motion_module_dropdown: self.update_motion_module(motion_module_dropdown)
185
+
 
 
 
 
 
 
 
 
186
  if is_xformers_available(): self.unet.enable_xformers_memory_efficient_attention()
187
 
188
  pipeline = AnimationPipeline(
189
  vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, unet=self.unet,
190
+ scheduler=DDIMScheduler(**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
191
  ).to("cuda")
192
 
193
+ if int(seed_textbox) > 0: seed = int(seed_textbox)
194
+ else: seed = random.randint(1, 1e16)
195
+ torch.manual_seed(int(seed))
196
+
197
+ assert seed == torch.initial_seed()
198
+ print(f"### seed: {seed}")
199
+
200
+ generator = torch.Generator(device="cuda")
201
+ generator.manual_seed(seed)
202
 
203
  sample = pipeline(
204
  prompt_textbox,
205
  negative_prompt = negative_prompt_textbox,
206
+ num_inference_steps = 25,
207
+ guidance_scale = 8.,
208
  width = width_slider,
209
  height = height_slider,
210
+ video_length = 16,
211
+ generator = generator,
212
  ).videos
213
 
214
+ save_sample_path = os.path.join(self.savedir, f"sample.mp4")
215
  save_videos_grid(sample, save_sample_path)
216
 
217
+ json_config = {
218
  "prompt": prompt_textbox,
219
  "n_prompt": negative_prompt_textbox,
 
 
 
220
  "width": width_slider,
221
  "height": height_slider,
222
+ "seed": seed,
223
+ "base_model": base_model_dropdown,
224
+ "motion_module": motion_module_dropdown,
225
  }
226
+ return gr.Video.update(value=save_sample_path), gr.Json.update(value=json_config)
 
 
 
 
 
227
 
228
 
229
  controller = AnimateController()
 
233
  with gr.Blocks(css=css) as demo:
234
  gr.Markdown(
235
  """
236
+ # AnimateDiff: Animate Your Personalized Text-to-Image Diffusion Models without Specific Tuning
237
  Yuwei Guo, Ceyuan Yang*, Anyi Rao, Yaohui Wang, Yu Qiao, Dahua Lin, Bo Dai (*Corresponding Author)<br>
238
  [Arxiv Report](https://arxiv.org/abs/2307.04725) | [Project Page](https://animatediff.github.io/) | [Github](https://github.com/guoyww/animatediff/)
239
  """
240
  )
241
+ gr.Markdown(
242
+ """
243
+ ### Quick Start
244
+ 1. Select desired `Base DreamBooth Model`.
245
+ 2. Select `Motion Module` from `mm_sd_v14.ckpt` and `mm_sd_v15.ckpt`. We recommend trying both of them for the best results.
246
+ 3. Provide `Prompt` and `Negative Prompt` for each model. You are encouraged to refer to each model's webpage on CivitAI to learn how to write prompts for them. Below are the DreamBooth models in this demo. Click to visit their homepage.
247
+ - [`toonyou_beta3.safetensors`](https://civitai.com/models/30240?modelVersionId=78775)
248
+ - [`lyriel_v16.safetensors`](https://civitai.com/models/22922/lyriel)
249
+ - [`rcnzCartoon3d_v10.safetensors`](https://civitai.com/models/66347?modelVersionId=71009)
250
+ - [`majicmixRealistic_v5Preview.safetensors`](https://civitai.com/models/43331?modelVersionId=79068)
251
+ - [`realisticVisionV20_v20.safetensors`](https://civitai.com/models/4201?modelVersionId=29460)
252
+ 4. Click `Generate`, wait for ~1 min, and enjoy.
253
+ """
254
+ )
255
+ with gr.Row():
256
+ with gr.Column():
257
+ base_model_dropdown = gr.Dropdown( label="Base DreamBooth Model", choices=controller.base_model_list, value=controller.base_model_list[0], interactive=True )
258
+ motion_module_dropdown = gr.Dropdown( label="Motion Module", choices=controller.motion_module_list, value=controller.motion_module_list[0], interactive=True )
259
+
260
+ base_model_dropdown.change(fn=controller.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
 
 
 
 
 
 
261
  motion_module_dropdown.change(fn=controller.update_motion_module, inputs=[motion_module_dropdown], outputs=[motion_module_dropdown])
262
+
263
+ prompt_textbox = gr.Textbox( label="Prompt", lines=3 )
264
+ negative_prompt_textbox = gr.Textbox( label="Negative Prompt", lines=3, value="worst quality, low quality, nsfw, logo")
265
+
266
+ with gr.Accordion("Advance", open=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  with gr.Row():
268
+ width_slider = gr.Slider( label="Width", value=512, minimum=256, maximum=1024, step=64 )
269
+ height_slider = gr.Slider( label="Height", value=512, minimum=256, maximum=1024, step=64 )
 
 
 
 
 
 
270
  with gr.Row():
271
+ seed_textbox = gr.Textbox( label="Seed", value=-1)
272
  seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
273
+ seed_button.click(fn=lambda: gr.Textbox.update(value=random.randint(1, 1e16)), inputs=[], outputs=[seed_textbox])
274
+
275
+ generate_button = gr.Button( value="Generate", variant='primary' )
276
+
277
+ with gr.Column():
278
+ result_video = gr.Video( label="Generated Animation", interactive=False )
279
+ json_config = gr.Json( label="Config", value=None )
280
+
281
+ inputs = [base_model_dropdown, motion_module_dropdown, prompt_textbox, negative_prompt_textbox, width_slider, height_slider, seed_textbox]
282
+ outputs = [result_video, json_config]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
+ generate_button.click( fn=controller.animate, inputs=inputs, outputs=outputs )
285
+
286
+ gr.Examples( fn=controller.animate, examples=examples, inputs=inputs, outputs=outputs, cache_examples=True )
287
+
288
  return demo
289
 
290
 
291
  if __name__ == "__main__":
292
  demo = ui()
293
+ demo.queue(max_size=20)
294
  demo.launch()
requirements.txt CHANGED
@@ -1,13 +1,15 @@
1
  torch==1.13.1
2
- torchvision==0.14.1
3
  torchaudio==0.13.1
4
  diffusers==0.11.1
5
  transformers==4.25.1
 
6
  imageio==2.27.0
7
- xformers
8
  einops
9
- gradio
10
- numpy
11
  omegaconf
12
  safetensors
13
- tqdm
 
 
 
 
1
  torch==1.13.1
2
+ torchvision==0.14.1
3
  torchaudio==0.13.1
4
  diffusers==0.11.1
5
  transformers==4.25.1
6
+ xformers==0.0.16
7
  imageio==2.27.0
8
+ gdown
9
  einops
 
 
10
  omegaconf
11
  safetensors
12
+ gradio
13
+ imageio[ffmpeg]
14
+ imageio[pyav]
15
+ accelerate